[WIP] bias_add NHWC loop (#149)
* one more test Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * bias_add nhwc 4D Signed-off-by: raver119 <raver119@gmail.com> * bias_add nhwc 4D Signed-off-by: raver119 <raver119@gmail.com> * bias_add nhwc 4D Signed-off-by: raver119 <raver119@gmail.com> * bias_add nhwc 4D Signed-off-by: raver119 <raver119@gmail.com> * disable test Signed-off-by: raver119 <raver119@gmail.com>master
parent
fc760de348
commit
9b329d2601
|
@ -83,15 +83,28 @@ static void addBias_(const NDArray& input, const NDArray& bias, NDArray &output,
|
||||||
const Nd4jLong xStrideH = isNCHW ? input.stridesOf()[2] : input.stridesOf()[1];
|
const Nd4jLong xStrideH = isNCHW ? input.stridesOf()[2] : input.stridesOf()[1];
|
||||||
const Nd4jLong xStrideW = isNCHW ? input.stridesOf()[3] : input.stridesOf()[2];
|
const Nd4jLong xStrideW = isNCHW ? input.stridesOf()[3] : input.stridesOf()[2];
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_3D {
|
if (isNCHW) {
|
||||||
for (uint b = start_x; b < stop_x; b += inc_x)
|
|
||||||
for (uint c = start_y; c < stop_y; c += inc_y)
|
|
||||||
for (uint h = start_z; h < stop_z; h += inc_z)
|
|
||||||
for (uint w = 0; w < oW; ++w)
|
|
||||||
z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + static_cast<X>(y[c * yStrideC]);
|
|
||||||
};
|
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, C, 1, 0, oH, 1);
|
auto func = PRAGMA_THREADS_FOR_3D {
|
||||||
|
for (uint b = start_x; b < stop_x; b += inc_x)
|
||||||
|
for (uint c = start_y; c < stop_y; c += inc_y)
|
||||||
|
for (uint h = start_z; h < stop_z; h += inc_z)
|
||||||
|
for (uint w = 0; w < oW; ++w)
|
||||||
|
z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + static_cast<X>(y[c * yStrideC]);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, bS, 1, 0, C, 1, 0, oH, 1);
|
||||||
|
} else {
|
||||||
|
auto func = PRAGMA_THREADS_FOR_3D {
|
||||||
|
for (uint b = start_x; b < stop_x; b++)
|
||||||
|
for (uint h = start_y; h < stop_y; h++)
|
||||||
|
for (uint w = start_z; w < stop_z; w++)
|
||||||
|
for (uint c = 0; c < C; c++)
|
||||||
|
z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + y[c * yStrideC];
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1, 0, oW, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(output.rankOf() == 5) {
|
else if(output.rankOf() == 5) {
|
||||||
|
|
|
@ -60,6 +60,36 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
TEST_F(PlaygroundTests, test_s_0) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {32, 112, 112, 16});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {16});
|
||||||
|
auto z = x.ulike();
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> values;
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &x);
|
||||||
|
ctx.setInputArray(1, &y);
|
||||||
|
ctx.setOutputArray(0, &z);
|
||||||
|
|
||||||
|
nd4j::ops::biasadd op;
|
||||||
|
|
||||||
|
|
||||||
|
for (int e = 0; e < 10000; e++) {
|
||||||
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
|
||||||
|
op.execute(&ctx);
|
||||||
|
|
||||||
|
auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
|
||||||
|
values.emplace_back(outerTime);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(values.begin(), values.end());
|
||||||
|
|
||||||
|
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||||
|
}
|
||||||
|
*/
|
||||||
/*
|
/*
|
||||||
TEST_F(PlaygroundTests, test_s_1) {
|
TEST_F(PlaygroundTests, test_s_1) {
|
||||||
auto x0 = NDArrayFactory::create<float>('c', {32, 7, 7, 176});
|
auto x0 = NDArrayFactory::create<float>('c', {32, 7, 7, 176});
|
||||||
|
|
Loading…
Reference in New Issue