parent
bac130bd78
commit
77244f5496
|
@ -1530,13 +1530,13 @@ namespace sd {
|
|||
const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3 && iStride4 == gIStride4;
|
||||
|
||||
if(poolingMode == 0) { // max
|
||||
auto func = PRAGMA_THREADS_FOR_3D {
|
||||
auto func = PRAGMA_THREADS_FOR_2D {
|
||||
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
|
||||
T sum, valO, *pIn, *pgI;
|
||||
|
||||
for (int b = start_x; b < stop_x; b += inc_x) {
|
||||
for (int c = start_y; c < stop_y; c += inc_y) {
|
||||
for (int od = start_z; od < stop_z; od += inc_z) {
|
||||
for (int b = start_x; b < stop_x; b++) {
|
||||
for (int c = start_y; c < stop_y; c++) {
|
||||
for (int od = 0; od < oD; od++) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
|
||||
|
@ -1618,17 +1618,17 @@ namespace sd {
|
|||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||
}
|
||||
/*************************************************************************/
|
||||
else if(poolingMode == 1) { // avg
|
||||
auto func = PRAGMA_THREADS_FOR_3D {
|
||||
auto func = PRAGMA_THREADS_FOR_2D {
|
||||
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
|
||||
T sum, valO, *pIn, *pgI;
|
||||
|
||||
for (int b = start_x; b < stop_x; b += inc_x) {
|
||||
for (int c = start_y; c < stop_y; c += inc_y) {
|
||||
for (int od = start_z; od < stop_z; od += inc_z) {
|
||||
for (int b = start_x; b < stop_x; b++) {
|
||||
for (int c = start_y; c < stop_y; c++) {
|
||||
for (int od = 0; od < oD; od++) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
|
||||
|
@ -1679,17 +1679,17 @@ namespace sd {
|
|||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||
}
|
||||
/*************************************************************************/
|
||||
else if(poolingMode == 2) { // pnorm
|
||||
auto func = PRAGMA_THREADS_FOR_3D {
|
||||
auto func = PRAGMA_THREADS_FOR_2D {
|
||||
Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW;
|
||||
T sum, valO, *pIn, *pgI;
|
||||
|
||||
for (int b = start_x; b < stop_x; b += inc_x) {
|
||||
for (int c = start_y; c < stop_y; c += inc_y) {
|
||||
for (int od = start_z; od < stop_z; od += inc_z) {
|
||||
for (int b = start_x; b < stop_x; b++) {
|
||||
for (int c = start_y; c < stop_y; c++) {
|
||||
for (int od = 0; od < oD; od++) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
|
||||
|
@ -1761,7 +1761,7 @@ namespace sd {
|
|||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||
}
|
||||
else {
|
||||
nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
||||
|
|
|
@ -3513,23 +3513,6 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4_119) {
|
|||
// ASSERT_TRUE(exp.equalsTo(out));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_5) {
|
||||
|
||||
auto x = NDArrayFactory::create<TypeParam>('f', {8, 32, 64, 64});
|
||||
x.linspace(1);
|
||||
|
||||
sd::ops::lrn op;
|
||||
auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2});
|
||||
auto out = results.at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
// ASSERT_TRUE(exp.isSameShape(out));
|
||||
// ASSERT_TRUE(exp.equalsTo(out));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_01) {
|
||||
|
||||
|
|
Loading…
Reference in New Issue