2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// Created by raver119 on 29/10/17.
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2021-02-05 14:35:41 +01:00
# if NOT_EXCLUDED(OP_reshape)
2019-06-06 14:21:15 +02:00
2021-02-05 14:35:41 +01:00
# include <ops/declarable/CustomOperations.h>
2019-06-06 14:21:15 +02:00
2021-02-05 14:35:41 +01:00
namespace sd {
namespace ops {
2019-06-06 14:21:15 +02:00
2020-02-17 06:04:28 +01:00
//////////////////////////////////////////////////////////////////////////
// here iArgs is a vector with (optional) negative of order as first element:
// ({-order, dim1, dim2, dim3, ...})
2021-02-05 14:35:41 +01:00
CUSTOM_OP_IMPL ( reshape , 1 , 1 , false , 0 , - 2 ) {
2019-06-06 14:21:15 +02:00
2021-02-05 14:35:41 +01:00
auto x = INPUT_VARIABLE ( 0 ) ;
auto z = OUTPUT_VARIABLE ( 0 ) ;
2019-06-06 14:21:15 +02:00
2021-02-05 14:35:41 +01:00
//Special case: empty.reshape(<other empty shape>) -> return empty
if ( x - > isEmpty ( ) ) {
2020-03-31 06:41:16 +02:00
REQUIRE_TRUE ( z - > isEmpty ( ) , 0 , " Reshape: when input is empty, output must also be empty " ) ;
return Status : : OK ( ) ; //No op
2021-02-05 14:35:41 +01:00
}
2019-06-06 14:21:15 +02:00
2021-02-05 14:35:41 +01:00
REQUIRE_TRUE ( x - > lengthOf ( ) = = z - > lengthOf ( ) , 0 , " Reshape: lengths before and after reshape should match, but got %i vs %i " , x - > lengthOf ( ) , z - > lengthOf ( ) ) ;
2019-06-06 14:21:15 +02:00
2021-02-05 14:35:41 +01:00
if ( Environment : : getInstance ( ) . isDebugAndVerbose ( ) )
2020-03-31 06:41:16 +02:00
nd4j_printv ( " Reshape: new shape " , z - > getShapeAsVector ( ) ) ;
2019-06-06 14:21:15 +02:00
2021-02-05 14:35:41 +01:00
z - > assign ( x - > reshape ( z - > ordering ( ) , z - > getShapeAsVector ( ) ) ) ;
2020-02-17 06:04:28 +01:00
2021-02-05 14:35:41 +01:00
return Status : : OK ( ) ;
}
2020-02-17 06:04:28 +01:00
2021-02-05 14:35:41 +01:00
DECLARE_TYPES ( reshape ) {
getOpDescriptor ( )
- > setAllowedInputTypes ( 0 , sd : : DataType : : ANY )
- > setAllowedInputTypes ( 1 , { ALL_INTS } )
- > setSameMode ( true ) ;
}
2020-03-31 06:41:16 +02:00
2021-02-05 14:35:41 +01:00
DECLARE_SHAPE_FN ( reshape ) {
const auto x = INPUT_VARIABLE ( 0 ) ;
std : : vector < int > reshapeArgs ;
std : : vector < Nd4jLong > shapeNew ;
char orderNew = ' c ' ;
/**
* NOTE : The value here is negative as a flag .
* A negative value signifies 1 of 3 values :
* - 1 - > dynamic shape
* - 99 - > c ordering
* - 102 - > f ordering
*
*/
if ( block . width ( ) = = 1 ) {
2020-03-31 06:41:16 +02:00
reshapeArgs = * block . getIArguments ( ) ;
if ( ! reshapeArgs . empty ( ) ) {
2021-02-05 14:35:41 +01:00
char potentialOrdering = ( char ) - reshapeArgs [ 0 ] ;
orderNew = potentialOrdering ;
if ( potentialOrdering ! = ' c ' & & potentialOrdering ! = ' f ' ) {
throw std : : runtime_error ( " reshape:: Value passed in must be -99 or -102 for the ordering if an int array is present. -99 represents c ordering and -102 represents f ordering. This number is negative for the long array case to flag the difference between an ordering and a dimension being specified. " ) ;
}
nd4j_debug ( " Reshape Ordering is %c int ordering is %d \n " , orderNew , - reshapeArgs [ 0 ] ) ;
if ( orderNew = = ' c ' | | orderNew = = ' f ' )
reshapeArgs . erase ( reshapeArgs . begin ( ) ) ; // remove first element being order in this case
2019-06-06 14:21:15 +02:00
}
2021-02-05 14:35:41 +01:00
}
else {
2020-03-31 06:41:16 +02:00
reshapeArgs = INPUT_VARIABLE ( 1 ) - > getBufferAsVector < int > ( ) ;
2021-02-05 14:35:41 +01:00
if ( block . numI ( ) > 0 ) {
//Note here that the ordering for this case can not be negative.
// Negative is used in the long array case to be used as a flag to differntiate between a 99 or 102 shaped array and
//the ordering. You can't have a -99 or -102 shaped array.
char potentialOrdering = ( char ) reshapeArgs [ 0 ] ;
if ( potentialOrdering ! = ' c ' & & potentialOrdering ! = ' f ' ) {
throw std : : runtime_error ( " reshape:: Value passed in must be -99 or -102 for the ordering if an int array is present. -99 represents c ordering and -102 represents f ordering. " ) ;
}
orderNew = potentialOrdering ;
}
else
orderNew = ' c ' ;
}
2019-06-06 14:21:15 +02:00
2021-02-05 14:35:41 +01:00
REQUIRE_TRUE ( ! reshapeArgs . empty ( ) | | x - > lengthOf ( ) = = 1 , 0 , " Reshape buffer should have at least 1 dimension ! " ) ;
2019-06-06 14:21:15 +02:00
2021-02-05 14:35:41 +01:00
// Nd4jLong xLen = x->lengthOf();
// if(x->isEmpty()) {
// xLen = 1;
// for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
// if(x->sizeAt(i) != 0)
// xLen *= x->sizeAt(i);
// }
2020-04-01 06:13:34 +02:00
2021-02-05 14:35:41 +01:00
// for (uint i = 0; i < reshapeArgs.size(); ++i) {
2020-04-01 06:13:34 +02:00
2021-02-05 14:35:41 +01:00
// if (reshapeArgs[i] == -1) {
2020-04-01 06:13:34 +02:00
2021-02-05 14:35:41 +01:00
// uint shapeLength = 1, numOfZeros = 0;
2020-04-01 06:13:34 +02:00
2021-02-05 14:35:41 +01:00
// for(uint j = 0; j < i; ++j)
// if(reshapeArgs[j] != 0)
// shapeLength *= reshapeArgs[j];
// else
// ++numOfZeros;
2020-04-01 06:13:34 +02:00
2021-02-05 14:35:41 +01:00
// for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
// REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
// if(reshapeArgs[j] != 0)
// shapeLength *= reshapeArgs[j];
// else
// ++numOfZeros;
// }
2020-04-01 06:13:34 +02:00
2021-02-05 14:35:41 +01:00
// const auto dim = xLen / shapeLength;
2020-04-01 06:13:34 +02:00
2021-02-05 14:35:41 +01:00
// if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
// shapeNew.push_back(0);
// else
// shapeNew.push_back(dim);
// }
// else
// shapeNew.push_back(reshapeArgs[i]);
// }
2020-04-01 06:13:34 +02:00
2021-02-05 14:35:41 +01:00
Nd4jLong newShapeLen = 1 ;
int pos = - 1 ;
bool newShapeEmpty = false ;
2020-04-01 06:13:34 +02:00
2021-02-05 14:35:41 +01:00
for ( int i = 0 ; i < reshapeArgs . size ( ) ; + + i ) {
2020-04-01 06:13:34 +02:00
const int dim = reshapeArgs [ i ] ;
if ( dim = = - 1 ) {
2021-02-05 14:35:41 +01:00
REQUIRE_TRUE ( pos = = - 1 , 0 , " Reshape : Only one unknown dimension (-1) is allowed. " ) ;
pos = i ;
shapeNew . push_back ( 1 ) ;
2020-04-01 06:13:34 +02:00
}
else if ( dim = = 0 ) {
2021-02-05 14:35:41 +01:00
shapeNew . push_back ( 0 ) ;
newShapeEmpty = true ;
2020-04-01 06:13:34 +02:00
}
else {
2021-02-05 14:35:41 +01:00
shapeNew . push_back ( dim ) ;
newShapeLen * = dim ;
}
2020-04-01 06:13:34 +02:00
}
2019-06-06 14:21:15 +02:00
2021-02-05 14:35:41 +01:00
if ( pos ! = - 1 ) {
2019-06-06 14:21:15 +02:00
2020-04-01 06:13:34 +02:00
Nd4jLong xLen = x - > lengthOf ( ) ;
if ( x - > isEmpty ( ) ) {
2021-02-05 14:35:41 +01:00
xLen = 1 ;
for ( uint i = 0 ; i < x - > rankOf ( ) ; + + i ) // take into account possible empty shapes
if ( x - > sizeAt ( i ) > 0 | | ! newShapeEmpty )
xLen * = x - > sizeAt ( i ) ;
2020-02-17 06:04:28 +01:00
}
2020-04-01 06:13:34 +02:00
shapeNew [ pos ] = xLen / newShapeLen ;
2021-02-05 14:35:41 +01:00
}
2020-02-17 06:04:28 +01:00
2021-02-05 14:35:41 +01:00
auto len = shape : : prodLong ( shapeNew . data ( ) , shapeNew . size ( ) ) ;
REQUIRE_TRUE ( x - > lengthOf ( ) = = len , 0 , " Reshape: lengths before and after reshape should match, but got %i vs %i " , x - > lengthOf ( ) , len ) ;
2020-02-17 06:04:28 +01:00
2021-02-05 14:35:41 +01:00
return SHAPELIST ( ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( x - > dataType ( ) , orderNew , shapeNew ) ) ;
}
2020-03-31 06:41:16 +02:00
2020-04-01 06:13:34 +02:00
2021-02-05 14:35:41 +01:00
}
}
2019-06-06 14:21:15 +02:00
2021-02-05 14:35:41 +01:00
# endif