Shyrma reshape empty (#338)

* - start working on reshape op which operates with empty shapes

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct reshaping for empty arrays

Signed-off-by: Yurii <iuriish@yahoo.com>

* - remove unnecessary check in Loopkind

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Yurii Shyrma 2020-03-31 07:41:16 +03:00 committed by GitHub
parent bf0ddbc06c
commit 29e61579c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 730 additions and 854 deletions

View File

@ -1944,7 +1944,7 @@ void NDArray::tilei(const std::vector<Nd4jLong>& reps) {
Nd4jLong NDArray::sizeAt(const int dim) const { Nd4jLong NDArray::sizeAt(const int dim) const {
if (dim >= this->rankOf() || dim < -this->rankOf()) if (dim >= this->rankOf() || dim < -this->rankOf())
throw std::runtime_error("Bad size index requested"); throw std::runtime_error("NDArray::sizeAt: bad size index requested");
if (dim >= 0) if (dim >= 0)
return shape::shapeOf(_shapeInfo)[dim]; return shape::shapeOf(_shapeInfo)[dim];

View File

@ -206,7 +206,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const
const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c'; const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c';
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';; const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';;
if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && shape::rank(xShapeInfo) == 2 && xEws == 1 && xOrder == 'c' && xRank == 2 && if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 &&
tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
return SMALLARR2DX; return SMALLARR2DX;
if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))

View File

@ -40,106 +40,14 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
return Status::OK(); //No op return Status::OK(); //No op
} }
if (block.width() == 1) { REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf());
auto arguments = block.getIArguments(); if (Environment::getInstance()->isDebugAndVerbose())
int argsSize = arguments->size(); nd4j_printv("Reshape: new shape", z->getShapeAsVector());
z->assign(x->reshape(z->ordering(), z->getShapeAsVector()));
int e = 1;
char order = (char) -(*arguments)[0];
if (order != 'c' && order != 'f') {
order = 'c'; //x->ordering();
e = 0;
}
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
std::vector<Nd4jLong> shapeNew;
int e2 = e;
for (; e < (int) arguments->size(); e++) {
if (arguments->at(e) == -1){
Nd4jLong shapeLength = 1;
for(; e2 < e; e2++){
shapeLength *= arguments->at(e2);
}
for(e2 = e + 1; e2 < arguments->size(); e2++){
shapeLength *= arguments->at(e2);
}
Nd4jLong realShape = x->lengthOf() / shapeLength;
shapeNew.push_back(realShape);
}
else{
shapeNew.push_back(arguments->at(e));
}
}
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
if (Environment::getInstance()->isDebugAndVerbose()) {
nd4j_printv("Reshape: new shape", shapeNew);
}
auto xr = x->reshape(order, shapeNew);
z->assign(xr);
STORE_RESULT(*z);
return Status::OK(); return Status::OK();
} else if (block.width() == 2) {
auto s = INPUT_VARIABLE(1);
char order = 'c';
if (block.numI() > 0)
order = (char) -INT_ARG(0);
std::vector<Nd4jLong> shapeNew(s->lengthOf());
for (int e = 0; e < (int) s->lengthOf(); e++) {
auto dim = s->e<Nd4jLong >(e);
if (dim == -1){
Nd4jLong shapeLength = 1;
for(int e2 = 0; e2 < e; e2++){
shapeLength *= s->e<Nd4jLong>(e2);
}
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= s->e<Nd4jLong>(e2);
}
Nd4jLong realShape = x->lengthOf() / shapeLength;
shapeNew[e] = realShape;
}
else{
shapeNew[e] = dim;
}
}
if (Environment::getInstance()->isDebugAndVerbose()) {
nd4j_printv("Reshape: new shape", shapeNew);
}
if (s->isScalar()) {
// just a scalar
z->assign(x);
} else {
// in some cases we might go away with simple memcpy call instead of assign call
if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) {
z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset());
} else {
auto xr = x->reshape(order, shapeNew);
z->assign(xr);
}
}
return Status::OK();
}
return ND4J_STATUS_BAD_INPUT;
} }
@ -151,117 +59,73 @@ DECLARE_TYPES(reshape) {
} }
DECLARE_SHAPE_FN(reshape) { DECLARE_SHAPE_FN(reshape) {
auto inp = inputShape->at(0);
// we can launch op using Int arguments const auto x = INPUT_VARIABLE(0);
if (inputShape->size() == 1) {
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
std::vector<int> *arguments = block.getIArguments();
int e = 1;
char order = (char) -(*arguments)[0];
if (order != 'c' && order != 'f') {
order = shape::order(inp);
e = 0;
}
std::vector<int> reshapeArgs;
std::vector<Nd4jLong> shapeNew; std::vector<Nd4jLong> shapeNew;
char orderNew = 'c';
int e2 = e; if (block.width() == 1) {
for (; e < (int) arguments->size(); e++) { reshapeArgs = *block.getIArguments();
if ((int) arguments->at(e) == -1){ if(!reshapeArgs.empty()) {
orderNew = (char) -reshapeArgs[0];
Nd4jLong shapeLength = 1; if(orderNew == 'c' || orderNew == 'f')
for(; e2 < e; e2 ++){ reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case
shapeLength *= arguments->at(e2);
} }
for(e2 = e + 1; e2 < arguments->size(); e2++){ }
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); else {
shapeLength *= arguments->at(e2); reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>();
orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c';
} }
if(shapeLength == 0){ REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
//Edge case for empty:
Nd4jLong xLen = x->lengthOf();
if(x->isEmpty()) {
xLen = 1;
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
if(x->sizeAt(i) != 0)
xLen *= x->sizeAt(i);
}
for (uint i = 0; i < reshapeArgs.size(); ++i) {
if (reshapeArgs[i] == -1) {
uint shapeLength = 1, numOfZeros = 0;
for(uint j = 0; j < i; ++j)
if(reshapeArgs[j] != 0)
shapeLength *= reshapeArgs[j];
else
++numOfZeros;
for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
if(reshapeArgs[j] != 0)
shapeLength *= reshapeArgs[j];
else
++numOfZeros;
}
const auto dim = xLen / shapeLength;
if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
shapeNew.push_back(0); shapeNew.push_back(0);
} else { else
//Standard case shapeNew.push_back(dim);
Nd4jLong realShape = shape::length(inp) / shapeLength;
shapeNew.push_back(realShape);
}
}
else{
shapeNew.push_back(arguments->at(e));
} }
else
shapeNew.push_back(reshapeArgs[i]);
} }
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew))); auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
} else { REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
// or, with second input "as shape"
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
// special case here return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew));
if (y->isEmpty()) {
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
}
//Special case: empty.reshape(-1) -> return empty
if (x->isEmpty()) {
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
Nd4jLong prod = 1;
bool hasNegs = false;
for (auto v:shapeOf) {
if (v < 0) {
hasNegs = true;
v = 0;
}
prod *= v;
}
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
// if there are -1s - we turn them into zeros
if (hasNegs) {
for (int e = 0; e < shapeOf.size(); e++)
if (shapeOf[e] < 0)
shapeOf[e] = 0;
}
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
return SHAPELIST(CONSTANT(newShape));
}
std::vector<Nd4jLong> shapeNew(y->lengthOf());
for (int e = 0; e < (int) y->lengthOf(); e++) {
auto dim = y->e<Nd4jLong>(e);
if (dim == -1){
Nd4jLong shapeLength = 1;
for(int e2 = 0; e2 < e; e2++){
shapeLength *= y->e<Nd4jLong>(e2);
}
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= y->e<Nd4jLong>(e2);
}
if(shapeLength == 0){
//Edge case for empty:
shapeNew[e] = 0;
} else {
Nd4jLong realShape = shape::length(inp) / shapeLength;
shapeNew[e] = realShape;
}
}else {
shapeNew[e] = dim;
}
}
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
}
} }
} }
} }

