/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by raver119 on 29/10/17. // #include #if NOT_EXCLUDED(OP_reshape) #include namespace nd4j { namespace ops { ////////////////////////////////////////////////////////////////////////// // here iArgs is a vector with (optional) negative of order as first element: // ({-order, dim1, dim2, dim3, ...}) CUSTOM_OP_IMPL(reshape, 1, 1, true, 0, -2) { auto x = INPUT_VARIABLE(0); if (block.width() == 1) { auto arguments = block.getIArguments(); int argsSize = arguments->size(); //Special case: empty.reshape() -> return empty if (x->isEmpty()) { REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); return ND4J_STATUS_OK; //No op } int e = 1; char order = (char) -(*arguments)[0]; if (order != 'c' && order != 'f') { order = 'c'; //x->ordering(); e = 0; } REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension"); std::vector shapeNew; int e2 = e; for (; e < (int) arguments->size(); e++) { if (arguments->at(e) == -1){ Nd4jLong shapeLength = 1; for(; e2 < e; e2++){ shapeLength *= arguments->at(e2); } for(e2 = e + 1; e2 < arguments->size(); e2++){ shapeLength *= arguments->at(e2); } Nd4jLong realShape = x->lengthOf() / shapeLength; shapeNew.push_back(realShape); } else{ shapeNew.push_back(arguments->at(e)); } } auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len); if (Environment::getInstance()->isDebugAndVerbose()) { nd4j_printv("Reshape: new shape", shapeNew); } if (block.isInplace()) { if (x->reshapei(order, shapeNew)) { STORE_RESULT(*x); return ND4J_STATUS_OK; } } else { auto ret = OUTPUT_VARIABLE(0); auto xr = x->reshape(order, shapeNew); ret->assign(xr); STORE_RESULT(*ret); return Status::OK(); } } else if (block.width() == 2) { auto s = INPUT_VARIABLE(1); //Special case: empty.reshape(-1) -> return empty if (x->isEmpty()) { //REQUIRE_TRUE(s->lengthOf() == 1 && s->e(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); return Status::OK(); //No op } char order = 'c'; if (block.numI() > 0) order = (char) -INT_ARG(0); std::vector shapeNew(s->lengthOf()); for (int e = 0; e < (int) s->lengthOf(); e++) { auto dim = s->e(e); if (dim == -1){ Nd4jLong shapeLength = 1; for(int e2 = 0; e2 < e; e2++){ shapeLength *= s->e(e2); } for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){ REQUIRE_TRUE(s->e(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); shapeLength *= s->e(e2); } Nd4jLong realShape = x->lengthOf() / shapeLength; shapeNew[e] = realShape; } else{ shapeNew[e] = dim; } } if (Environment::getInstance()->isDebugAndVerbose()) { nd4j_printv("Reshape: new shape", shapeNew); } if (block.isInplace()) { if (x->reshapei(order, shapeNew)) { STORE_RESULT(*x); return Status::OK(); } } else { auto ret = OUTPUT_VARIABLE(0); if (s->isEmpty()) { // just a scalar ret->assign(x); } else { auto xr = x->reshape(order, shapeNew); ret->assign(xr); } return Status::OK(); } } return ND4J_STATUS_BAD_INPUT; } DECLARE_TYPES(reshape) { getOpDescriptor() ->setAllowedInputTypes(0, nd4j::DataType::ANY) ->setAllowedInputTypes(1, {ALL_INTS}) ->setSameMode(true); } DECLARE_SHAPE_FN(reshape) { auto inp = inputShape->at(0); // we can launch op using Int arguments if (inputShape->size() == 1) { std::vector *arguments = block.getIArguments(); int e = 1; char order = (char) -(*arguments)[0]; if (order != 'c' && order != 'f') { order = shape::order(inp); e = 0; } // //Special case: empty.reshape(-1) -> return empty // if (INPUT_VARIABLE(0)->isEmpty()) { // // // auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp)); // return SHAPELIST(newShape); // } std::vector shapeNew; int e2 = e; for (; e < (int) arguments->size(); e++) { if ((int) arguments->at(e) == -1){ Nd4jLong shapeLength = 1; for(; e2 < e; e2 ++){ shapeLength *= arguments->at(e2); } for(e2 = e + 1; e2 < arguments->size(); e2++){ REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); shapeLength *= arguments->at(e2); } if(shapeLength == 0){ //Edge case for empty: shapeNew.push_back(0); } else { //Standard case Nd4jLong realShape = shape::length(inp) / shapeLength; shapeNew.push_back(realShape); } } else{ shapeNew.push_back(arguments->at(e)); } } return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew))); } else { // or, with second input "as shape" auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); // special case here if (y->isEmpty()) { REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array"); return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp))); } //Special case: empty.reshape(-1) -> return empty if (x->isEmpty()) { //REQUIRE_TRUE(y->lengthOf() == 1 && y->e(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); auto shapeOf = y->getBufferAsVector(); Nd4jLong prod = 1; for (auto v:shapeOf) prod *= v; REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well"); auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data()); return SHAPELIST(CONSTANT(newShape)); } std::vector shapeNew(y->lengthOf()); for (int e = 0; e < (int) y->lengthOf(); e++) { auto dim = y->e(e); if (dim == -1){ Nd4jLong shapeLength = 1; for(int e2 = 0; e2 < e; e2++){ shapeLength *= y->e(e2); } for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){ REQUIRE_TRUE(y->e(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); shapeLength *= y->e(e2); } if(shapeLength == 0){ //Edge case for empty: shapeNew[e] = 0; } else { Nd4jLong realShape = shape::length(inp) / shapeLength; shapeNew[e] = realShape; } }else { shapeNew[e] = dim; } } return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew)); } } } } #endif