Shyrma reshape empty (#338)
* - start working on reshape op which operates with empty shapes Signed-off-by: Yurii <iuriish@yahoo.com> * - correct reshaping for empty arrays Signed-off-by: Yurii <iuriish@yahoo.com> * - remove unnecessary check in Loopkind Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
bf0ddbc06c
commit
29e61579c1
|
@ -1944,7 +1944,7 @@ void NDArray::tilei(const std::vector<Nd4jLong>& reps) {
|
||||||
Nd4jLong NDArray::sizeAt(const int dim) const {
|
Nd4jLong NDArray::sizeAt(const int dim) const {
|
||||||
|
|
||||||
if (dim >= this->rankOf() || dim < -this->rankOf())
|
if (dim >= this->rankOf() || dim < -this->rankOf())
|
||||||
throw std::runtime_error("Bad size index requested");
|
throw std::runtime_error("NDArray::sizeAt: bad size index requested");
|
||||||
|
|
||||||
if (dim >= 0)
|
if (dim >= 0)
|
||||||
return shape::shapeOf(_shapeInfo)[dim];
|
return shape::shapeOf(_shapeInfo)[dim];
|
||||||
|
|
|
@ -206,7 +206,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const
|
||||||
const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c';
|
const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c';
|
||||||
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';;
|
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';;
|
||||||
|
|
||||||
if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && shape::rank(xShapeInfo) == 2 && xEws == 1 && xOrder == 'c' && xRank == 2 &&
|
if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 &&
|
||||||
tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
||||||
return SMALLARR2DX;
|
return SMALLARR2DX;
|
||||||
if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
||||||
|
|
|
@ -40,106 +40,14 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
||||||
return Status::OK(); //No op
|
return Status::OK(); //No op
|
||||||
}
|
}
|
||||||
|
|
||||||
if (block.width() == 1) {
|
REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf());
|
||||||
|
|
||||||
auto arguments = block.getIArguments();
|
if (Environment::getInstance()->isDebugAndVerbose())
|
||||||
int argsSize = arguments->size();
|
nd4j_printv("Reshape: new shape", z->getShapeAsVector());
|
||||||
|
|
||||||
|
z->assign(x->reshape(z->ordering(), z->getShapeAsVector()));
|
||||||
|
|
||||||
int e = 1;
|
|
||||||
char order = (char) -(*arguments)[0];
|
|
||||||
if (order != 'c' && order != 'f') {
|
|
||||||
order = 'c'; //x->ordering();
|
|
||||||
e = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew;
|
|
||||||
int e2 = e;
|
|
||||||
for (; e < (int) arguments->size(); e++) {
|
|
||||||
if (arguments->at(e) == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(; e2 < e; e2++){
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
|
||||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
|
||||||
shapeNew.push_back(realShape);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew.push_back(arguments->at(e));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
|
||||||
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
|
||||||
|
|
||||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
|
||||||
nd4j_printv("Reshape: new shape", shapeNew);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto xr = x->reshape(order, shapeNew);
|
|
||||||
z->assign(xr);
|
|
||||||
STORE_RESULT(*z);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
} else if (block.width() == 2) {
|
|
||||||
|
|
||||||
auto s = INPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
char order = 'c';
|
|
||||||
if (block.numI() > 0)
|
|
||||||
order = (char) -INT_ARG(0);
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew(s->lengthOf());
|
|
||||||
|
|
||||||
for (int e = 0; e < (int) s->lengthOf(); e++) {
|
|
||||||
auto dim = s->e<Nd4jLong >(e);
|
|
||||||
if (dim == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(int e2 = 0; e2 < e; e2++){
|
|
||||||
shapeLength *= s->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
|
||||||
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
shapeLength *= s->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
|
||||||
shapeNew[e] = realShape;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew[e] = dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
|
||||||
nd4j_printv("Reshape: new shape", shapeNew);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (s->isScalar()) {
|
|
||||||
// just a scalar
|
|
||||||
z->assign(x);
|
|
||||||
} else {
|
|
||||||
// in some cases we might go away with simple memcpy call instead of assign call
|
|
||||||
if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) {
|
|
||||||
z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset());
|
|
||||||
} else {
|
|
||||||
auto xr = x->reshape(order, shapeNew);
|
|
||||||
z->assign(xr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return ND4J_STATUS_BAD_INPUT;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -151,117 +59,73 @@ DECLARE_TYPES(reshape) {
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(reshape) {
|
DECLARE_SHAPE_FN(reshape) {
|
||||||
auto inp = inputShape->at(0);
|
|
||||||
|
|
||||||
// we can launch op using Int arguments
|
const auto x = INPUT_VARIABLE(0);
|
||||||
if (inputShape->size() == 1) {
|
|
||||||
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
|
|
||||||
std::vector<int> *arguments = block.getIArguments();
|
|
||||||
|
|
||||||
int e = 1;
|
|
||||||
char order = (char) -(*arguments)[0];
|
|
||||||
if (order != 'c' && order != 'f') {
|
|
||||||
order = shape::order(inp);
|
|
||||||
e = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
std::vector<int> reshapeArgs;
|
||||||
std::vector<Nd4jLong> shapeNew;
|
std::vector<Nd4jLong> shapeNew;
|
||||||
|
char orderNew = 'c';
|
||||||
|
|
||||||
int e2 = e;
|
if (block.width() == 1) {
|
||||||
for (; e < (int) arguments->size(); e++) {
|
reshapeArgs = *block.getIArguments();
|
||||||
if ((int) arguments->at(e) == -1){
|
if(!reshapeArgs.empty()) {
|
||||||
|
orderNew = (char) -reshapeArgs[0];
|
||||||
Nd4jLong shapeLength = 1;
|
if(orderNew == 'c' || orderNew == 'f')
|
||||||
for(; e2 < e; e2 ++){
|
reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case
|
||||||
shapeLength *= arguments->at(e2);
|
|
||||||
}
|
}
|
||||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
}
|
||||||
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
else {
|
||||||
shapeLength *= arguments->at(e2);
|
reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>();
|
||||||
|
orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
if(shapeLength == 0){
|
REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
|
||||||
//Edge case for empty:
|
|
||||||
|
Nd4jLong xLen = x->lengthOf();
|
||||||
|
if(x->isEmpty()) {
|
||||||
|
xLen = 1;
|
||||||
|
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||||
|
if(x->sizeAt(i) != 0)
|
||||||
|
xLen *= x->sizeAt(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint i = 0; i < reshapeArgs.size(); ++i) {
|
||||||
|
|
||||||
|
if (reshapeArgs[i] == -1) {
|
||||||
|
|
||||||
|
uint shapeLength = 1, numOfZeros = 0;
|
||||||
|
|
||||||
|
for(uint j = 0; j < i; ++j)
|
||||||
|
if(reshapeArgs[j] != 0)
|
||||||
|
shapeLength *= reshapeArgs[j];
|
||||||
|
else
|
||||||
|
++numOfZeros;
|
||||||
|
|
||||||
|
for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
|
||||||
|
REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||||
|
if(reshapeArgs[j] != 0)
|
||||||
|
shapeLength *= reshapeArgs[j];
|
||||||
|
else
|
||||||
|
++numOfZeros;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto dim = xLen / shapeLength;
|
||||||
|
|
||||||
|
if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
|
||||||
shapeNew.push_back(0);
|
shapeNew.push_back(0);
|
||||||
} else {
|
else
|
||||||
//Standard case
|
shapeNew.push_back(dim);
|
||||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
|
||||||
shapeNew.push_back(realShape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
shapeNew.push_back(arguments->at(e));
|
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
shapeNew.push_back(reshapeArgs[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
|
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||||
} else {
|
REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
||||||
// or, with second input "as shape"
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
|
||||||
auto y = INPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
// special case here
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew));
|
||||||
if (y->isEmpty()) {
|
|
||||||
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
|
|
||||||
}
|
|
||||||
//Special case: empty.reshape(-1) -> return empty
|
|
||||||
if (x->isEmpty()) {
|
|
||||||
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
|
||||||
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
|
||||||
Nd4jLong prod = 1;
|
|
||||||
bool hasNegs = false;
|
|
||||||
for (auto v:shapeOf) {
|
|
||||||
if (v < 0) {
|
|
||||||
hasNegs = true;
|
|
||||||
v = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
prod *= v;
|
|
||||||
}
|
|
||||||
|
|
||||||
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
|
||||||
|
|
||||||
// if there are -1s - we turn them into zeros
|
|
||||||
if (hasNegs) {
|
|
||||||
for (int e = 0; e < shapeOf.size(); e++)
|
|
||||||
if (shapeOf[e] < 0)
|
|
||||||
shapeOf[e] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
|
||||||
return SHAPELIST(CONSTANT(newShape));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
|
||||||
|
|
||||||
for (int e = 0; e < (int) y->lengthOf(); e++) {
|
|
||||||
auto dim = y->e<Nd4jLong>(e);
|
|
||||||
if (dim == -1){
|
|
||||||
Nd4jLong shapeLength = 1;
|
|
||||||
for(int e2 = 0; e2 < e; e2++){
|
|
||||||
shapeLength *= y->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
|
|
||||||
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
shapeLength *= y->e<Nd4jLong>(e2);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(shapeLength == 0){
|
|
||||||
//Edge case for empty:
|
|
||||||
shapeNew[e] = 0;
|
|
||||||
} else {
|
|
||||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
|
||||||
shapeNew[e] = realShape;
|
|
||||||
}
|
|
||||||
}else {
|
|
||||||
shapeNew[e] = dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,6 @@ TEST_F(ArrayOptionsTests, TestShape_Basic_1) {
|
||||||
ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
|
ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(ArrayOptionsTests, TestShape_Basic_2) {
|
TEST_F(ArrayOptionsTests, TestShape_Basic_2) {
|
||||||
shape[5] = 258;
|
shape[5] = 258;
|
||||||
|
|
||||||
|
|
|
@ -1676,31 +1676,6 @@ TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) {
|
||||||
delete block;
|
delete block;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshapeas1) {
|
|
||||||
const std::vector<Nd4jLong> xShape = { 5,4,3 };
|
|
||||||
const std::vector<Nd4jLong> yShape = { 3,5,4 };
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create_<float>('f', xShape);
|
|
||||||
auto y = NDArrayFactory::create_<float>('f', yShape);
|
|
||||||
|
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
|
||||||
variableSpace->putVariable(-1, x);
|
|
||||||
variableSpace->putVariable(-2, y);
|
|
||||||
auto block = new Context(1, variableSpace, true);
|
|
||||||
block->fillInputs({ -1, -2 });
|
|
||||||
|
|
||||||
sd::ops::reshapeas reshape;
|
|
||||||
|
|
||||||
reshape.execute(block);
|
|
||||||
|
|
||||||
ASSERT_TRUE(x->isSameShape(y));
|
|
||||||
|
|
||||||
delete variableSpace;
|
|
||||||
delete block;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Test_Cast_1) {
|
TEST_F(DeclarableOpsTests1, Test_Cast_1) {
|
||||||
// TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected
|
// TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected
|
||||||
auto x = NDArrayFactory::create<float>('c', { 5, 5 });
|
auto x = NDArrayFactory::create<float>('c', { 5, 5 });
|
||||||
|
@ -1848,113 +1823,6 @@ TEST_F(DeclarableOpsTests1, TestGemv1) {
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape2) {
|
|
||||||
const std::vector<Nd4jLong> xShape = { 5,4,3 };
|
|
||||||
const std::vector<Nd4jLong> yShape = { 3,5,4 };
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create_<float>('c', xShape);
|
|
||||||
auto y = NDArrayFactory::create_<float>('c', yShape);
|
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
|
||||||
variableSpace->putVariable(-1, x);
|
|
||||||
variableSpace->putVariable(1, new Variable());
|
|
||||||
|
|
||||||
auto block = new Context(1, variableSpace, false);
|
|
||||||
block->fillInputs({ -1 });
|
|
||||||
std::vector<int>* arguments = block->getIArguments();
|
|
||||||
arguments->push_back(-y->ordering());
|
|
||||||
arguments->push_back(3);
|
|
||||||
arguments->push_back(5);
|
|
||||||
arguments->push_back(4);
|
|
||||||
|
|
||||||
sd::ops::reshape reshape;
|
|
||||||
|
|
||||||
Nd4jStatus status = reshape.execute(block);
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
||||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
||||||
|
|
||||||
ASSERT_TRUE(result->isSameShape(y));
|
|
||||||
|
|
||||||
delete y;
|
|
||||||
delete block;
|
|
||||||
delete variableSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape3) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 });
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(x.isSameShape(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape4) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({ &x }, {}, { 3, 4, 5 });
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(x.isSameShape(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape5) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({ &x }, {}, { 5, 4, 3 });
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape6) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', { 4, 15 });
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({ &x }, {}, { 4, -1 });
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(z->isSameShape(exp));
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape7) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', { 60 });
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({ &x }, {}, { -1 });
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(z->isSameShape(exp));
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, Transpose1) {
|
TEST_F(DeclarableOpsTests1, Transpose1) {
|
||||||
|
@ -1983,7 +1851,6 @@ TEST_F(DeclarableOpsTests1, Transpose1) {
|
||||||
delete variableSpace;
|
delete variableSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
// not-in-place
|
// not-in-place
|
||||||
TEST_F(DeclarableOpsTests1, Permute1) {
|
TEST_F(DeclarableOpsTests1, Permute1) {
|
||||||
|
|
|
@ -52,22 +52,6 @@ TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) {
|
||||||
ASSERT_EQ(exp, *z);
|
ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Test_Reshape_CF_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0});
|
|
||||||
auto e = NDArrayFactory::create<double>('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0});
|
|
||||||
|
|
||||||
auto r = x.reshape('c', {3, 2});;
|
|
||||||
r.streamline('f');
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x}, {3, 2});
|
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) {
|
TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) {
|
||||||
|
@ -2073,7 +2057,363 @@ TEST_F(DeclarableOpsTests14, Stack_21) {
|
||||||
|
|
||||||
ASSERT_TRUE(outStack->isSameShape(outConcat));
|
ASSERT_TRUE(outStack->isSameShape(outConcat));
|
||||||
ASSERT_TRUE(outStack->equalsTo(outConcat));
|
ASSERT_TRUE(outStack->equalsTo(outConcat));
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape1) {
|
||||||
|
const std::vector<Nd4jLong> xShape = { 5,4,3 };
|
||||||
|
const std::vector<Nd4jLong> yShape = { 3,5,4 };
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create_<float>('f', xShape);
|
||||||
|
auto y = NDArrayFactory::create_<float>('f', yShape);
|
||||||
|
|
||||||
|
|
||||||
|
auto variableSpace = new VariableSpace();
|
||||||
|
variableSpace->putVariable(-1, x);
|
||||||
|
variableSpace->putVariable(-2, y);
|
||||||
|
auto block = new Context(1, variableSpace, true);
|
||||||
|
block->fillInputs({ -1, -2 });
|
||||||
|
|
||||||
|
sd::ops::reshapeas reshape;
|
||||||
|
|
||||||
|
reshape.execute(block);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x->isSameShape(y));
|
||||||
|
|
||||||
|
delete variableSpace;
|
||||||
|
delete block;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape2) {
|
||||||
|
const std::vector<Nd4jLong> xShape = { 5,4,3 };
|
||||||
|
const std::vector<Nd4jLong> yShape = { 3,5,4 };
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create_<float>('c', xShape);
|
||||||
|
auto y = NDArrayFactory::create_<float>('c', yShape);
|
||||||
|
|
||||||
|
auto variableSpace = new VariableSpace();
|
||||||
|
variableSpace->putVariable(-1, x);
|
||||||
|
variableSpace->putVariable(1, new Variable());
|
||||||
|
|
||||||
|
auto block = new Context(1, variableSpace, false);
|
||||||
|
block->fillInputs({ -1 });
|
||||||
|
std::vector<int>* arguments = block->getIArguments();
|
||||||
|
arguments->push_back(-y->ordering());
|
||||||
|
arguments->push_back(3);
|
||||||
|
arguments->push_back(5);
|
||||||
|
arguments->push_back(4);
|
||||||
|
|
||||||
|
sd::ops::reshape reshape;
|
||||||
|
|
||||||
|
Nd4jStatus status = reshape.execute(block);
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_TRUE(result->isSameShape(y));
|
||||||
|
|
||||||
|
delete y;
|
||||||
|
delete block;
|
||||||
|
delete variableSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape3) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.isSameShape(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape4) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({ &x }, {}, { 3, 4, 5 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.isSameShape(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape5) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({ &x }, {}, { 5, 4, 3 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape6) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', { 4, 15 });
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({ &x }, {}, { 4, -1 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(z->isSameShape(exp));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape7) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { 3, 4, 5 });
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', { 60 });
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({ &x }, {}, { -1 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(z->isSameShape(exp));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape8) {
|
||||||
|
auto x = NDArrayFactory::create<double>('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0});
|
||||||
|
auto e = NDArrayFactory::create<double>('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0});
|
||||||
|
|
||||||
|
auto r = x.reshape('c', {3, 2});;
|
||||||
|
r.streamline('f');
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x}, {3, 2});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape9) {
|
||||||
|
auto array = NDArrayFactory::create<float>(119.f);
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&array}, {}, {1, 1});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape10) {
|
||||||
|
auto array = NDArrayFactory::create<float>(119.f);
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {1, 1});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.execute({&array}, {&z}, {}, {1, 1}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape11) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {4, 3});
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
exp.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x}, {-99, 4, 3});
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape12) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
auto shape = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 2});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x, &shape});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape13) {
|
||||||
|
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
|
||||||
|
auto exp = NDArrayFactory::create<float>(119.f);
|
||||||
|
auto empty = NDArrayFactory::empty_<int>();
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&vector, empty}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, *result.at(0));
|
||||||
|
|
||||||
|
delete empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape14) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {10, 0});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape15) {
|
||||||
|
|
||||||
|
auto x0 = NDArrayFactory::create<float>('c', {2, 0});
|
||||||
|
auto x1 = NDArrayFactory::create<float>('c', {0, 1, 2});
|
||||||
|
|
||||||
|
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
|
||||||
|
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
|
||||||
|
|
||||||
|
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
|
||||||
|
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result0.status());
|
||||||
|
auto z0 = result0.at(0);
|
||||||
|
ASSERT_EQ(e0, *z0);
|
||||||
|
|
||||||
|
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result1.status());
|
||||||
|
auto z1 = result1.at(0);
|
||||||
|
ASSERT_EQ(e1, *z1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape16) {
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
auto shape = NDArrayFactory::create<int>('c', {1, 3}, {1, 2, 2});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<int>('c', {1, 2, 2}, {1, 2, 3, 4});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x, &shape}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape17) {
|
||||||
|
auto x = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape18) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x}, {}, {-99, 3});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape19) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, Reshape20) {
|
||||||
|
|
||||||
|
NDArray x1('c', {2,0}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x2('c', {10,0}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x3('c', {2,0,0,10}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x4('c', {0,0,10}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
sd::ops::reshape op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x1}, {}, {2, -1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2,0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x2}, {}, {2, 0, -1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2,0,5}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x2}, {}, {5, 2, -1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({5,2,0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x2}, {}, {-1, 2, 0});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({5,2,0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x3}, {}, {2, 0, -1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2,0,10}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x4}, {}, {2, -1, 0});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2,5,0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2,0,0,0,10}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x6}, {}, {-1, 2, 0});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({5, 2, 0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x7}, {}, {-1, 0});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2, 0}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x7}, {}, {10,0,50,100});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100}));
|
||||||
|
}
|
||||||
|
|
|
@ -710,30 +710,6 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) {
|
||||||
ASSERT_NE(*resultA0.at(0), *resultB0.at(0));
|
ASSERT_NE(*resultA0.at(0), *resultB0.at(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) {
|
|
||||||
auto array = NDArrayFactory::create<float>(119.f);
|
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&array}, {}, {1, 1});
|
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_2) {
|
|
||||||
auto array = NDArrayFactory::create<float>(119.f);
|
|
||||||
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {1, 1});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.execute({&array}, {&z}, {}, {1, 1}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result);
|
|
||||||
ASSERT_EQ(e, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_rank_1) {
|
TEST_F(DeclarableOpsTests15, test_rank_1) {
|
||||||
auto array = NDArrayFactory::create<float>('c', {4, 64});
|
auto array = NDArrayFactory::create<float>('c', {4, 64});
|
||||||
auto e = NDArrayFactory::create<int>('c', {}, {2});
|
auto e = NDArrayFactory::create<int>('c', {}, {2});
|
||||||
|
|
|
@ -789,24 +789,6 @@ TEST_F(DeclarableOpsTests4, Test_FloorTests_1) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4, 3});
|
|
||||||
|
|
||||||
x.linspace(1);
|
|
||||||
exp.linspace(1);
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x}, {-99, 4, 3});
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Split_1) {
|
TEST_F(DeclarableOpsTests4, Test_Split_1) {
|
||||||
|
@ -1209,23 +1191,6 @@ TEST_F(DeclarableOpsTests4, Test_Add_119) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
|
||||||
auto shape = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 2});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x, &shape});
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {
|
TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {
|
||||||
|
|
|
@ -140,37 +140,6 @@ TEST_F(EmptyTests, Test_Concat_4) {
|
||||||
ASSERT_EQ(exp, *z);
|
ASSERT_EQ(exp, *z);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Reshape_1) {
|
|
||||||
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
|
|
||||||
auto exp = NDArrayFactory::create<float>(119.f);
|
|
||||||
auto empty = NDArrayFactory::empty_<int>();
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&vector, empty}, {}, {});
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
|
||||||
|
|
||||||
ASSERT_EQ(exp, *result.at(0));
|
|
||||||
|
|
||||||
delete empty;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Reshape_3) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
|
|
||||||
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
|
|
||||||
auto e = NDArrayFactory::create<float>('c', {10, 0});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x, &y}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
|
||||||
ASSERT_EQ(e, *z);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_dup_1) {
|
TEST_F(EmptyTests, Test_dup_1) {
|
||||||
auto empty = NDArrayFactory::empty<int>();
|
auto empty = NDArrayFactory::empty<int>();
|
||||||
auto dup = new NDArray(empty.dup());
|
auto dup = new NDArray(empty.dup());
|
||||||
|
@ -256,41 +225,6 @@ TEST_F(EmptyTests, test_shaped_empty_4) {
|
||||||
ASSERT_EQ(shapeOf, array.getShapeAsVector());
|
ASSERT_EQ(shapeOf, array.getShapeAsVector());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, test_empty_reshape_1) {
|
|
||||||
/*
|
|
||||||
INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0);
|
|
||||||
INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2);
|
|
||||||
|
|
||||||
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0];
|
|
||||||
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0];
|
|
||||||
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0];
|
|
||||||
|
|
||||||
assertArrayEquals(new long[]{2, 0, 0}, out0.shape());
|
|
||||||
assertArrayEquals(new long[]{0, 1}, out1.shape());
|
|
||||||
assertArrayEquals(new long[]{10, 0}, out2.shape());
|
|
||||||
*/
|
|
||||||
auto x0 = NDArrayFactory::create<float>('c', {2, 0});
|
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {0, 1, 2});
|
|
||||||
|
|
||||||
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
|
|
||||||
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
|
|
||||||
|
|
||||||
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
|
|
||||||
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result0.status());
|
|
||||||
auto z0 = result0.at(0);
|
|
||||||
ASSERT_EQ(e0, *z0);
|
|
||||||
|
|
||||||
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result1.status());
|
|
||||||
auto z1 = result1.at(0);
|
|
||||||
ASSERT_EQ(e1, *z1);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(EmptyTests, test_empty_matmul_1) {
|
TEST_F(EmptyTests, test_empty_matmul_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {0, 1});
|
auto x = NDArrayFactory::create<float>('c', {0, 1});
|
||||||
|
|
|
@ -669,24 +669,6 @@ TEST_F(ParityOpsTests, Test_Select_3) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Reshape_TF_1) {
|
|
||||||
auto x = NDArrayFactory::create<int>('c', {2, 2}, {1, 2, 3, 4});
|
|
||||||
auto shape = NDArrayFactory::create<int>('c', {1, 3}, {1, 2, 2});
|
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<int>('c', {1, 2, 2}, {1, 2, 3, 4});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
|
|
||||||
auto result = op.evaluate({&x, &shape}, {}, {});
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Bias_Add_1) {
|
TEST_F(ParityOpsTests, Test_Bias_Add_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {10, 5});
|
auto x = NDArrayFactory::create<float>('c', {10, 5});
|
||||||
x.assign(0.0);
|
x.assign(0.0);
|
||||||
|
|
|
@ -182,24 +182,6 @@ TEST_F(ScalarTests, Test_Squeeze_1) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(ScalarTests, Test_Reshape_1) {
|
|
||||||
auto x = NDArrayFactory::create<float>(2.0f);
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 1, 1}, {2.0f});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(ScalarTests, Test_Permute_1) {
|
TEST_F(ScalarTests, Test_Permute_1) {
|
||||||
auto x = NDArrayFactory::create<float>(3.0f);
|
auto x = NDArrayFactory::create<float>(3.0f);
|
||||||
auto exp = NDArrayFactory::create<float>(3.0f);
|
auto exp = NDArrayFactory::create<float>(3.0f);
|
||||||
|
|
|
@ -168,39 +168,6 @@ TEST_F(SingleDimTests, Test_Squeeze_2) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Reshape_1) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x}, {}, {-99, 3});
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Reshape_2) {
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
|
||||||
|
|
||||||
sd::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Permute_1) {
|
TEST_F(SingleDimTests, Test_Permute_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
auto x = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||||
|
|
Loading…
Reference in New Issue