[WIP] bias_add NHWC loop (#149)

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

* bias_add nhwc 4D

Signed-off-by: raver119 <raver119@gmail.com>

* bias_add nhwc 4D

Signed-off-by: raver119 <raver119@gmail.com>

* bias_add nhwc 4D

Signed-off-by: raver119 <raver119@gmail.com>

* bias_add nhwc 4D

Signed-off-by: raver119 <raver119@gmail.com>

* disable test

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-12-24 20:56:49 +03:00 committed by GitHub
parent fc760de348
commit 9b329d2601
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 8 deletions

View File

@ -83,15 +83,28 @@ static void addBias_(const NDArray& input, const NDArray& bias, NDArray &output,
const Nd4jLong xStrideH = isNCHW ? input.stridesOf()[2] : input.stridesOf()[1];
const Nd4jLong xStrideW = isNCHW ? input.stridesOf()[3] : input.stridesOf()[2];
auto func = PRAGMA_THREADS_FOR_3D {
for (uint b = start_x; b < stop_x; b += inc_x)
for (uint c = start_y; c < stop_y; c += inc_y)
for (uint h = start_z; h < stop_z; h += inc_z)
for (uint w = 0; w < oW; ++w)
z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + static_cast<X>(y[c * yStrideC]);
};
if (isNCHW) {
samediff::Threads::parallel_for(func, 0, bS, 1, 0, C, 1, 0, oH, 1);
auto func = PRAGMA_THREADS_FOR_3D {
for (uint b = start_x; b < stop_x; b += inc_x)
for (uint c = start_y; c < stop_y; c += inc_y)
for (uint h = start_z; h < stop_z; h += inc_z)
for (uint w = 0; w < oW; ++w)
z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + static_cast<X>(y[c * yStrideC]);
};
samediff::Threads::parallel_for(func, 0, bS, 1, 0, C, 1, 0, oH, 1);
} else {
auto func = PRAGMA_THREADS_FOR_3D {
for (uint b = start_x; b < stop_x; b++)
for (uint h = start_y; h < stop_y; h++)
for (uint w = start_z; w < stop_z; w++)
for (uint c = 0; c < C; c++)
z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + y[c * yStrideC];
};
samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1, 0, oW, 1);
}
}
}
else if(output.rankOf() == 5) {

View File

@ -60,6 +60,36 @@ public:
}
};
/*
TEST_F(PlaygroundTests, test_s_0) {
auto x = NDArrayFactory::create<float>('c', {32, 112, 112, 16});
auto y = NDArrayFactory::create<float>('c', {16});
auto z = x.ulike();
std::vector<Nd4jLong> values;
Context ctx(1);
ctx.setInputArray(0, &x);
ctx.setInputArray(1, &y);
ctx.setOutputArray(0, &z);
nd4j::ops::biasadd op;
for (int e = 0; e < 10000; e++) {
auto timeStart = std::chrono::system_clock::now();
op.execute(&ctx);
auto timeEnd = std::chrono::system_clock::now();
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
values.emplace_back(outerTime);
}
std::sort(values.begin(), values.end());
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
}
*/
/*
TEST_F(PlaygroundTests, test_s_1) {
auto x0 = NDArrayFactory::create<float>('c', {32, 7, 7, 176});