From 77244f54961758040562525d98df26853e4af0fe Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 16 Mar 2020 18:17:42 +0300 Subject: [PATCH] avg/max pooling3d bp fixed (#323) Signed-off-by: raver119 --- .../declarable/helpers/cpu/convolutions.cpp | 30 +++++++++---------- .../layers_tests/DeclarableOpsTests8.cpp | 17 ----------- 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index c1dd5dd56..f852bed23 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -1530,13 +1530,13 @@ namespace sd { const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3 && iStride4 == gIStride4; if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_3D { + auto func = PRAGMA_THREADS_FOR_2D { Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; T sum, valO, *pIn, *pgI; - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { for (int oh = 0; oh < oH; ++oh) { for (int ow = 0; ow < oW; ++ow) { @@ -1618,17 +1618,17 @@ namespace sd { } }; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); } /*************************************************************************/ else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_3D { + auto func = PRAGMA_THREADS_FOR_2D { Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; T sum, valO, *pIn, *pgI; - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { for (int oh = 0; oh < oH; ++oh) { for (int ow = 0; ow < oW; ++ow) { @@ -1679,17 +1679,17 @@ namespace sd { } }; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); } /*************************************************************************/ else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_3D { + auto func = PRAGMA_THREADS_FOR_2D { Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; T sum, valO, *pIn, *pgI; - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { for (int oh = 0; oh < oH; ++oh) { for (int ow = 0; ow < oW; ++ow) { @@ -1761,7 +1761,7 @@ namespace sd { } }; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); } else { nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index 276c74e78..589adebcb 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -3513,23 +3513,6 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4_119) { // ASSERT_TRUE(exp.equalsTo(out)); } -//////////////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_5) { - - auto x = NDArrayFactory::create('f', {8, 32, 64, 64}); - x.linspace(1); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); -// ASSERT_TRUE(exp.isSameShape(out)); -// ASSERT_TRUE(exp.equalsTo(out)); - - -} - //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_01) {