View File

@ -46,7 +46,6 @@ TEST_F(ArrayOptionsTests, TestShape_Basic_1) {
ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
} }
TEST_F(ArrayOptionsTests, TestShape_Basic_2) { TEST_F(ArrayOptionsTests, TestShape_Basic_2) {
shape[5] = 258; shape[5] = 258;

View File

@ -1676,31 +1676,6 @@ TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) {
delete block; delete block;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Reshapeas1) {
const std::vector<Nd4jLong> xShape = { 5,4,3 };
const std::vector<Nd4jLong> yShape = { 3,5,4 };
auto x = NDArrayFactory::create_<float>('f', xShape);
auto y = NDArrayFactory::create_<float>('f', yShape);
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(-2, y);
auto block = new Context(1, variableSpace, true);
block->fillInputs({ -1, -2 });
sd::ops::reshapeas reshape;
reshape.execute(block);
ASSERT_TRUE(x->isSameShape(y));
delete variableSpace;
delete block;
}
TEST_F(DeclarableOpsTests1, Test_Cast_1) { TEST_F(DeclarableOpsTests1, Test_Cast_1) {
// TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected // TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected
auto x = NDArrayFactory::create<float>('c', { 5, 5 }); auto x = NDArrayFactory::create<float>('c', { 5, 5 });
@ -1848,113 +1823,6 @@ TEST_F(DeclarableOpsTests1, TestGemv1) {
#endif #endif
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Reshape2) {
const std::vector<Nd4jLong> xShape = { 5,4,3 };
const std::vector<Nd4jLong> yShape = { 3,5,4 };
auto x = NDArrayFactory::create_<float>('c', xShape);
auto y = NDArrayFactory::create_<float>('c', yShape);
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(1, new Variable());
auto block = new Context(1, variableSpace, false);
block->fillInputs({ -1 });
std::vector<int>* arguments = block->getIArguments();
arguments->push_back(-y->ordering());
arguments->push_back(3);
arguments->push_back(5);
arguments->push_back(4);
sd::ops::reshape reshape;
Nd4jStatus status = reshape.execute(block);
ASSERT_EQ(ND4J_STATUS_OK, status);
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
ASSERT_TRUE(result->isSameShape(y));
delete y;
delete block;
delete variableSpace;
}
TEST_F(DeclarableOpsTests1, Reshape3) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(x.isSameShape(z));
}
TEST_F(DeclarableOpsTests1, Reshape4) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 3, 4, 5 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(x.isSameShape(z));
}
TEST_F(DeclarableOpsTests1, Reshape5) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 5, 4, 3 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
}
TEST_F(DeclarableOpsTests1, Reshape6) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
auto exp = NDArrayFactory::create<float>('c', { 4, 15 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 4, -1 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(z->isSameShape(exp));
}
TEST_F(DeclarableOpsTests1, Reshape7) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
auto exp = NDArrayFactory::create<float>('c', { 60 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { -1 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(z->isSameShape(exp));
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Transpose1) { TEST_F(DeclarableOpsTests1, Transpose1) {
@ -1983,7 +1851,6 @@ TEST_F(DeclarableOpsTests1, Transpose1) {
delete variableSpace; delete variableSpace;
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// not-in-place // not-in-place
TEST_F(DeclarableOpsTests1, Permute1) { TEST_F(DeclarableOpsTests1, Permute1) {

View File

@ -52,22 +52,6 @@ TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) {
ASSERT_EQ(exp, *z); ASSERT_EQ(exp, *z);
}
TEST_F(DeclarableOpsTests14, Test_Reshape_CF_1) {
auto x = NDArrayFactory::create<double>('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0});
auto e = NDArrayFactory::create<double>('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0});
auto r = x.reshape('c', {3, 2});;
r.streamline('f');
sd::ops::reshape op;
auto result = op.evaluate({&x}, {3, 2});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
} }
TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) { TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) {
@ -2073,7 +2057,363 @@ TEST_F(DeclarableOpsTests14, Stack_21) {
ASSERT_TRUE(outStack->isSameShape(outConcat)); ASSERT_TRUE(outStack->isSameShape(outConcat));
ASSERT_TRUE(outStack->equalsTo(outConcat)); ASSERT_TRUE(outStack->equalsTo(outConcat));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Reshape1) {
const std::vector<Nd4jLong> xShape = { 5,4,3 };
const std::vector<Nd4jLong> yShape = { 3,5,4 };
auto x = NDArrayFactory::create_<float>('f', xShape);
auto y = NDArrayFactory::create_<float>('f', yShape);
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(-2, y);
auto block = new Context(1, variableSpace, true);
block->fillInputs({ -1, -2 });
sd::ops::reshapeas reshape;
reshape.execute(block);
ASSERT_TRUE(x->isSameShape(y));
delete variableSpace;
delete block;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Reshape2) {
const std::vector<Nd4jLong> xShape = { 5,4,3 };
const std::vector<Nd4jLong> yShape = { 3,5,4 };
auto x = NDArrayFactory::create_<float>('c', xShape);
auto y = NDArrayFactory::create_<float>('c', yShape);
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(1, new Variable());
auto block = new Context(1, variableSpace, false);
block->fillInputs({ -1 });
std::vector<int>* arguments = block->getIArguments();
arguments->push_back(-y->ordering());
arguments->push_back(3);
arguments->push_back(5);
arguments->push_back(4);
sd::ops::reshape reshape;
Nd4jStatus status = reshape.execute(block);
ASSERT_EQ(ND4J_STATUS_OK, status);
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
ASSERT_TRUE(result->isSameShape(y));
delete y;
delete block;
delete variableSpace;
}
TEST_F(DeclarableOpsTests14, Reshape3) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(x.isSameShape(z));
}
TEST_F(DeclarableOpsTests14, Reshape4) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 3, 4, 5 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(x.isSameShape(z));
}
TEST_F(DeclarableOpsTests14, Reshape5) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 5, 4, 3 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
}
TEST_F(DeclarableOpsTests14, Reshape6) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
auto exp = NDArrayFactory::create<float>('c', { 4, 15 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { 4, -1 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(z->isSameShape(exp));
}
TEST_F(DeclarableOpsTests14, Reshape7) {
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
auto exp = NDArrayFactory::create<float>('c', { 60 });
sd::ops::reshape op;
auto result = op.evaluate({ &x }, {}, { -1 });
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(z->isSameShape(exp));
}
TEST_F(DeclarableOpsTests14, Reshape8) {
auto x = NDArrayFactory::create<double>('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0});
auto e = NDArrayFactory::create<double>('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0});
auto r = x.reshape('c', {3, 2});;
r.streamline('f');
sd::ops::reshape op;
auto result = op.evaluate({&x}, {3, 2});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
}
TEST_F(DeclarableOpsTests14, Reshape9) {
auto array = NDArrayFactory::create<float>(119.f);
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
sd::ops::reshape op;
auto result = op.evaluate({&array}, {}, {1, 1});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
ASSERT_EQ(e, *z);
}
TEST_F(DeclarableOpsTests14, Reshape10) {
auto array = NDArrayFactory::create<float>(119.f);
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
auto z = NDArrayFactory::create<float>('c', {1, 1});
sd::ops::reshape op;
auto result = op.execute({&array}, {&z}, {}, {1, 1}, {});
ASSERT_EQ(Status::OK(), result);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests14, Reshape11) {
auto x = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('c', {4, 3});
x.linspace(1);
exp.linspace(1);
sd::ops::reshape op;
auto result = op.evaluate({&x}, {-99, 4, 3});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape12) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
auto shape = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 2});
auto exp = NDArrayFactory::create<double>('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
sd::ops::reshape op;
auto result = op.evaluate({&x, &shape});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape13) {
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
auto exp = NDArrayFactory::create<float>(119.f);
auto empty = NDArrayFactory::empty_<int>();
sd::ops::reshape op;
auto result = op.evaluate({&vector, empty}, {}, {});
ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(exp, *result.at(0));
delete empty;
}
TEST_F(DeclarableOpsTests14, Reshape14) {
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
auto e = NDArrayFactory::create<float>('c', {10, 0});
sd::ops::reshape op;
auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_EQ(e, *z);
} }
TEST_F(DeclarableOpsTests14, Reshape15) {
auto x0 = NDArrayFactory::create<float>('c', {2, 0});
auto x1 = NDArrayFactory::create<float>('c', {0, 1, 2});
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
sd::ops::reshape op;
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
ASSERT_EQ(Status::OK(), result0.status());
auto z0 = result0.at(0);
ASSERT_EQ(e0, *z0);
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
ASSERT_EQ(Status::OK(), result1.status());
auto z1 = result1.at(0);
ASSERT_EQ(e1, *z1);
}
TEST_F(DeclarableOpsTests14, Reshape16) {
auto x = NDArrayFactory::create<int>('c', {2, 2}, {1, 2, 3, 4});
auto shape = NDArrayFactory::create<int>('c', {1, 3}, {1, 2, 2});
auto exp = NDArrayFactory::create<int>('c', {1, 2, 2}, {1, 2, 3, 4});
sd::ops::reshape op;
auto result = op.evaluate({&x, &shape}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape17) {
auto x = NDArrayFactory::create<float>(2.0f);
auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape18) {
auto x = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 3});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape19) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests14, Reshape20) {
NDArray x1('c', {2,0}, sd::DataType::FLOAT32);
NDArray x2('c', {10,0}, sd::DataType::FLOAT32);
NDArray x3('c', {2,0,0,10}, sd::DataType::FLOAT32);
NDArray x4('c', {0,0,10}, sd::DataType::FLOAT32);
NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32);
NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32);
NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32);
sd::ops::reshape op;
auto result = op.evaluate({&x1}, {}, {2, -1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2,0}));
result = op.evaluate({&x2}, {}, {2, 0, -1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2,0,5}));
result = op.evaluate({&x2}, {}, {5, 2, -1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({5,2,0}));
result = op.evaluate({&x2}, {}, {-1, 2, 0});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({5,2,0}));
result = op.evaluate({&x3}, {}, {2, 0, -1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2,0,10}));
result = op.evaluate({&x4}, {}, {2, -1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2,5,0}));
result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2,0,0,0,10}));
result = op.evaluate({&x6}, {}, {-1, 2, 0});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({5, 2, 0}));
result = op.evaluate({&x7}, {}, {-1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2, 0}));
result = op.evaluate({&x7}, {}, {10,0,50,100});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100}));
}

