From 04209693f50b0e66464ccc3681d34023cc583ae6 Mon Sep 17 00:00:00 2001 From: agibsonccc Date: Fri, 5 Feb 2021 22:57:57 +0900 Subject: [PATCH] Update reshape.cpp --- .../ops/declarable/generic/shape/reshape.cpp | 269 ++++++++---------- 1 file changed, 121 insertions(+), 148 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index be581bfc0..bca23c1cc 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -21,178 +21,151 @@ // #include - #if NOT_EXCLUDED(OP_reshape) - - #include - - namespace sd { - namespace ops { +#if NOT_EXCLUDED(OP_reshape) +#include +namespace sd { +namespace ops { ////////////////////////////////////////////////////////////////////////// // here iArgs is a vector with (optional) negative of order as first element: // ({-order, dim1, dim2, dim3, ...}) - CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { +CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //Special case: empty.reshape() -> return empty - if (x->isEmpty()) { - REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); - return Status::OK(); //No op - } + // Special case: empty.reshape() -> return empty + if (x->isEmpty()) { + REQUIRE_TRUE(z->isEmpty(), 0, + "Reshape: when input is empty, output must also be empty"); + return Status::OK(); // No op + } - REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf()); + REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, + "Reshape: lengths before and after reshape should match, but " + "got %i vs %i", + x->lengthOf(), z->lengthOf()); - if (Environment::getInstance().isDebugAndVerbose()) - nd4j_printv("Reshape: new shape", z->getShapeAsVector()); + if (Environment::getInstance().isDebugAndVerbose()) + nd4j_printv("Reshape: new shape", z->getShapeAsVector()); - z->assign(x->reshape(z->ordering(), z->getShapeAsVector())); + z->assign(x->reshape(z->ordering(), z->getShapeAsVector())); - return Status::OK(); - } + return Status::OK(); +} +DECLARE_TYPES(reshape) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); +} - DECLARE_TYPES(reshape) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); - } +DECLARE_SHAPE_FN(reshape) { - DECLARE_SHAPE_FN(reshape) { + const auto x = INPUT_VARIABLE(0); - const auto x = INPUT_VARIABLE(0); + std::vector reshapeArgs; + std::vector shapeNew; + char orderNew = 'c'; + /** + * NOTE: The value here is negative as a flag. + * A negative value signifies 1 of 3 values: + * -1 -> dynamic shape + * -99 -> c ordering + * -102 -> f ordering + * + */ + if (block.width() == 1) { + reshapeArgs = *block.getIArguments(); + if (!reshapeArgs.empty()) { + char potentialOrdering = (char)-reshapeArgs[0]; + orderNew = potentialOrdering; + if (potentialOrdering != 'c' && potentialOrdering != 'f') { + throw std::runtime_error( + "reshape:: Value passed in must be -99 or -102 for the ordering if " + "an int array is present. -99 represents c ordering and -102 " + "represents f ordering. This number is negative for the long array " + "case to flag the difference between an ordering and a dimension " + "being specified."); + } - std::vector reshapeArgs; - std::vector shapeNew; - char orderNew = 'c'; - /** - * NOTE: The value here is negative as a flag. - * A negative value signifies 1 of 3 values: - * -1 -> dynamic shape - * -99 -> c ordering - * -102 -> f ordering - * - */ - if (block.width() == 1) { - reshapeArgs = *block.getIArguments(); - if(!reshapeArgs.empty()) { - char potentialOrdering = (char) -reshapeArgs[0]; - orderNew = potentialOrdering; - if(potentialOrdering != 'c' && potentialOrdering != 'f') { - throw std::runtime_error("reshape:: Value passed in must be -99 or -102 for the ordering if an int array is present. -99 represents c ordering and -102 represents f ordering. This number is negative for the long array case to flag the difference between an ordering and a dimension being specified."); - } + nd4j_debug("Reshape Ordering is %c int ordering is %d\n", orderNew, + -reshapeArgs[0]); - nd4j_debug("Reshape Ordering is %c int ordering is %d\n",orderNew,-reshapeArgs[0]); + if (orderNew == 'c' || orderNew == 'f') + reshapeArgs.erase( + reshapeArgs + .begin()); // remove first element being order in this case + } + } else { + reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector(); + if (block.numI() > 0) { + // Note here that the ordering for this case can not be negative. + // Negative is used in the long array case to be used as a flag to + // differntiate between a 99 or 102 shaped array and + // the ordering. You can't have a -99 or -102 shaped array. + char potentialOrdering = (char)reshapeArgs[0]; + if (potentialOrdering != 'c' && potentialOrdering != 'f') { + throw std::runtime_error( + "reshape:: Value passed in must be -99 or -102 for the ordering if " + "an int array is present. -99 represents c ordering and -102 " + "represents f ordering."); + } - if(orderNew == 'c' || orderNew == 'f') - reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case - } - } - else { - reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector(); - if(block.numI() > 0) { - //Note here that the ordering for this case can not be negative. - // Negative is used in the long array case to be used as a flag to differntiate between a 99 or 102 shaped array and - //the ordering. You can't have a -99 or -102 shaped array. - char potentialOrdering = (char) reshapeArgs[0]; - if(potentialOrdering != 'c' && potentialOrdering != 'f') { - throw std::runtime_error("reshape:: Value passed in must be -99 or -102 for the ordering if an int array is present. -99 represents c ordering and -102 represents f ordering."); - } + orderNew = potentialOrdering; + } else + orderNew = 'c'; + } - orderNew = potentialOrdering; - } - else - orderNew = 'c'; - } + REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, + "Reshape buffer should have at least 1 dimension !"); - REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !"); + Nd4jLong newShapeLen = 1; + int pos = -1; + bool newShapeEmpty = false; - // Nd4jLong xLen = x->lengthOf(); - // if(x->isEmpty()) { - // xLen = 1; - // for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes - // if(x->sizeAt(i) != 0) - // xLen *= x->sizeAt(i); - // } + for (int i = 0; i < reshapeArgs.size(); ++i) { + const int dim = reshapeArgs[i]; + if (dim == -1) { + REQUIRE_TRUE(pos == -1, 0, + "Reshape : Only one unknown dimension (-1) is allowed."); + pos = i; + shapeNew.push_back(1); + } else if (dim == 0) { + shapeNew.push_back(0); + newShapeEmpty = true; + } else { + shapeNew.push_back(dim); + newShapeLen *= dim; + } + } - // for (uint i = 0; i < reshapeArgs.size(); ++i) { + if (pos != -1) { - // if (reshapeArgs[i] == -1) { + Nd4jLong xLen = x->lengthOf(); + if (x->isEmpty()) { + xLen = 1; + for (uint i = 0; i < x->rankOf(); + ++i) // take into account possible empty shapes + if (x->sizeAt(i) > 0 || !newShapeEmpty) + xLen *= x->sizeAt(i); + } - // uint shapeLength = 1, numOfZeros = 0; + shapeNew[pos] = xLen / newShapeLen; + } - // for(uint j = 0; j < i; ++j) - // if(reshapeArgs[j] != 0) - // shapeLength *= reshapeArgs[j]; - // else - // ++numOfZeros; + auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); + REQUIRE_TRUE(x->lengthOf() == len, 0, + "Reshape: lengths before and after reshape should match, but " + "got %i vs %i", + x->lengthOf(), len); - // for(uint j = i + 1; j < reshapeArgs.size(); ++j) { - // REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - // if(reshapeArgs[j] != 0) - // shapeLength *= reshapeArgs[j]; - // else - // ++numOfZeros; - // } + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( + x->dataType(), orderNew, shapeNew)); +} - // const auto dim = xLen / shapeLength; +} // namespace ops +} // namespace sd - // if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) - // shapeNew.push_back(0); - // else - // shapeNew.push_back(dim); - // } - // else - // shapeNew.push_back(reshapeArgs[i]); - // } - - Nd4jLong newShapeLen = 1; - int pos = -1; - bool newShapeEmpty = false; - - for (int i = 0; i < reshapeArgs.size(); ++i) { - - const int dim = reshapeArgs[i]; - - if (dim == -1) { - REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - pos = i; - shapeNew.push_back(1); - } - else if (dim == 0) { - shapeNew.push_back(0); - newShapeEmpty = true; - } - else { - shapeNew.push_back(dim); - newShapeLen *= dim; - } - } - - if (pos != -1) { - - Nd4jLong xLen = x->lengthOf(); - if(x->isEmpty()) { - xLen = 1; - for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes - if(x->sizeAt(i) > 0 || !newShapeEmpty) - xLen *= x->sizeAt(i); - } - - shapeNew[pos] = xLen / newShapeLen; - } - - auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); - REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len); - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(x->dataType(), orderNew, shapeNew)); - } - - - - } - } - - #endif \ No newline at end of file +#endif \ No newline at end of file