avg/max pooling3d bp fixed (#323)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-03-16 18:17:42 +03:00 committed by GitHub
parent bac130bd78
commit 77244f5496
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 32 deletions

View File

@ -1530,13 +1530,13 @@ namespace sd {
const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3 && iStride4 == gIStride4;
if(poolingMode == 0) { // max
auto func = PRAGMA_THREADS_FOR_3D {
auto func = PRAGMA_THREADS_FOR_2D {
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
T sum, valO, *pIn, *pgI;
for (int b = start_x; b < stop_x; b += inc_x) {
for (int c = start_y; c < stop_y; c += inc_y) {
for (int od = start_z; od < stop_z; od += inc_z) {
for (int b = start_x; b < stop_x; b++) {
for (int c = start_y; c < stop_y; c++) {
for (int od = 0; od < oD; od++) {
for (int oh = 0; oh < oH; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
@ -1618,17 +1618,17 @@ namespace sd {
}
};
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
}
/*************************************************************************/
else if(poolingMode == 1) { // avg
auto func = PRAGMA_THREADS_FOR_3D {
auto func = PRAGMA_THREADS_FOR_2D {
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
T sum, valO, *pIn, *pgI;
for (int b = start_x; b < stop_x; b += inc_x) {
for (int c = start_y; c < stop_y; c += inc_y) {
for (int od = start_z; od < stop_z; od += inc_z) {
for (int b = start_x; b < stop_x; b++) {
for (int c = start_y; c < stop_y; c++) {
for (int od = 0; od < oD; od++) {
for (int oh = 0; oh < oH; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
@ -1679,17 +1679,17 @@ namespace sd {
}
};
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
}
/*************************************************************************/
else if(poolingMode == 2) { // pnorm
auto func = PRAGMA_THREADS_FOR_3D {
auto func = PRAGMA_THREADS_FOR_2D {
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
T sum, valO, *pIn, *pgI;
for (int b = start_x; b < stop_x; b += inc_x) {
for (int c = start_y; c < stop_y; c += inc_y) {
for (int od = start_z; od < stop_z; od += inc_z) {
for (int b = start_x; b < stop_x; b++) {
for (int c = start_y; c < stop_y; c++) {
for (int od = 0; od < oD; od++) {
for (int oh = 0; oh < oH; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
@ -1761,7 +1761,7 @@ namespace sd {
}
};
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
}
else {
nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);

View File

@ -3513,23 +3513,6 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4_119) {
// ASSERT_TRUE(exp.equalsTo(out));
}
////////////////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_5) {
auto x = NDArrayFactory::create<TypeParam>('f', {8, 32, 64, 64});
x.linspace(1);
sd::ops::lrn op;
auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2});
auto out = results.at(0);
ASSERT_EQ(Status::OK(), results.status());
// ASSERT_TRUE(exp.isSameShape(out));
// ASSERT_TRUE(exp.equalsTo(out));
}
////////////////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_01) {