2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author raver119@gmail.com
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# if NOT_EXCLUDED(OP_roll)
# include <ops/declarable/headers/parity_ops.h>
# include <ops/declarable/helpers/roll.h>
2019-08-17 13:15:08 +02:00
# include <ops/declarable/helpers/axis.h>
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
2019-08-17 13:15:08 +02:00
CONFIGURABLE_OP_IMPL ( roll , 1 , 1 , true , 0 , 0 ) {
2019-06-06 14:21:15 +02:00
auto output = OUTPUT_VARIABLE ( 0 ) ;
auto input = INPUT_VARIABLE ( 0 ) ;
int inputLen = input - > lengthOf ( ) ;
2019-08-17 13:15:08 +02:00
bool shiftIsLinear = block . width ( ) = = 1 ;
std : : vector < int > axes ;
std : : vector < int > shifts ;
if ( block . width ( ) > 1 ) {
REQUIRE_TRUE ( block . width ( ) = = 3 , 0 , " roll: 3 arguments required for roll - input, shifts and axes. But %i given. " , block . width ( ) ) ;
auto axesI = INPUT_VARIABLE ( 2 ) ;
auto shiftsI = INPUT_VARIABLE ( 1 ) ;
REQUIRE_TRUE ( axesI - > rankOf ( ) = = shiftsI - > rankOf ( ) , 0 , " roll: shifts and axes should be the same rank, but %i and %i given. " , ( int ) shiftsI - > rankOf ( ) , ( int ) axesI - > rankOf ( ) ) ;
REQUIRE_TRUE ( axesI - > lengthOf ( ) = = shiftsI - > lengthOf ( ) , 0 , " roll: shifts and axes should be the same length, but %i and %i given. " , ( int ) shiftsI - > lengthOf ( ) , ( int ) axesI - > lengthOf ( ) ) ;
helpers : : adjustAxis ( axesI - > lengthOf ( ) , axesI , axes ) ;
shifts . resize ( shiftsI - > lengthOf ( ) ) ;
for ( Nd4jLong i = 0 ; i < shiftsI - > lengthOf ( ) ; i + + ) {
auto shift = shiftsI - > e < int > ( i ) ;
if ( shift < 0 ) {
shift - = input - > sizeAt ( i ) * ( shift / inputLen - 1 ) ;
}
2021-02-07 11:27:41 +01:00
else if ( shift ! = 0 ) {
2019-08-17 13:15:08 +02:00
shift % = input - > sizeAt ( i ) ;
}
2021-02-07 11:27:41 +01:00
2019-08-17 13:15:08 +02:00
shifts [ i ] = shift ;
}
2019-06-06 14:21:15 +02:00
}
else {
2019-08-17 13:15:08 +02:00
int shift = INT_ARG ( 0 ) ;
if ( shift < 0 ) {
// convert shift to positive value between 1 and inputLen - 1
shift - = inputLen * ( shift / inputLen - 1 ) ;
}
2021-02-07 11:27:41 +01:00
else if ( shift ! = 0 )
2019-08-17 13:15:08 +02:00
// cut shift to value between 1 and inputLen - 1
shift % = inputLen ;
axes . resize ( block . getIArguments ( ) - > size ( ) - 1 ) ;
if ( axes . size ( ) )
shifts . resize ( axes . size ( ) ) ; //emplace_back(shift);
else
shifts . push_back ( shift ) ;
for ( auto & s : shifts )
s = shift ;
2019-06-06 14:21:15 +02:00
for ( unsigned e = 0 ; e < axes . size ( ) ; + + e ) {
2019-08-17 13:15:08 +02:00
int axis = INT_ARG ( e + 1 ) ;
REQUIRE_TRUE ( axis < input - > rankOf ( ) & & axis > = - input - > rankOf ( ) , 0 , " roll: axe value should be between -%i and %i, but %i was given. " ,
input - > rankOf ( ) , input - > rankOf ( ) - 1 , axis ) ;
axes [ e ] = ( axis < 0 ? ( input - > rankOf ( ) + axis ) : axis ) ;
2019-06-06 14:21:15 +02:00
}
2019-08-17 13:15:08 +02:00
}
if ( block . isInplace ( ) ) output = input ;
2019-12-19 11:10:06 +01:00
shiftIsLinear = ( axes . size ( ) = = 0 ) | | ( input - > rankOf ( ) = = 1 ) ;
2021-02-07 11:27:41 +01:00
nd4j_debug ( " Roll: Shift is linear %d Shift is %d, first dimension is %d \n " , shiftIsLinear , shifts [ 0 ] , axes [ 0 ] ) ;
bool shiftsSumZero = false ;
auto shiftSum = 0 ;
for ( auto & s : shifts ) {
shiftSum + = s ;
nd4j_debug ( " Roll: Shift is %d \n " , s ) ;
}
//all zeros is no op
if ( shiftSum < 1 ) {
nd4j_debug ( " Roll: No shift needed. Shift total was %d \n " , shiftSum ) ;
if ( ! block . isInplace ( ) ) {
output - > assign ( input ) ;
}
return Status : : OK ( ) ;
}
2019-08-17 13:15:08 +02:00
if ( shiftIsLinear ) {
helpers : : rollFunctorLinear ( block . launchContext ( ) , input , output , shifts [ 0 ] , block . isInplace ( ) ) ;
}
else {
helpers : : rollFunctorFull ( block . launchContext ( ) , input , output , shifts , axes , block . isInplace ( ) ) ;
2019-06-06 14:21:15 +02:00
}
return Status : : OK ( ) ;
}
DECLARE_TYPES ( roll ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( 0 , sd : : DataType : : ANY )
- > setAllowedInputTypes ( 1 , sd : : DataType : : INT32 ) // TODO: all ints in future
- > setAllowedInputTypes ( 2 , sd : : DataType : : INT32 )
- > setAllowedOutputTypes ( sd : : DataType : : ANY )
2019-06-06 14:21:15 +02:00
- > setSameMode ( true ) ;
}
}
}
# endif