diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX index 43c6fe2ad..1caae85a4 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/NDArray.hXX @@ -1944,7 +1944,7 @@ void NDArray::tilei(const std::vector& reps) { Nd4jLong NDArray::sizeAt(const int dim) const { if (dim >= this->rankOf() || dim < -this->rankOf()) - throw std::runtime_error("Bad size index requested"); + throw std::runtime_error("NDArray::sizeAt: bad size index requested"); if (dim >= 0) return shape::shapeOf(_shapeInfo)[dim]; diff --git a/libnd4j/include/helpers/LoopKind.h b/libnd4j/include/helpers/LoopKind.h index 95e9238ad..e3ca932b3 100644 --- a/libnd4j/include/helpers/LoopKind.h +++ b/libnd4j/include/helpers/LoopKind.h @@ -35,16 +35,16 @@ namespace sd { class ND4J_EXPORT LoopKind { - + public: enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D }; static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); - static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo); + static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo); static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); - + }; ////////////////////////////////////////////////////////////////////////////// @@ -59,8 +59,8 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd int temp; const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; - const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; - const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo); + const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; + const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo); if (xEws == 1 && zEws == 1 && xOrder == zOrder && (shapesSame || xOrder == 'c')) return EWS1; @@ -160,7 +160,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const N const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; const bool yVectorOrC = shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c'; const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; - const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); + const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c')) return EWS1; @@ -206,7 +206,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c'; const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';; - if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && shape::rank(xShapeInfo) == 2 && xEws == 1 && xOrder == 'c' && xRank == 2 && + if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 && tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) return SMALLARR2DX; if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) @@ -233,18 +233,18 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const ////////////////////////////////////////////////////////////////////////////// LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo) { - // both tad shapes are the same, but strides and ews may be different + // both tad shapes are the same, but strides and ews may be different const int tadRank = shape::rank(xTadShapeInfo); const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo); - const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo); - const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); + const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo); + const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); const char xTadOrder = shape::order(xTadShapeInfo); const char yTadOrder = shape::order(xTadShapeInfo); const char zOrder = shape::order(zShapeInfo); - + int position; const bool xTadVectorOrC = shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c'; const bool yTadVectorOrC = shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c'; @@ -265,7 +265,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, c return RANK4; if(tadRank == 5 && zEws > 0 && zVectorOrC) return RANK5; - return COMMON; + return COMMON; } diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index ace58a0b8..5ac7686e2 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -35,111 +35,19 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { auto z = OUTPUT_VARIABLE(0); //Special case: empty.reshape() -> return empty - if (x->isEmpty()) { - REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); - return Status::OK(); //No op - } - - if (block.width() == 1) { - - auto arguments = block.getIArguments(); - int argsSize = arguments->size(); - - - - int e = 1; - char order = (char) -(*arguments)[0]; - if (order != 'c' && order != 'f') { - order = 'c'; //x->ordering(); - e = 0; - } - - REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension"); - - std::vector shapeNew; - int e2 = e; - for (; e < (int) arguments->size(); e++) { - if (arguments->at(e) == -1){ - Nd4jLong shapeLength = 1; - for(; e2 < e; e2++){ - shapeLength *= arguments->at(e2); - } - for(e2 = e + 1; e2 < arguments->size(); e2++){ - shapeLength *= arguments->at(e2); - } - Nd4jLong realShape = x->lengthOf() / shapeLength; - shapeNew.push_back(realShape); - } - else{ - shapeNew.push_back(arguments->at(e)); - } - - } - - auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); - REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len); - - if (Environment::getInstance()->isDebugAndVerbose()) { - nd4j_printv("Reshape: new shape", shapeNew); - } - - auto xr = x->reshape(order, shapeNew); - z->assign(xr); - STORE_RESULT(*z); - - return Status::OK(); - - } else if (block.width() == 2) { - - auto s = INPUT_VARIABLE(1); - - char order = 'c'; - if (block.numI() > 0) - order = (char) -INT_ARG(0); - - std::vector shapeNew(s->lengthOf()); - - for (int e = 0; e < (int) s->lengthOf(); e++) { - auto dim = s->e(e); - if (dim == -1){ - Nd4jLong shapeLength = 1; - for(int e2 = 0; e2 < e; e2++){ - shapeLength *= s->e(e2); - } - for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){ - REQUIRE_TRUE(s->e(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - shapeLength *= s->e(e2); - } - Nd4jLong realShape = x->lengthOf() / shapeLength; - shapeNew[e] = realShape; - } - else{ - shapeNew[e] = dim; - } - } - - if (Environment::getInstance()->isDebugAndVerbose()) { - nd4j_printv("Reshape: new shape", shapeNew); - } - - if (s->isScalar()) { - // just a scalar - z->assign(x); - } else { - // in some cases we might go away with simple memcpy call instead of assign call - if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) { - z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset()); - } else { - auto xr = x->reshape(order, shapeNew); - z->assign(xr); - } - } - - return Status::OK(); - + if (x->isEmpty()) { + REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); + return Status::OK(); //No op } - return ND4J_STATUS_BAD_INPUT; + REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf()); + + if (Environment::getInstance()->isDebugAndVerbose()) + nd4j_printv("Reshape: new shape", z->getShapeAsVector()); + + z->assign(x->reshape(z->ordering(), z->getShapeAsVector())); + + return Status::OK(); } @@ -151,117 +59,73 @@ DECLARE_TYPES(reshape) { } DECLARE_SHAPE_FN(reshape) { - auto inp = inputShape->at(0); - // we can launch op using Int arguments - if (inputShape->size() == 1) { - REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined"); - std::vector *arguments = block.getIArguments(); + const auto x = INPUT_VARIABLE(0); - int e = 1; - char order = (char) -(*arguments)[0]; - if (order != 'c' && order != 'f') { - order = shape::order(inp); - e = 0; + std::vector reshapeArgs; + std::vector shapeNew; + char orderNew = 'c'; + + if (block.width() == 1) { + reshapeArgs = *block.getIArguments(); + if(!reshapeArgs.empty()) { + orderNew = (char) -reshapeArgs[0]; + if(orderNew == 'c' || orderNew == 'f') + reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case } - - std::vector shapeNew; - - int e2 = e; - for (; e < (int) arguments->size(); e++) { - if ((int) arguments->at(e) == -1){ - - Nd4jLong shapeLength = 1; - for(; e2 < e; e2 ++){ - shapeLength *= arguments->at(e2); - } - for(e2 = e + 1; e2 < arguments->size(); e2++){ - REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - shapeLength *= arguments->at(e2); - } - - if(shapeLength == 0){ - //Edge case for empty: - shapeNew.push_back(0); - } else { - //Standard case - Nd4jLong realShape = shape::length(inp) / shapeLength; - shapeNew.push_back(realShape); - } - } - else{ - shapeNew.push_back(arguments->at(e)); - } - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew))); - } else { - // or, with second input "as shape" - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - - // special case here - if (y->isEmpty()) { - REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array"); - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp))); - } - //Special case: empty.reshape(-1) -> return empty - if (x->isEmpty()) { - //REQUIRE_TRUE(y->lengthOf() == 1 && y->e(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); - auto shapeOf = y->getBufferAsVector(); - Nd4jLong prod = 1; - bool hasNegs = false; - for (auto v:shapeOf) { - if (v < 0) { - hasNegs = true; - v = 0; - } - - prod *= v; - } - - REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well"); - - // if there are -1s - we turn them into zeros - if (hasNegs) { - for (int e = 0; e < shapeOf.size(); e++) - if (shapeOf[e] < 0) - shapeOf[e] = 0; - } - - auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data()); - return SHAPELIST(CONSTANT(newShape)); - } - - std::vector shapeNew(y->lengthOf()); - - for (int e = 0; e < (int) y->lengthOf(); e++) { - auto dim = y->e(e); - if (dim == -1){ - Nd4jLong shapeLength = 1; - for(int e2 = 0; e2 < e; e2++){ - shapeLength *= y->e(e2); - } - for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){ - REQUIRE_TRUE(y->e(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - shapeLength *= y->e(e2); - } - - if(shapeLength == 0){ - //Edge case for empty: - shapeNew[e] = 0; - } else { - Nd4jLong realShape = shape::length(inp) / shapeLength; - shapeNew[e] = realShape; - } - }else { - shapeNew[e] = dim; - } - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew)); } + else { + reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector(); + orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c'; + } + + REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !"); + + Nd4jLong xLen = x->lengthOf(); + if(x->isEmpty()) { + xLen = 1; + for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes + if(x->sizeAt(i) != 0) + xLen *= x->sizeAt(i); + } + + for (uint i = 0; i < reshapeArgs.size(); ++i) { + + if (reshapeArgs[i] == -1) { + + uint shapeLength = 1, numOfZeros = 0; + + for(uint j = 0; j < i; ++j) + if(reshapeArgs[j] != 0) + shapeLength *= reshapeArgs[j]; + else + ++numOfZeros; + + for(uint j = i + 1; j < reshapeArgs.size(); ++j) { + REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); + if(reshapeArgs[j] != 0) + shapeLength *= reshapeArgs[j]; + else + ++numOfZeros; + } + + const auto dim = xLen / shapeLength; + + if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) + shapeNew.push_back(0); + else + shapeNew.push_back(dim); + } + else + shapeNew.push_back(reshapeArgs[i]); + } + + auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); + REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len); + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew)); } + } } diff --git a/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp b/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp index 18551909c..97893ca5b 100644 --- a/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp @@ -40,16 +40,15 @@ TEST_F(ArrayOptionsTests, TestShape_Basic_0) { TEST_F(ArrayOptionsTests, TestShape_Basic_1) { shape[5] = 2; - + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); } - TEST_F(ArrayOptionsTests, TestShape_Basic_2) { shape[5] = 258; - + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); @@ -58,7 +57,7 @@ TEST_F(ArrayOptionsTests, TestShape_Basic_2) { TEST_F(ArrayOptionsTests, TestShape_Basic_3) { ASSERT_EQ(0, shape::extra(shape)); - + ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 28240cc10..8a03d4abc 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -166,7 +166,7 @@ TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) { auto z = result.at(0); ASSERT_TRUE(z->equalsTo(exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -180,7 +180,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) { auto z = result.at(0); ASSERT_TRUE(z->equalsTo(exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -198,7 +198,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) { ASSERT_TRUE(z1->equalsTo(exp1)); ASSERT_TRUE(z2->equalsTo(exp2)); - + } ////////////////////////////////////////////////////////////////////// @@ -213,7 +213,7 @@ TEST_F(DeclarableOpsTests1, AXpY_Test_1) { auto z = result.at(0); ASSERT_TRUE(z->equalsTo(exp)); - + } TEST_F(DeclarableOpsTests1, BasicInitialization3) { @@ -258,7 +258,7 @@ TEST_F(DeclarableOpsTests1, TestTensorMmul1) { ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); - + } TEST_F(DeclarableOpsTests1, TestTensorDot2) { @@ -278,7 +278,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot2) { ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); - + } TEST_F(DeclarableOpsTests1, TestTensorDot3) { @@ -298,7 +298,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot3) { ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); - + } TEST_F(DeclarableOpsTests1, TestTensorDot4) { @@ -318,7 +318,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) { ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////// @@ -338,7 +338,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot5) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -360,7 +360,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot6) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -381,7 +381,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot7) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -402,7 +402,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot8) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -431,7 +431,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot9) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -452,7 +452,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot10) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -474,7 +474,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot11) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -495,7 +495,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot12) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -516,7 +516,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot13) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -537,7 +537,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot14) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -558,7 +558,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot15) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -579,7 +579,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot16) { ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); - + } //////////////////////////////////////////////////////////////////// @@ -786,7 +786,7 @@ TEST_F(DeclarableOpsTests1, SubtractTest_2) { ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - + } TEST_F(DeclarableOpsTests1, TestRng1) { @@ -1046,7 +1046,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1071,7 +1071,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1093,7 +1093,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1121,7 +1121,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_1) { ASSERT_TRUE(res.at(0)->equalsTo(&exp)); ASSERT_TRUE(exp.equalsTo(&z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1147,7 +1147,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(&exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1402,7 +1402,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1421,7 +1421,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1437,7 +1437,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1463,7 +1463,7 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { ASSERT_TRUE(z.equalsTo(&exp)); - + } ////////////////////////////////////////////////////////////////////// @@ -1676,31 +1676,6 @@ TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) { delete block; } -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, Reshapeas1) { - const std::vector xShape = { 5,4,3 }; - const std::vector yShape = { 3,5,4 }; - - auto x = NDArrayFactory::create_('f', xShape); - auto y = NDArrayFactory::create_('f', yShape); - - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); - - sd::ops::reshapeas reshape; - - reshape.execute(block); - - ASSERT_TRUE(x->isSameShape(y)); - - delete variableSpace; - delete block; -} - TEST_F(DeclarableOpsTests1, Test_Cast_1) { // TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected auto x = NDArrayFactory::create('c', { 5, 5 }); @@ -1715,7 +1690,7 @@ TEST_F(DeclarableOpsTests1, Test_Cast_1) { auto z = result.at(0); ASSERT_TRUE(yExp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1848,113 +1823,6 @@ TEST_F(DeclarableOpsTests1, TestGemv1) { #endif -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, Reshape2) { - const std::vector xShape = { 5,4,3 }; - const std::vector yShape = { 3,5,4 }; - - auto x = NDArrayFactory::create_('c', xShape); - auto y = NDArrayFactory::create_('c', yShape); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1 }); - std::vector* arguments = block->getIArguments(); - arguments->push_back(-y->ordering()); - arguments->push_back(3); - arguments->push_back(5); - arguments->push_back(4); - - sd::ops::reshape reshape; - - Nd4jStatus status = reshape.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - - ASSERT_TRUE(result->isSameShape(y)); - - delete y; - delete block; - delete variableSpace; -} - -TEST_F(DeclarableOpsTests1, Reshape3) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(x.isSameShape(z)); - - -} - -TEST_F(DeclarableOpsTests1, Reshape4) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 3, 4, 5 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(x.isSameShape(z)); - - -} - -TEST_F(DeclarableOpsTests1, Reshape5) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 5, 4, 3 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - -} - -TEST_F(DeclarableOpsTests1, Reshape6) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto exp = NDArrayFactory::create('c', { 4, 15 }); - - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 4, -1 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(z->isSameShape(exp)); - - - -} - -TEST_F(DeclarableOpsTests1, Reshape7) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto exp = NDArrayFactory::create('c', { 60 }); - - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { -1 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(z->isSameShape(exp)); - - - -} ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Transpose1) { @@ -1983,7 +1851,6 @@ TEST_F(DeclarableOpsTests1, Transpose1) { delete variableSpace; } - ////////////////////////////////////////////////////////////////////// // not-in-place TEST_F(DeclarableOpsTests1, Permute1) { @@ -2259,7 +2126,7 @@ TEST_F(DeclarableOpsTests1, IsMax1) { //res->printIndexedBuffer("IS_MAX"); ASSERT_TRUE(exp.equalsTo(res)); - + } ////////////////////////////////////////////////////////////////////// @@ -2281,7 +2148,7 @@ TEST_F(DeclarableOpsTests1, IsMax2) { //res->printIndexedBuffer("IS_MAX"); ASSERT_TRUE(exp.equalsTo(res)); - + } ////////////////////////////////////////////////////////////////////// @@ -2303,7 +2170,7 @@ TEST_F(DeclarableOpsTests1, IsMax3) { //res->printIndexedBuffer("IS_MAX"); ASSERT_TRUE(exp.equalsTo(res)); - + } ////////////////////////////////////////////////////////////////////// @@ -2352,7 +2219,7 @@ TEST_F(DeclarableOpsTests1, IsMax4) { // ASSERT_TRUE(expState.equalsTo(state)); // ASSERT_TRUE(expOut.equalsTo(output)); -// +// // } ////////////////////////////////////////////////////////////////// @@ -2386,7 +2253,7 @@ TEST_F(DeclarableOpsTests1, sru_test1) { ASSERT_TRUE(expState.equalsTo(state)); ASSERT_TRUE(expOut.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2438,7 +2305,7 @@ TEST_F(DeclarableOpsTests1, sru_bp) { ASSERT_TRUE(expGradB.equalsTo(gradB)); ASSERT_TRUE(expGradInit.equalsTo(gradInit)); - + } ////////////////////////////////////////////////////////////////// @@ -2474,7 +2341,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_1) { ASSERT_TRUE(expState.equalsTo(state)); ASSERT_TRUE(expOut.equalsTo(output)); - + } TEST_F(DeclarableOpsTests1, sru_bi_bp_1) { @@ -2527,7 +2394,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) { ASSERT_TRUE(expGradB.equalsTo(gradB)); ASSERT_TRUE(expGradInit.equalsTo(gradInit)); - + } TEST_F(DeclarableOpsTests1, ArgMax1) { @@ -2547,7 +2414,7 @@ TEST_F(DeclarableOpsTests1, ArgMax1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -2568,7 +2435,7 @@ TEST_F(DeclarableOpsTests1, ArgMax2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -2590,7 +2457,7 @@ TEST_F(DeclarableOpsTests1, ArgMax3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, ArgMax4) { @@ -2611,7 +2478,7 @@ TEST_F(DeclarableOpsTests1, ArgMax4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -2633,7 +2500,7 @@ TEST_F(DeclarableOpsTests1, ArgMax5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, ArgMax6) { @@ -2676,7 +2543,7 @@ TEST_F(DeclarableOpsTests1, ArgMin1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -2697,7 +2564,7 @@ TEST_F(DeclarableOpsTests1, SquareTests1) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_1) { @@ -2717,7 +2584,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_2) { @@ -2736,7 +2603,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_2) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_3) { @@ -2756,7 +2623,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_4) { @@ -2775,7 +2642,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_5) { @@ -2796,7 +2663,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests1, OneHotTests_6) { @@ -2809,7 +2676,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_6) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests1, OneHotTests_7) { @@ -2822,7 +2689,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_7) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests1, FillAs_1) { @@ -2840,7 +2707,7 @@ TEST_F(DeclarableOpsTests1, FillAs_1) { ASSERT_NEAR(scalar, result.at(0)->meanNumber().e(0), 1e-5f); - + } ////////////////////////////////////////////////////////////////////// @@ -2866,7 +2733,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) { ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.equalsTo(array)); - + } @@ -2893,7 +2760,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_2) { ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.equalsTo(array)); - + } @@ -2913,7 +2780,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) { ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.equalsTo(array)); - + } ////////////////////////////////////////////////////////////////////// @@ -2931,7 +2798,7 @@ TEST_F(DeclarableOpsTests1, softmax_test1) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2947,7 +2814,7 @@ TEST_F(DeclarableOpsTests1, softmax_test2) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2963,7 +2830,7 @@ TEST_F(DeclarableOpsTests1, softmax_test3) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2979,7 +2846,7 @@ TEST_F(DeclarableOpsTests1, softmax_test4) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2995,7 +2862,7 @@ TEST_F(DeclarableOpsTests1, softmax_test5) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3011,7 +2878,7 @@ TEST_F(DeclarableOpsTests1, softmax_test6) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3027,7 +2894,7 @@ TEST_F(DeclarableOpsTests1, softmax_test7) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3043,7 +2910,7 @@ TEST_F(DeclarableOpsTests1, softmax_test8) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3059,7 +2926,7 @@ TEST_F(DeclarableOpsTests1, softmax_test9) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test10) { @@ -3074,7 +2941,7 @@ TEST_F(DeclarableOpsTests1, softmax_test10) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test11) { @@ -3089,7 +2956,7 @@ TEST_F(DeclarableOpsTests1, softmax_test11) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3108,7 +2975,7 @@ TEST_F(DeclarableOpsTests1, softmax_test12) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_1) { @@ -3132,7 +2999,7 @@ TEST_F(DeclarableOpsTests1, Reverse_1) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3157,7 +3024,7 @@ TEST_F(DeclarableOpsTests1, Reverse_2) { ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.equalsTo(&input)); - + } ////////////////////////////////////////////////////////////////////// @@ -3183,7 +3050,7 @@ TEST_F(DeclarableOpsTests1, Reverse_3) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3209,7 +3076,7 @@ TEST_F(DeclarableOpsTests1, Reverse_4) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3234,7 +3101,7 @@ TEST_F(DeclarableOpsTests1, Reverse_5) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } //////////////////////////////////////////////////////////////////// @@ -3260,7 +3127,7 @@ TEST_F(DeclarableOpsTests1, Reverse_6) { ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.equalsTo(&input)); - + } @@ -3288,7 +3155,7 @@ TEST_F(DeclarableOpsTests1, Reverse_7) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -3316,7 +3183,7 @@ TEST_F(DeclarableOpsTests1, Reverse_8) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } //////////////////////////////////////////////////////////////////// @@ -3341,7 +3208,7 @@ TEST_F(DeclarableOpsTests1, Reverse_9) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } TEST_F(DeclarableOpsTests1, Reverse_10) { @@ -3357,7 +3224,7 @@ TEST_F(DeclarableOpsTests1, Reverse_10) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -3380,7 +3247,7 @@ TEST_F(DeclarableOpsTests1, Reverse_11) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3402,7 +3269,7 @@ TEST_F(DeclarableOpsTests1, Reverse_12) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3423,7 +3290,7 @@ TEST_F(DeclarableOpsTests1, Reverse_13) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } ////////////////////////////////////////////////////////////////////// @@ -3444,7 +3311,7 @@ TEST_F(DeclarableOpsTests1, Reverse_14) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } TEST_F(DeclarableOpsTests1, Test_Expose_1) { @@ -3463,7 +3330,7 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) { ASSERT_TRUE(input0.equalsTo(z0)); ASSERT_TRUE(input1.equalsTo(z1)); - + } TEST_F(DeclarableOpsTests1, Test_Expose_2) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index b3a710be9..db49c12f2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -51,23 +51,7 @@ TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) { ASSERT_EQ(exp, *z); - -} -TEST_F(DeclarableOpsTests14, Test_Reshape_CF_1) { - auto x = NDArrayFactory::create('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0}); - auto e = NDArrayFactory::create('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); - - auto r = x.reshape('c', {3, 2});; - r.streamline('f'); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {3, 2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - } TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) { @@ -108,7 +92,7 @@ TEST_F(DeclarableOpsTests14, Multiply_test) { ASSERT_EQ(e, r); ASSERT_EQ(e, *f); - + } } @@ -124,7 +108,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) { auto z = result.at(0); ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) { @@ -139,7 +123,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) { auto z = result.at(0); ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_0) { @@ -193,7 +177,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_1) { ASSERT_EQ(e, *result.at(0)); - + } TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) { @@ -210,7 +194,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) { ASSERT_EQ(e, *result.at(0)); - + } TEST_F(DeclarableOpsTests14, test_empty_fill_1) { @@ -224,7 +208,7 @@ TEST_F(DeclarableOpsTests14, test_empty_fill_1) { auto z = result.at(0); ASSERT_EQ(y, *z); - + } TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) { @@ -259,7 +243,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) { auto out = res2.at(0); ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); - + } TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) { @@ -271,7 +255,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) { auto out = res2.at(0); ASSERT_EQ(out->e(0), -DataTypeUtils::infOrMax()); - + } TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) { @@ -286,7 +270,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) { ASSERT_EQ(res2.status(), Status::OK()); auto out = res2.at(0); ASSERT_EQ(out->e(0), 0.f); - + } TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) { @@ -303,7 +287,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) { // out->printShapeInfo("ReduceMean empty shape with keep dims"); // out->printIndexedBuffer("ReduceMean scalar"); ASSERT_TRUE(std::isnan(out->e(0))); - + } TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) { @@ -324,7 +308,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) { ASSERT_TRUE(exp.isSameShape(z)); - + } TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) { @@ -345,7 +329,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) { ASSERT_TRUE(exp.isSameShape(z)); - + } TEST_F(DeclarableOpsTests14, test_empty_argmax_1) { @@ -363,7 +347,7 @@ TEST_F(DeclarableOpsTests14, test_empty_argmax_1) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests14, test_empty_argmax_2) { @@ -391,7 +375,7 @@ TEST_F(DeclarableOpsTests14, test_empty_tanh_5) { ASSERT_TRUE(x.isSameShape(z)); ASSERT_EQ(x, *z); - + } ////////////////////////////////////////////////////////////////////// @@ -409,7 +393,7 @@ TEST_F(DeclarableOpsTests14, repeat_1) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -427,7 +411,7 @@ TEST_F(DeclarableOpsTests14, repeat_2) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -445,7 +429,7 @@ TEST_F(DeclarableOpsTests14, repeat_3) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -463,7 +447,7 @@ TEST_F(DeclarableOpsTests14, repeat_4) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -481,7 +465,7 @@ TEST_F(DeclarableOpsTests14, repeat_5) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) { @@ -502,7 +486,7 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) { ASSERT_EQ(e, res); - + } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) { @@ -523,7 +507,7 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) { ASSERT_EQ(e, res); - + } /////////////////////////////////////////////////////////////////////// @@ -639,7 +623,7 @@ TEST_F(DeclarableOpsTests14, matmul_test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -661,7 +645,7 @@ TEST_F(DeclarableOpsTests14, matmul_test2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -682,7 +666,7 @@ TEST_F(DeclarableOpsTests14, matmul_test3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -704,7 +688,7 @@ TEST_F(DeclarableOpsTests14, matmul_test4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -726,7 +710,7 @@ TEST_F(DeclarableOpsTests14, matmul_test5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -747,7 +731,7 @@ TEST_F(DeclarableOpsTests14, matmul_test6) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -770,7 +754,7 @@ TEST_F(DeclarableOpsTests14, matmul_test7) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -795,7 +779,7 @@ TEST_F(DeclarableOpsTests14, matmul_test8) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -820,7 +804,7 @@ TEST_F(DeclarableOpsTests14, matmul_test9) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test10) { @@ -876,7 +860,7 @@ TEST_F(DeclarableOpsTests14, matmul_test11) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test12) { @@ -894,7 +878,7 @@ TEST_F(DeclarableOpsTests14, matmul_test12) { ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -914,7 +898,7 @@ TEST_F(DeclarableOpsTests14, matmul_test13) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test14) { @@ -933,7 +917,7 @@ TEST_F(DeclarableOpsTests14, matmul_test14) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test15) { @@ -952,7 +936,7 @@ TEST_F(DeclarableOpsTests14, matmul_test15) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test16) { @@ -971,7 +955,7 @@ TEST_F(DeclarableOpsTests14, matmul_test16) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, matmul_test17) { @@ -985,7 +969,7 @@ TEST_F(DeclarableOpsTests14, matmul_test17) { ASSERT_EQ(exp, *result.at(0)); - + } @@ -1007,7 +991,7 @@ TEST_F(DeclarableOpsTests14, matmul_test18) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1027,7 +1011,7 @@ TEST_F(DeclarableOpsTests14, matmul_test19) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1048,7 +1032,7 @@ TEST_F(DeclarableOpsTests14, matmul_test20) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1069,7 +1053,7 @@ TEST_F(DeclarableOpsTests14, matmul_test21) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1090,7 +1074,7 @@ TEST_F(DeclarableOpsTests14, matmul_test22) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1111,7 +1095,7 @@ TEST_F(DeclarableOpsTests14, matmul_test23) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1135,7 +1119,7 @@ TEST_F(DeclarableOpsTests14, matmul_test24) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1156,7 +1140,7 @@ TEST_F(DeclarableOpsTests14, matmul_test25) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1177,7 +1161,7 @@ TEST_F(DeclarableOpsTests14, matmul_test26) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1198,7 +1182,7 @@ TEST_F(DeclarableOpsTests14, matmul_test27) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -1220,7 +1204,7 @@ TEST_F(DeclarableOpsTests14, matmul_test28) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -1242,7 +1226,7 @@ TEST_F(DeclarableOpsTests14, matmul_test29) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test30) { @@ -1262,7 +1246,7 @@ TEST_F(DeclarableOpsTests14, matmul_test30) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test31) { @@ -1282,7 +1266,7 @@ TEST_F(DeclarableOpsTests14, matmul_test31) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test32) { @@ -1299,7 +1283,7 @@ TEST_F(DeclarableOpsTests14, matmul_test32) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test33) { @@ -1319,7 +1303,7 @@ TEST_F(DeclarableOpsTests14, matmul_test33) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test34) { @@ -1336,7 +1320,7 @@ TEST_F(DeclarableOpsTests14, matmul_test34) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ///////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test35) { @@ -1353,7 +1337,7 @@ TEST_F(DeclarableOpsTests14, matmul_test35) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test36) { @@ -1370,7 +1354,7 @@ TEST_F(DeclarableOpsTests14, matmul_test36) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test37) { @@ -1617,7 +1601,7 @@ TEST_F(DeclarableOpsTests14, Stack_1) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } @@ -1645,7 +1629,7 @@ TEST_F(DeclarableOpsTests14, Stack_2) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } @@ -1673,7 +1657,7 @@ TEST_F(DeclarableOpsTests14, Stack_3) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1700,7 +1684,7 @@ TEST_F(DeclarableOpsTests14, Stack_4) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1727,7 +1711,7 @@ TEST_F(DeclarableOpsTests14, Stack_5) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1754,7 +1738,7 @@ TEST_F(DeclarableOpsTests14, Stack_6) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } @@ -1778,7 +1762,7 @@ TEST_F(DeclarableOpsTests14, Stack_7) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1801,7 +1785,7 @@ TEST_F(DeclarableOpsTests14, Stack_8) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1824,7 +1808,7 @@ TEST_F(DeclarableOpsTests14, Stack_9) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1850,7 +1834,7 @@ TEST_F(DeclarableOpsTests14, Stack_10) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } TEST_F(DeclarableOpsTests14, Stack_11) { @@ -1872,7 +1856,7 @@ TEST_F(DeclarableOpsTests14, Stack_11) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } @@ -1895,7 +1879,7 @@ TEST_F(DeclarableOpsTests14, Stack_12) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1917,7 +1901,7 @@ TEST_F(DeclarableOpsTests14, Stack_13) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1941,7 +1925,7 @@ TEST_F(DeclarableOpsTests14, Stack_14) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, Stack_15) { @@ -1959,7 +1943,7 @@ TEST_F(DeclarableOpsTests14, Stack_15) { ASSERT_TRUE(exp.isSameShape(z)); - + } @@ -1978,7 +1962,7 @@ TEST_F(DeclarableOpsTests14, Stack_16) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, Stack_17) { @@ -1999,7 +1983,7 @@ TEST_F(DeclarableOpsTests14, Stack_17) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests14, Stack_18) { @@ -2018,8 +2002,8 @@ TEST_F(DeclarableOpsTests14, Stack_18) { auto out = res2.at(0); ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); - - + + } TEST_F(DeclarableOpsTests14, Stack_19) { @@ -2033,7 +2017,7 @@ TEST_F(DeclarableOpsTests14, Stack_19) { auto z = result.at(0); ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests14, Stack_20) { @@ -2047,7 +2031,7 @@ TEST_F(DeclarableOpsTests14, Stack_20) { auto z = result.at(0); ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests14, Stack_21) { @@ -2073,7 +2057,363 @@ TEST_F(DeclarableOpsTests14, Stack_21) { ASSERT_TRUE(outStack->isSameShape(outConcat)); ASSERT_TRUE(outStack->equalsTo(outConcat)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Reshape1) { + const std::vector xShape = { 5,4,3 }; + const std::vector yShape = { 3,5,4 }; + + auto x = NDArrayFactory::create_('f', xShape); + auto y = NDArrayFactory::create_('f', yShape); + + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::reshapeas reshape; + + reshape.execute(block); + + ASSERT_TRUE(x->isSameShape(y)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Reshape2) { + const std::vector xShape = { 5,4,3 }; + const std::vector yShape = { 3,5,4 }; + + auto x = NDArrayFactory::create_('c', xShape); + auto y = NDArrayFactory::create_('c', yShape); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(1, new Variable()); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({ -1 }); + std::vector* arguments = block->getIArguments(); + arguments->push_back(-y->ordering()); + arguments->push_back(3); + arguments->push_back(5); + arguments->push_back(4); + + sd::ops::reshape reshape; + + Nd4jStatus status = reshape.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + + ASSERT_TRUE(result->isSameShape(y)); + + delete y; + delete block; + delete variableSpace; +} + +TEST_F(DeclarableOpsTests14, Reshape3) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x.isSameShape(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape4) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { 3, 4, 5 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x.isSameShape(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape5) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { 5, 4, 3 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); +} + +TEST_F(DeclarableOpsTests14, Reshape6) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto exp = NDArrayFactory::create('c', { 4, 15 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { 4, -1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(z->isSameShape(exp)); +} + +TEST_F(DeclarableOpsTests14, Reshape7) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto exp = NDArrayFactory::create('c', { 60 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { -1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(z->isSameShape(exp)); +} + +TEST_F(DeclarableOpsTests14, Reshape8) { + auto x = NDArrayFactory::create('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0}); + auto e = NDArrayFactory::create('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); + + auto r = x.reshape('c', {3, 2});; + r.streamline('f'); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {3, 2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +} + +TEST_F(DeclarableOpsTests14, Reshape9) { + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + + sd::ops::reshape op; + auto result = op.evaluate({&array}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); +} + +TEST_F(DeclarableOpsTests14, Reshape10) { + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + auto z = NDArrayFactory::create('c', {1, 1}); + + sd::ops::reshape op; + auto result = op.execute({&array}, {&z}, {}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests14, Reshape11) { + auto x = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('c', {4, 3}); + + x.linspace(1); + exp.linspace(1); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {-99, 4, 3}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape12) { + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + auto shape = NDArrayFactory::create('c', {2}, {-1, 2}); + auto exp = NDArrayFactory::create('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + + sd::ops::reshape op; + auto result = op.evaluate({&x, &shape}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape13) { + auto vector = NDArrayFactory::create('c', {1}, {119.0f}); + auto exp = NDArrayFactory::create(119.f); + auto empty = NDArrayFactory::empty_(); + + sd::ops::reshape op; + auto result = op.evaluate({&vector, empty}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(exp, *result.at(0)); + + delete empty; +} + +TEST_F(DeclarableOpsTests14, Reshape14) { + auto x = NDArrayFactory::create('c', {1, 0, 0, 2}); + auto y = NDArrayFactory::create('c', {2}, {10, 0}); + auto e = NDArrayFactory::create('c', {10, 0}); + + sd::ops::reshape op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_EQ(e, *z); } +TEST_F(DeclarableOpsTests14, Reshape15) { + auto x0 = NDArrayFactory::create('c', {2, 0}); + auto x1 = NDArrayFactory::create('c', {0, 1, 2}); + + auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); + auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); + + auto e0 = NDArrayFactory::create('c', {2, 0, 0}); + auto e1 = NDArrayFactory::create('c', {0, 1}); + + sd::ops::reshape op; + auto result0 = op.evaluate({&x0, &shape0}, {}, {}); + ASSERT_EQ(Status::OK(), result0.status()); + auto z0 = result0.at(0); + ASSERT_EQ(e0, *z0); + + auto result1 = op.evaluate({&x1, &shape1}, {}, {}); + ASSERT_EQ(Status::OK(), result1.status()); + auto z1 = result1.at(0); + ASSERT_EQ(e1, *z1); +} + +TEST_F(DeclarableOpsTests14, Reshape16) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto shape = NDArrayFactory::create('c', {1, 3}, {1, 2, 2}); + + auto exp = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); + + sd::ops::reshape op; + + auto result = op.evaluate({&x, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape17) { + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {2.0f}); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape18) { + auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, Reshape19) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + + +TEST_F(DeclarableOpsTests14, Reshape20) { + + NDArray x1('c', {2,0}, sd::DataType::FLOAT32); + NDArray x2('c', {10,0}, sd::DataType::FLOAT32); + NDArray x3('c', {2,0,0,10}, sd::DataType::FLOAT32); + NDArray x4('c', {0,0,10}, sd::DataType::FLOAT32); + NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32); + NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32); + NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32); + + sd::ops::reshape op; + + auto result = op.evaluate({&x1}, {}, {2, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0})); + + result = op.evaluate({&x2}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,5})); + + result = op.evaluate({&x2}, {}, {5, 2, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({5,2,0})); + + result = op.evaluate({&x2}, {}, {-1, 2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({5,2,0})); + + result = op.evaluate({&x3}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,10})); + + result = op.evaluate({&x4}, {}, {2, -1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,5,0})); + + result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,0,0,10})); + + result = op.evaluate({&x6}, {}, {-1, 2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({5, 2, 0})); + + result = op.evaluate({&x7}, {}, {-1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2, 0})); + + result = op.evaluate({&x7}, {}, {10,0,50,100}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100})); +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index f983d27a3..5fffa73c5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -89,7 +89,7 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) { sd::ops::standardize_bp op; auto result = op.evaluate({&x, &eps}, {0}); ASSERT_EQ(Status::OK(), result.status()); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { @@ -108,7 +108,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { auto out = result.at(0); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { @@ -126,7 +126,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { auto out = result.at(0); // out->printIndexedBuffer("Adjusted Constrast"); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { @@ -144,7 +144,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { auto out = result.at(0); // out->printIndexedBuffer("Adjusted Constrast"); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { @@ -162,7 +162,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { auto out = result.at(0); // out->printIndexedBuffer("Adjusted Constrast"); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { @@ -177,7 +177,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { auto out = result.at(0); // out->printIndexedBuffer("Adjusted Constrast"); ASSERT_TRUE(e.equalsTo(out)); - + } /* @@ -308,7 +308,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) { // out->printBuffer("Adjusted Constrast6"); // e.printBuffer("Adjusted Expected 6"); // ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { @@ -415,7 +415,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { auto diff = e - *out; // diff.printBuffer("Adjusted subtract 7"); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_BitCast_1) { @@ -429,7 +429,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_1) { auto out = result.at(0); // out->printIndexedBuffer("Casted result"); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_BitCast_2) { @@ -444,7 +444,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_2) { auto out = result.at(0); ASSERT_TRUE(e.equalsTo(out)); - + } TEST_F(DeclarableOpsTests15, Test_BitCast_3) { @@ -487,7 +487,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) { // e.printIndexedBuffer("Double to int64"); auto res = result.at(0); ASSERT_EQ(*res, e); - + } @@ -508,7 +508,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_5) { // res->printIndexedBuffer("BITCAST5"); ASSERT_TRUE(e.equalsTo(res)); - + } TEST_F(DeclarableOpsTests15, Test_BitCast_6) { @@ -528,7 +528,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_6) { // res->printIndexedBuffer("BITCAST6"); ASSERT_TRUE(e.equalsTo(res)); - + } TEST_F(DeclarableOpsTests15, Test_BitCast_7) { auto x = NDArrayFactory::create('c', {4, 4}, { @@ -547,7 +547,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_7) { // res->printIndexedBuffer("BITCAST7"); ASSERT_TRUE(e.equalsTo(res)); - + } TEST_F(DeclarableOpsTests15, test_matmul_bp_1) { @@ -637,7 +637,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { sd::ops::layer_norm op; auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false}); ASSERT_EQ(Status::OK(), result.status()); - + } TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { @@ -649,7 +649,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { sd::ops::layer_norm_bp op; auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false}); ASSERT_EQ(Status::OK(), result.status()); - + } ////////////////////////////////////////////////////////////////////// @@ -710,30 +710,6 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) { ASSERT_NE(*resultA0.at(0), *resultB0.at(0)); } -TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) { - auto array = NDArrayFactory::create(119.f); - auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); - - sd::ops::reshape op; - auto result = op.evaluate({&array}, {}, {1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_EQ(e, *z); -} - -TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_2) { - auto array = NDArrayFactory::create(119.f); - auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); - auto z = NDArrayFactory::create('c', {1, 1}); - - sd::ops::reshape op; - auto result = op.execute({&array}, {&z}, {}, {1, 1}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, z); -} - TEST_F(DeclarableOpsTests15, test_rank_1) { auto array = NDArrayFactory::create('c', {4, 64}); auto e = NDArrayFactory::create('c', {}, {2}); @@ -757,7 +733,7 @@ TEST_F(DeclarableOpsTests15, test_rank_2) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { @@ -800,7 +776,7 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_2) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - + } TEST_F(DeclarableOpsTests15, test_lstmBlock_3) { @@ -969,7 +945,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) { sd::ops::rgb_to_grs op; auto result = op.evaluate({ &rgbs }, {}, {}); ASSERT_EQ(Status::THROW(), result.status()); - + } catch (std::exception& e) { nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); } @@ -1063,7 +1039,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) { ASSERT_EQ(Status::OK(), result.status()); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1074,7 +1050,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) { sd::ops::rgb_to_yuv op; auto result = op.evaluate({ &rgbs }, {}, {}); ASSERT_EQ(Status::THROW(), result.status()); - + } catch (std::exception & e) { nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); @@ -1109,7 +1085,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1168,7 +1144,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) { ASSERT_EQ(Status::OK(), result.status()); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1179,7 +1155,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) { sd::ops::yuv_to_rgb op; auto result = op.evaluate({ &yuv }, {}, {}); ASSERT_EQ(Status::THROW(), result.status()); - + } catch (std::exception & e) { nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); @@ -1423,7 +1399,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test8) { ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); - + } TEST_F(DeclarableOpsTests15, Pow_BP_Test9) { @@ -1515,7 +1491,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) { ASSERT_NEAR(dLdyB->e(i), dLdyExpB.e(i), 0.00001); } - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) { @@ -1532,10 +1508,10 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); - + ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); @@ -1554,10 +1530,10 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); - + ASSERT_TRUE(B.isSameShape(*dLdAbp)); ASSERT_TRUE(B.equalsTo(*dLdAbp)); @@ -1606,7 +1582,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); @@ -1632,7 +1608,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); @@ -1655,7 +1631,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); @@ -1706,7 +1682,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - + auto* dLdAbp = resultsBP.at(0); auto* dLdBbp = resultsBP.at(1); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 69dec8359..1e877ecc6 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -789,24 +789,6 @@ TEST_F(DeclarableOpsTests4, Test_FloorTests_1) { ASSERT_TRUE(exp.equalsTo(z)); -} - -TEST_F(DeclarableOpsTests4, Test_Reshape_Again) { - auto x = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('c', {4, 3}); - - x.linspace(1); - exp.linspace(1); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {-99, 4, 3}); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(DeclarableOpsTests4, Test_Split_1) { @@ -1209,23 +1191,6 @@ TEST_F(DeclarableOpsTests4, Test_Add_119) { ASSERT_TRUE(exp.equalsTo(z)); -} - -TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) { - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - auto shape = NDArrayFactory::create('c', {2}, {-1, 2}); - auto exp = NDArrayFactory::create('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - - sd::ops::reshape op; - auto result = op.evaluate({&x, &shape}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(DeclarableOpsTests4, Test_TileToShape_1) { diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index ab6bad3c4..e6aeb43d4 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -140,37 +140,6 @@ TEST_F(EmptyTests, Test_Concat_4) { ASSERT_EQ(exp, *z); } -TEST_F(EmptyTests, Test_Reshape_1) { - auto vector = NDArrayFactory::create('c', {1}, {119.0f}); - auto exp = NDArrayFactory::create(119.f); - auto empty = NDArrayFactory::empty_(); - - sd::ops::reshape op; - auto result = op.evaluate({&vector, empty}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(exp, *result.at(0)); - - delete empty; -} - -TEST_F(EmptyTests, Test_Reshape_3) { - auto x = NDArrayFactory::create('c', {1, 0, 0, 2}); - auto y = NDArrayFactory::create('c', {2}, {10, 0}); - auto e = NDArrayFactory::create('c', {10, 0}); - - sd::ops::reshape op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_EQ(e, *z); - -} - TEST_F(EmptyTests, Test_dup_1) { auto empty = NDArrayFactory::empty(); auto dup = new NDArray(empty.dup()); @@ -256,41 +225,6 @@ TEST_F(EmptyTests, test_shaped_empty_4) { ASSERT_EQ(shapeOf, array.getShapeAsVector()); } -TEST_F(EmptyTests, test_empty_reshape_1) { - /* - INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0); - INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2); - - INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0]; - INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0]; - INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0]; - - assertArrayEquals(new long[]{2, 0, 0}, out0.shape()); - assertArrayEquals(new long[]{0, 1}, out1.shape()); - assertArrayEquals(new long[]{10, 0}, out2.shape()); - */ - auto x0 = NDArrayFactory::create('c', {2, 0}); - auto x1 = NDArrayFactory::create('c', {0, 1, 2}); - - auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); - auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); - - auto e0 = NDArrayFactory::create('c', {2, 0, 0}); - auto e1 = NDArrayFactory::create('c', {0, 1}); - - sd::ops::reshape op; - auto result0 = op.evaluate({&x0, &shape0}, {}, {}); - ASSERT_EQ(Status::OK(), result0.status()); - auto z0 = result0.at(0); - ASSERT_EQ(e0, *z0); - - auto result1 = op.evaluate({&x1, &shape1}, {}, {}); - ASSERT_EQ(Status::OK(), result1.status()); - auto z1 = result1.at(0); - ASSERT_EQ(e1, *z1); - -} - TEST_F(EmptyTests, test_empty_matmul_1) { auto x = NDArrayFactory::create('c', {0, 1}); diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index ce24c8a9b..089b4a92f 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -48,7 +48,7 @@ TEST_F(ParityOpsTests, TestZeroAs1) { ASSERT_TRUE(z->isSameShape(&x)); ASSERT_TRUE(z->equalsTo(&exp)); - + } TEST_F(ParityOpsTests, TestMaximum1) { @@ -66,7 +66,7 @@ TEST_F(ParityOpsTests, TestMaximum1) { ASSERT_TRUE(y.equalsTo(z)); - + } @@ -86,7 +86,7 @@ TEST_F(ParityOpsTests, TestMinimum1) { ASSERT_TRUE(y.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestTear1) { @@ -106,7 +106,7 @@ TEST_F(ParityOpsTests, TestTear1) { for (int e = 0; e < result.size(); e++) ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); - + } TEST_F(ParityOpsTests, TestUnstack1) { @@ -126,7 +126,7 @@ TEST_F(ParityOpsTests, TestUnstack1) { for (int e = 0; e < result.size(); e++) ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); - + } @@ -148,7 +148,7 @@ TEST_F(ParityOpsTests, TestUnstack2) { for (int e = 0; e < result.size(); e++) ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); - + } TEST_F(ParityOpsTests, TestUnstack3) { @@ -166,7 +166,7 @@ TEST_F(ParityOpsTests, TestUnstack3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -185,7 +185,7 @@ TEST_F(ParityOpsTests, TestUnstack4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestUnstack5) { @@ -203,7 +203,7 @@ TEST_F(ParityOpsTests, TestUnstack5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestUnstack6) { @@ -221,7 +221,7 @@ TEST_F(ParityOpsTests, TestUnstack6) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestUnstack7) { @@ -239,7 +239,7 @@ TEST_F(ParityOpsTests, TestUnstack7) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestUnstack8) { @@ -257,7 +257,7 @@ TEST_F(ParityOpsTests, TestUnstack8) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, TestUnstack9) { @@ -275,7 +275,7 @@ TEST_F(ParityOpsTests, TestUnstack9) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -293,7 +293,7 @@ TEST_F(ParityOpsTests, TestUnstack10) { ASSERT_TRUE(exp.isSameShape(result.at(1))); ASSERT_TRUE(exp.isSameShape(result.at(2))); - + } //////////////////////////////////////////////////////////////////////// @@ -310,7 +310,7 @@ TEST_F(ParityOpsTests, TestUnstack11) { ASSERT_TRUE(exp.isSameShape(result.at(0))); ASSERT_TRUE(exp.isSameShape(result.at(1))); - + } //////////////////////////////////////////////////////////////////////// @@ -325,7 +325,7 @@ TEST_F(ParityOpsTests, TestUnstack12) { ASSERT_TRUE(result.size() == 0); - + } TEST_F(ParityOpsTests, TestUnstack13) { @@ -361,7 +361,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest1) { ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.equalsTo(z)); - + } @@ -380,7 +380,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest2) { ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.equalsTo(z)); - + } @@ -399,7 +399,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest3) { ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.equalsTo(z)); - + } TEST_F(ParityOpsTests, ExpandDimsTest4) { @@ -417,7 +417,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest4) { ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.equalsTo(z)); - + } @@ -434,7 +434,7 @@ TEST_F(ParityOpsTests, Test_Shape_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -452,7 +452,7 @@ TEST_F(ParityOpsTests, Test_Equals_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -470,7 +470,7 @@ TEST_F(ParityOpsTests, Test_NotEquals_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Less_1) { @@ -487,7 +487,7 @@ TEST_F(ParityOpsTests, Test_Less_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_LessEquals_1) { @@ -504,7 +504,7 @@ TEST_F(ParityOpsTests, Test_LessEquals_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_GreaterEquals_1) { @@ -521,7 +521,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_GreaterEquals_2) { @@ -538,7 +538,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Greater_1) { @@ -555,7 +555,7 @@ TEST_F(ParityOpsTests, Test_Greater_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Where_1) { @@ -575,7 +575,7 @@ TEST_F(ParityOpsTests, Test_Where_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Where_2) { @@ -593,7 +593,7 @@ TEST_F(ParityOpsTests, Test_Where_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -612,7 +612,7 @@ TEST_F(ParityOpsTests, Test_Where_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Select_1) { @@ -630,7 +630,7 @@ TEST_F(ParityOpsTests, Test_Select_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Select_2) { @@ -648,7 +648,7 @@ TEST_F(ParityOpsTests, Test_Select_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Select_3) { @@ -666,25 +666,7 @@ TEST_F(ParityOpsTests, Test_Select_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - -} -TEST_F(ParityOpsTests, Test_Reshape_TF_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto shape = NDArrayFactory::create('c', {1, 3}, {1, 2, 2}); - - auto exp = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); - - sd::ops::reshape op; - - auto result = op.evaluate({&x, &shape}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - } TEST_F(ParityOpsTests, Test_Bias_Add_1) { @@ -702,7 +684,7 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) { for (int e = 0; e < tads.size(); e++) { ASSERT_TRUE(bias.equalsTo(tads.at(e))); } - + } TEST_F(ParityOpsTests, Test_Scatter_Add_1) { @@ -718,7 +700,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_2) { @@ -735,7 +717,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_2) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_3) { @@ -751,7 +733,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_3) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_4) { @@ -767,7 +749,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_5) { @@ -784,7 +766,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_5) { // z->printBuffer(); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_6) { @@ -800,7 +782,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, Test_Scatter_Add_7) { @@ -816,7 +798,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_7) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -864,7 +846,7 @@ TEST_F(ParityOpsTests, scatterMax_test1) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMax_test2) { @@ -880,7 +862,7 @@ TEST_F(ParityOpsTests, scatterMax_test2) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMax_test3) { @@ -897,7 +879,7 @@ TEST_F(ParityOpsTests, scatterMax_test3) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMax_test4) { @@ -913,7 +895,7 @@ TEST_F(ParityOpsTests, scatterMax_test4) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMax_test5) { @@ -929,7 +911,7 @@ TEST_F(ParityOpsTests, scatterMax_test5) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMax_test6) { @@ -945,7 +927,7 @@ TEST_F(ParityOpsTests, scatterMax_test6) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -963,7 +945,7 @@ TEST_F(ParityOpsTests, scatterMin_test1) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMin_test2) { @@ -979,7 +961,7 @@ TEST_F(ParityOpsTests, scatterMin_test2) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMin_test3) { @@ -995,7 +977,7 @@ TEST_F(ParityOpsTests, scatterMin_test3) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ParityOpsTests, scatterMin_test4) { @@ -1012,7 +994,7 @@ TEST_F(ParityOpsTests, scatterMin_test4) { // z->printBuffer(); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1044,7 +1026,7 @@ TEST_F(ParityOpsTests, scatterND_test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1064,7 +1046,7 @@ TEST_F(ParityOpsTests, scatterND_test2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1088,7 +1070,7 @@ TEST_F(ParityOpsTests, scatterND_test3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1107,7 +1089,7 @@ TEST_F(ParityOpsTests, scatterND_test4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1127,7 +1109,7 @@ TEST_F(ParityOpsTests, scatterND_test5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1154,7 +1136,7 @@ TEST_F(ParityOpsTests, scatterND_test6) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1181,7 +1163,7 @@ TEST_F(ParityOpsTests, scatterND_test7) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1202,7 +1184,7 @@ TEST_F(ParityOpsTests, scatterND_test8) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1236,7 +1218,7 @@ TEST_F(ParityOpsTests, scatterND_add_test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1260,7 +1242,7 @@ TEST_F(ParityOpsTests, scatterND_add_test2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1283,7 +1265,7 @@ TEST_F(ParityOpsTests, scatterND_add_test3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1310,7 +1292,7 @@ TEST_F(ParityOpsTests, scatterND_add_test4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1346,7 +1328,7 @@ TEST_F(ParityOpsTests, scatterND_add_test5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1379,7 +1361,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1404,7 +1386,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1427,7 +1409,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1454,7 +1436,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1490,7 +1472,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -1511,7 +1493,7 @@ TEST_F(ParityOpsTests, scatterND_update_test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1535,7 +1517,7 @@ TEST_F(ParityOpsTests, scatterND_update_test2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1559,7 +1541,7 @@ TEST_F(ParityOpsTests, scatterND_update_test3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1586,7 +1568,7 @@ TEST_F(ParityOpsTests, scatterND_update_test4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1622,7 +1604,7 @@ TEST_F(ParityOpsTests, scatterND_update_test5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////// @@ -1655,7 +1637,7 @@ TEST_F(ParityOpsTests, scatter_update_1) { ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.equalsTo(x)); - + } ////////////////////////////////////////////////////////////////////// @@ -1674,7 +1656,7 @@ TEST_F(ParityOpsTests, scatter_update_2) { ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.equalsTo(x)); - + } ////////////////////////////////////////////////////////////////////// @@ -1693,7 +1675,7 @@ TEST_F(ParityOpsTests, scatter_update_3) { ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.equalsTo(x)); - + } ////////////////////////////////////////////////////////////////////// @@ -1712,5 +1694,5 @@ TEST_F(ParityOpsTests, scatter_update_4) { ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.equalsTo(x)); - + } diff --git a/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp b/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp index 898af1722..937ca4675 100644 --- a/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp @@ -103,7 +103,7 @@ TEST_F(ScalarTests, Test_Concat_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -124,7 +124,7 @@ TEST_F(ScalarTests, Test_Concat_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -146,7 +146,7 @@ TEST_F(ScalarTests, Test_Concat_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ScalarTests, Test_ExpandDims_1) { @@ -163,7 +163,7 @@ TEST_F(ScalarTests, Test_ExpandDims_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ScalarTests, Test_Squeeze_1) { @@ -179,27 +179,9 @@ TEST_F(ScalarTests, Test_Squeeze_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } - -TEST_F(ScalarTests, Test_Reshape_1) { - auto x = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {2.0f}); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - -} - - TEST_F(ScalarTests, Test_Permute_1) { auto x = NDArrayFactory::create(3.0f); auto exp = NDArrayFactory::create(3.0f); @@ -213,7 +195,7 @@ TEST_F(ScalarTests, Test_Permute_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(ScalarTests, Test_Concat_Scalar_1) { diff --git a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp index 636206957..cc13f3529 100644 --- a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp @@ -77,7 +77,7 @@ TEST_F(SingleDimTests, Test_Concat_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(SingleDimTests, Test_Reduce_1) { @@ -111,7 +111,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -129,7 +129,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -149,7 +149,7 @@ TEST_F(SingleDimTests, Test_Squeeze_1) { ASSERT_EQ(exp.rankOf(), z->rankOf()); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(SingleDimTests, Test_Squeeze_2) { @@ -165,42 +165,9 @@ TEST_F(SingleDimTests, Test_Squeeze_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } -TEST_F(SingleDimTests, Test_Reshape_1) { - auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - -} - -TEST_F(SingleDimTests, Test_Reshape_2) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 1, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - -} - - TEST_F(SingleDimTests, Test_Permute_1) { auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); @@ -214,5 +181,5 @@ TEST_F(SingleDimTests, Test_Permute_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } \ No newline at end of file