Update reshape.cpp
This commit is contained in:
		
							parent
							
								
									e9a1a3d3f1
								
							
						
					
					
						commit
						04209693f5
					
				| @ -21,28 +21,30 @@ | ||||
| //
 | ||||
| 
 | ||||
| #include <system/op_boilerplate.h> | ||||
|         #if NOT_EXCLUDED(OP_reshape) | ||||
| 
 | ||||
|         #include <ops/declarable/CustomOperations.h> | ||||
| 
 | ||||
|         namespace sd { | ||||
|         namespace ops  { | ||||
| #if NOT_EXCLUDED(OP_reshape) | ||||
| #include <ops/declarable/CustomOperations.h> | ||||
| namespace sd { | ||||
| namespace ops { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| // here iArgs is a vector with (optional) negative of order as first element:
 | ||||
| // ({-order, dim1, dim2, dim3, ...})
 | ||||
|         CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { | ||||
| CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { | ||||
| 
 | ||||
|   auto x = INPUT_VARIABLE(0); | ||||
|   auto z = OUTPUT_VARIABLE(0); | ||||
| 
 | ||||
|         //Special case: empty.reshape(<other empty shape>) -> return empty
 | ||||
|   // Special case: empty.reshape(<other empty shape>) -> return empty
 | ||||
|   if (x->isEmpty()) { | ||||
|         REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); | ||||
|         return Status::OK();    //No op
 | ||||
|     REQUIRE_TRUE(z->isEmpty(), 0, | ||||
|                  "Reshape: when input is empty, output must also be empty"); | ||||
|     return Status::OK(); // No op
 | ||||
|   } | ||||
| 
 | ||||
|         REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf()); | ||||
|   REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, | ||||
|                "Reshape: lengths before and after reshape should match, but " | ||||
|                "got %i vs %i", | ||||
|                x->lengthOf(), z->lengthOf()); | ||||
| 
 | ||||
|   if (Environment::getInstance().isDebugAndVerbose()) | ||||
|     nd4j_printv("Reshape: new shape", z->getShapeAsVector()); | ||||
| @ -50,17 +52,16 @@ | ||||
|   z->assign(x->reshape(z->ordering(), z->getShapeAsVector())); | ||||
| 
 | ||||
|   return Status::OK(); | ||||
|         } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
|         DECLARE_TYPES(reshape) { | ||||
| DECLARE_TYPES(reshape) { | ||||
|   getOpDescriptor() | ||||
|       ->setAllowedInputTypes(0, sd::DataType::ANY) | ||||
|       ->setAllowedInputTypes(1, {ALL_INTS}) | ||||
|       ->setSameMode(true); | ||||
|         } | ||||
| } | ||||
| 
 | ||||
|         DECLARE_SHAPE_FN(reshape) { | ||||
| DECLARE_SHAPE_FN(reshape) { | ||||
| 
 | ||||
|   const auto x = INPUT_VARIABLE(0); | ||||
| 
 | ||||
| @ -77,95 +78,64 @@ | ||||
|    */ | ||||
|   if (block.width() == 1) { | ||||
|     reshapeArgs = *block.getIArguments(); | ||||
|         if(!reshapeArgs.empty()) { | ||||
|         char potentialOrdering = (char) -reshapeArgs[0]; | ||||
|     if (!reshapeArgs.empty()) { | ||||
|       char potentialOrdering = (char)-reshapeArgs[0]; | ||||
|       orderNew = potentialOrdering; | ||||
|         if(potentialOrdering != 'c' && potentialOrdering != 'f') { | ||||
|             throw std::runtime_error("reshape:: Value passed in must be -99 or -102 for the ordering if an int array is present. -99 represents c ordering and -102 represents f ordering. This number is negative for the long array case to flag the difference between an ordering and a dimension being specified."); | ||||
|       if (potentialOrdering != 'c' && potentialOrdering != 'f') { | ||||
|         throw std::runtime_error( | ||||
|             "reshape:: Value passed in must be -99 or -102 for the ordering if " | ||||
|             "an int array is present. -99 represents c ordering and -102 " | ||||
|             "represents f ordering. This number is negative for the long array " | ||||
|             "case to flag the difference between an ordering and a dimension " | ||||
|             "being specified."); | ||||
|       } | ||||
| 
 | ||||
|         nd4j_debug("Reshape Ordering is %c int ordering is %d\n",orderNew,-reshapeArgs[0]); | ||||
|       nd4j_debug("Reshape Ordering is %c int ordering is %d\n", orderNew, | ||||
|                  -reshapeArgs[0]); | ||||
| 
 | ||||
|         if(orderNew == 'c' || orderNew == 'f') | ||||
|                reshapeArgs.erase(reshapeArgs.begin());   // remove first element being order in this case
 | ||||
|       if (orderNew == 'c' || orderNew == 'f') | ||||
|         reshapeArgs.erase( | ||||
|             reshapeArgs | ||||
|                 .begin()); // remove first element being order in this case
 | ||||
|     } | ||||
|         } | ||||
|         else { | ||||
|   } else { | ||||
|     reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>(); | ||||
|         if(block.numI() > 0) { | ||||
|             //Note here that the ordering for this case can not be negative.
 | ||||
|         // Negative is used in the long array case to be used as a flag to differntiate between a 99 or 102 shaped array and
 | ||||
|         //the ordering. You can't have a -99 or -102 shaped array.
 | ||||
|             char potentialOrdering = (char) reshapeArgs[0]; | ||||
|            if(potentialOrdering != 'c' && potentialOrdering != 'f') { | ||||
|             throw std::runtime_error("reshape:: Value passed in must be -99 or -102 for the ordering if an int array is present. -99 represents c ordering and -102 represents f ordering."); | ||||
|     if (block.numI() > 0) { | ||||
|       // Note here that the ordering for this case can not be negative.
 | ||||
|       // Negative is used in the long array case to be used as a flag to
 | ||||
|       // differntiate between a 99 or 102 shaped array and
 | ||||
|       // the ordering. You can't have a -99 or -102 shaped array.
 | ||||
|       char potentialOrdering = (char)reshapeArgs[0]; | ||||
|       if (potentialOrdering != 'c' && potentialOrdering != 'f') { | ||||
|         throw std::runtime_error( | ||||
|             "reshape:: Value passed in must be -99 or -102 for the ordering if " | ||||
|             "an int array is present. -99 represents c ordering and -102 " | ||||
|             "represents f ordering."); | ||||
|       } | ||||
| 
 | ||||
|       orderNew = potentialOrdering; | ||||
|         } | ||||
|         else | ||||
|     } else | ||||
|       orderNew = 'c'; | ||||
|   } | ||||
| 
 | ||||
|         REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !"); | ||||
| 
 | ||||
|         // Nd4jLong xLen = x->lengthOf();
 | ||||
|         // if(x->isEmpty()) {
 | ||||
|         //     xLen = 1;
 | ||||
|         //     for (uint i = 0; i < x->rankOf(); ++i)                            // take into account possible empty shapes
 | ||||
|         //         if(x->sizeAt(i) != 0)
 | ||||
|         //             xLen *= x->sizeAt(i);
 | ||||
|         // }
 | ||||
| 
 | ||||
|         // for (uint i = 0; i < reshapeArgs.size(); ++i) {
 | ||||
| 
 | ||||
|         //     if (reshapeArgs[i] == -1) {
 | ||||
| 
 | ||||
|         //         uint shapeLength = 1, numOfZeros = 0;
 | ||||
| 
 | ||||
|         //         for(uint j = 0; j < i; ++j)
 | ||||
|         //             if(reshapeArgs[j] != 0)
 | ||||
|         //                 shapeLength *= reshapeArgs[j];
 | ||||
|         //             else
 | ||||
|         //                 ++numOfZeros;
 | ||||
| 
 | ||||
|         //         for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
 | ||||
|         //             REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
 | ||||
|         //             if(reshapeArgs[j] != 0)
 | ||||
|         //                 shapeLength *= reshapeArgs[j];
 | ||||
|         //             else
 | ||||
|         //                 ++numOfZeros;
 | ||||
|         //         }
 | ||||
| 
 | ||||
|         //         const auto dim = xLen / shapeLength;
 | ||||
| 
 | ||||
|         //         if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
 | ||||
|         //             shapeNew.push_back(0);
 | ||||
|         //         else
 | ||||
|         //             shapeNew.push_back(dim);
 | ||||
|         //     }
 | ||||
|         //     else
 | ||||
|         //         shapeNew.push_back(reshapeArgs[i]);
 | ||||
|         // }
 | ||||
|   REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, | ||||
|                "Reshape buffer should have at least 1 dimension !"); | ||||
| 
 | ||||
|   Nd4jLong newShapeLen = 1; | ||||
|   int pos = -1; | ||||
|   bool newShapeEmpty = false; | ||||
| 
 | ||||
|   for (int i = 0; i < reshapeArgs.size(); ++i) { | ||||
| 
 | ||||
|     const int dim = reshapeArgs[i]; | ||||
| 
 | ||||
|     if (dim == -1) { | ||||
|         REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); | ||||
|       REQUIRE_TRUE(pos == -1, 0, | ||||
|                    "Reshape : Only one unknown dimension (-1) is allowed."); | ||||
|       pos = i; | ||||
|       shapeNew.push_back(1); | ||||
|         } | ||||
|         else if (dim == 0) { | ||||
|     } else if (dim == 0) { | ||||
|       shapeNew.push_back(0); | ||||
|       newShapeEmpty = true; | ||||
|         } | ||||
|         else { | ||||
|     } else { | ||||
|       shapeNew.push_back(dim); | ||||
|       newShapeLen *= dim; | ||||
|     } | ||||
| @ -174,10 +144,11 @@ | ||||
|   if (pos != -1) { | ||||
| 
 | ||||
|     Nd4jLong xLen = x->lengthOf(); | ||||
|         if(x->isEmpty()) { | ||||
|     if (x->isEmpty()) { | ||||
|       xLen = 1; | ||||
|         for (uint i = 0; i < x->rankOf(); ++i)                            // take into account possible empty shapes
 | ||||
|         if(x->sizeAt(i) > 0 || !newShapeEmpty) | ||||
|       for (uint i = 0; i < x->rankOf(); | ||||
|            ++i) // take into account possible empty shapes
 | ||||
|         if (x->sizeAt(i) > 0 || !newShapeEmpty) | ||||
|           xLen *= x->sizeAt(i); | ||||
|     } | ||||
| 
 | ||||
| @ -185,14 +156,16 @@ | ||||
|   } | ||||
| 
 | ||||
|   auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); | ||||
|         REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len); | ||||
|   REQUIRE_TRUE(x->lengthOf() == len, 0, | ||||
|                "Reshape: lengths before and after reshape should match, but " | ||||
|                "got %i vs %i", | ||||
|                x->lengthOf(), len); | ||||
| 
 | ||||
|         return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(x->dataType(), orderNew, shapeNew)); | ||||
|         } | ||||
|   return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo( | ||||
|       x->dataType(), orderNew, shapeNew)); | ||||
| } | ||||
| 
 | ||||
| } // namespace ops
 | ||||
| } // namespace sd
 | ||||
| 
 | ||||
| 
 | ||||
|         } | ||||
|         } | ||||
| 
 | ||||
|         #endif | ||||
| #endif | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user