parent
bac130bd78
commit
77244f5496
|
@ -1530,13 +1530,13 @@ namespace sd {
|
||||||
const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3 && iStride4 == gIStride4;
|
const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3 && iStride4 == gIStride4;
|
||||||
|
|
||||||
if(poolingMode == 0) { // max
|
if(poolingMode == 0) { // max
|
||||||
auto func = PRAGMA_THREADS_FOR_3D {
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
|
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
|
||||||
T sum, valO, *pIn, *pgI;
|
T sum, valO, *pIn, *pgI;
|
||||||
|
|
||||||
for (int b = start_x; b < stop_x; b += inc_x) {
|
for (int b = start_x; b < stop_x; b++) {
|
||||||
for (int c = start_y; c < stop_y; c += inc_y) {
|
for (int c = start_y; c < stop_y; c++) {
|
||||||
for (int od = start_z; od < stop_z; od += inc_z) {
|
for (int od = 0; od < oD; od++) {
|
||||||
for (int oh = 0; oh < oH; ++oh) {
|
for (int oh = 0; oh < oH; ++oh) {
|
||||||
for (int ow = 0; ow < oW; ++ow) {
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
|
|
||||||
|
@ -1618,17 +1618,17 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||||
}
|
}
|
||||||
/*************************************************************************/
|
/*************************************************************************/
|
||||||
else if(poolingMode == 1) { // avg
|
else if(poolingMode == 1) { // avg
|
||||||
auto func = PRAGMA_THREADS_FOR_3D {
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
|
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
|
||||||
T sum, valO, *pIn, *pgI;
|
T sum, valO, *pIn, *pgI;
|
||||||
|
|
||||||
for (int b = start_x; b < stop_x; b += inc_x) {
|
for (int b = start_x; b < stop_x; b++) {
|
||||||
for (int c = start_y; c < stop_y; c += inc_y) {
|
for (int c = start_y; c < stop_y; c++) {
|
||||||
for (int od = start_z; od < stop_z; od += inc_z) {
|
for (int od = 0; od < oD; od++) {
|
||||||
for (int oh = 0; oh < oH; ++oh) {
|
for (int oh = 0; oh < oH; ++oh) {
|
||||||
for (int ow = 0; ow < oW; ++ow) {
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
|
|
||||||
|
@ -1679,17 +1679,17 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||||
}
|
}
|
||||||
/*************************************************************************/
|
/*************************************************************************/
|
||||||
else if(poolingMode == 2) { // pnorm
|
else if(poolingMode == 2) { // pnorm
|
||||||
auto func = PRAGMA_THREADS_FOR_3D {
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
|
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
|
||||||
T sum, valO, *pIn, *pgI;
|
T sum, valO, *pIn, *pgI;
|
||||||
|
|
||||||
for (int b = start_x; b < stop_x; b += inc_x) {
|
for (int b = start_x; b < stop_x; b++) {
|
||||||
for (int c = start_y; c < stop_y; c += inc_y) {
|
for (int c = start_y; c < stop_y; c++) {
|
||||||
for (int od = start_z; od < stop_z; od += inc_z) {
|
for (int od = 0; od < oD; od++) {
|
||||||
for (int oh = 0; oh < oH; ++oh) {
|
for (int oh = 0; oh < oH; ++oh) {
|
||||||
for (int ow = 0; ow < oW; ++ow) {
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
|
|
||||||
|
@ -1761,7 +1761,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
||||||
|
|
|
@ -3513,23 +3513,6 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4_119) {
|
||||||
// ASSERT_TRUE(exp.equalsTo(out));
|
// ASSERT_TRUE(exp.equalsTo(out));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
|
||||||
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_5) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<TypeParam>('f', {8, 32, 64, 64});
|
|
||||||
x.linspace(1);
|
|
||||||
|
|
||||||
sd::ops::lrn op;
|
|
||||||
auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2});
|
|
||||||
auto out = results.at(0);
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
|
||||||
// ASSERT_TRUE(exp.isSameShape(out));
|
|
||||||
// ASSERT_TRUE(exp.equalsTo(out));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_01) {
|
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_01) {
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue