Shyrma reshape empty (#338)

* - start working on reshape op which operates with empty shapes

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct reshaping for empty arrays

Signed-off-by: Yurii <iuriish@yahoo.com>

* - remove unnecessary check in Loopkind

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Yurii Shyrma 2020-03-31 07:41:16 +03:00 committed by GitHub
parent bf0ddbc06c
commit 29e61579c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 730 additions and 854 deletions

View File

@ -1944,7 +1944,7 @@ void NDArray::tilei(const std::vector<Nd4jLong>& reps) {
Nd4jLong NDArray::sizeAt(const int dim) const { Nd4jLong NDArray::sizeAt(const int dim) const {
if (dim >= this->rankOf() || dim < -this->rankOf()) if (dim >= this->rankOf() || dim < -this->rankOf())
throw std::runtime_error("Bad size index requested"); throw std::runtime_error("NDArray::sizeAt: bad size index requested");
if (dim >= 0) if (dim >= 0)
return shape::shapeOf(_shapeInfo)[dim]; return shape::shapeOf(_shapeInfo)[dim];

View File

@ -35,16 +35,16 @@ namespace sd {
class ND4J_EXPORT LoopKind { class ND4J_EXPORT LoopKind {
public: public:
enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D }; enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D };
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo); static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
}; };
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -59,8 +59,8 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd
int temp; int temp;
const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c';
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';
const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo); const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo);
if (xEws == 1 && zEws == 1 && xOrder == zOrder && (shapesSame || xOrder == 'c')) if (xEws == 1 && zEws == 1 && xOrder == zOrder && (shapesSame || xOrder == 'c'))
return EWS1; return EWS1;
@ -160,7 +160,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const N
const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c';
const bool yVectorOrC = shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c'; const bool yVectorOrC = shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c';
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';
const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo);
if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c')) if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c'))
return EWS1; return EWS1;
@ -206,7 +206,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const
const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c'; const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c';
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';; const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';;
if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && shape::rank(xShapeInfo) == 2 && xEws == 1 && xOrder == 'c' && xRank == 2 && if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 &&
tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
return SMALLARR2DX; return SMALLARR2DX;
if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
@ -233,18 +233,18 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo) { LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo) {
// both tad shapes are the same, but strides and ews may be different // both tad shapes are the same, but strides and ews may be different
const int tadRank = shape::rank(xTadShapeInfo); const int tadRank = shape::rank(xTadShapeInfo);
const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo); const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo);
const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo); const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo);
const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo);
const char xTadOrder = shape::order(xTadShapeInfo); const char xTadOrder = shape::order(xTadShapeInfo);
const char yTadOrder = shape::order(xTadShapeInfo); const char yTadOrder = shape::order(xTadShapeInfo);
const char zOrder = shape::order(zShapeInfo); const char zOrder = shape::order(zShapeInfo);
int position; int position;
const bool xTadVectorOrC = shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c'; const bool xTadVectorOrC = shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c';
const bool yTadVectorOrC = shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c'; const bool yTadVectorOrC = shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c';
@ -265,7 +265,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, c
return RANK4; return RANK4;
if(tadRank == 5 && zEws > 0 && zVectorOrC) if(tadRank == 5 && zEws > 0 && zVectorOrC)
return RANK5; return RANK5;
return COMMON; return COMMON;
} }

View File

@ -35,111 +35,19 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
//Special case: empty.reshape(<other empty shape>) -> return empty //Special case: empty.reshape(<other empty shape>) -> return empty
if (x->isEmpty()) { if (x->isEmpty()) {
REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
return Status::OK(); //No op return Status::OK(); //No op
}
if (block.width() == 1) {
auto arguments = block.getIArguments();
int argsSize = arguments->size();
int e = 1;
char order = (char) -(*arguments)[0];
if (order != 'c' && order != 'f') {
order = 'c'; //x->ordering();
e = 0;
}
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
std::vector<Nd4jLong> shapeNew;
int e2 = e;
for (; e < (int) arguments->size(); e++) {
if (arguments->at(e) == -1){
Nd4jLong shapeLength = 1;
for(; e2 < e; e2++){
shapeLength *= arguments->at(e2);
}
for(e2 = e + 1; e2 < arguments->size(); e2++){
shapeLength *= arguments->at(e2);
}
Nd4jLong realShape = x->lengthOf() / shapeLength;
shapeNew.push_back(realShape);
}
else{
shapeNew.push_back(arguments->at(e));
}
}
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
if (Environment::getInstance()->isDebugAndVerbose()) {
nd4j_printv("Reshape: new shape", shapeNew);
}
auto xr = x->reshape(order, shapeNew);
z->assign(xr);
STORE_RESULT(*z);
return Status::OK();
} else if (block.width() == 2) {
auto s = INPUT_VARIABLE(1);
char order = 'c';
if (block.numI() > 0)
order = (char) -INT_ARG(0);
std::vector<Nd4jLong> shapeNew(s->lengthOf());
for (int e = 0; e < (int) s->lengthOf(); e++) {
auto dim = s->e<Nd4jLong >(e);
if (dim == -1){
Nd4jLong shapeLength = 1;
for(int e2 = 0; e2 < e; e2++){
shapeLength *= s->e<Nd4jLong>(e2);
}
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= s->e<Nd4jLong>(e2);
}
Nd4jLong realShape = x->lengthOf() / shapeLength;
shapeNew[e] = realShape;
}
else{
shapeNew[e] = dim;
}
}
if (Environment::getInstance()->isDebugAndVerbose()) {
nd4j_printv("Reshape: new shape", shapeNew);
}
if (s->isScalar()) {
// just a scalar
z->assign(x);
} else {
// in some cases we might go away with simple memcpy call instead of assign call
if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) {
z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset());
} else {
auto xr = x->reshape(order, shapeNew);
z->assign(xr);
}
}
return Status::OK();
} }
return ND4J_STATUS_BAD_INPUT; REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf());
if (Environment::getInstance()->isDebugAndVerbose())
nd4j_printv("Reshape: new shape", z->getShapeAsVector());
z->assign(x->reshape(z->ordering(), z->getShapeAsVector()));
return Status::OK();
} }
@ -151,117 +59,73 @@ DECLARE_TYPES(reshape) {
} }
DECLARE_SHAPE_FN(reshape) { DECLARE_SHAPE_FN(reshape) {
auto inp = inputShape->at(0);
// we can launch op using Int arguments const auto x = INPUT_VARIABLE(0);
if (inputShape->size() == 1) {
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
std::vector<int> *arguments = block.getIArguments();
int e = 1; std::vector<int> reshapeArgs;
char order = (char) -(*arguments)[0]; std::vector<Nd4jLong> shapeNew;
if (order != 'c' && order != 'f') { char orderNew = 'c';
order = shape::order(inp);
e = 0; if (block.width() == 1) {
reshapeArgs = *block.getIArguments();
if(!reshapeArgs.empty()) {
orderNew = (char) -reshapeArgs[0];
if(orderNew == 'c' || orderNew == 'f')
reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case
} }
std::vector<Nd4jLong> shapeNew;
int e2 = e;
for (; e < (int) arguments->size(); e++) {
if ((int) arguments->at(e) == -1){
Nd4jLong shapeLength = 1;
for(; e2 < e; e2 ++){
shapeLength *= arguments->at(e2);
}
for(e2 = e + 1; e2 < arguments->size(); e2++){
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= arguments->at(e2);
}
if(shapeLength == 0){
//Edge case for empty:
shapeNew.push_back(0);
} else {
//Standard case
Nd4jLong realShape = shape::length(inp) / shapeLength;
shapeNew.push_back(realShape);
}
}
else{
shapeNew.push_back(arguments->at(e));
}
}
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
} else {
// or, with second input "as shape"
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
// special case here
if (y->isEmpty()) {
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
}
//Special case: empty.reshape(-1) -> return empty
if (x->isEmpty()) {
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
Nd4jLong prod = 1;
bool hasNegs = false;
for (auto v:shapeOf) {
if (v < 0) {
hasNegs = true;
v = 0;
}
prod *= v;
}
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
// if there are -1s - we turn them into zeros
if (hasNegs) {
for (int e = 0; e < shapeOf.size(); e++)
if (shapeOf[e] < 0)
shapeOf[e] = 0;
}
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
return SHAPELIST(CONSTANT(newShape));
}
std::vector<Nd4jLong> shapeNew(y->lengthOf());
for (int e = 0; e < (int) y->lengthOf(); e++) {
auto dim = y->e<Nd4jLong>(e);
if (dim == -1){
Nd4jLong shapeLength = 1;
for(int e2 = 0; e2 < e; e2++){
shapeLength *= y->e<Nd4jLong>(e2);
}
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= y->e<Nd4jLong>(e2);
}
if(shapeLength == 0){
//Edge case for empty:
shapeNew[e] = 0;
} else {
Nd4jLong realShape = shape::length(inp) / shapeLength;
shapeNew[e] = realShape;
}
}else {
shapeNew[e] = dim;
}
}
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
} }
else {
reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>();
orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c';
}
REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
Nd4jLong xLen = x->lengthOf();
if(x->isEmpty()) {
xLen = 1;
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
if(x->sizeAt(i) != 0)
xLen *= x->sizeAt(i);
}
for (uint i = 0; i < reshapeArgs.size(); ++i) {
if (reshapeArgs[i] == -1) {
uint shapeLength = 1, numOfZeros = 0;
for(uint j = 0; j < i; ++j)
if(reshapeArgs[j] != 0)
shapeLength *= reshapeArgs[j];
else
++numOfZeros;
for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
if(reshapeArgs[j] != 0)
shapeLength *= reshapeArgs[j];
else
++numOfZeros;
}
const auto dim = xLen / shapeLength;
if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
shapeNew.push_back(0);
else
shapeNew.push_back(dim);
}
else
shapeNew.push_back(reshapeArgs[i]);
}
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew));
} }
} }
} }