View File

@ -710,30 +710,6 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) {
ASSERT_NE(*resultA0.at(0), *resultB0.at(0)); ASSERT_NE(*resultA0.at(0), *resultB0.at(0));
} }
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) {
auto array = NDArrayFactory::create<float>(119.f);
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
sd::ops::reshape op;
auto result = op.evaluate({&array}, {}, {1, 1});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
ASSERT_EQ(e, *z);
}
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_2) {
auto array = NDArrayFactory::create<float>(119.f);
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
auto z = NDArrayFactory::create<float>('c', {1, 1});
sd::ops::reshape op;
auto result = op.execute({&array}, {&z}, {}, {1, 1}, {});
ASSERT_EQ(Status::OK(), result);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests15, test_rank_1) { TEST_F(DeclarableOpsTests15, test_rank_1) {
auto array = NDArrayFactory::create<float>('c', {4, 64}); auto array = NDArrayFactory::create<float>('c', {4, 64});
auto e = NDArrayFactory::create<int>('c', {}, {2}); auto e = NDArrayFactory::create<int>('c', {}, {2});

View File

@ -789,24 +789,6 @@ TEST_F(DeclarableOpsTests4, Test_FloorTests_1) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
auto x = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('c', {4, 3});
x.linspace(1);
exp.linspace(1);
sd::ops::reshape op;
auto result = op.evaluate({&x}, {-99, 4, 3});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_Split_1) { TEST_F(DeclarableOpsTests4, Test_Split_1) {
@ -1209,23 +1191,6 @@ TEST_F(DeclarableOpsTests4, Test_Add_119) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
auto shape = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 2});
auto exp = NDArrayFactory::create<double>('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
sd::ops::reshape op;
auto result = op.evaluate({&x, &shape});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_TileToShape_1) { TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {

View File

@ -140,37 +140,6 @@ TEST_F(EmptyTests, Test_Concat_4) {
ASSERT_EQ(exp, *z); ASSERT_EQ(exp, *z);
} }
TEST_F(EmptyTests, Test_Reshape_1) {
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
auto exp = NDArrayFactory::create<float>(119.f);
auto empty = NDArrayFactory::empty_<int>();
sd::ops::reshape op;
auto result = op.evaluate({&vector, empty}, {}, {});
ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(exp, *result.at(0));
delete empty;
}
TEST_F(EmptyTests, Test_Reshape_3) {
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
auto e = NDArrayFactory::create<float>('c', {10, 0});
sd::ops::reshape op;
auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_EQ(e, *z);
}
TEST_F(EmptyTests, Test_dup_1) { TEST_F(EmptyTests, Test_dup_1) {
auto empty = NDArrayFactory::empty<int>(); auto empty = NDArrayFactory::empty<int>();
auto dup = new NDArray(empty.dup()); auto dup = new NDArray(empty.dup());
@ -256,41 +225,6 @@ TEST_F(EmptyTests, test_shaped_empty_4) {
ASSERT_EQ(shapeOf, array.getShapeAsVector()); ASSERT_EQ(shapeOf, array.getShapeAsVector());
} }
TEST_F(EmptyTests, test_empty_reshape_1) {
/*
INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0);
INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2);
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0];
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0];
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0];
assertArrayEquals(new long[]{2, 0, 0}, out0.shape());
assertArrayEquals(new long[]{0, 1}, out1.shape());
assertArrayEquals(new long[]{10, 0}, out2.shape());
*/
auto x0 = NDArrayFactory::create<float>('c', {2, 0});
auto x1 = NDArrayFactory::create<float>('c', {0, 1, 2});
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
sd::ops::reshape op;
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
ASSERT_EQ(Status::OK(), result0.status());
auto z0 = result0.at(0);
ASSERT_EQ(e0, *z0);
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
ASSERT_EQ(Status::OK(), result1.status());
auto z1 = result1.at(0);
ASSERT_EQ(e1, *z1);
}
TEST_F(EmptyTests, test_empty_matmul_1) { TEST_F(EmptyTests, test_empty_matmul_1) {
auto x = NDArrayFactory::create<float>('c', {0, 1}); auto x = NDArrayFactory::create<float>('c', {0, 1});

View File

@ -669,24 +669,6 @@ TEST_F(ParityOpsTests, Test_Select_3) {
} }
TEST_F(ParityOpsTests, Test_Reshape_TF_1) {
auto x = NDArrayFactory::create<int>('c', {2, 2}, {1, 2, 3, 4});
auto shape = NDArrayFactory::create<int>('c', {1, 3}, {1, 2, 2});
auto exp = NDArrayFactory::create<int>('c', {1, 2, 2}, {1, 2, 3, 4});
sd::ops::reshape op;
auto result = op.evaluate({&x, &shape}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(ParityOpsTests, Test_Bias_Add_1) { TEST_F(ParityOpsTests, Test_Bias_Add_1) {
auto x = NDArrayFactory::create<float>('c', {10, 5}); auto x = NDArrayFactory::create<float>('c', {10, 5});
x.assign(0.0); x.assign(0.0);

View File

@ -182,24 +182,6 @@ TEST_F(ScalarTests, Test_Squeeze_1) {
} }
TEST_F(ScalarTests, Test_Reshape_1) {
auto x = NDArrayFactory::create<float>(2.0f);
auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(ScalarTests, Test_Permute_1) { TEST_F(ScalarTests, Test_Permute_1) {
auto x = NDArrayFactory::create<float>(3.0f); auto x = NDArrayFactory::create<float>(3.0f);
auto exp = NDArrayFactory::create<float>(3.0f); auto exp = NDArrayFactory::create<float>(3.0f);

View File

@ -168,39 +168,6 @@ TEST_F(SingleDimTests, Test_Squeeze_2) {
} }
TEST_F(SingleDimTests, Test_Reshape_1) {
auto x = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 3});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(SingleDimTests, Test_Reshape_2) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(SingleDimTests, Test_Permute_1) { TEST_F(SingleDimTests, Test_Permute_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3}); auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3}); auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});