2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// Created by GS <sgazeos@gmail.com>
//
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/helpers/random_crop.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( random_crop , 2 , 1 , false , 0 , 0 ) {
auto input = INPUT_VARIABLE ( 0 ) ; // values for crop
auto shape = INPUT_VARIABLE ( 1 ) ; // shape for result
NDArray * reduceShape = nullptr ; // this param is optional
auto output = OUTPUT_VARIABLE ( 0 ) ; //
int seed = 0 ;
if ( block . getIArguments ( ) - > size ( ) > 0 )
seed = INT_ARG ( 0 ) ;
REQUIRE_TRUE ( shape - > isVector ( ) , 0 , " random_crop: Shape tensor should be a vector. " ) ;
REQUIRE_TRUE ( input - > rankOf ( ) = = shape - > lengthOf ( ) , 0 , " random_crop: The length of the shape vector is not match input rank. %i and %i were given. " ,
input - > rankOf ( ) , shape - > lengthOf ( ) ) ;
for ( int e = 0 ; e < shape - > lengthOf ( ) ; + + e ) {
REQUIRE_TRUE ( ( * shape ) . e < Nd4jLong > ( e ) < = input - > sizeAt ( e ) , 0 , " random_crop: Shape tensor should be less than proper input dimension (dim %i, %i > %i). " , e , ( * shape ) . e < Nd4jLong > ( e ) , input - > sizeAt ( e ) ) ;
}
return helpers : : randomCropFunctor ( block , input , shape , output , seed ) ;
}
DECLARE_SHAPE_FN ( random_crop ) {
auto in = INPUT_VARIABLE ( 1 ) ;
auto typeShape = inputShape - > at ( 0 ) ;
std : : vector < Nd4jLong > shape ( in - > lengthOf ( ) ) ;
for ( int e = 0 ; e < shape . size ( ) ; e + + )
shape [ e ] = ( * in ) . e < Nd4jLong > ( e ) ;
2020-06-06 14:26:55 +02:00
auto newShape = ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( ArrayOptions : : dataType ( typeShape ) , ' c ' , shape ) ;
2019-06-06 14:21:15 +02:00
return SHAPELIST ( newShape ) ;
}
DECLARE_TYPES ( random_crop ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-06-06 14:21:15 +02:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
}
}