diff --git a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp index 204ceaf81..a910a854c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp @@ -83,15 +83,28 @@ static void addBias_(const NDArray& input, const NDArray& bias, NDArray &output, const Nd4jLong xStrideH = isNCHW ? input.stridesOf()[2] : input.stridesOf()[1]; const Nd4jLong xStrideW = isNCHW ? input.stridesOf()[3] : input.stridesOf()[2]; - auto func = PRAGMA_THREADS_FOR_3D { - for (uint b = start_x; b < stop_x; b += inc_x) - for (uint c = start_y; c < stop_y; c += inc_y) - for (uint h = start_z; h < stop_z; h += inc_z) - for (uint w = 0; w < oW; ++w) - z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + static_cast(y[c * yStrideC]); - }; + if (isNCHW) { - samediff::Threads::parallel_for(func, 0, bS, 1, 0, C, 1, 0, oH, 1); + auto func = PRAGMA_THREADS_FOR_3D { + for (uint b = start_x; b < stop_x; b += inc_x) + for (uint c = start_y; c < stop_y; c += inc_y) + for (uint h = start_z; h < stop_z; h += inc_z) + for (uint w = 0; w < oW; ++w) + z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + static_cast(y[c * yStrideC]); + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, C, 1, 0, oH, 1); + } else { + auto func = PRAGMA_THREADS_FOR_3D { + for (uint b = start_x; b < stop_x; b++) + for (uint h = start_y; h < stop_y; h++) + for (uint w = start_z; w < stop_z; w++) + for (uint c = 0; c < C; c++) + z[b * zStrideB + c * zStrideC + h * zStrideH + w * zStrideW] = x[b * xStrideB + c * xStrideC + h * xStrideH + w * xStrideW] + y[c * yStrideC]; + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1, 0, oW, 1); + } } } else if(output.rankOf() == 5) { diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 410ec53a7..43baf007d 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -60,6 +60,36 @@ public: } }; +/* +TEST_F(PlaygroundTests, test_s_0) { + auto x = NDArrayFactory::create('c', {32, 112, 112, 16}); + auto y = NDArrayFactory::create('c', {16}); + auto z = x.ulike(); + + std::vector values; + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + nd4j::ops::biasadd op; + + + for (int e = 0; e < 10000; e++) { + auto timeStart = std::chrono::system_clock::now(); + + op.execute(&ctx); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +} +*/ /* TEST_F(PlaygroundTests, test_s_1) { auto x0 = NDArrayFactory::create('c', {32, 7, 7, 176});