View File

@ -40,16 +40,15 @@ TEST_F(ArrayOptionsTests, TestShape_Basic_0) {
TEST_F(ArrayOptionsTests, TestShape_Basic_1) { TEST_F(ArrayOptionsTests, TestShape_Basic_1) {
shape[5] = 2; shape[5] = 2;
ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); ASSERT_TRUE(ArrayOptions::isNewFormat(shape));
ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
} }
TEST_F(ArrayOptionsTests, TestShape_Basic_2) { TEST_F(ArrayOptionsTests, TestShape_Basic_2) {
shape[5] = 258; shape[5] = 258;
ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); ASSERT_TRUE(ArrayOptions::isNewFormat(shape));
ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
@ -58,7 +57,7 @@ TEST_F(ArrayOptionsTests, TestShape_Basic_2) {
TEST_F(ArrayOptionsTests, TestShape_Basic_3) { TEST_F(ArrayOptionsTests, TestShape_Basic_3) {
ASSERT_EQ(0, shape::extra(shape)); ASSERT_EQ(0, shape::extra(shape));
ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape));
} }

View File

@ -166,7 +166,7 @@ TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(z->equalsTo(exp)); ASSERT_TRUE(z->equalsTo(exp));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -180,7 +180,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(z->equalsTo(exp)); ASSERT_TRUE(z->equalsTo(exp));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -198,7 +198,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) {
ASSERT_TRUE(z1->equalsTo(exp1)); ASSERT_TRUE(z1->equalsTo(exp1));
ASSERT_TRUE(z2->equalsTo(exp2)); ASSERT_TRUE(z2->equalsTo(exp2));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -213,7 +213,7 @@ TEST_F(DeclarableOpsTests1, AXpY_Test_1) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(z->equalsTo(exp)); ASSERT_TRUE(z->equalsTo(exp));
} }
TEST_F(DeclarableOpsTests1, BasicInitialization3) { TEST_F(DeclarableOpsTests1, BasicInitialization3) {
@ -258,7 +258,7 @@ TEST_F(DeclarableOpsTests1, TestTensorMmul1) {
ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.isSameShape(out));
ASSERT_TRUE(exp.equalsTo(out)); ASSERT_TRUE(exp.equalsTo(out));
} }
TEST_F(DeclarableOpsTests1, TestTensorDot2) { TEST_F(DeclarableOpsTests1, TestTensorDot2) {
@ -278,7 +278,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot2) {
ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.isSameShape(out));
ASSERT_TRUE(exp.equalsTo(out)); ASSERT_TRUE(exp.equalsTo(out));
} }
TEST_F(DeclarableOpsTests1, TestTensorDot3) { TEST_F(DeclarableOpsTests1, TestTensorDot3) {
@ -298,7 +298,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot3) {
ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.isSameShape(out));
ASSERT_TRUE(exp.equalsTo(out)); ASSERT_TRUE(exp.equalsTo(out));
} }
TEST_F(DeclarableOpsTests1, TestTensorDot4) { TEST_F(DeclarableOpsTests1, TestTensorDot4) {
@ -318,7 +318,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) {
ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.isSameShape(out));
ASSERT_TRUE(exp.equalsTo(out)); ASSERT_TRUE(exp.equalsTo(out));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
@ -338,7 +338,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot5) {
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -360,7 +360,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot6) {
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -381,7 +381,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot7) {
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -402,7 +402,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot8) {
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -431,7 +431,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot9) {
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -452,7 +452,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot10) {
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -474,7 +474,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot11) {
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -495,7 +495,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot12) {
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -516,7 +516,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot13) {
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -537,7 +537,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot14) {
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -558,7 +558,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot15) {
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -579,7 +579,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot16) {
ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result)); ASSERT_TRUE(exp.equalsTo(result));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
@ -786,7 +786,7 @@ TEST_F(DeclarableOpsTests1, SubtractTest_2) {
ASSERT_TRUE(res.at(0)->equalsTo(&exp)); ASSERT_TRUE(res.at(0)->equalsTo(&exp));
} }
TEST_F(DeclarableOpsTests1, TestRng1) { TEST_F(DeclarableOpsTests1, TestRng1) {
@ -1046,7 +1046,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) {
ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
ASSERT_TRUE(res.at(0)->equalsTo(&exp)); ASSERT_TRUE(res.at(0)->equalsTo(&exp));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1071,7 +1071,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) {
ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
ASSERT_TRUE(res.at(0)->equalsTo(&exp)); ASSERT_TRUE(res.at(0)->equalsTo(&exp));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1093,7 +1093,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) {
ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
ASSERT_TRUE(res.at(0)->equalsTo(&exp)); ASSERT_TRUE(res.at(0)->equalsTo(&exp));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1121,7 +1121,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_1) {
ASSERT_TRUE(res.at(0)->equalsTo(&exp)); ASSERT_TRUE(res.at(0)->equalsTo(&exp));
ASSERT_TRUE(exp.equalsTo(&z)); ASSERT_TRUE(exp.equalsTo(&z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1147,7 +1147,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) {
ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
ASSERT_TRUE(res.at(0)->equalsTo(&exp)); ASSERT_TRUE(res.at(0)->equalsTo(&exp));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1402,7 +1402,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) {
ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_EQ(res.status(), ND4J_STATUS_OK);
ASSERT_TRUE(res.at(0)->equalsTo(exp)); ASSERT_TRUE(res.at(0)->equalsTo(exp));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1421,7 +1421,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) {
ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_EQ(res.status(), ND4J_STATUS_OK);
ASSERT_TRUE(res.at(0)->equalsTo(exp)); ASSERT_TRUE(res.at(0)->equalsTo(exp));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1437,7 +1437,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) {
ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_EQ(res.status(), ND4J_STATUS_OK);
ASSERT_TRUE(res.at(0)->equalsTo(exp)); ASSERT_TRUE(res.at(0)->equalsTo(exp));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1463,7 +1463,7 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) {
ASSERT_TRUE(z.equalsTo(&exp)); ASSERT_TRUE(z.equalsTo(&exp));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1676,31 +1676,6 @@ TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) {
delete block; delete block;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Reshapeas1) {
const std::vector<Nd4jLong> xShape = { 5,4,3 };
const std::vector<Nd4jLong> yShape = { 3,5,4 };
auto x = NDArrayFactory::create_<float>('f', xShape);
auto y = NDArrayFactory::create_<float>('f', yShape);
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(-2, y);
auto block = new Context(1, variableSpace, true);
block->fillInputs({ -1, -2 });
sd::ops::reshapeas reshape;
reshape.execute(block);
ASSERT_TRUE(x->isSameShape(y));
delete variableSpace;
delete block;
}
TEST_F(DeclarableOpsTests1, Test_Cast_1) { TEST_F(DeclarableOpsTests1, Test_Cast_1) {
// TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected // TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected
auto x = NDArrayFactory::create<float>('c', { 5, 5 }); auto x = NDArrayFactory::create<float>('c', { 5, 5 });
@ -1715,7 +1690,7 @@ TEST_F(DeclarableOpsTests1, Test_Cast_1) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(yExp.equalsTo(z)); ASSERT_TRUE(yExp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1848,113 +1823,6 @@ TEST_F(DeclarableOpsTests1, TestGemv1) {
#endif #endif
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Reshape2) {
const std::vector<Nd4jLong> xShape = { 5,4,3 };
const std::vector<Nd4jLong> yShape = { 3,5,4 };
auto x = NDArrayFactory::create_<float>('c', xShape);
auto y = NDArrayFactory::create_<float>('c', yShape);
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(1, new Variable());
auto block = new Context(1, variableSpace, false);
block->fillInputs({ -1 });
std::vector<int>* arguments = block->getIArguments();
arguments->push_back(-y->ordering());
arguments->push_back(3);
arguments->push_back(5);
arguments->push_back(4);
sd::ops::reshape reshape;
Nd4jStatus status = reshape.execute(block);
ASSERT_EQ(ND4J_STATUS_OK, status);
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
ASSERT_TRUE(result->isSameShape(y));
delete y;
delete block;
delete variableSpace;
}
TEST_F(DeclarableOpsTests1, Reshape3) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(x.isSameShape(z));
}
TEST_F(DeclarableOpsTests1, Reshape4) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 3, 4, 5 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(x.isSameShape(z));
}
TEST_F(DeclarableOpsTests1, Reshape5) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 5, 4, 3 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
}
TEST_F(DeclarableOpsTests1, Reshape6) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
auto exp = NDArrayFactory::create<float>('c', { 4, 15 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 4, -1 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(z->isSameShape(exp));
}
TEST_F(DeclarableOpsTests1, Reshape7) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
auto exp = NDArrayFactory::create<float>('c', { 60 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { -1 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(z->isSameShape(exp));
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Transpose1) { TEST_F(DeclarableOpsTests1, Transpose1) {
@ -1983,7 +1851,6 @@ TEST_F(DeclarableOpsTests1, Transpose1) {
delete variableSpace; delete variableSpace;
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// not-in-place // not-in-place
TEST_F(DeclarableOpsTests1, Permute1) { TEST_F(DeclarableOpsTests1, Permute1) {
@ -2259,7 +2126,7 @@ TEST_F(DeclarableOpsTests1, IsMax1) {
//res->printIndexedBuffer("IS_MAX"); //res->printIndexedBuffer("IS_MAX");
ASSERT_TRUE(exp.equalsTo(res)); ASSERT_TRUE(exp.equalsTo(res));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2281,7 +2148,7 @@ TEST_F(DeclarableOpsTests1, IsMax2) {
//res->printIndexedBuffer("IS_MAX"); //res->printIndexedBuffer("IS_MAX");
ASSERT_TRUE(exp.equalsTo(res)); ASSERT_TRUE(exp.equalsTo(res));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2303,7 +2170,7 @@ TEST_F(DeclarableOpsTests1, IsMax3) {
//res->printIndexedBuffer("IS_MAX"); //res->printIndexedBuffer("IS_MAX");
ASSERT_TRUE(exp.equalsTo(res)); ASSERT_TRUE(exp.equalsTo(res));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2352,7 +2219,7 @@ TEST_F(DeclarableOpsTests1, IsMax4) {
// ASSERT_TRUE(expState.equalsTo(state)); // ASSERT_TRUE(expState.equalsTo(state));
// ASSERT_TRUE(expOut.equalsTo(output)); // ASSERT_TRUE(expOut.equalsTo(output));
// //
// } // }
////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////
@ -2386,7 +2253,7 @@ TEST_F(DeclarableOpsTests1, sru_test1) {
ASSERT_TRUE(expState.equalsTo(state)); ASSERT_TRUE(expState.equalsTo(state));
ASSERT_TRUE(expOut.equalsTo(output)); ASSERT_TRUE(expOut.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2438,7 +2305,7 @@ TEST_F(DeclarableOpsTests1, sru_bp) {
ASSERT_TRUE(expGradB.equalsTo(gradB)); ASSERT_TRUE(expGradB.equalsTo(gradB));
ASSERT_TRUE(expGradInit.equalsTo(gradInit)); ASSERT_TRUE(expGradInit.equalsTo(gradInit));
} }
////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////
@ -2474,7 +2341,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_1) {
ASSERT_TRUE(expState.equalsTo(state)); ASSERT_TRUE(expState.equalsTo(state));
ASSERT_TRUE(expOut.equalsTo(output)); ASSERT_TRUE(expOut.equalsTo(output));
} }
TEST_F(DeclarableOpsTests1, sru_bi_bp_1) { TEST_F(DeclarableOpsTests1, sru_bi_bp_1) {
@ -2527,7 +2394,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) {
ASSERT_TRUE(expGradB.equalsTo(gradB)); ASSERT_TRUE(expGradB.equalsTo(gradB));
ASSERT_TRUE(expGradInit.equalsTo(gradInit)); ASSERT_TRUE(expGradInit.equalsTo(gradInit));
} }
TEST_F(DeclarableOpsTests1, ArgMax1) { TEST_F(DeclarableOpsTests1, ArgMax1) {
@ -2547,7 +2414,7 @@ TEST_F(DeclarableOpsTests1, ArgMax1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -2568,7 +2435,7 @@ TEST_F(DeclarableOpsTests1, ArgMax2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -2590,7 +2457,7 @@ TEST_F(DeclarableOpsTests1, ArgMax3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests1, ArgMax4) { TEST_F(DeclarableOpsTests1, ArgMax4) {
@ -2611,7 +2478,7 @@ TEST_F(DeclarableOpsTests1, ArgMax4) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -2633,7 +2500,7 @@ TEST_F(DeclarableOpsTests1, ArgMax5) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests1, ArgMax6) { TEST_F(DeclarableOpsTests1, ArgMax6) {
@ -2676,7 +2543,7 @@ TEST_F(DeclarableOpsTests1, ArgMin1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -2697,7 +2564,7 @@ TEST_F(DeclarableOpsTests1, SquareTests1) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests1, OneHotTests_1) { TEST_F(DeclarableOpsTests1, OneHotTests_1) {
@ -2717,7 +2584,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests1, OneHotTests_2) { TEST_F(DeclarableOpsTests1, OneHotTests_2) {
@ -2736,7 +2603,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_2) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests1, OneHotTests_3) { TEST_F(DeclarableOpsTests1, OneHotTests_3) {
@ -2756,7 +2623,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests1, OneHotTests_4) { TEST_F(DeclarableOpsTests1, OneHotTests_4) {
@ -2775,7 +2642,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_4) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests1, OneHotTests_5) { TEST_F(DeclarableOpsTests1, OneHotTests_5) {
@ -2796,7 +2663,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_5) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests1, OneHotTests_6) { TEST_F(DeclarableOpsTests1, OneHotTests_6) {
@ -2809,7 +2676,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_6) {
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
} }
TEST_F(DeclarableOpsTests1, OneHotTests_7) { TEST_F(DeclarableOpsTests1, OneHotTests_7) {
@ -2822,7 +2689,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_7) {
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
} }
TEST_F(DeclarableOpsTests1, FillAs_1) { TEST_F(DeclarableOpsTests1, FillAs_1) {
@ -2840,7 +2707,7 @@ TEST_F(DeclarableOpsTests1, FillAs_1) {
ASSERT_NEAR(scalar, result.at(0)->meanNumber().e<float>(0), 1e-5f); ASSERT_NEAR(scalar, result.at(0)->meanNumber().e<float>(0), 1e-5f);
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2866,7 +2733,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) {
ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.isSameShape(array));
ASSERT_TRUE(exp.equalsTo(array)); ASSERT_TRUE(exp.equalsTo(array));
} }
@ -2893,7 +2760,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_2) {
ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.isSameShape(array));
ASSERT_TRUE(exp.equalsTo(array)); ASSERT_TRUE(exp.equalsTo(array));
} }
@ -2913,7 +2780,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) {
ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.isSameShape(array));
ASSERT_TRUE(exp.equalsTo(array)); ASSERT_TRUE(exp.equalsTo(array));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2931,7 +2798,7 @@ TEST_F(DeclarableOpsTests1, softmax_test1) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2947,7 +2814,7 @@ TEST_F(DeclarableOpsTests1, softmax_test2) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2963,7 +2830,7 @@ TEST_F(DeclarableOpsTests1, softmax_test3) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2979,7 +2846,7 @@ TEST_F(DeclarableOpsTests1, softmax_test4) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2995,7 +2862,7 @@ TEST_F(DeclarableOpsTests1, softmax_test5) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3011,7 +2878,7 @@ TEST_F(DeclarableOpsTests1, softmax_test6) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3027,7 +2894,7 @@ TEST_F(DeclarableOpsTests1, softmax_test7) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3043,7 +2910,7 @@ TEST_F(DeclarableOpsTests1, softmax_test8) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3059,7 +2926,7 @@ TEST_F(DeclarableOpsTests1, softmax_test9) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, softmax_test10) { TEST_F(DeclarableOpsTests1, softmax_test10) {
@ -3074,7 +2941,7 @@ TEST_F(DeclarableOpsTests1, softmax_test10) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, softmax_test11) { TEST_F(DeclarableOpsTests1, softmax_test11) {
@ -3089,7 +2956,7 @@ TEST_F(DeclarableOpsTests1, softmax_test11) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3108,7 +2975,7 @@ TEST_F(DeclarableOpsTests1, softmax_test12) {
ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z)); ASSERT_TRUE(expOutput.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Reverse_1) { TEST_F(DeclarableOpsTests1, Reverse_1) {
@ -3132,7 +2999,7 @@ TEST_F(DeclarableOpsTests1, Reverse_1) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3157,7 +3024,7 @@ TEST_F(DeclarableOpsTests1, Reverse_2) {
ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.isSameShapeStrict(input));
ASSERT_TRUE(expected.equalsTo(&input)); ASSERT_TRUE(expected.equalsTo(&input));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3183,7 +3050,7 @@ TEST_F(DeclarableOpsTests1, Reverse_3) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3209,7 +3076,7 @@ TEST_F(DeclarableOpsTests1, Reverse_4) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3234,7 +3101,7 @@ TEST_F(DeclarableOpsTests1, Reverse_5) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
@ -3260,7 +3127,7 @@ TEST_F(DeclarableOpsTests1, Reverse_6) {
ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.isSameShapeStrict(input));
ASSERT_TRUE(expected.equalsTo(&input)); ASSERT_TRUE(expected.equalsTo(&input));
} }
@ -3288,7 +3155,7 @@ TEST_F(DeclarableOpsTests1, Reverse_7) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
@ -3316,7 +3183,7 @@ TEST_F(DeclarableOpsTests1, Reverse_8) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
@ -3341,7 +3208,7 @@ TEST_F(DeclarableOpsTests1, Reverse_9) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
TEST_F(DeclarableOpsTests1, Reverse_10) { TEST_F(DeclarableOpsTests1, Reverse_10) {
@ -3357,7 +3224,7 @@ TEST_F(DeclarableOpsTests1, Reverse_10) {
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(z)); ASSERT_TRUE(e.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3380,7 +3247,7 @@ TEST_F(DeclarableOpsTests1, Reverse_11) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3402,7 +3269,7 @@ TEST_F(DeclarableOpsTests1, Reverse_12) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3423,7 +3290,7 @@ TEST_F(DeclarableOpsTests1, Reverse_13) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -3444,7 +3311,7 @@ TEST_F(DeclarableOpsTests1, Reverse_14) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
TEST_F(DeclarableOpsTests1, Test_Expose_1) { TEST_F(DeclarableOpsTests1, Test_Expose_1) {
@ -3463,7 +3330,7 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) {
ASSERT_TRUE(input0.equalsTo(z0)); ASSERT_TRUE(input0.equalsTo(z0));
ASSERT_TRUE(input1.equalsTo(z1)); ASSERT_TRUE(input1.equalsTo(z1));
} }
TEST_F(DeclarableOpsTests1, Test_Expose_2) { TEST_F(DeclarableOpsTests1, Test_Expose_2) {

View File

@ -51,23 +51,7 @@ TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) {
ASSERT_EQ(exp, *z); ASSERT_EQ(exp, *z);
}
TEST_F(DeclarableOpsTests14, Test_Reshape_CF_1) {
auto x = NDArrayFactory::create<double>('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0});
auto e = NDArrayFactory::create<double>('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0});
auto r = x.reshape('c', {3, 2});;
r.streamline('f');
sd::ops::reshape op;
auto result = op.evaluate({&x}, {3, 2});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
} }
TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) { TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) {
@ -108,7 +92,7 @@ TEST_F(DeclarableOpsTests14, Multiply_test) {
ASSERT_EQ(e, r); ASSERT_EQ(e, r);
ASSERT_EQ(e, *f); ASSERT_EQ(e, *f);
} }
} }
@ -124,7 +108,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
} }
TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) { TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) {
@ -139,7 +123,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
} }
TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_0) { TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_0) {
@ -193,7 +177,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_1) {
ASSERT_EQ(e, *result.at(0)); ASSERT_EQ(e, *result.at(0));
} }
TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) { TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) {
@ -210,7 +194,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) {
ASSERT_EQ(e, *result.at(0)); ASSERT_EQ(e, *result.at(0));
} }
TEST_F(DeclarableOpsTests14, test_empty_fill_1) { TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
@ -224,7 +208,7 @@ TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_EQ(y, *z); ASSERT_EQ(y, *z);
} }
TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) { TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) {
@ -259,7 +243,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) {
auto out = res2.at(0); auto out = res2.at(0);
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>()); ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
} }
TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) { TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) {
@ -271,7 +255,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) {
auto out = res2.at(0); auto out = res2.at(0);
ASSERT_EQ(out->e<float>(0), -DataTypeUtils::infOrMax<float>()); ASSERT_EQ(out->e<float>(0), -DataTypeUtils::infOrMax<float>());
} }
TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) { TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) {
@ -286,7 +270,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) {
ASSERT_EQ(res2.status(), Status::OK()); ASSERT_EQ(res2.status(), Status::OK());
auto out = res2.at(0); auto out = res2.at(0);
ASSERT_EQ(out->e<float>(0), 0.f); ASSERT_EQ(out->e<float>(0), 0.f);
} }
TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) { TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) {
@ -303,7 +287,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) {
// out->printShapeInfo("ReduceMean empty shape with keep dims"); // out->printShapeInfo("ReduceMean empty shape with keep dims");
// out->printIndexedBuffer("ReduceMean scalar"); // out->printIndexedBuffer("ReduceMean scalar");
ASSERT_TRUE(std::isnan(out->e<float>(0))); ASSERT_TRUE(std::isnan(out->e<float>(0)));
} }
TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) { TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) {
@ -324,7 +308,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
} }
TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) { TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) {
@ -345,7 +329,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
} }
TEST_F(DeclarableOpsTests14, test_empty_argmax_1) { TEST_F(DeclarableOpsTests14, test_empty_argmax_1) {
@ -363,7 +347,7 @@ TEST_F(DeclarableOpsTests14, test_empty_argmax_1) {
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
} }
TEST_F(DeclarableOpsTests14, test_empty_argmax_2) { TEST_F(DeclarableOpsTests14, test_empty_argmax_2) {
@ -391,7 +375,7 @@ TEST_F(DeclarableOpsTests14, test_empty_tanh_5) {
ASSERT_TRUE(x.isSameShape(z)); ASSERT_TRUE(x.isSameShape(z));
ASSERT_EQ(x, *z); ASSERT_EQ(x, *z);
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -409,7 +393,7 @@ TEST_F(DeclarableOpsTests14, repeat_1) {
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(z)); ASSERT_TRUE(e.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -427,7 +411,7 @@ TEST_F(DeclarableOpsTests14, repeat_2) {
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(z)); ASSERT_TRUE(e.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -445,7 +429,7 @@ TEST_F(DeclarableOpsTests14, repeat_3) {
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(z)); ASSERT_TRUE(e.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -463,7 +447,7 @@ TEST_F(DeclarableOpsTests14, repeat_4) {
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(z)); ASSERT_TRUE(e.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -481,7 +465,7 @@ TEST_F(DeclarableOpsTests14, repeat_5) {
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(z)); ASSERT_TRUE(e.equalsTo(z));
} }
///////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) { TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) {
@ -502,7 +486,7 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) {
ASSERT_EQ(e, res); ASSERT_EQ(e, res);
} }
///////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) { TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) {
@ -523,7 +507,7 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) {
ASSERT_EQ(e, res); ASSERT_EQ(e, res);
} }
/////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////
@ -639,7 +623,7 @@ TEST_F(DeclarableOpsTests14, matmul_test1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -661,7 +645,7 @@ TEST_F(DeclarableOpsTests14, matmul_test2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -682,7 +666,7 @@ TEST_F(DeclarableOpsTests14, matmul_test3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -704,7 +688,7 @@ TEST_F(DeclarableOpsTests14, matmul_test4) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -726,7 +710,7 @@ TEST_F(DeclarableOpsTests14, matmul_test5) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -747,7 +731,7 @@ TEST_F(DeclarableOpsTests14, matmul_test6) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -770,7 +754,7 @@ TEST_F(DeclarableOpsTests14, matmul_test7) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -795,7 +779,7 @@ TEST_F(DeclarableOpsTests14, matmul_test8) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -820,7 +804,7 @@ TEST_F(DeclarableOpsTests14, matmul_test9) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests14, matmul_test10) { TEST_F(DeclarableOpsTests14, matmul_test10) {
@ -876,7 +860,7 @@ TEST_F(DeclarableOpsTests14, matmul_test11) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests14, matmul_test12) { TEST_F(DeclarableOpsTests14, matmul_test12) {
@ -894,7 +878,7 @@ TEST_F(DeclarableOpsTests14, matmul_test12) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -914,7 +898,7 @@ TEST_F(DeclarableOpsTests14, matmul_test13) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests14, matmul_test14) { TEST_F(DeclarableOpsTests14, matmul_test14) {
@ -933,7 +917,7 @@ TEST_F(DeclarableOpsTests14, matmul_test14) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests14, matmul_test15) { TEST_F(DeclarableOpsTests14, matmul_test15) {
@ -952,7 +936,7 @@ TEST_F(DeclarableOpsTests14, matmul_test15) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests14, matmul_test16) { TEST_F(DeclarableOpsTests14, matmul_test16) {
@ -971,7 +955,7 @@ TEST_F(DeclarableOpsTests14, matmul_test16) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests14, matmul_test17) { TEST_F(DeclarableOpsTests14, matmul_test17) {
@ -985,7 +969,7 @@ TEST_F(DeclarableOpsTests14, matmul_test17) {
ASSERT_EQ(exp, *result.at(0)); ASSERT_EQ(exp, *result.at(0));
} }
@ -1007,7 +991,7 @@ TEST_F(DeclarableOpsTests14, matmul_test18) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1027,7 +1011,7 @@ TEST_F(DeclarableOpsTests14, matmul_test19) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1048,7 +1032,7 @@ TEST_F(DeclarableOpsTests14, matmul_test20) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1069,7 +1053,7 @@ TEST_F(DeclarableOpsTests14, matmul_test21) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1090,7 +1074,7 @@ TEST_F(DeclarableOpsTests14, matmul_test22) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1111,7 +1095,7 @@ TEST_F(DeclarableOpsTests14, matmul_test23) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1135,7 +1119,7 @@ TEST_F(DeclarableOpsTests14, matmul_test24) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1156,7 +1140,7 @@ TEST_F(DeclarableOpsTests14, matmul_test25) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1177,7 +1161,7 @@ TEST_F(DeclarableOpsTests14, matmul_test26) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1198,7 +1182,7 @@ TEST_F(DeclarableOpsTests14, matmul_test27) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -1220,7 +1204,7 @@ TEST_F(DeclarableOpsTests14, matmul_test28) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -1242,7 +1226,7 @@ TEST_F(DeclarableOpsTests14, matmul_test29) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test30) { TEST_F(DeclarableOpsTests14, matmul_test30) {
@ -1262,7 +1246,7 @@ TEST_F(DeclarableOpsTests14, matmul_test30) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test31) { TEST_F(DeclarableOpsTests14, matmul_test31) {
@ -1282,7 +1266,7 @@ TEST_F(DeclarableOpsTests14, matmul_test31) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test32) { TEST_F(DeclarableOpsTests14, matmul_test32) {
@ -1299,7 +1283,7 @@ TEST_F(DeclarableOpsTests14, matmul_test32) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
///////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test33) { TEST_F(DeclarableOpsTests14, matmul_test33) {
@ -1319,7 +1303,7 @@ TEST_F(DeclarableOpsTests14, matmul_test33) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test34) { TEST_F(DeclarableOpsTests14, matmul_test34) {
@ -1336,7 +1320,7 @@ TEST_F(DeclarableOpsTests14, matmul_test34) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test35) { TEST_F(DeclarableOpsTests14, matmul_test35) {
@ -1353,7 +1337,7 @@ TEST_F(DeclarableOpsTests14, matmul_test35) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test36) { TEST_F(DeclarableOpsTests14, matmul_test36) {
@ -1370,7 +1354,7 @@ TEST_F(DeclarableOpsTests14, matmul_test36) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test37) { TEST_F(DeclarableOpsTests14, matmul_test37) {
@ -1617,7 +1601,7 @@ TEST_F(DeclarableOpsTests14, Stack_1) {
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
@ -1645,7 +1629,7 @@ TEST_F(DeclarableOpsTests14, Stack_2) {
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
@ -1673,7 +1657,7 @@ TEST_F(DeclarableOpsTests14, Stack_3) {
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1700,7 +1684,7 @@ TEST_F(DeclarableOpsTests14, Stack_4) {
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1727,7 +1711,7 @@ TEST_F(DeclarableOpsTests14, Stack_5) {
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1754,7 +1738,7 @@ TEST_F(DeclarableOpsTests14, Stack_6) {
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
@ -1778,7 +1762,7 @@ TEST_F(DeclarableOpsTests14, Stack_7) {
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1801,7 +1785,7 @@ TEST_F(DeclarableOpsTests14, Stack_8) {
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1824,7 +1808,7 @@ TEST_F(DeclarableOpsTests14, Stack_9) {
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1850,7 +1834,7 @@ TEST_F(DeclarableOpsTests14, Stack_10) {
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
TEST_F(DeclarableOpsTests14, Stack_11) { TEST_F(DeclarableOpsTests14, Stack_11) {
@ -1872,7 +1856,7 @@ TEST_F(DeclarableOpsTests14, Stack_11) {
ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
@ -1895,7 +1879,7 @@ TEST_F(DeclarableOpsTests14, Stack_12) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1917,7 +1901,7 @@ TEST_F(DeclarableOpsTests14, Stack_13) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1941,7 +1925,7 @@ TEST_F(DeclarableOpsTests14, Stack_14) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests14, Stack_15) { TEST_F(DeclarableOpsTests14, Stack_15) {
@ -1959,7 +1943,7 @@ TEST_F(DeclarableOpsTests14, Stack_15) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
} }
@ -1978,7 +1962,7 @@ TEST_F(DeclarableOpsTests14, Stack_16) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests14, Stack_17) { TEST_F(DeclarableOpsTests14, Stack_17) {
@ -1999,7 +1983,7 @@ TEST_F(DeclarableOpsTests14, Stack_17) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests14, Stack_18) { TEST_F(DeclarableOpsTests14, Stack_18) {
@ -2018,8 +2002,8 @@ TEST_F(DeclarableOpsTests14, Stack_18) {
auto out = res2.at(0); auto out = res2.at(0);
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>()); ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
} }
TEST_F(DeclarableOpsTests14, Stack_19) { TEST_F(DeclarableOpsTests14, Stack_19) {
@ -2033,7 +2017,7 @@ TEST_F(DeclarableOpsTests14, Stack_19) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
} }
TEST_F(DeclarableOpsTests14, Stack_20) { TEST_F(DeclarableOpsTests14, Stack_20) {
@ -2047,7 +2031,7 @@ TEST_F(DeclarableOpsTests14, Stack_20) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
} }
TEST_F(DeclarableOpsTests14, Stack_21) { TEST_F(DeclarableOpsTests14, Stack_21) {
@ -2073,7 +2057,363 @@ TEST_F(DeclarableOpsTests14, Stack_21) {
ASSERT_TRUE(outStack->isSameShape(outConcat)); ASSERT_TRUE(outStack->isSameShape(outConcat));
ASSERT_TRUE(outStack->equalsTo(outConcat)); ASSERT_TRUE(outStack->equalsTo(outConcat));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Reshape1) {
const std::vector<Nd4jLong> xShape = { 5,4,3 };
const std::vector<Nd4jLong> yShape = { 3,5,4 };
auto x = NDArrayFactory::create_<float>('f', xShape);
auto y = NDArrayFactory::create_<float>('f', yShape);
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(-2, y);
auto block = new Context(1, variableSpace, true);
block->fillInputs({ -1, -2 });
sd::ops::reshapeas reshape;
reshape.execute(block);
ASSERT_TRUE(x->isSameShape(y));
delete variableSpace;
delete block;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Reshape2) {
const std::vector<Nd4jLong> xShape = { 5,4,3 };
const std::vector<Nd4jLong> yShape = { 3,5,4 };
auto x = NDArrayFactory::create_<float>('c', xShape);
auto y = NDArrayFactory::create_<float>('c', yShape);
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(1, new Variable());
auto block = new Context(1, variableSpace, false);
block->fillInputs({ -1 });
std::vector<int>* arguments = block->getIArguments();
arguments->push_back(-y->ordering());
arguments->push_back(3);
arguments->push_back(5);
arguments->push_back(4);
sd::ops::reshape reshape;
Nd4jStatus status = reshape.execute(block);
ASSERT_EQ(ND4J_STATUS_OK, status);
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
ASSERT_TRUE(result->isSameShape(y));
delete y;
delete block;
delete variableSpace;
}
TEST_F(DeclarableOpsTests14, Reshape3) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(x.isSameShape(z));
}
TEST_F(DeclarableOpsTests14, Reshape4) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 3, 4, 5 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(x.isSameShape(z));
}
TEST_F(DeclarableOpsTests14, Reshape5) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 5, 4, 3 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
}
TEST_F(DeclarableOpsTests14, Reshape6) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
auto exp = NDArrayFactory::create<float>('c', { 4, 15 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 4, -1 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(z->isSameShape(exp));
}
TEST_F(DeclarableOpsTests14, Reshape7) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
auto exp = NDArrayFactory::create<float>('c', { 60 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { -1 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(z->isSameShape(exp));
}
TEST_F(DeclarableOpsTests14, Reshape8) {
auto x = NDArrayFactory::create<double>('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0});
auto e = NDArrayFactory::create<double>('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0});
auto r = x.reshape('c', {3, 2});;
r.streamline('f');
sd::ops::reshape op;
auto result = op.evaluate({&x}, {3, 2});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
}
TEST_F(DeclarableOpsTests14, Reshape9) {
auto array = NDArrayFactory::create<float>(119.f);
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
sd::ops::reshape op;
auto result = op.evaluate({&array}, {}, {1, 1});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
ASSERT_EQ(e, *z);
}
TEST_F(DeclarableOpsTests14, Reshape10) {
auto array = NDArrayFactory::create<float>(119.f);
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
auto z = NDArrayFactory::create<float>('c', {1, 1});
sd::ops::reshape op;
auto result = op.execute({&array}, {&z}, {}, {1, 1}, {});
ASSERT_EQ(Status::OK(), result);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests14, Reshape11) {
auto x = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('c', {4, 3});
x.linspace(1);
exp.linspace(1);
sd::ops::reshape op;
auto result = op.evaluate({&x}, {-99, 4, 3});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape12) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
auto shape = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 2});
auto exp = NDArrayFactory::create<double>('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
sd::ops::reshape op;
auto result = op.evaluate({&x, &shape});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape13) {
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
auto exp = NDArrayFactory::create<float>(119.f);
auto empty = NDArrayFactory::empty_<int>();
sd::ops::reshape op;
auto result = op.evaluate({&vector, empty}, {}, {});
ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(exp, *result.at(0));
delete empty;
}
TEST_F(DeclarableOpsTests14, Reshape14) {
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
auto e = NDArrayFactory::create<float>('c', {10, 0});
sd::ops::reshape op;
auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_EQ(e, *z);
} }
TEST_F(DeclarableOpsTests14, Reshape15) {
auto x0 = NDArrayFactory::create<float>('c', {2, 0});
auto x1 = NDArrayFactory::create<float>('c', {0, 1, 2});
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
sd::ops::reshape op;
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
ASSERT_EQ(Status::OK(), result0.status());
auto z0 = result0.at(0);
ASSERT_EQ(e0, *z0);
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
ASSERT_EQ(Status::OK(), result1.status());
auto z1 = result1.at(0);
ASSERT_EQ(e1, *z1);
}
TEST_F(DeclarableOpsTests14, Reshape16) {
auto x = NDArrayFactory::create<int>('c', {2, 2}, {1, 2, 3, 4});
auto shape = NDArrayFactory::create<int>('c', {1, 3}, {1, 2, 2});
auto exp = NDArrayFactory::create<int>('c', {1, 2, 2}, {1, 2, 3, 4});
sd::ops::reshape op;
auto result = op.evaluate({&x, &shape}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape17) {
auto x = NDArrayFactory::create<float>(2.0f);
auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape18) {
auto x = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 3});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape19) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape20) {
NDArray x1('c', {2,0}, sd::DataType::FLOAT32);
NDArray x2('c', {10,0}, sd::DataType::FLOAT32);
NDArray x3('c', {2,0,0,10}, sd::DataType::FLOAT32);
NDArray x4('c', {0,0,10}, sd::DataType::FLOAT32);
NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32);
NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32);
NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32);
sd::ops::reshape op;
auto result = op.evaluate({&x1}, {}, {2, -1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2,0}));
result = op.evaluate({&x2}, {}, {2, 0, -1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2,0,5}));
result = op.evaluate({&x2}, {}, {5, 2, -1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({5,2,0}));
result = op.evaluate({&x2}, {}, {-1, 2, 0});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({5,2,0}));
result = op.evaluate({&x3}, {}, {2, 0, -1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2,0,10}));
result = op.evaluate({&x4}, {}, {2, -1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2,5,0}));
result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2,0,0,0,10}));
result = op.evaluate({&x6}, {}, {-1, 2, 0});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({5, 2, 0}));
result = op.evaluate({&x7}, {}, {-1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2, 0}));
result = op.evaluate({&x7}, {}, {10,0,50,100});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100}));
}

View File

@ -89,7 +89,7 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
sd::ops::standardize_bp op; sd::ops::standardize_bp op;
auto result = op.evaluate({&x, &eps}, {0}); auto result = op.evaluate({&x, &eps}, {0});
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
} }
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
@ -108,7 +108,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
auto out = result.at(0); auto out = result.at(0);
ASSERT_TRUE(e.equalsTo(out)); ASSERT_TRUE(e.equalsTo(out));
} }
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) {
@ -126,7 +126,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) {
auto out = result.at(0); auto out = result.at(0);
// out->printIndexedBuffer("Adjusted Constrast"); // out->printIndexedBuffer("Adjusted Constrast");
ASSERT_TRUE(e.equalsTo(out)); ASSERT_TRUE(e.equalsTo(out));
} }
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) {
@ -144,7 +144,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) {
auto out = result.at(0); auto out = result.at(0);
// out->printIndexedBuffer("Adjusted Constrast"); // out->printIndexedBuffer("Adjusted Constrast");
ASSERT_TRUE(e.equalsTo(out)); ASSERT_TRUE(e.equalsTo(out));
} }
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
@ -162,7 +162,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
auto out = result.at(0); auto out = result.at(0);
// out->printIndexedBuffer("Adjusted Constrast"); // out->printIndexedBuffer("Adjusted Constrast");
ASSERT_TRUE(e.equalsTo(out)); ASSERT_TRUE(e.equalsTo(out));
} }
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) {
@ -177,7 +177,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) {
auto out = result.at(0); auto out = result.at(0);
// out->printIndexedBuffer("Adjusted Constrast"); // out->printIndexedBuffer("Adjusted Constrast");
ASSERT_TRUE(e.equalsTo(out)); ASSERT_TRUE(e.equalsTo(out));
} }
/* /*
@ -308,7 +308,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) {
// out->printBuffer("Adjusted Constrast6"); // out->printBuffer("Adjusted Constrast6");
// e.printBuffer("Adjusted Expected 6"); // e.printBuffer("Adjusted Expected 6");
// ASSERT_TRUE(e.equalsTo(out)); // ASSERT_TRUE(e.equalsTo(out));
} }
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) {
@ -415,7 +415,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) {
auto diff = e - *out; auto diff = e - *out;
// diff.printBuffer("Adjusted subtract 7"); // diff.printBuffer("Adjusted subtract 7");
ASSERT_TRUE(e.equalsTo(out)); ASSERT_TRUE(e.equalsTo(out));
} }
TEST_F(DeclarableOpsTests15, Test_BitCast_1) { TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
@ -429,7 +429,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
auto out = result.at(0); auto out = result.at(0);
// out->printIndexedBuffer("Casted result"); // out->printIndexedBuffer("Casted result");
ASSERT_TRUE(e.equalsTo(out)); ASSERT_TRUE(e.equalsTo(out));
} }
TEST_F(DeclarableOpsTests15, Test_BitCast_2) { TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
@ -444,7 +444,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
auto out = result.at(0); auto out = result.at(0);
ASSERT_TRUE(e.equalsTo(out)); ASSERT_TRUE(e.equalsTo(out));
} }
TEST_F(DeclarableOpsTests15, Test_BitCast_3) { TEST_F(DeclarableOpsTests15, Test_BitCast_3) {
@ -487,7 +487,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) {
// e.printIndexedBuffer("Double to int64"); // e.printIndexedBuffer("Double to int64");
auto res = result.at(0); auto res = result.at(0);
ASSERT_EQ(*res, e); ASSERT_EQ(*res, e);
} }
@ -508,7 +508,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_5) {
// res->printIndexedBuffer("BITCAST5"); // res->printIndexedBuffer("BITCAST5");
ASSERT_TRUE(e.equalsTo(res)); ASSERT_TRUE(e.equalsTo(res));
} }
TEST_F(DeclarableOpsTests15, Test_BitCast_6) { TEST_F(DeclarableOpsTests15, Test_BitCast_6) {
@ -528,7 +528,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_6) {
// res->printIndexedBuffer("BITCAST6"); // res->printIndexedBuffer("BITCAST6");
ASSERT_TRUE(e.equalsTo(res)); ASSERT_TRUE(e.equalsTo(res));
} }
TEST_F(DeclarableOpsTests15, Test_BitCast_7) { TEST_F(DeclarableOpsTests15, Test_BitCast_7) {
auto x = NDArrayFactory::create<float16>('c', {4, 4}, { auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
@ -547,7 +547,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_7) {
// res->printIndexedBuffer("BITCAST7"); // res->printIndexedBuffer("BITCAST7");
ASSERT_TRUE(e.equalsTo(res)); ASSERT_TRUE(e.equalsTo(res));
} }
TEST_F(DeclarableOpsTests15, test_matmul_bp_1) { TEST_F(DeclarableOpsTests15, test_matmul_bp_1) {
@ -637,7 +637,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
sd::ops::layer_norm op; sd::ops::layer_norm op;
auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false}); auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false});
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
} }
TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
@ -649,7 +649,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
sd::ops::layer_norm_bp op; sd::ops::layer_norm_bp op;
auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false}); auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false});
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -710,30 +710,6 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) {
ASSERT_NE(*resultA0.at(0), *resultB0.at(0)); ASSERT_NE(*resultA0.at(0), *resultB0.at(0));
} }
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) {
auto array = NDArrayFactory::create<float>(119.f);
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
sd::ops::reshape op;
auto result = op.evaluate({&array}, {}, {1, 1});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
ASSERT_EQ(e, *z);
}
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_2) {
auto array = NDArrayFactory::create<float>(119.f);
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
auto z = NDArrayFactory::create<float>('c', {1, 1});
sd::ops::reshape op;
auto result = op.execute({&array}, {&z}, {}, {1, 1}, {});
ASSERT_EQ(Status::OK(), result);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests15, test_rank_1) { TEST_F(DeclarableOpsTests15, test_rank_1) {
auto array = NDArrayFactory::create<float>('c', {4, 64}); auto array = NDArrayFactory::create<float>('c', {4, 64});
auto e = NDArrayFactory::create<int>('c', {}, {2}); auto e = NDArrayFactory::create<int>('c', {}, {2});
@ -757,7 +733,7 @@ TEST_F(DeclarableOpsTests15, test_rank_2) {
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
} }
TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
@ -800,7 +776,7 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_2) {
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0); auto z = result.at(0);
} }
TEST_F(DeclarableOpsTests15, test_lstmBlock_3) { TEST_F(DeclarableOpsTests15, test_lstmBlock_3) {
@ -969,7 +945,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) {
sd::ops::rgb_to_grs op; sd::ops::rgb_to_grs op;
auto result = op.evaluate({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
ASSERT_EQ(Status::THROW(), result.status()); ASSERT_EQ(Status::THROW(), result.status());
} catch (std::exception& e) { } catch (std::exception& e) {
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
} }
@ -1063,7 +1039,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) {
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1074,7 +1050,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) {
sd::ops::rgb_to_yuv op; sd::ops::rgb_to_yuv op;
auto result = op.evaluate({ &rgbs }, {}, {}); auto result = op.evaluate({ &rgbs }, {}, {});
ASSERT_EQ(Status::THROW(), result.status()); ASSERT_EQ(Status::THROW(), result.status());
} }
catch (std::exception & e) { catch (std::exception & e) {
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
@ -1109,7 +1085,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1168,7 +1144,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) {
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1179,7 +1155,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) {
sd::ops::yuv_to_rgb op; sd::ops::yuv_to_rgb op;
auto result = op.evaluate({ &yuv }, {}, {}); auto result = op.evaluate({ &yuv }, {}, {});
ASSERT_EQ(Status::THROW(), result.status()); ASSERT_EQ(Status::THROW(), result.status());
} }
catch (std::exception & e) { catch (std::exception & e) {
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
@ -1423,7 +1399,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test8) {
ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); ASSERT_TRUE(dLdxExp.equalsTo(dLdx));
ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); ASSERT_TRUE(dLdyExp.isSameShape(dLdy));
ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); ASSERT_TRUE(dLdyExp.equalsTo(dLdy));
} }
TEST_F(DeclarableOpsTests15, Pow_BP_Test9) { TEST_F(DeclarableOpsTests15, Pow_BP_Test9) {
@ -1515,7 +1491,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) {
ASSERT_NEAR(dLdyB->e<float>(i), dLdyExpB.e<float>(i), 0.00001); ASSERT_NEAR(dLdyB->e<float>(i), dLdyExpB.e<float>(i), 0.00001);
} }
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) { TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) {
@ -1532,10 +1508,10 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) {
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {}); auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
auto* dLdAbp = resultsBP.at(0); auto* dLdAbp = resultsBP.at(0);
auto* dLdBbp = resultsBP.at(1); auto* dLdBbp = resultsBP.at(1);
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
@ -1554,10 +1530,10 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) {
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
auto* dLdAbp = resultsBP.at(0); auto* dLdAbp = resultsBP.at(0);
auto* dLdBbp = resultsBP.at(1); auto* dLdBbp = resultsBP.at(1);
ASSERT_TRUE(B.isSameShape(*dLdAbp)); ASSERT_TRUE(B.isSameShape(*dLdAbp));
ASSERT_TRUE(B.equalsTo(*dLdAbp)); ASSERT_TRUE(B.equalsTo(*dLdAbp));
@ -1606,7 +1582,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) {
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
auto* dLdAbp = resultsBP.at(0); auto* dLdAbp = resultsBP.at(0);
auto* dLdBbp = resultsBP.at(1); auto* dLdBbp = resultsBP.at(1);
@ -1632,7 +1608,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) {
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
auto* dLdAbp = resultsBP.at(0); auto* dLdAbp = resultsBP.at(0);
auto* dLdBbp = resultsBP.at(1); auto* dLdBbp = resultsBP.at(1);
@ -1655,7 +1631,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) {
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
auto* dLdAbp = resultsBP.at(0); auto* dLdAbp = resultsBP.at(0);
auto* dLdBbp = resultsBP.at(1); auto* dLdBbp = resultsBP.at(1);
@ -1706,7 +1682,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) {
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
auto* dLdAbp = resultsBP.at(0); auto* dLdAbp = resultsBP.at(0);
auto* dLdBbp = resultsBP.at(1); auto* dLdBbp = resultsBP.at(1);

View File

@ -789,24 +789,6 @@ TEST_F(DeclarableOpsTests4, Test_FloorTests_1) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
auto x = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('c', {4, 3});
x.linspace(1);
exp.linspace(1);
sd::ops::reshape op;
auto result = op.evaluate({&x}, {-99, 4, 3});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_Split_1) { TEST_F(DeclarableOpsTests4, Test_Split_1) {
@ -1209,23 +1191,6 @@ TEST_F(DeclarableOpsTests4, Test_Add_119) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
auto shape = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 2});
auto exp = NDArrayFactory::create<double>('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
sd::ops::reshape op;
auto result = op.evaluate({&x, &shape});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_TileToShape_1) { TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {

View File

@ -140,37 +140,6 @@ TEST_F(EmptyTests, Test_Concat_4) {
ASSERT_EQ(exp, *z); ASSERT_EQ(exp, *z);
} }
TEST_F(EmptyTests, Test_Reshape_1) {
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
auto exp = NDArrayFactory::create<float>(119.f);
auto empty = NDArrayFactory::empty_<int>();
sd::ops::reshape op;
auto result = op.evaluate({&vector, empty}, {}, {});
ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(exp, *result.at(0));
delete empty;
}
TEST_F(EmptyTests, Test_Reshape_3) {
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
auto e = NDArrayFactory::create<float>('c', {10, 0});
sd::ops::reshape op;
auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_EQ(e, *z);
}
TEST_F(EmptyTests, Test_dup_1) { TEST_F(EmptyTests, Test_dup_1) {
auto empty = NDArrayFactory::empty<int>(); auto empty = NDArrayFactory::empty<int>();
auto dup = new NDArray(empty.dup()); auto dup = new NDArray(empty.dup());
@ -256,41 +225,6 @@ TEST_F(EmptyTests, test_shaped_empty_4) {
ASSERT_EQ(shapeOf, array.getShapeAsVector()); ASSERT_EQ(shapeOf, array.getShapeAsVector());
} }
TEST_F(EmptyTests, test_empty_reshape_1) {
/*
INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0);
INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2);
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0];
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0];
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0];
assertArrayEquals(new long[]{2, 0, 0}, out0.shape());
assertArrayEquals(new long[]{0, 1}, out1.shape());
assertArrayEquals(new long[]{10, 0}, out2.shape());
*/
auto x0 = NDArrayFactory::create<float>('c', {2, 0});
auto x1 = NDArrayFactory::create<float>('c', {0, 1, 2});
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
sd::ops::reshape op;
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
ASSERT_EQ(Status::OK(), result0.status());
auto z0 = result0.at(0);
ASSERT_EQ(e0, *z0);
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
ASSERT_EQ(Status::OK(), result1.status());
auto z1 = result1.at(0);
ASSERT_EQ(e1, *z1);
}
TEST_F(EmptyTests, test_empty_matmul_1) { TEST_F(EmptyTests, test_empty_matmul_1) {
auto x = NDArrayFactory::create<float>('c', {0, 1}); auto x = NDArrayFactory::create<float>('c', {0, 1});

View File

@ -48,7 +48,7 @@ TEST_F(ParityOpsTests, TestZeroAs1) {
ASSERT_TRUE(z->isSameShape(&x)); ASSERT_TRUE(z->isSameShape(&x));
ASSERT_TRUE(z->equalsTo(&exp)); ASSERT_TRUE(z->equalsTo(&exp));
} }
TEST_F(ParityOpsTests, TestMaximum1) { TEST_F(ParityOpsTests, TestMaximum1) {
@ -66,7 +66,7 @@ TEST_F(ParityOpsTests, TestMaximum1) {
ASSERT_TRUE(y.equalsTo(z)); ASSERT_TRUE(y.equalsTo(z));
} }
@ -86,7 +86,7 @@ TEST_F(ParityOpsTests, TestMinimum1) {
ASSERT_TRUE(y.equalsTo(z)); ASSERT_TRUE(y.equalsTo(z));
} }
TEST_F(ParityOpsTests, TestTear1) { TEST_F(ParityOpsTests, TestTear1) {
@ -106,7 +106,7 @@ TEST_F(ParityOpsTests, TestTear1) {
for (int e = 0; e < result.size(); e++) for (int e = 0; e < result.size(); e++)
ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e)));
} }
TEST_F(ParityOpsTests, TestUnstack1) { TEST_F(ParityOpsTests, TestUnstack1) {
@ -126,7 +126,7 @@ TEST_F(ParityOpsTests, TestUnstack1) {
for (int e = 0; e < result.size(); e++) for (int e = 0; e < result.size(); e++)
ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e)));
} }
@ -148,7 +148,7 @@ TEST_F(ParityOpsTests, TestUnstack2) {
for (int e = 0; e < result.size(); e++) for (int e = 0; e < result.size(); e++)
ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e)));
} }
TEST_F(ParityOpsTests, TestUnstack3) { TEST_F(ParityOpsTests, TestUnstack3) {
@ -166,7 +166,7 @@ TEST_F(ParityOpsTests, TestUnstack3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -185,7 +185,7 @@ TEST_F(ParityOpsTests, TestUnstack4) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, TestUnstack5) { TEST_F(ParityOpsTests, TestUnstack5) {
@ -203,7 +203,7 @@ TEST_F(ParityOpsTests, TestUnstack5) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, TestUnstack6) { TEST_F(ParityOpsTests, TestUnstack6) {
@ -221,7 +221,7 @@ TEST_F(ParityOpsTests, TestUnstack6) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, TestUnstack7) { TEST_F(ParityOpsTests, TestUnstack7) {
@ -239,7 +239,7 @@ TEST_F(ParityOpsTests, TestUnstack7) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, TestUnstack8) { TEST_F(ParityOpsTests, TestUnstack8) {
@ -257,7 +257,7 @@ TEST_F(ParityOpsTests, TestUnstack8) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, TestUnstack9) { TEST_F(ParityOpsTests, TestUnstack9) {
@ -275,7 +275,7 @@ TEST_F(ParityOpsTests, TestUnstack9) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -293,7 +293,7 @@ TEST_F(ParityOpsTests, TestUnstack10) {
ASSERT_TRUE(exp.isSameShape(result.at(1))); ASSERT_TRUE(exp.isSameShape(result.at(1)));
ASSERT_TRUE(exp.isSameShape(result.at(2))); ASSERT_TRUE(exp.isSameShape(result.at(2)));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -310,7 +310,7 @@ TEST_F(ParityOpsTests, TestUnstack11) {
ASSERT_TRUE(exp.isSameShape(result.at(0))); ASSERT_TRUE(exp.isSameShape(result.at(0)));
ASSERT_TRUE(exp.isSameShape(result.at(1))); ASSERT_TRUE(exp.isSameShape(result.at(1)));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -325,7 +325,7 @@ TEST_F(ParityOpsTests, TestUnstack12) {
ASSERT_TRUE(result.size() == 0); ASSERT_TRUE(result.size() == 0);
} }
TEST_F(ParityOpsTests, TestUnstack13) { TEST_F(ParityOpsTests, TestUnstack13) {
@ -361,7 +361,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest1) {
ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.isSameShape(z));
ASSERT_TRUE(reshaped.equalsTo(z)); ASSERT_TRUE(reshaped.equalsTo(z));
} }
@ -380,7 +380,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest2) {
ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.isSameShape(z));
ASSERT_TRUE(reshaped.equalsTo(z)); ASSERT_TRUE(reshaped.equalsTo(z));
} }
@ -399,7 +399,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest3) {
ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.isSameShape(z));
ASSERT_TRUE(reshaped.equalsTo(z)); ASSERT_TRUE(reshaped.equalsTo(z));
} }
TEST_F(ParityOpsTests, ExpandDimsTest4) { TEST_F(ParityOpsTests, ExpandDimsTest4) {
@ -417,7 +417,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest4) {
ASSERT_TRUE(reshaped.isSameShape(z)); ASSERT_TRUE(reshaped.isSameShape(z));
ASSERT_TRUE(reshaped.equalsTo(z)); ASSERT_TRUE(reshaped.equalsTo(z));
} }
@ -434,7 +434,7 @@ TEST_F(ParityOpsTests, Test_Shape_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -452,7 +452,7 @@ TEST_F(ParityOpsTests, Test_Equals_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -470,7 +470,7 @@ TEST_F(ParityOpsTests, Test_NotEquals_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Less_1) { TEST_F(ParityOpsTests, Test_Less_1) {
@ -487,7 +487,7 @@ TEST_F(ParityOpsTests, Test_Less_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_LessEquals_1) { TEST_F(ParityOpsTests, Test_LessEquals_1) {
@ -504,7 +504,7 @@ TEST_F(ParityOpsTests, Test_LessEquals_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_GreaterEquals_1) { TEST_F(ParityOpsTests, Test_GreaterEquals_1) {
@ -521,7 +521,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_GreaterEquals_2) { TEST_F(ParityOpsTests, Test_GreaterEquals_2) {
@ -538,7 +538,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Greater_1) { TEST_F(ParityOpsTests, Test_Greater_1) {
@ -555,7 +555,7 @@ TEST_F(ParityOpsTests, Test_Greater_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Where_1) { TEST_F(ParityOpsTests, Test_Where_1) {
@ -575,7 +575,7 @@ TEST_F(ParityOpsTests, Test_Where_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Where_2) { TEST_F(ParityOpsTests, Test_Where_2) {
@ -593,7 +593,7 @@ TEST_F(ParityOpsTests, Test_Where_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -612,7 +612,7 @@ TEST_F(ParityOpsTests, Test_Where_3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Select_1) { TEST_F(ParityOpsTests, Test_Select_1) {
@ -630,7 +630,7 @@ TEST_F(ParityOpsTests, Test_Select_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Select_2) { TEST_F(ParityOpsTests, Test_Select_2) {
@ -648,7 +648,7 @@ TEST_F(ParityOpsTests, Test_Select_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Select_3) { TEST_F(ParityOpsTests, Test_Select_3) {
@ -666,25 +666,7 @@ TEST_F(ParityOpsTests, Test_Select_3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(ParityOpsTests, Test_Reshape_TF_1) {
auto x = NDArrayFactory::create<int>('c', {2, 2}, {1, 2, 3, 4});
auto shape = NDArrayFactory::create<int>('c', {1, 3}, {1, 2, 2});
auto exp = NDArrayFactory::create<int>('c', {1, 2, 2}, {1, 2, 3, 4});
sd::ops::reshape op;
auto result = op.evaluate({&x, &shape}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Bias_Add_1) { TEST_F(ParityOpsTests, Test_Bias_Add_1) {
@ -702,7 +684,7 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) {
for (int e = 0; e < tads.size(); e++) { for (int e = 0; e < tads.size(); e++) {
ASSERT_TRUE(bias.equalsTo(tads.at(e))); ASSERT_TRUE(bias.equalsTo(tads.at(e)));
} }
} }
TEST_F(ParityOpsTests, Test_Scatter_Add_1) { TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
@ -718,7 +700,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Scatter_Add_2) { TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
@ -735,7 +717,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Scatter_Add_3) { TEST_F(ParityOpsTests, Test_Scatter_Add_3) {
@ -751,7 +733,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_3) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Scatter_Add_4) { TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
@ -767,7 +749,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Scatter_Add_5) { TEST_F(ParityOpsTests, Test_Scatter_Add_5) {
@ -784,7 +766,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_5) {
// z->printBuffer(); // z->printBuffer();
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Scatter_Add_6) { TEST_F(ParityOpsTests, Test_Scatter_Add_6) {
@ -800,7 +782,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, Test_Scatter_Add_7) { TEST_F(ParityOpsTests, Test_Scatter_Add_7) {
@ -816,7 +798,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_7) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
@ -864,7 +846,7 @@ TEST_F(ParityOpsTests, scatterMax_test1) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, scatterMax_test2) { TEST_F(ParityOpsTests, scatterMax_test2) {
@ -880,7 +862,7 @@ TEST_F(ParityOpsTests, scatterMax_test2) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, scatterMax_test3) { TEST_F(ParityOpsTests, scatterMax_test3) {
@ -897,7 +879,7 @@ TEST_F(ParityOpsTests, scatterMax_test3) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, scatterMax_test4) { TEST_F(ParityOpsTests, scatterMax_test4) {
@ -913,7 +895,7 @@ TEST_F(ParityOpsTests, scatterMax_test4) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, scatterMax_test5) { TEST_F(ParityOpsTests, scatterMax_test5) {
@ -929,7 +911,7 @@ TEST_F(ParityOpsTests, scatterMax_test5) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, scatterMax_test6) { TEST_F(ParityOpsTests, scatterMax_test6) {
@ -945,7 +927,7 @@ TEST_F(ParityOpsTests, scatterMax_test6) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -963,7 +945,7 @@ TEST_F(ParityOpsTests, scatterMin_test1) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, scatterMin_test2) { TEST_F(ParityOpsTests, scatterMin_test2) {
@ -979,7 +961,7 @@ TEST_F(ParityOpsTests, scatterMin_test2) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, scatterMin_test3) { TEST_F(ParityOpsTests, scatterMin_test3) {
@ -995,7 +977,7 @@ TEST_F(ParityOpsTests, scatterMin_test3) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ParityOpsTests, scatterMin_test4) { TEST_F(ParityOpsTests, scatterMin_test4) {
@ -1012,7 +994,7 @@ TEST_F(ParityOpsTests, scatterMin_test4) {
// z->printBuffer(); // z->printBuffer();
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1044,7 +1026,7 @@ TEST_F(ParityOpsTests, scatterND_test1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1064,7 +1046,7 @@ TEST_F(ParityOpsTests, scatterND_test2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1088,7 +1070,7 @@ TEST_F(ParityOpsTests, scatterND_test3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1107,7 +1089,7 @@ TEST_F(ParityOpsTests, scatterND_test4) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1127,7 +1109,7 @@ TEST_F(ParityOpsTests, scatterND_test5) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1154,7 +1136,7 @@ TEST_F(ParityOpsTests, scatterND_test6) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1181,7 +1163,7 @@ TEST_F(ParityOpsTests, scatterND_test7) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1202,7 +1184,7 @@ TEST_F(ParityOpsTests, scatterND_test8) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1236,7 +1218,7 @@ TEST_F(ParityOpsTests, scatterND_add_test1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1260,7 +1242,7 @@ TEST_F(ParityOpsTests, scatterND_add_test2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1283,7 +1265,7 @@ TEST_F(ParityOpsTests, scatterND_add_test3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1310,7 +1292,7 @@ TEST_F(ParityOpsTests, scatterND_add_test4) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1346,7 +1328,7 @@ TEST_F(ParityOpsTests, scatterND_add_test5) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1379,7 +1361,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1404,7 +1386,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1427,7 +1409,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1454,7 +1436,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test4) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1490,7 +1472,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test5) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -1511,7 +1493,7 @@ TEST_F(ParityOpsTests, scatterND_update_test1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1535,7 +1517,7 @@ TEST_F(ParityOpsTests, scatterND_update_test2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1559,7 +1541,7 @@ TEST_F(ParityOpsTests, scatterND_update_test3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1586,7 +1568,7 @@ TEST_F(ParityOpsTests, scatterND_update_test4) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1622,7 +1604,7 @@ TEST_F(ParityOpsTests, scatterND_update_test5) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1655,7 +1637,7 @@ TEST_F(ParityOpsTests, scatter_update_1) {
ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.isSameShape(x));
ASSERT_TRUE(exp.equalsTo(x)); ASSERT_TRUE(exp.equalsTo(x));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1674,7 +1656,7 @@ TEST_F(ParityOpsTests, scatter_update_2) {
ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.isSameShape(x));
ASSERT_TRUE(exp.equalsTo(x)); ASSERT_TRUE(exp.equalsTo(x));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1693,7 +1675,7 @@ TEST_F(ParityOpsTests, scatter_update_3) {
ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.isSameShape(x));
ASSERT_TRUE(exp.equalsTo(x)); ASSERT_TRUE(exp.equalsTo(x));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1712,5 +1694,5 @@ TEST_F(ParityOpsTests, scatter_update_4) {
ASSERT_TRUE(exp.isSameShape(x)); ASSERT_TRUE(exp.isSameShape(x));
ASSERT_TRUE(exp.equalsTo(x)); ASSERT_TRUE(exp.equalsTo(x));
} }

View File

@ -103,7 +103,7 @@ TEST_F(ScalarTests, Test_Concat_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -124,7 +124,7 @@ TEST_F(ScalarTests, Test_Concat_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -146,7 +146,7 @@ TEST_F(ScalarTests, Test_Concat_3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ScalarTests, Test_ExpandDims_1) { TEST_F(ScalarTests, Test_ExpandDims_1) {
@ -163,7 +163,7 @@ TEST_F(ScalarTests, Test_ExpandDims_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ScalarTests, Test_Squeeze_1) { TEST_F(ScalarTests, Test_Squeeze_1) {
@ -179,27 +179,9 @@ TEST_F(ScalarTests, Test_Squeeze_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ScalarTests, Test_Reshape_1) {
auto x = NDArrayFactory::create<float>(2.0f);
auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(ScalarTests, Test_Permute_1) { TEST_F(ScalarTests, Test_Permute_1) {
auto x = NDArrayFactory::create<float>(3.0f); auto x = NDArrayFactory::create<float>(3.0f);
auto exp = NDArrayFactory::create<float>(3.0f); auto exp = NDArrayFactory::create<float>(3.0f);
@ -213,7 +195,7 @@ TEST_F(ScalarTests, Test_Permute_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(ScalarTests, Test_Concat_Scalar_1) { TEST_F(ScalarTests, Test_Concat_Scalar_1) {

View File

@ -77,7 +77,7 @@ TEST_F(SingleDimTests, Test_Concat_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(SingleDimTests, Test_Reduce_1) { TEST_F(SingleDimTests, Test_Reduce_1) {
@ -111,7 +111,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -129,7 +129,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -149,7 +149,7 @@ TEST_F(SingleDimTests, Test_Squeeze_1) {
ASSERT_EQ(exp.rankOf(), z->rankOf()); ASSERT_EQ(exp.rankOf(), z->rankOf());
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(SingleDimTests, Test_Squeeze_2) { TEST_F(SingleDimTests, Test_Squeeze_2) {
@ -165,42 +165,9 @@ TEST_F(SingleDimTests, Test_Squeeze_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(SingleDimTests, Test_Reshape_1) {
auto x = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 3});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(SingleDimTests, Test_Reshape_2) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(SingleDimTests, Test_Permute_1) { TEST_F(SingleDimTests, Test_Permute_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3}); auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3}); auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
@ -214,5 +181,5 @@ TEST_F(SingleDimTests, Test_Permute_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }