Shyrma reshape empty (#338)
* - start working on reshape op which operates with empty shapes Signed-off-by: Yurii <iuriish@yahoo.com> * - correct reshaping for empty arrays Signed-off-by: Yurii <iuriish@yahoo.com> * - remove unnecessary check in Loopkind Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
bf0ddbc06c
commit
29e61579c1
|
@ -1944,7 +1944,7 @@ void NDArray::tilei(const std::vector<Nd4jLong>& reps) {
|
||||||
Nd4jLong NDArray::sizeAt(const int dim) const {
|
Nd4jLong NDArray::sizeAt(const int dim) const {
|
||||||
|
|
||||||
if (dim >= this->rankOf() || dim < -this->rankOf())
|
if (dim >= this->rankOf() || dim < -this->rankOf())
|
||||||
throw std::runtime_error("Bad size index requested");
|
throw std::runtime_error("NDArray::sizeAt: bad size index requested");
|
||||||
|
|
||||||
if (dim >= 0)
|
if (dim >= 0)
|
||||||
return shape::shapeOf(_shapeInfo)[dim];
|
return shape::shapeOf(_shapeInfo)[dim];
|
||||||
|
|
|
@ -35,16 +35,16 @@ namespace sd {
|
||||||
|
|
||||||
|
|
||||||
class ND4J_EXPORT LoopKind {
|
class ND4J_EXPORT LoopKind {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D };
|
enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D };
|
||||||
|
|
||||||
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -59,8 +59,8 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd
|
||||||
|
|
||||||
int temp;
|
int temp;
|
||||||
const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c';
|
const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c';
|
||||||
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';
|
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';
|
||||||
const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo);
|
const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
if (xEws == 1 && zEws == 1 && xOrder == zOrder && (shapesSame || xOrder == 'c'))
|
if (xEws == 1 && zEws == 1 && xOrder == zOrder && (shapesSame || xOrder == 'c'))
|
||||||
return EWS1;
|
return EWS1;
|
||||||
|
@ -160,7 +160,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const N
|
||||||
const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c';
|
const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c';
|
||||||
const bool yVectorOrC = shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c';
|
const bool yVectorOrC = shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c';
|
||||||
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';
|
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';
|
||||||
const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo);
|
const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c'))
|
if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c'))
|
||||||
return EWS1;
|
return EWS1;
|
||||||
|
@ -206,7 +206,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const
|
||||||
const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c';
|
const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c';
|
||||||
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';;
|
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';;
|
||||||
|
|
||||||
if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && shape::rank(xShapeInfo) == 2 && xEws == 1 && xOrder == 'c' && xRank == 2 &&
|
if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 &&
|
||||||
tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
||||||
return SMALLARR2DX;
|
return SMALLARR2DX;
|
||||||
if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
||||||
|
@ -233,18 +233,18 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo) {
|
LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo) {
|
||||||
|
|
||||||
// both tad shapes are the same, but strides and ews may be different
|
// both tad shapes are the same, but strides and ews may be different
|
||||||
|
|
||||||
const int tadRank = shape::rank(xTadShapeInfo);
|
const int tadRank = shape::rank(xTadShapeInfo);
|
||||||
|
|
||||||
const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo);
|
const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo);
|
||||||
const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo);
|
const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo);
|
||||||
const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo);
|
const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo);
|
||||||
|
|
||||||
const char xTadOrder = shape::order(xTadShapeInfo);
|
const char xTadOrder = shape::order(xTadShapeInfo);
|
||||||
const char yTadOrder = shape::order(xTadShapeInfo);
|
const char yTadOrder = shape::order(xTadShapeInfo);
|
||||||
const char zOrder = shape::order(zShapeInfo);
|
const char zOrder = shape::order(zShapeInfo);
|
||||||
|
|
||||||
int position;
|
int position;
|
||||||
const bool xTadVectorOrC = shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c';
|
const bool xTadVectorOrC = shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c';
|
||||||
const bool yTadVectorOrC = shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c';
|
const bool yTadVectorOrC = shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c';
|
||||||
|
@ -265,7 +265,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, c
|
||||||
return RANK4;
|
return RANK4;
|
||||||
if(tadRank == 5 && zEws > 0 && zVectorOrC)
|
if(tadRank == 5 && zEws > 0 && zVectorOrC)
|
||||||
return RANK5;
|
return RANK5;
|
||||||
return COMMON;
|
return COMMON;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -35,111 +35,19 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||||
if (x->isEmpty()) {
|
if (x->isEmpty()) {
|
||||||
REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||||
return Status::OK(); //No op
|
return Status::OK(); //No op
|
||||||
}
|
|
||||||
|
|
||||||
if (block.width() == 1) {
|
|
||||||
|
|
||||||
auto arguments = block.getIArguments();
|
|
||||||
int argsSize = arguments->size();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int e = 1;
|
|
||||||
char order = (char) -(*arguments)[0];
|
|
||||||
if (order != 'c' && order != 'f') {
|
|
||||||
order = 'c'; //x->ordering();
|
|
||||||
e = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew;
|
|
||||||
int e2 = e;
|
|
||||||
for (; e < (int) arguments->size(); e++) {
|
|
||||||
if (arguments->at(e) == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(; e2 < e; e2++){
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
|
||||||
shapeNew.push_back(realShape);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew.push_back(arguments->at(e));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
|
||||||
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
|
||||||
|
|
||||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
|
||||||
nd4j_printv("Reshape: new shape", shapeNew);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto xr = x->reshape(order, shapeNew);
|
|
||||||
z->assign(xr);
|
|
||||||
STORE_RESULT(*z);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
|
|
||||||
} else if (block.width() == 2) {
|
|
||||||
|
|
||||||
auto s = INPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
char order = 'c';
|
|
||||||
if (block.numI() > 0)
|
|
||||||
order = (char) -INT_ARG(0);
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew(s->lengthOf());
|
|
||||||
|
|
||||||
for (int e = 0; e < (int) s->lengthOf(); e++) {
|
|
||||||
auto dim = s->e<Nd4jLong >(e);
|
|
||||||
if (dim == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(int e2 = 0; e2 < e; e2++){
|
|
||||||
shapeLength *= s->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
|
||||||
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
shapeLength *= s->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
|
||||||
shapeNew[e] = realShape;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew[e] = dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
|
||||||
nd4j_printv("Reshape: new shape", shapeNew);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (s->isScalar()) {
|
|
||||||
// just a scalar
|
|
||||||
z->assign(x);
|
|
||||||
} else {
|
|
||||||
// in some cases we might go away with simple memcpy call instead of assign call
|
|
||||||
if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) {
|
|
||||||
z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset());
|
|
||||||
} else {
|
|
||||||
auto xr = x->reshape(order, shapeNew);
|
|
||||||
z->assign(xr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ND4J_STATUS_BAD_INPUT;
|
REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf());
|
||||||
|
|
||||||
|
if (Environment::getInstance()->isDebugAndVerbose())
|
||||||
|
nd4j_printv("Reshape: new shape", z->getShapeAsVector());
|
||||||
|
|
||||||
|
z->assign(x->reshape(z->ordering(), z->getShapeAsVector()));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -151,117 +59,73 @@ DECLARE_TYPES(reshape) {
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(reshape) {
|
DECLARE_SHAPE_FN(reshape) {
|
||||||
auto inp = inputShape->at(0);
|
|
||||||
|
|
||||||
// we can launch op using Int arguments
|
const auto x = INPUT_VARIABLE(0);
|
||||||
if (inputShape->size() == 1) {
|
|
||||||
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
|
|
||||||
std::vector<int> *arguments = block.getIArguments();
|
|
||||||
|
|
||||||
int e = 1;
|
std::vector<int> reshapeArgs;
|
||||||
char order = (char) -(*arguments)[0];
|
std::vector<Nd4jLong> shapeNew;
|
||||||
if (order != 'c' && order != 'f') {
|
char orderNew = 'c';
|
||||||
order = shape::order(inp);
|
|
||||||
e = 0;
|
if (block.width() == 1) {
|
||||||
|
reshapeArgs = *block.getIArguments();
|
||||||
|
if(!reshapeArgs.empty()) {
|
||||||
|
orderNew = (char) -reshapeArgs[0];
|
||||||
|
if(orderNew == 'c' || orderNew == 'f')
|
||||||
|
reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew;
|
|
||||||
|
|
||||||
int e2 = e;
|
|
||||||
for (; e < (int) arguments->size(); e++) {
|
|
||||||
if ((int) arguments->at(e) == -1){
|
|
||||||
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(; e2 < e; e2 ++){
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
|
||||||
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(shapeLength == 0){
|
|
||||||
//Edge case for empty:
|
|
||||||
shapeNew.push_back(0);
|
|
||||||
} else {
|
|
||||||
//Standard case
|
|
||||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
|
||||||
shapeNew.push_back(realShape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew.push_back(arguments->at(e));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
|
|
||||||
} else {
|
|
||||||
// or, with second input "as shape"
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
|
||||||
auto y = INPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
// special case here
|
|
||||||
if (y->isEmpty()) {
|
|
||||||
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
|
|
||||||
}
|
|
||||||
//Special case: empty.reshape(-1) -> return empty
|
|
||||||
if (x->isEmpty()) {
|
|
||||||
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
|
||||||
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
|
||||||
Nd4jLong prod = 1;
|
|
||||||
bool hasNegs = false;
|
|
||||||
for (auto v:shapeOf) {
|
|
||||||
if (v < 0) {
|
|
||||||
hasNegs = true;
|
|
||||||
v = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
prod *= v;
|
|
||||||
}
|
|
||||||
|
|
||||||
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
|
||||||
|
|
||||||
// if there are -1s - we turn them into zeros
|
|
||||||
if (hasNegs) {
|
|
||||||
for (int e = 0; e < shapeOf.size(); e++)
|
|
||||||
if (shapeOf[e] < 0)
|
|
||||||
shapeOf[e] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
|
||||||
return SHAPELIST(CONSTANT(newShape));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
|
||||||
|
|
||||||
for (int e = 0; e < (int) y->lengthOf(); e++) {
|
|
||||||
auto dim = y->e<Nd4jLong>(e);
|
|
||||||
if (dim == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(int e2 = 0; e2 < e; e2++){
|
|
||||||
shapeLength *= y->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
|
|
||||||
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
shapeLength *= y->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(shapeLength == 0){
|
|
||||||
//Edge case for empty:
|
|
||||||
shapeNew[e] = 0;
|
|
||||||
} else {
|
|
||||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
|
||||||
shapeNew[e] = realShape;
|
|
||||||
}
|
|
||||||
}else {
|
|
||||||
shapeNew[e] = dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
|
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>();
|
||||||
|
orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c';
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
|
||||||
|
|
||||||
|
Nd4jLong xLen = x->lengthOf();
|
||||||
|
if(x->isEmpty()) {
|
||||||
|
xLen = 1;
|
||||||
|
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||||
|
if(x->sizeAt(i) != 0)
|
||||||
|
xLen *= x->sizeAt(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint i = 0; i < reshapeArgs.size(); ++i) {
|
||||||
|
|
||||||
|
if (reshapeArgs[i] == -1) {
|
||||||
|
|
||||||
|
uint shapeLength = 1, numOfZeros = 0;
|
||||||
|
|
||||||
|
for(uint j = 0; j < i; ++j)
|
||||||
|
if(reshapeArgs[j] != 0)
|
||||||
|
shapeLength *= reshapeArgs[j];
|
||||||
|
else
|
||||||
|
++numOfZeros;
|
||||||
|
|
||||||
|
for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
|
||||||
|
REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||||
|
if(reshapeArgs[j] != 0)
|
||||||
|
shapeLength *= reshapeArgs[j];
|
||||||
|
else
|
||||||
|
++numOfZeros;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto dim = xLen / shapeLength;
|
||||||
|
|
||||||
|
if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
|
||||||
|
shapeNew.push_back(0);
|
||||||
|
else
|
||||||
|
shapeNew.push_back(dim);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
shapeNew.push_back(reshapeArgs[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||||
|
REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
||||||
|
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,16 +40,15 @@ TEST_F(ArrayOptionsTests, TestShape_Basic_0) {
|
||||||
|
|
||||||
TEST_F(ArrayOptionsTests, TestShape_Basic_1) {
|
TEST_F(ArrayOptionsTests, TestShape_Basic_1) {
|
||||||
shape[5] = 2;
|
shape[5] = 2;
|
||||||
|
|
||||||
|
|
||||||
ASSERT_TRUE(ArrayOptions::isNewFormat(shape));
|
ASSERT_TRUE(ArrayOptions::isNewFormat(shape));
|
||||||
ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
|
ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(ArrayOptionsTests, TestShape_Basic_2) {
|
TEST_F(ArrayOptionsTests, TestShape_Basic_2) {
|
||||||
shape[5] = 258;
|
shape[5] = 258;
|
||||||
|
|
||||||
ASSERT_TRUE(ArrayOptions::isNewFormat(shape));
|
ASSERT_TRUE(ArrayOptions::isNewFormat(shape));
|
||||||
|
|
||||||
ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
|
ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
|
||||||
|
@ -58,7 +57,7 @@ TEST_F(ArrayOptionsTests, TestShape_Basic_2) {
|
||||||
|
|
||||||
TEST_F(ArrayOptionsTests, TestShape_Basic_3) {
|
TEST_F(ArrayOptionsTests, TestShape_Basic_3) {
|
||||||
ASSERT_EQ(0, shape::extra(shape));
|
ASSERT_EQ(0, shape::extra(shape));
|
||||||
|
|
||||||
ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape));
|
ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -166,7 +166,7 @@ TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(z->equalsTo(exp));
|
ASSERT_TRUE(z->equalsTo(exp));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -180,7 +180,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(z->equalsTo(exp));
|
ASSERT_TRUE(z->equalsTo(exp));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -198,7 +198,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) {
|
||||||
|
|
||||||
ASSERT_TRUE(z1->equalsTo(exp1));
|
ASSERT_TRUE(z1->equalsTo(exp1));
|
||||||
ASSERT_TRUE(z2->equalsTo(exp2));
|
ASSERT_TRUE(z2->equalsTo(exp2));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -213,7 +213,7 @@ TEST_F(DeclarableOpsTests1, AXpY_Test_1) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(z->equalsTo(exp));
|
ASSERT_TRUE(z->equalsTo(exp));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, BasicInitialization3) {
|
TEST_F(DeclarableOpsTests1, BasicInitialization3) {
|
||||||
|
@ -258,7 +258,7 @@ TEST_F(DeclarableOpsTests1, TestTensorMmul1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(out));
|
ASSERT_TRUE(exp.isSameShape(out));
|
||||||
ASSERT_TRUE(exp.equalsTo(out));
|
ASSERT_TRUE(exp.equalsTo(out));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, TestTensorDot2) {
|
TEST_F(DeclarableOpsTests1, TestTensorDot2) {
|
||||||
|
@ -278,7 +278,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(out));
|
ASSERT_TRUE(exp.isSameShape(out));
|
||||||
ASSERT_TRUE(exp.equalsTo(out));
|
ASSERT_TRUE(exp.equalsTo(out));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, TestTensorDot3) {
|
TEST_F(DeclarableOpsTests1, TestTensorDot3) {
|
||||||
|
@ -298,7 +298,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(out));
|
ASSERT_TRUE(exp.isSameShape(out));
|
||||||
ASSERT_TRUE(exp.equalsTo(out));
|
ASSERT_TRUE(exp.equalsTo(out));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, TestTensorDot4) {
|
TEST_F(DeclarableOpsTests1, TestTensorDot4) {
|
||||||
|
@ -318,7 +318,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(out));
|
ASSERT_TRUE(exp.isSameShape(out));
|
||||||
ASSERT_TRUE(exp.equalsTo(out));
|
ASSERT_TRUE(exp.equalsTo(out));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
@ -338,7 +338,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot5) {
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -360,7 +360,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot6) {
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -381,7 +381,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot7) {
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -402,7 +402,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot8) {
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -431,7 +431,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot9) {
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -452,7 +452,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot10) {
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -474,7 +474,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot11) {
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -495,7 +495,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot12) {
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -516,7 +516,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot13) {
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -537,7 +537,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot14) {
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -558,7 +558,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot15) {
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -579,7 +579,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot16) {
|
||||||
ASSERT_TRUE(exp.isSameShape(result));
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
ASSERT_TRUE(exp.equalsTo(result));
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
@ -786,7 +786,7 @@ TEST_F(DeclarableOpsTests1, SubtractTest_2) {
|
||||||
|
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, TestRng1) {
|
TEST_F(DeclarableOpsTests1, TestRng1) {
|
||||||
|
@ -1046,7 +1046,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) {
|
||||||
ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1071,7 +1071,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) {
|
||||||
ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1093,7 +1093,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) {
|
||||||
ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1121,7 +1121,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_1) {
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
||||||
ASSERT_TRUE(exp.equalsTo(&z));
|
ASSERT_TRUE(exp.equalsTo(&z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1147,7 +1147,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) {
|
||||||
ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res.status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(&exp));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1402,7 +1402,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) {
|
||||||
ASSERT_EQ(res.status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res.status(), ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1421,7 +1421,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) {
|
||||||
ASSERT_EQ(res.status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res.status(), ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1437,7 +1437,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) {
|
||||||
ASSERT_EQ(res.status(), ND4J_STATUS_OK);
|
ASSERT_EQ(res.status(), ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1463,7 +1463,7 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) {
|
||||||
|
|
||||||
ASSERT_TRUE(z.equalsTo(&exp));
|
ASSERT_TRUE(z.equalsTo(&exp));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1676,31 +1676,6 @@ TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) {
|
||||||
delete block;
|
delete block;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshapeas1) {
|
|
||||||
const std::vector<Nd4jLong> xShape = { 5,4,3 };
|
|
||||||
const std::vector<Nd4jLong> yShape = { 3,5,4 };
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create_<float>('f', xShape);
|
|
||||||
auto y = NDArrayFactory::create_<float>('f', yShape);
|
|
||||||
|
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
|
||||||
variableSpace->putVariable(-1, x);
|
|
||||||
variableSpace->putVariable(-2, y);
|
|
||||||
auto block = new Context(1, variableSpace, true);
|
|
||||||
block->fillInputs({ -1, -2 });
|
|
||||||
|
|
||||||
sd::ops::reshapeas reshape;
|
|
||||||
|
|
||||||
reshape.execute(block);
|
|
||||||
|
|
||||||
ASSERT_TRUE(x->isSameShape(y));
|
|
||||||
|
|
||||||
delete variableSpace;
|
|
||||||
delete block;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Test_Cast_1) {
|
TEST_F(DeclarableOpsTests1, Test_Cast_1) {
|
||||||
// TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected
|
// TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected
|
||||||
auto x = NDArrayFactory::create<float>('c', { 5, 5 });
|
auto x = NDArrayFactory::create<float>('c', { 5, 5 });
|
||||||
|
@ -1715,7 +1690,7 @@ TEST_F(DeclarableOpsTests1, Test_Cast_1) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(yExp.equalsTo(z));
|
ASSERT_TRUE(yExp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1848,113 +1823,6 @@ TEST_F(DeclarableOpsTests1, TestGemv1) {
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape2) {
|
|
||||||
const std::vector<Nd4jLong> xShape = { 5,4,3 };
|
|
||||||
const std::vector<Nd4jLong> yShape = { 3,5,4 };
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create_<float>('c', xShape);
|
|
||||||
auto y = NDArrayFactory::create_<float>('c', yShape);
|
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
|
||||||
variableSpace->putVariable(-1, x);
|
|
||||||
variableSpace->putVariable(1, new Variable());
|
|
||||||
|
|
||||||
auto block = new Context(1, variableSpace, false);
|
|
||||||
block->fillInputs({ -1 });
|
|
||||||
std::vector<int>* arguments = block->getIArguments();
|
|
||||||
arguments->push_back(-y->ordering());
|
|
||||||
arguments->push_back(3);
|
|
||||||
arguments->push_back(5);
|
|
||||||
arguments->push_back(4);
|
|
||||||
|
|
||||||
sd::ops::reshape reshape;
|
|
||||||
|
|
||||||
Nd4jStatus status = reshape.execute(block);
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
||||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
||||||
|
|
||||||
ASSERT_TRUE(result->isSameShape(y));
|
|
||||||
|
|
||||||
delete y;
|
|
||||||
delete block;
|
|
||||||
delete variableSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape3) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 });
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(x.isSameShape(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape4) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({ &x }, {}, { 3, 4, 5 });
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(x.isSameShape(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape5) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({ &x }, {}, { 5, 4, 3 });
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape6) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', { 4, 15 });
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({ &x }, {}, { 4, -1 });
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(z->isSameShape(exp));
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape7) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', { 60 });
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({ &x }, {}, { -1 });
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(z->isSameShape(exp));
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, Transpose1) {
|
TEST_F(DeclarableOpsTests1, Transpose1) {
|
||||||
|
@ -1983,7 +1851,6 @@ TEST_F(DeclarableOpsTests1, Transpose1) {
|
||||||
delete variableSpace;
|
delete variableSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
// not-in-place
|
// not-in-place
|
||||||
TEST_F(DeclarableOpsTests1, Permute1) {
|
TEST_F(DeclarableOpsTests1, Permute1) {
|
||||||
|
@ -2259,7 +2126,7 @@ TEST_F(DeclarableOpsTests1, IsMax1) {
|
||||||
//res->printIndexedBuffer("IS_MAX");
|
//res->printIndexedBuffer("IS_MAX");
|
||||||
ASSERT_TRUE(exp.equalsTo(res));
|
ASSERT_TRUE(exp.equalsTo(res));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2281,7 +2148,7 @@ TEST_F(DeclarableOpsTests1, IsMax2) {
|
||||||
//res->printIndexedBuffer("IS_MAX");
|
//res->printIndexedBuffer("IS_MAX");
|
||||||
ASSERT_TRUE(exp.equalsTo(res));
|
ASSERT_TRUE(exp.equalsTo(res));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2303,7 +2170,7 @@ TEST_F(DeclarableOpsTests1, IsMax3) {
|
||||||
//res->printIndexedBuffer("IS_MAX");
|
//res->printIndexedBuffer("IS_MAX");
|
||||||
ASSERT_TRUE(exp.equalsTo(res));
|
ASSERT_TRUE(exp.equalsTo(res));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2352,7 +2219,7 @@ TEST_F(DeclarableOpsTests1, IsMax4) {
|
||||||
// ASSERT_TRUE(expState.equalsTo(state));
|
// ASSERT_TRUE(expState.equalsTo(state));
|
||||||
// ASSERT_TRUE(expOut.equalsTo(output));
|
// ASSERT_TRUE(expOut.equalsTo(output));
|
||||||
|
|
||||||
//
|
//
|
||||||
// }
|
// }
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////
|
||||||
|
@ -2386,7 +2253,7 @@ TEST_F(DeclarableOpsTests1, sru_test1) {
|
||||||
ASSERT_TRUE(expState.equalsTo(state));
|
ASSERT_TRUE(expState.equalsTo(state));
|
||||||
ASSERT_TRUE(expOut.equalsTo(output));
|
ASSERT_TRUE(expOut.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2438,7 +2305,7 @@ TEST_F(DeclarableOpsTests1, sru_bp) {
|
||||||
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
||||||
ASSERT_TRUE(expGradInit.equalsTo(gradInit));
|
ASSERT_TRUE(expGradInit.equalsTo(gradInit));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////
|
||||||
|
@ -2474,7 +2341,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_1) {
|
||||||
ASSERT_TRUE(expState.equalsTo(state));
|
ASSERT_TRUE(expState.equalsTo(state));
|
||||||
ASSERT_TRUE(expOut.equalsTo(output));
|
ASSERT_TRUE(expOut.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, sru_bi_bp_1) {
|
TEST_F(DeclarableOpsTests1, sru_bi_bp_1) {
|
||||||
|
@ -2527,7 +2394,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) {
|
||||||
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
||||||
ASSERT_TRUE(expGradInit.equalsTo(gradInit));
|
ASSERT_TRUE(expGradInit.equalsTo(gradInit));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, ArgMax1) {
|
TEST_F(DeclarableOpsTests1, ArgMax1) {
|
||||||
|
@ -2547,7 +2414,7 @@ TEST_F(DeclarableOpsTests1, ArgMax1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2568,7 +2435,7 @@ TEST_F(DeclarableOpsTests1, ArgMax2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2590,7 +2457,7 @@ TEST_F(DeclarableOpsTests1, ArgMax3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, ArgMax4) {
|
TEST_F(DeclarableOpsTests1, ArgMax4) {
|
||||||
|
@ -2611,7 +2478,7 @@ TEST_F(DeclarableOpsTests1, ArgMax4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2633,7 +2500,7 @@ TEST_F(DeclarableOpsTests1, ArgMax5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, ArgMax6) {
|
TEST_F(DeclarableOpsTests1, ArgMax6) {
|
||||||
|
@ -2676,7 +2543,7 @@ TEST_F(DeclarableOpsTests1, ArgMin1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2697,7 +2564,7 @@ TEST_F(DeclarableOpsTests1, SquareTests1) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, OneHotTests_1) {
|
TEST_F(DeclarableOpsTests1, OneHotTests_1) {
|
||||||
|
@ -2717,7 +2584,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, OneHotTests_2) {
|
TEST_F(DeclarableOpsTests1, OneHotTests_2) {
|
||||||
|
@ -2736,7 +2603,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_2) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, OneHotTests_3) {
|
TEST_F(DeclarableOpsTests1, OneHotTests_3) {
|
||||||
|
@ -2756,7 +2623,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, OneHotTests_4) {
|
TEST_F(DeclarableOpsTests1, OneHotTests_4) {
|
||||||
|
@ -2775,7 +2642,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, OneHotTests_5) {
|
TEST_F(DeclarableOpsTests1, OneHotTests_5) {
|
||||||
|
@ -2796,7 +2663,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, OneHotTests_6) {
|
TEST_F(DeclarableOpsTests1, OneHotTests_6) {
|
||||||
|
@ -2809,7 +2676,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_6) {
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, OneHotTests_7) {
|
TEST_F(DeclarableOpsTests1, OneHotTests_7) {
|
||||||
|
@ -2822,7 +2689,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_7) {
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, FillAs_1) {
|
TEST_F(DeclarableOpsTests1, FillAs_1) {
|
||||||
|
@ -2840,7 +2707,7 @@ TEST_F(DeclarableOpsTests1, FillAs_1) {
|
||||||
|
|
||||||
ASSERT_NEAR(scalar, result.at(0)->meanNumber().e<float>(0), 1e-5f);
|
ASSERT_NEAR(scalar, result.at(0)->meanNumber().e<float>(0), 1e-5f);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2866,7 +2733,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(array));
|
ASSERT_TRUE(exp.isSameShape(array));
|
||||||
ASSERT_TRUE(exp.equalsTo(array));
|
ASSERT_TRUE(exp.equalsTo(array));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2893,7 +2760,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(array));
|
ASSERT_TRUE(exp.isSameShape(array));
|
||||||
ASSERT_TRUE(exp.equalsTo(array));
|
ASSERT_TRUE(exp.equalsTo(array));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2913,7 +2780,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(array));
|
ASSERT_TRUE(exp.isSameShape(array));
|
||||||
ASSERT_TRUE(exp.equalsTo(array));
|
ASSERT_TRUE(exp.equalsTo(array));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2931,7 +2798,7 @@ TEST_F(DeclarableOpsTests1, softmax_test1) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2947,7 +2814,7 @@ TEST_F(DeclarableOpsTests1, softmax_test2) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2963,7 +2830,7 @@ TEST_F(DeclarableOpsTests1, softmax_test3) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2979,7 +2846,7 @@ TEST_F(DeclarableOpsTests1, softmax_test4) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2995,7 +2862,7 @@ TEST_F(DeclarableOpsTests1, softmax_test5) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3011,7 +2878,7 @@ TEST_F(DeclarableOpsTests1, softmax_test6) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3027,7 +2894,7 @@ TEST_F(DeclarableOpsTests1, softmax_test7) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3043,7 +2910,7 @@ TEST_F(DeclarableOpsTests1, softmax_test8) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3059,7 +2926,7 @@ TEST_F(DeclarableOpsTests1, softmax_test9) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, softmax_test10) {
|
TEST_F(DeclarableOpsTests1, softmax_test10) {
|
||||||
|
@ -3074,7 +2941,7 @@ TEST_F(DeclarableOpsTests1, softmax_test10) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, softmax_test11) {
|
TEST_F(DeclarableOpsTests1, softmax_test11) {
|
||||||
|
@ -3089,7 +2956,7 @@ TEST_F(DeclarableOpsTests1, softmax_test11) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3108,7 +2975,7 @@ TEST_F(DeclarableOpsTests1, softmax_test12) {
|
||||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(z));
|
ASSERT_TRUE(expOutput.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, Reverse_1) {
|
TEST_F(DeclarableOpsTests1, Reverse_1) {
|
||||||
|
@ -3132,7 +2999,7 @@ TEST_F(DeclarableOpsTests1, Reverse_1) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3157,7 +3024,7 @@ TEST_F(DeclarableOpsTests1, Reverse_2) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(input));
|
ASSERT_TRUE(expected.isSameShapeStrict(input));
|
||||||
ASSERT_TRUE(expected.equalsTo(&input));
|
ASSERT_TRUE(expected.equalsTo(&input));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3183,7 +3050,7 @@ TEST_F(DeclarableOpsTests1, Reverse_3) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3209,7 +3076,7 @@ TEST_F(DeclarableOpsTests1, Reverse_4) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3234,7 +3101,7 @@ TEST_F(DeclarableOpsTests1, Reverse_5) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3260,7 +3127,7 @@ TEST_F(DeclarableOpsTests1, Reverse_6) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(input));
|
ASSERT_TRUE(expected.isSameShapeStrict(input));
|
||||||
ASSERT_TRUE(expected.equalsTo(&input));
|
ASSERT_TRUE(expected.equalsTo(&input));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -3288,7 +3155,7 @@ TEST_F(DeclarableOpsTests1, Reverse_7) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -3316,7 +3183,7 @@ TEST_F(DeclarableOpsTests1, Reverse_8) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3341,7 +3208,7 @@ TEST_F(DeclarableOpsTests1, Reverse_9) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Reverse_10) {
|
TEST_F(DeclarableOpsTests1, Reverse_10) {
|
||||||
|
@ -3357,7 +3224,7 @@ TEST_F(DeclarableOpsTests1, Reverse_10) {
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(z));
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3380,7 +3247,7 @@ TEST_F(DeclarableOpsTests1, Reverse_11) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3402,7 +3269,7 @@ TEST_F(DeclarableOpsTests1, Reverse_12) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3423,7 +3290,7 @@ TEST_F(DeclarableOpsTests1, Reverse_13) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3444,7 +3311,7 @@ TEST_F(DeclarableOpsTests1, Reverse_14) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Test_Expose_1) {
|
TEST_F(DeclarableOpsTests1, Test_Expose_1) {
|
||||||
|
@ -3463,7 +3330,7 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) {
|
||||||
ASSERT_TRUE(input0.equalsTo(z0));
|
ASSERT_TRUE(input0.equalsTo(z0));
|
||||||
ASSERT_TRUE(input1.equalsTo(z1));
|
ASSERT_TRUE(input1.equalsTo(z1));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Test_Expose_2) {
|
TEST_F(DeclarableOpsTests1, Test_Expose_2) {
|
||||||
|
|
|
@ -51,23 +51,7 @@ TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) {
|
||||||
|
|
||||||
ASSERT_EQ(exp, *z);
|
ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Test_Reshape_CF_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0});
|
|
||||||
auto e = NDArrayFactory::create<double>('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0});
|
|
||||||
|
|
||||||
auto r = x.reshape('c', {3, 2});;
|
|
||||||
r.streamline('f');
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x}, {3, 2});
|
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) {
|
TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) {
|
||||||
|
@ -108,7 +92,7 @@ TEST_F(DeclarableOpsTests14, Multiply_test) {
|
||||||
ASSERT_EQ(e, r);
|
ASSERT_EQ(e, r);
|
||||||
ASSERT_EQ(e, *f);
|
ASSERT_EQ(e, *f);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,7 +108,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) {
|
TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) {
|
||||||
|
@ -139,7 +123,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_0) {
|
TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_0) {
|
||||||
|
@ -193,7 +177,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_1) {
|
||||||
|
|
||||||
ASSERT_EQ(e, *result.at(0));
|
ASSERT_EQ(e, *result.at(0));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) {
|
TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) {
|
||||||
|
@ -210,7 +194,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) {
|
||||||
|
|
||||||
ASSERT_EQ(e, *result.at(0));
|
ASSERT_EQ(e, *result.at(0));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
|
TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
|
||||||
|
@ -224,7 +208,7 @@ TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_EQ(y, *z);
|
ASSERT_EQ(y, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) {
|
TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) {
|
||||||
|
@ -259,7 +243,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) {
|
||||||
auto out = res2.at(0);
|
auto out = res2.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) {
|
TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) {
|
||||||
|
@ -271,7 +255,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) {
|
||||||
auto out = res2.at(0);
|
auto out = res2.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(out->e<float>(0), -DataTypeUtils::infOrMax<float>());
|
ASSERT_EQ(out->e<float>(0), -DataTypeUtils::infOrMax<float>());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) {
|
TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) {
|
||||||
|
@ -286,7 +270,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) {
|
||||||
ASSERT_EQ(res2.status(), Status::OK());
|
ASSERT_EQ(res2.status(), Status::OK());
|
||||||
auto out = res2.at(0);
|
auto out = res2.at(0);
|
||||||
ASSERT_EQ(out->e<float>(0), 0.f);
|
ASSERT_EQ(out->e<float>(0), 0.f);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) {
|
TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) {
|
||||||
|
@ -303,7 +287,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) {
|
||||||
// out->printShapeInfo("ReduceMean empty shape with keep dims");
|
// out->printShapeInfo("ReduceMean empty shape with keep dims");
|
||||||
// out->printIndexedBuffer("ReduceMean scalar");
|
// out->printIndexedBuffer("ReduceMean scalar");
|
||||||
ASSERT_TRUE(std::isnan(out->e<float>(0)));
|
ASSERT_TRUE(std::isnan(out->e<float>(0)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) {
|
TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) {
|
||||||
|
@ -324,7 +308,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) {
|
TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) {
|
||||||
|
@ -345,7 +329,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, test_empty_argmax_1) {
|
TEST_F(DeclarableOpsTests14, test_empty_argmax_1) {
|
||||||
|
@ -363,7 +347,7 @@ TEST_F(DeclarableOpsTests14, test_empty_argmax_1) {
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, test_empty_argmax_2) {
|
TEST_F(DeclarableOpsTests14, test_empty_argmax_2) {
|
||||||
|
@ -391,7 +375,7 @@ TEST_F(DeclarableOpsTests14, test_empty_tanh_5) {
|
||||||
ASSERT_TRUE(x.isSameShape(z));
|
ASSERT_TRUE(x.isSameShape(z));
|
||||||
ASSERT_EQ(x, *z);
|
ASSERT_EQ(x, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -409,7 +393,7 @@ TEST_F(DeclarableOpsTests14, repeat_1) {
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(z));
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -427,7 +411,7 @@ TEST_F(DeclarableOpsTests14, repeat_2) {
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(z));
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -445,7 +429,7 @@ TEST_F(DeclarableOpsTests14, repeat_3) {
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(z));
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -463,7 +447,7 @@ TEST_F(DeclarableOpsTests14, repeat_4) {
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(z));
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -481,7 +465,7 @@ TEST_F(DeclarableOpsTests14, repeat_5) {
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(z));
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
/////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) {
|
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) {
|
||||||
|
@ -502,7 +486,7 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) {
|
||||||
|
|
||||||
ASSERT_EQ(e, res);
|
ASSERT_EQ(e, res);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
/////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) {
|
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) {
|
||||||
|
@ -523,7 +507,7 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) {
|
||||||
|
|
||||||
ASSERT_EQ(e, res);
|
ASSERT_EQ(e, res);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////
|
||||||
|
@ -639,7 +623,7 @@ TEST_F(DeclarableOpsTests14, matmul_test1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -661,7 +645,7 @@ TEST_F(DeclarableOpsTests14, matmul_test2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -682,7 +666,7 @@ TEST_F(DeclarableOpsTests14, matmul_test3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -704,7 +688,7 @@ TEST_F(DeclarableOpsTests14, matmul_test4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -726,7 +710,7 @@ TEST_F(DeclarableOpsTests14, matmul_test5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -747,7 +731,7 @@ TEST_F(DeclarableOpsTests14, matmul_test6) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -770,7 +754,7 @@ TEST_F(DeclarableOpsTests14, matmul_test7) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -795,7 +779,7 @@ TEST_F(DeclarableOpsTests14, matmul_test8) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -820,7 +804,7 @@ TEST_F(DeclarableOpsTests14, matmul_test9) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test10) {
|
TEST_F(DeclarableOpsTests14, matmul_test10) {
|
||||||
|
@ -876,7 +860,7 @@ TEST_F(DeclarableOpsTests14, matmul_test11) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test12) {
|
TEST_F(DeclarableOpsTests14, matmul_test12) {
|
||||||
|
@ -894,7 +878,7 @@ TEST_F(DeclarableOpsTests14, matmul_test12) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -914,7 +898,7 @@ TEST_F(DeclarableOpsTests14, matmul_test13) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test14) {
|
TEST_F(DeclarableOpsTests14, matmul_test14) {
|
||||||
|
@ -933,7 +917,7 @@ TEST_F(DeclarableOpsTests14, matmul_test14) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test15) {
|
TEST_F(DeclarableOpsTests14, matmul_test15) {
|
||||||
|
@ -952,7 +936,7 @@ TEST_F(DeclarableOpsTests14, matmul_test15) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test16) {
|
TEST_F(DeclarableOpsTests14, matmul_test16) {
|
||||||
|
@ -971,7 +955,7 @@ TEST_F(DeclarableOpsTests14, matmul_test16) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test17) {
|
TEST_F(DeclarableOpsTests14, matmul_test17) {
|
||||||
|
@ -985,7 +969,7 @@ TEST_F(DeclarableOpsTests14, matmul_test17) {
|
||||||
|
|
||||||
ASSERT_EQ(exp, *result.at(0));
|
ASSERT_EQ(exp, *result.at(0));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1007,7 +991,7 @@ TEST_F(DeclarableOpsTests14, matmul_test18) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1027,7 +1011,7 @@ TEST_F(DeclarableOpsTests14, matmul_test19) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1048,7 +1032,7 @@ TEST_F(DeclarableOpsTests14, matmul_test20) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1069,7 +1053,7 @@ TEST_F(DeclarableOpsTests14, matmul_test21) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1090,7 +1074,7 @@ TEST_F(DeclarableOpsTests14, matmul_test22) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1111,7 +1095,7 @@ TEST_F(DeclarableOpsTests14, matmul_test23) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1135,7 +1119,7 @@ TEST_F(DeclarableOpsTests14, matmul_test24) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1156,7 +1140,7 @@ TEST_F(DeclarableOpsTests14, matmul_test25) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1177,7 +1161,7 @@ TEST_F(DeclarableOpsTests14, matmul_test26) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1198,7 +1182,7 @@ TEST_F(DeclarableOpsTests14, matmul_test27) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1220,7 +1204,7 @@ TEST_F(DeclarableOpsTests14, matmul_test28) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1242,7 +1226,7 @@ TEST_F(DeclarableOpsTests14, matmul_test29) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test30) {
|
TEST_F(DeclarableOpsTests14, matmul_test30) {
|
||||||
|
@ -1262,7 +1246,7 @@ TEST_F(DeclarableOpsTests14, matmul_test30) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test31) {
|
TEST_F(DeclarableOpsTests14, matmul_test31) {
|
||||||
|
@ -1282,7 +1266,7 @@ TEST_F(DeclarableOpsTests14, matmul_test31) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test32) {
|
TEST_F(DeclarableOpsTests14, matmul_test32) {
|
||||||
|
@ -1299,7 +1283,7 @@ TEST_F(DeclarableOpsTests14, matmul_test32) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
/////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test33) {
|
TEST_F(DeclarableOpsTests14, matmul_test33) {
|
||||||
|
@ -1319,7 +1303,7 @@ TEST_F(DeclarableOpsTests14, matmul_test33) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test34) {
|
TEST_F(DeclarableOpsTests14, matmul_test34) {
|
||||||
|
@ -1336,7 +1320,7 @@ TEST_F(DeclarableOpsTests14, matmul_test34) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test35) {
|
TEST_F(DeclarableOpsTests14, matmul_test35) {
|
||||||
|
@ -1353,7 +1337,7 @@ TEST_F(DeclarableOpsTests14, matmul_test35) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test36) {
|
TEST_F(DeclarableOpsTests14, matmul_test36) {
|
||||||
|
@ -1370,7 +1354,7 @@ TEST_F(DeclarableOpsTests14, matmul_test36) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests14, matmul_test37) {
|
TEST_F(DeclarableOpsTests14, matmul_test37) {
|
||||||
|
@ -1617,7 +1601,7 @@ TEST_F(DeclarableOpsTests14, Stack_1) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1645,7 +1629,7 @@ TEST_F(DeclarableOpsTests14, Stack_2) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1673,7 +1657,7 @@ TEST_F(DeclarableOpsTests14, Stack_3) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1700,7 +1684,7 @@ TEST_F(DeclarableOpsTests14, Stack_4) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1727,7 +1711,7 @@ TEST_F(DeclarableOpsTests14, Stack_5) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1754,7 +1738,7 @@ TEST_F(DeclarableOpsTests14, Stack_6) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1778,7 +1762,7 @@ TEST_F(DeclarableOpsTests14, Stack_7) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1801,7 +1785,7 @@ TEST_F(DeclarableOpsTests14, Stack_8) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1824,7 +1808,7 @@ TEST_F(DeclarableOpsTests14, Stack_9) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1850,7 +1834,7 @@ TEST_F(DeclarableOpsTests14, Stack_10) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Stack_11) {
|
TEST_F(DeclarableOpsTests14, Stack_11) {
|
||||||
|
@ -1872,7 +1856,7 @@ TEST_F(DeclarableOpsTests14, Stack_11) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1895,7 +1879,7 @@ TEST_F(DeclarableOpsTests14, Stack_12) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1917,7 +1901,7 @@ TEST_F(DeclarableOpsTests14, Stack_13) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1941,7 +1925,7 @@ TEST_F(DeclarableOpsTests14, Stack_14) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Stack_15) {
|
TEST_F(DeclarableOpsTests14, Stack_15) {
|
||||||
|
@ -1959,7 +1943,7 @@ TEST_F(DeclarableOpsTests14, Stack_15) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1978,7 +1962,7 @@ TEST_F(DeclarableOpsTests14, Stack_16) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Stack_17) {
|
TEST_F(DeclarableOpsTests14, Stack_17) {
|
||||||
|
@ -1999,7 +1983,7 @@ TEST_F(DeclarableOpsTests14, Stack_17) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Stack_18) {
|
TEST_F(DeclarableOpsTests14, Stack_18) {
|
||||||
|
@ -2018,8 +2002,8 @@ TEST_F(DeclarableOpsTests14, Stack_18) {
|
||||||
auto out = res2.at(0);
|
auto out = res2.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Stack_19) {
|
TEST_F(DeclarableOpsTests14, Stack_19) {
|
||||||
|
@ -2033,7 +2017,7 @@ TEST_F(DeclarableOpsTests14, Stack_19) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Stack_20) {
|
TEST_F(DeclarableOpsTests14, Stack_20) {
|
||||||
|
@ -2047,7 +2031,7 @@ TEST_F(DeclarableOpsTests14, Stack_20) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Stack_21) {
|
TEST_F(DeclarableOpsTests14, Stack_21) {
|
||||||
|
@ -2073,7 +2057,363 @@ TEST_F(DeclarableOpsTests14, Stack_21) {
|
||||||
|
|
||||||
ASSERT_TRUE(outStack->isSameShape(outConcat));
|
ASSERT_TRUE(outStack->isSameShape(outConcat));
|
||||||
ASSERT_TRUE(outStack->equalsTo(outConcat));
|
ASSERT_TRUE(outStack->equalsTo(outConcat));
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape1) {
|
||||||
|
const std::vector<Nd4jLong> xShape = { 5,4,3 };
|
||||||
|
const std::vector<Nd4jLong> yShape = { 3,5,4 };
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create_<float>('f', xShape);
|
||||||
|
auto y = NDArrayFactory::create_<float>('f', yShape);
|
||||||
|
|
||||||
|
|
||||||
|
auto variableSpace = new VariableSpace();
|
||||||
|
variableSpace->putVariable(-1, x);
|
||||||
|
variableSpace->putVariable(-2, y);
|
||||||
|
auto block = new Context(1, variableSpace, true);
|
||||||
|
block->fillInputs({ -1, -2 });
|
||||||
|
|
||||||
|
sd::ops::reshapeas reshape;
|
||||||
|
|
||||||
|
reshape.execute(block);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x->isSameShape(y));
|
||||||
|
|
||||||
|
delete variableSpace;
|
||||||
|
delete block;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape2) {
|
||||||
|
const std::vector<Nd4jLong> xShape = { 5,4,3 };
|
||||||
|
const std::vector<Nd4jLong> yShape = { 3,5,4 };
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create_<float>('c', xShape);
|
||||||
|
auto y = NDArrayFactory::create_<float>('c', yShape);
|
||||||
|
|
||||||
|
auto variableSpace = new VariableSpace();
|
||||||
|
variableSpace->putVariable(-1, x);
|
||||||
|
variableSpace->putVariable(1, new Variable());
|
||||||
|
|
||||||
|
auto block = new Context(1, variableSpace, false);
|
||||||
|
block->fillInputs({ -1 });
|
||||||
|
std::vector<int>* arguments = block->getIArguments();
|
||||||
|
arguments->push_back(-y->ordering());
|
||||||
|
arguments->push_back(3);
|
||||||
|
arguments->push_back(5);
|
||||||
|
arguments->push_back(4);
|
||||||
|
|
||||||
|
sd::ops::reshape reshape;
|
||||||
|
|
||||||
|
Nd4jStatus status = reshape.execute(block);
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_TRUE(result->isSameShape(y));
|
||||||
|
|
||||||
|
delete y;
|
||||||
|
delete block;
|
||||||
|
delete variableSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape3) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.isSameShape(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape4) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({ &x }, {}, { 3, 4, 5 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.isSameShape(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape5) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({ &x }, {}, { 5, 4, 3 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape6) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', { 4, 15 });
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({ &x }, {}, { 4, -1 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(z->isSameShape(exp));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape7) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', { 60 });
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({ &x }, {}, { -1 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(z->isSameShape(exp));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape8) {
|
||||||
|
auto x = NDArrayFactory::create<double>('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0});
|
||||||
|
auto e = NDArrayFactory::create<double>('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0});
|
||||||
|
|
||||||
|
auto r = x.reshape('c', {3, 2});;
|
||||||
|
r.streamline('f');
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x}, {3, 2});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape9) {
|
||||||
|
auto array = NDArrayFactory::create<float>(119.f);
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&array}, {}, {1, 1});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape10) {
|
||||||
|
auto array = NDArrayFactory::create<float>(119.f);
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {1, 1});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.execute({&array}, {&z}, {}, {1, 1}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape11) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {4, 3});
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
exp.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x}, {-99, 4, 3});
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape12) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
auto shape = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 2});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x, &shape});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape13) {
|
||||||
|
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
|
||||||
|
auto exp = NDArrayFactory::create<float>(119.f);
|
||||||
|
auto empty = NDArrayFactory::empty_<int>();
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&vector, empty}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, *result.at(0));
|
||||||
|
|
||||||
|
delete empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape14) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {10, 0});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape15) {
|
||||||
|
|
||||||
|
auto x0 = NDArrayFactory::create<float>('c', {2, 0});
|
||||||
|
auto x1 = NDArrayFactory::create<float>('c', {0, 1, 2});
|
||||||
|
|
||||||
|
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
|
||||||
|
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
|
||||||
|
|
||||||
|
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
|
||||||
|
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result0.status());
|
||||||
|
auto z0 = result0.at(0);
|
||||||
|
ASSERT_EQ(e0, *z0);
|
||||||
|
|
||||||
|
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result1.status());
|
||||||
|
auto z1 = result1.at(0);
|
||||||
|
ASSERT_EQ(e1, *z1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape16) {
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
auto shape = NDArrayFactory::create<int>('c', {1, 3}, {1, 2, 2});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<int>('c', {1, 2, 2}, {1, 2, 3, 4});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x, &shape}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape17) {
|
||||||
|
auto x = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape18) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x}, {}, {-99, 3});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape19) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape20) {
|
||||||
|
|
||||||
|
NDArray x1('c', {2,0}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x2('c', {10,0}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x3('c', {2,0,0,10}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x4('c', {0,0,10}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x1}, {}, {2, -1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2,0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x2}, {}, {2, 0, -1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2,0,5}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x2}, {}, {5, 2, -1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({5,2,0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x2}, {}, {-1, 2, 0});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({5,2,0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x3}, {}, {2, 0, -1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2,0,10}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x4}, {}, {2, -1, 0});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2,5,0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2,0,0,0,10}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x6}, {}, {-1, 2, 0});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({5, 2, 0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x7}, {}, {-1, 0});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2, 0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x7}, {}, {10,0,50,100});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100}));
|
||||||
|
}
|
||||||
|
|
|
@ -89,7 +89,7 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
|
||||||
sd::ops::standardize_bp op;
|
sd::ops::standardize_bp op;
|
||||||
auto result = op.evaluate({&x, &eps}, {0});
|
auto result = op.evaluate({&x, &eps}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
|
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
|
||||||
|
@ -108,7 +108,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
|
||||||
auto out = result.at(0);
|
auto out = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.equalsTo(out));
|
ASSERT_TRUE(e.equalsTo(out));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) {
|
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) {
|
||||||
|
@ -126,7 +126,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) {
|
||||||
auto out = result.at(0);
|
auto out = result.at(0);
|
||||||
// out->printIndexedBuffer("Adjusted Constrast");
|
// out->printIndexedBuffer("Adjusted Constrast");
|
||||||
ASSERT_TRUE(e.equalsTo(out));
|
ASSERT_TRUE(e.equalsTo(out));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) {
|
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) {
|
||||||
|
@ -144,7 +144,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) {
|
||||||
auto out = result.at(0);
|
auto out = result.at(0);
|
||||||
// out->printIndexedBuffer("Adjusted Constrast");
|
// out->printIndexedBuffer("Adjusted Constrast");
|
||||||
ASSERT_TRUE(e.equalsTo(out));
|
ASSERT_TRUE(e.equalsTo(out));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
|
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
|
||||||
|
@ -162,7 +162,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
|
||||||
auto out = result.at(0);
|
auto out = result.at(0);
|
||||||
// out->printIndexedBuffer("Adjusted Constrast");
|
// out->printIndexedBuffer("Adjusted Constrast");
|
||||||
ASSERT_TRUE(e.equalsTo(out));
|
ASSERT_TRUE(e.equalsTo(out));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) {
|
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) {
|
||||||
|
@ -177,7 +177,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) {
|
||||||
auto out = result.at(0);
|
auto out = result.at(0);
|
||||||
// out->printIndexedBuffer("Adjusted Constrast");
|
// out->printIndexedBuffer("Adjusted Constrast");
|
||||||
ASSERT_TRUE(e.equalsTo(out));
|
ASSERT_TRUE(e.equalsTo(out));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -308,7 +308,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) {
|
||||||
// out->printBuffer("Adjusted Constrast6");
|
// out->printBuffer("Adjusted Constrast6");
|
||||||
// e.printBuffer("Adjusted Expected 6");
|
// e.printBuffer("Adjusted Expected 6");
|
||||||
// ASSERT_TRUE(e.equalsTo(out));
|
// ASSERT_TRUE(e.equalsTo(out));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) {
|
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) {
|
||||||
|
@ -415,7 +415,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) {
|
||||||
auto diff = e - *out;
|
auto diff = e - *out;
|
||||||
// diff.printBuffer("Adjusted subtract 7");
|
// diff.printBuffer("Adjusted subtract 7");
|
||||||
ASSERT_TRUE(e.equalsTo(out));
|
ASSERT_TRUE(e.equalsTo(out));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
|
TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
|
||||||
|
@ -429,7 +429,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
|
||||||
auto out = result.at(0);
|
auto out = result.at(0);
|
||||||
// out->printIndexedBuffer("Casted result");
|
// out->printIndexedBuffer("Casted result");
|
||||||
ASSERT_TRUE(e.equalsTo(out));
|
ASSERT_TRUE(e.equalsTo(out));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
|
TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
|
||||||
|
@ -444,7 +444,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
|
||||||
auto out = result.at(0);
|
auto out = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.equalsTo(out));
|
ASSERT_TRUE(e.equalsTo(out));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_BitCast_3) {
|
TEST_F(DeclarableOpsTests15, Test_BitCast_3) {
|
||||||
|
@ -487,7 +487,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) {
|
||||||
// e.printIndexedBuffer("Double to int64");
|
// e.printIndexedBuffer("Double to int64");
|
||||||
auto res = result.at(0);
|
auto res = result.at(0);
|
||||||
ASSERT_EQ(*res, e);
|
ASSERT_EQ(*res, e);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -508,7 +508,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_5) {
|
||||||
|
|
||||||
// res->printIndexedBuffer("BITCAST5");
|
// res->printIndexedBuffer("BITCAST5");
|
||||||
ASSERT_TRUE(e.equalsTo(res));
|
ASSERT_TRUE(e.equalsTo(res));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_BitCast_6) {
|
TEST_F(DeclarableOpsTests15, Test_BitCast_6) {
|
||||||
|
@ -528,7 +528,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_6) {
|
||||||
|
|
||||||
// res->printIndexedBuffer("BITCAST6");
|
// res->printIndexedBuffer("BITCAST6");
|
||||||
ASSERT_TRUE(e.equalsTo(res));
|
ASSERT_TRUE(e.equalsTo(res));
|
||||||
|
|
||||||
}
|
}
|
||||||
TEST_F(DeclarableOpsTests15, Test_BitCast_7) {
|
TEST_F(DeclarableOpsTests15, Test_BitCast_7) {
|
||||||
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
|
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
|
||||||
|
@ -547,7 +547,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_7) {
|
||||||
|
|
||||||
// res->printIndexedBuffer("BITCAST7");
|
// res->printIndexedBuffer("BITCAST7");
|
||||||
ASSERT_TRUE(e.equalsTo(res));
|
ASSERT_TRUE(e.equalsTo(res));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_matmul_bp_1) {
|
TEST_F(DeclarableOpsTests15, test_matmul_bp_1) {
|
||||||
|
@ -637,7 +637,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
|
||||||
sd::ops::layer_norm op;
|
sd::ops::layer_norm op;
|
||||||
auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false});
|
auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false});
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
|
TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
|
||||||
|
@ -649,7 +649,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
|
||||||
sd::ops::layer_norm_bp op;
|
sd::ops::layer_norm_bp op;
|
||||||
auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false});
|
auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false});
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -710,30 +710,6 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) {
|
||||||
ASSERT_NE(*resultA0.at(0), *resultB0.at(0));
|
ASSERT_NE(*resultA0.at(0), *resultB0.at(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) {
|
|
||||||
auto array = NDArrayFactory::create<float>(119.f);
|
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&array}, {}, {1, 1});
|
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_2) {
|
|
||||||
auto array = NDArrayFactory::create<float>(119.f);
|
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {1, 1});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.execute({&array}, {&z}, {}, {1, 1}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result);
|
|
||||||
ASSERT_EQ(e, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_rank_1) {
|
TEST_F(DeclarableOpsTests15, test_rank_1) {
|
||||||
auto array = NDArrayFactory::create<float>('c', {4, 64});
|
auto array = NDArrayFactory::create<float>('c', {4, 64});
|
||||||
auto e = NDArrayFactory::create<int>('c', {}, {2});
|
auto e = NDArrayFactory::create<int>('c', {}, {2});
|
||||||
|
@ -757,7 +733,7 @@ TEST_F(DeclarableOpsTests15, test_rank_2) {
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
|
TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
|
||||||
|
@ -800,7 +776,7 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_2) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_lstmBlock_3) {
|
TEST_F(DeclarableOpsTests15, test_lstmBlock_3) {
|
||||||
|
@ -969,7 +945,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) {
|
||||||
sd::ops::rgb_to_grs op;
|
sd::ops::rgb_to_grs op;
|
||||||
auto result = op.evaluate({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
ASSERT_EQ(Status::THROW(), result.status());
|
ASSERT_EQ(Status::THROW(), result.status());
|
||||||
|
|
||||||
} catch (std::exception& e) {
|
} catch (std::exception& e) {
|
||||||
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
|
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
|
||||||
}
|
}
|
||||||
|
@ -1063,7 +1039,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1074,7 +1050,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) {
|
||||||
sd::ops::rgb_to_yuv op;
|
sd::ops::rgb_to_yuv op;
|
||||||
auto result = op.evaluate({ &rgbs }, {}, {});
|
auto result = op.evaluate({ &rgbs }, {}, {});
|
||||||
ASSERT_EQ(Status::THROW(), result.status());
|
ASSERT_EQ(Status::THROW(), result.status());
|
||||||
|
|
||||||
}
|
}
|
||||||
catch (std::exception & e) {
|
catch (std::exception & e) {
|
||||||
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
|
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
|
||||||
|
@ -1109,7 +1085,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1168,7 +1144,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1179,7 +1155,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) {
|
||||||
sd::ops::yuv_to_rgb op;
|
sd::ops::yuv_to_rgb op;
|
||||||
auto result = op.evaluate({ &yuv }, {}, {});
|
auto result = op.evaluate({ &yuv }, {}, {});
|
||||||
ASSERT_EQ(Status::THROW(), result.status());
|
ASSERT_EQ(Status::THROW(), result.status());
|
||||||
|
|
||||||
}
|
}
|
||||||
catch (std::exception & e) {
|
catch (std::exception & e) {
|
||||||
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
|
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
|
||||||
|
@ -1423,7 +1399,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test8) {
|
||||||
ASSERT_TRUE(dLdxExp.equalsTo(dLdx));
|
ASSERT_TRUE(dLdxExp.equalsTo(dLdx));
|
||||||
ASSERT_TRUE(dLdyExp.isSameShape(dLdy));
|
ASSERT_TRUE(dLdyExp.isSameShape(dLdy));
|
||||||
ASSERT_TRUE(dLdyExp.equalsTo(dLdy));
|
ASSERT_TRUE(dLdyExp.equalsTo(dLdy));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test9) {
|
TEST_F(DeclarableOpsTests15, Pow_BP_Test9) {
|
||||||
|
@ -1515,7 +1491,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) {
|
||||||
ASSERT_NEAR(dLdyB->e<float>(i), dLdyExpB.e<float>(i), 0.00001);
|
ASSERT_NEAR(dLdyB->e<float>(i), dLdyExpB.e<float>(i), 0.00001);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) {
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) {
|
||||||
|
@ -1532,10 +1508,10 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) {
|
||||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {});
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
||||||
|
|
||||||
auto* dLdAbp = resultsBP.at(0);
|
auto* dLdAbp = resultsBP.at(0);
|
||||||
auto* dLdBbp = resultsBP.at(1);
|
auto* dLdBbp = resultsBP.at(1);
|
||||||
|
|
||||||
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
||||||
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
@ -1554,10 +1530,10 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) {
|
||||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
||||||
|
|
||||||
auto* dLdAbp = resultsBP.at(0);
|
auto* dLdAbp = resultsBP.at(0);
|
||||||
auto* dLdBbp = resultsBP.at(1);
|
auto* dLdBbp = resultsBP.at(1);
|
||||||
|
|
||||||
ASSERT_TRUE(B.isSameShape(*dLdAbp));
|
ASSERT_TRUE(B.isSameShape(*dLdAbp));
|
||||||
ASSERT_TRUE(B.equalsTo(*dLdAbp));
|
ASSERT_TRUE(B.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
@ -1606,7 +1582,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) {
|
||||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
||||||
|
|
||||||
auto* dLdAbp = resultsBP.at(0);
|
auto* dLdAbp = resultsBP.at(0);
|
||||||
auto* dLdBbp = resultsBP.at(1);
|
auto* dLdBbp = resultsBP.at(1);
|
||||||
|
|
||||||
|
@ -1632,7 +1608,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) {
|
||||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
||||||
|
|
||||||
auto* dLdAbp = resultsBP.at(0);
|
auto* dLdAbp = resultsBP.at(0);
|
||||||
auto* dLdBbp = resultsBP.at(1);
|
auto* dLdBbp = resultsBP.at(1);
|
||||||
|
|
||||||
|
@ -1655,7 +1631,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) {
|
||||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
||||||
|
|
||||||
auto* dLdAbp = resultsBP.at(0);
|
auto* dLdAbp = resultsBP.at(0);
|
||||||
auto* dLdBbp = resultsBP.at(1);
|
auto* dLdBbp = resultsBP.at(1);
|
||||||
|
|
||||||
|
@ -1706,7 +1682,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) {
|
||||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status());
|
||||||
|
|
||||||
auto* dLdAbp = resultsBP.at(0);
|
auto* dLdAbp = resultsBP.at(0);
|
||||||
auto* dLdBbp = resultsBP.at(1);
|
auto* dLdBbp = resultsBP.at(1);
|
||||||
|
|
||||||
|
|
|
@ -789,24 +789,6 @@ TEST_F(DeclarableOpsTests4, Test_FloorTests_1) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4, 3});
|
|
||||||
|
|
||||||
x.linspace(1);
|
|
||||||
exp.linspace(1);
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x}, {-99, 4, 3});
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Split_1) {
|
TEST_F(DeclarableOpsTests4, Test_Split_1) {
|
||||||
|
@ -1209,23 +1191,6 @@ TEST_F(DeclarableOpsTests4, Test_Add_119) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
|
||||||
auto shape = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 2});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x, &shape});
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {
|
TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {
|
||||||
|
|
|
@ -140,37 +140,6 @@ TEST_F(EmptyTests, Test_Concat_4) {
|
||||||
ASSERT_EQ(exp, *z);
|
ASSERT_EQ(exp, *z);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Reshape_1) {
|
|
||||||
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
|
|
||||||
auto exp = NDArrayFactory::create<float>(119.f);
|
|
||||||
auto empty = NDArrayFactory::empty_<int>();
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&vector, empty}, {}, {});
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
|
||||||
|
|
||||||
ASSERT_EQ(exp, *result.at(0));
|
|
||||||
|
|
||||||
delete empty;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Reshape_3) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
|
|
||||||
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
|
|
||||||
auto e = NDArrayFactory::create<float>('c', {10, 0});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x, &y}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
|
||||||
ASSERT_EQ(e, *z);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_dup_1) {
|
TEST_F(EmptyTests, Test_dup_1) {
|
||||||
auto empty = NDArrayFactory::empty<int>();
|
auto empty = NDArrayFactory::empty<int>();
|
||||||
auto dup = new NDArray(empty.dup());
|
auto dup = new NDArray(empty.dup());
|
||||||
|
@ -256,41 +225,6 @@ TEST_F(EmptyTests, test_shaped_empty_4) {
|
||||||
ASSERT_EQ(shapeOf, array.getShapeAsVector());
|
ASSERT_EQ(shapeOf, array.getShapeAsVector());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, test_empty_reshape_1) {
|
|
||||||
/*
|
|
||||||
INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0);
|
|
||||||
INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2);
|
|
||||||
|
|
||||||
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0];
|
|
||||||
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0];
|
|
||||||
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0];
|
|
||||||
|
|
||||||
assertArrayEquals(new long[]{2, 0, 0}, out0.shape());
|
|
||||||
assertArrayEquals(new long[]{0, 1}, out1.shape());
|
|
||||||
assertArrayEquals(new long[]{10, 0}, out2.shape());
|
|
||||||
*/
|
|
||||||
auto x0 = NDArrayFactory::create<float>('c', {2, 0});
|
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {0, 1, 2});
|
|
||||||
|
|
||||||
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
|
|
||||||
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
|
|
||||||
|
|
||||||
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
|
|
||||||
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result0.status());
|
|
||||||
auto z0 = result0.at(0);
|
|
||||||
ASSERT_EQ(e0, *z0);
|
|
||||||
|
|
||||||
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result1.status());
|
|
||||||
auto z1 = result1.at(0);
|
|
||||||
ASSERT_EQ(e1, *z1);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(EmptyTests, test_empty_matmul_1) {
|
TEST_F(EmptyTests, test_empty_matmul_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {0, 1});
|
auto x = NDArrayFactory::create<float>('c', {0, 1});
|
||||||
|
|
|
@ -48,7 +48,7 @@ TEST_F(ParityOpsTests, TestZeroAs1) {
|
||||||
ASSERT_TRUE(z->isSameShape(&x));
|
ASSERT_TRUE(z->isSameShape(&x));
|
||||||
ASSERT_TRUE(z->equalsTo(&exp));
|
ASSERT_TRUE(z->equalsTo(&exp));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, TestMaximum1) {
|
TEST_F(ParityOpsTests, TestMaximum1) {
|
||||||
|
@ -66,7 +66,7 @@ TEST_F(ParityOpsTests, TestMaximum1) {
|
||||||
|
|
||||||
ASSERT_TRUE(y.equalsTo(z));
|
ASSERT_TRUE(y.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ TEST_F(ParityOpsTests, TestMinimum1) {
|
||||||
|
|
||||||
ASSERT_TRUE(y.equalsTo(z));
|
ASSERT_TRUE(y.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, TestTear1) {
|
TEST_F(ParityOpsTests, TestTear1) {
|
||||||
|
@ -106,7 +106,7 @@ TEST_F(ParityOpsTests, TestTear1) {
|
||||||
for (int e = 0; e < result.size(); e++)
|
for (int e = 0; e < result.size(); e++)
|
||||||
ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e)));
|
ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e)));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, TestUnstack1) {
|
TEST_F(ParityOpsTests, TestUnstack1) {
|
||||||
|
@ -126,7 +126,7 @@ TEST_F(ParityOpsTests, TestUnstack1) {
|
||||||
for (int e = 0; e < result.size(); e++)
|
for (int e = 0; e < result.size(); e++)
|
||||||
ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e)));
|
ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e)));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,7 +148,7 @@ TEST_F(ParityOpsTests, TestUnstack2) {
|
||||||
for (int e = 0; e < result.size(); e++)
|
for (int e = 0; e < result.size(); e++)
|
||||||
ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e)));
|
ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e)));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, TestUnstack3) {
|
TEST_F(ParityOpsTests, TestUnstack3) {
|
||||||
|
@ -166,7 +166,7 @@ TEST_F(ParityOpsTests, TestUnstack3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -185,7 +185,7 @@ TEST_F(ParityOpsTests, TestUnstack4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, TestUnstack5) {
|
TEST_F(ParityOpsTests, TestUnstack5) {
|
||||||
|
@ -203,7 +203,7 @@ TEST_F(ParityOpsTests, TestUnstack5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, TestUnstack6) {
|
TEST_F(ParityOpsTests, TestUnstack6) {
|
||||||
|
@ -221,7 +221,7 @@ TEST_F(ParityOpsTests, TestUnstack6) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, TestUnstack7) {
|
TEST_F(ParityOpsTests, TestUnstack7) {
|
||||||
|
@ -239,7 +239,7 @@ TEST_F(ParityOpsTests, TestUnstack7) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, TestUnstack8) {
|
TEST_F(ParityOpsTests, TestUnstack8) {
|
||||||
|
@ -257,7 +257,7 @@ TEST_F(ParityOpsTests, TestUnstack8) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, TestUnstack9) {
|
TEST_F(ParityOpsTests, TestUnstack9) {
|
||||||
|
@ -275,7 +275,7 @@ TEST_F(ParityOpsTests, TestUnstack9) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -293,7 +293,7 @@ TEST_F(ParityOpsTests, TestUnstack10) {
|
||||||
ASSERT_TRUE(exp.isSameShape(result.at(1)));
|
ASSERT_TRUE(exp.isSameShape(result.at(1)));
|
||||||
ASSERT_TRUE(exp.isSameShape(result.at(2)));
|
ASSERT_TRUE(exp.isSameShape(result.at(2)));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -310,7 +310,7 @@ TEST_F(ParityOpsTests, TestUnstack11) {
|
||||||
ASSERT_TRUE(exp.isSameShape(result.at(0)));
|
ASSERT_TRUE(exp.isSameShape(result.at(0)));
|
||||||
ASSERT_TRUE(exp.isSameShape(result.at(1)));
|
ASSERT_TRUE(exp.isSameShape(result.at(1)));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -325,7 +325,7 @@ TEST_F(ParityOpsTests, TestUnstack12) {
|
||||||
|
|
||||||
ASSERT_TRUE(result.size() == 0);
|
ASSERT_TRUE(result.size() == 0);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, TestUnstack13) {
|
TEST_F(ParityOpsTests, TestUnstack13) {
|
||||||
|
@ -361,7 +361,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest1) {
|
||||||
ASSERT_TRUE(reshaped.isSameShape(z));
|
ASSERT_TRUE(reshaped.isSameShape(z));
|
||||||
ASSERT_TRUE(reshaped.equalsTo(z));
|
ASSERT_TRUE(reshaped.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -380,7 +380,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest2) {
|
||||||
ASSERT_TRUE(reshaped.isSameShape(z));
|
ASSERT_TRUE(reshaped.isSameShape(z));
|
||||||
ASSERT_TRUE(reshaped.equalsTo(z));
|
ASSERT_TRUE(reshaped.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -399,7 +399,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest3) {
|
||||||
ASSERT_TRUE(reshaped.isSameShape(z));
|
ASSERT_TRUE(reshaped.isSameShape(z));
|
||||||
ASSERT_TRUE(reshaped.equalsTo(z));
|
ASSERT_TRUE(reshaped.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, ExpandDimsTest4) {
|
TEST_F(ParityOpsTests, ExpandDimsTest4) {
|
||||||
|
@ -417,7 +417,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest4) {
|
||||||
ASSERT_TRUE(reshaped.isSameShape(z));
|
ASSERT_TRUE(reshaped.isSameShape(z));
|
||||||
ASSERT_TRUE(reshaped.equalsTo(z));
|
ASSERT_TRUE(reshaped.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -434,7 +434,7 @@ TEST_F(ParityOpsTests, Test_Shape_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -452,7 +452,7 @@ TEST_F(ParityOpsTests, Test_Equals_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -470,7 +470,7 @@ TEST_F(ParityOpsTests, Test_NotEquals_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Less_1) {
|
TEST_F(ParityOpsTests, Test_Less_1) {
|
||||||
|
@ -487,7 +487,7 @@ TEST_F(ParityOpsTests, Test_Less_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_LessEquals_1) {
|
TEST_F(ParityOpsTests, Test_LessEquals_1) {
|
||||||
|
@ -504,7 +504,7 @@ TEST_F(ParityOpsTests, Test_LessEquals_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_GreaterEquals_1) {
|
TEST_F(ParityOpsTests, Test_GreaterEquals_1) {
|
||||||
|
@ -521,7 +521,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_GreaterEquals_2) {
|
TEST_F(ParityOpsTests, Test_GreaterEquals_2) {
|
||||||
|
@ -538,7 +538,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Greater_1) {
|
TEST_F(ParityOpsTests, Test_Greater_1) {
|
||||||
|
@ -555,7 +555,7 @@ TEST_F(ParityOpsTests, Test_Greater_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Where_1) {
|
TEST_F(ParityOpsTests, Test_Where_1) {
|
||||||
|
@ -575,7 +575,7 @@ TEST_F(ParityOpsTests, Test_Where_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Where_2) {
|
TEST_F(ParityOpsTests, Test_Where_2) {
|
||||||
|
@ -593,7 +593,7 @@ TEST_F(ParityOpsTests, Test_Where_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -612,7 +612,7 @@ TEST_F(ParityOpsTests, Test_Where_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Select_1) {
|
TEST_F(ParityOpsTests, Test_Select_1) {
|
||||||
|
@ -630,7 +630,7 @@ TEST_F(ParityOpsTests, Test_Select_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Select_2) {
|
TEST_F(ParityOpsTests, Test_Select_2) {
|
||||||
|
@ -648,7 +648,7 @@ TEST_F(ParityOpsTests, Test_Select_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Select_3) {
|
TEST_F(ParityOpsTests, Test_Select_3) {
|
||||||
|
@ -666,25 +666,7 @@ TEST_F(ParityOpsTests, Test_Select_3) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Reshape_TF_1) {
|
|
||||||
auto x = NDArrayFactory::create<int>('c', {2, 2}, {1, 2, 3, 4});
|
|
||||||
auto shape = NDArrayFactory::create<int>('c', {1, 3}, {1, 2, 2});
|
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<int>('c', {1, 2, 2}, {1, 2, 3, 4});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
|
|
||||||
auto result = op.evaluate({&x, &shape}, {}, {});
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Bias_Add_1) {
|
TEST_F(ParityOpsTests, Test_Bias_Add_1) {
|
||||||
|
@ -702,7 +684,7 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) {
|
||||||
for (int e = 0; e < tads.size(); e++) {
|
for (int e = 0; e < tads.size(); e++) {
|
||||||
ASSERT_TRUE(bias.equalsTo(tads.at(e)));
|
ASSERT_TRUE(bias.equalsTo(tads.at(e)));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
|
TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
|
||||||
|
@ -718,7 +700,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
|
TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
|
||||||
|
@ -735,7 +717,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Scatter_Add_3) {
|
TEST_F(ParityOpsTests, Test_Scatter_Add_3) {
|
||||||
|
@ -751,7 +733,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_3) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
|
TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
|
||||||
|
@ -767,7 +749,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Scatter_Add_5) {
|
TEST_F(ParityOpsTests, Test_Scatter_Add_5) {
|
||||||
|
@ -784,7 +766,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_5) {
|
||||||
// z->printBuffer();
|
// z->printBuffer();
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Scatter_Add_6) {
|
TEST_F(ParityOpsTests, Test_Scatter_Add_6) {
|
||||||
|
@ -800,7 +782,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Scatter_Add_7) {
|
TEST_F(ParityOpsTests, Test_Scatter_Add_7) {
|
||||||
|
@ -816,7 +798,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_7) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
@ -864,7 +846,7 @@ TEST_F(ParityOpsTests, scatterMax_test1) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, scatterMax_test2) {
|
TEST_F(ParityOpsTests, scatterMax_test2) {
|
||||||
|
@ -880,7 +862,7 @@ TEST_F(ParityOpsTests, scatterMax_test2) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, scatterMax_test3) {
|
TEST_F(ParityOpsTests, scatterMax_test3) {
|
||||||
|
@ -897,7 +879,7 @@ TEST_F(ParityOpsTests, scatterMax_test3) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, scatterMax_test4) {
|
TEST_F(ParityOpsTests, scatterMax_test4) {
|
||||||
|
@ -913,7 +895,7 @@ TEST_F(ParityOpsTests, scatterMax_test4) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, scatterMax_test5) {
|
TEST_F(ParityOpsTests, scatterMax_test5) {
|
||||||
|
@ -929,7 +911,7 @@ TEST_F(ParityOpsTests, scatterMax_test5) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, scatterMax_test6) {
|
TEST_F(ParityOpsTests, scatterMax_test6) {
|
||||||
|
@ -945,7 +927,7 @@ TEST_F(ParityOpsTests, scatterMax_test6) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -963,7 +945,7 @@ TEST_F(ParityOpsTests, scatterMin_test1) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, scatterMin_test2) {
|
TEST_F(ParityOpsTests, scatterMin_test2) {
|
||||||
|
@ -979,7 +961,7 @@ TEST_F(ParityOpsTests, scatterMin_test2) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, scatterMin_test3) {
|
TEST_F(ParityOpsTests, scatterMin_test3) {
|
||||||
|
@ -995,7 +977,7 @@ TEST_F(ParityOpsTests, scatterMin_test3) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, scatterMin_test4) {
|
TEST_F(ParityOpsTests, scatterMin_test4) {
|
||||||
|
@ -1012,7 +994,7 @@ TEST_F(ParityOpsTests, scatterMin_test4) {
|
||||||
// z->printBuffer();
|
// z->printBuffer();
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1044,7 +1026,7 @@ TEST_F(ParityOpsTests, scatterND_test1) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1064,7 +1046,7 @@ TEST_F(ParityOpsTests, scatterND_test2) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1088,7 +1070,7 @@ TEST_F(ParityOpsTests, scatterND_test3) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1107,7 +1089,7 @@ TEST_F(ParityOpsTests, scatterND_test4) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1127,7 +1109,7 @@ TEST_F(ParityOpsTests, scatterND_test5) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1154,7 +1136,7 @@ TEST_F(ParityOpsTests, scatterND_test6) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1181,7 +1163,7 @@ TEST_F(ParityOpsTests, scatterND_test7) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1202,7 +1184,7 @@ TEST_F(ParityOpsTests, scatterND_test8) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1236,7 +1218,7 @@ TEST_F(ParityOpsTests, scatterND_add_test1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1260,7 +1242,7 @@ TEST_F(ParityOpsTests, scatterND_add_test2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1283,7 +1265,7 @@ TEST_F(ParityOpsTests, scatterND_add_test3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1310,7 +1292,7 @@ TEST_F(ParityOpsTests, scatterND_add_test4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1346,7 +1328,7 @@ TEST_F(ParityOpsTests, scatterND_add_test5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1379,7 +1361,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1404,7 +1386,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1427,7 +1409,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1454,7 +1436,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1490,7 +1472,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1511,7 +1493,7 @@ TEST_F(ParityOpsTests, scatterND_update_test1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1535,7 +1517,7 @@ TEST_F(ParityOpsTests, scatterND_update_test2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1559,7 +1541,7 @@ TEST_F(ParityOpsTests, scatterND_update_test3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1586,7 +1568,7 @@ TEST_F(ParityOpsTests, scatterND_update_test4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1622,7 +1604,7 @@ TEST_F(ParityOpsTests, scatterND_update_test5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1655,7 +1637,7 @@ TEST_F(ParityOpsTests, scatter_update_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(x));
|
ASSERT_TRUE(exp.isSameShape(x));
|
||||||
ASSERT_TRUE(exp.equalsTo(x));
|
ASSERT_TRUE(exp.equalsTo(x));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1674,7 +1656,7 @@ TEST_F(ParityOpsTests, scatter_update_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(x));
|
ASSERT_TRUE(exp.isSameShape(x));
|
||||||
ASSERT_TRUE(exp.equalsTo(x));
|
ASSERT_TRUE(exp.equalsTo(x));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1693,7 +1675,7 @@ TEST_F(ParityOpsTests, scatter_update_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(x));
|
ASSERT_TRUE(exp.isSameShape(x));
|
||||||
ASSERT_TRUE(exp.equalsTo(x));
|
ASSERT_TRUE(exp.equalsTo(x));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1712,5 +1694,5 @@ TEST_F(ParityOpsTests, scatter_update_4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(x));
|
ASSERT_TRUE(exp.isSameShape(x));
|
||||||
ASSERT_TRUE(exp.equalsTo(x));
|
ASSERT_TRUE(exp.equalsTo(x));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,7 +103,7 @@ TEST_F(ScalarTests, Test_Concat_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,7 +124,7 @@ TEST_F(ScalarTests, Test_Concat_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ TEST_F(ScalarTests, Test_Concat_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ScalarTests, Test_ExpandDims_1) {
|
TEST_F(ScalarTests, Test_ExpandDims_1) {
|
||||||
|
@ -163,7 +163,7 @@ TEST_F(ScalarTests, Test_ExpandDims_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ScalarTests, Test_Squeeze_1) {
|
TEST_F(ScalarTests, Test_Squeeze_1) {
|
||||||
|
@ -179,27 +179,9 @@ TEST_F(ScalarTests, Test_Squeeze_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(ScalarTests, Test_Reshape_1) {
|
|
||||||
auto x = NDArrayFactory::create<float>(2.0f);
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(ScalarTests, Test_Permute_1) {
|
TEST_F(ScalarTests, Test_Permute_1) {
|
||||||
auto x = NDArrayFactory::create<float>(3.0f);
|
auto x = NDArrayFactory::create<float>(3.0f);
|
||||||
auto exp = NDArrayFactory::create<float>(3.0f);
|
auto exp = NDArrayFactory::create<float>(3.0f);
|
||||||
|
@ -213,7 +195,7 @@ TEST_F(ScalarTests, Test_Permute_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ScalarTests, Test_Concat_Scalar_1) {
|
TEST_F(ScalarTests, Test_Concat_Scalar_1) {
|
||||||
|
|
|
@ -77,7 +77,7 @@ TEST_F(SingleDimTests, Test_Concat_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Reduce_1) {
|
TEST_F(SingleDimTests, Test_Reduce_1) {
|
||||||
|
@ -111,7 +111,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ TEST_F(SingleDimTests, Test_Squeeze_1) {
|
||||||
ASSERT_EQ(exp.rankOf(), z->rankOf());
|
ASSERT_EQ(exp.rankOf(), z->rankOf());
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Squeeze_2) {
|
TEST_F(SingleDimTests, Test_Squeeze_2) {
|
||||||
|
@ -165,42 +165,9 @@ TEST_F(SingleDimTests, Test_Squeeze_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Reshape_1) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x}, {}, {-99, 3});
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Reshape_2) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Permute_1) {
|
TEST_F(SingleDimTests, Test_Permute_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
|
@ -214,5 +181,5 @@ TEST_F(SingleDimTests, Test_Permute_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue