2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Created by raver119 on 24.11.17.
2019-08-30 15:32:01 +02:00
// @author Yurii Shyrma (iuriish@yahoo.com)
2019-06-06 14:21:15 +02:00
//
# include <op_boilerplate.h>
# if NOT_EXCLUDED(OP_scatter_div)
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/generic/helpers/ScatterHelper.h>
namespace nd4j {
namespace ops {
OP_IMPL ( scatter_div , 3 , 1 , true ) {
2019-08-30 15:32:01 +02:00
2019-06-06 14:21:15 +02:00
auto input = INPUT_VARIABLE ( 0 ) ;
auto indices = INPUT_VARIABLE ( 1 ) ;
auto updates = INPUT_VARIABLE ( 2 ) ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
2019-08-30 15:32:01 +02:00
if ( ! block . isInplace ( ) )
output - > assign ( input ) ;
2019-06-06 14:21:15 +02:00
const bool lock = block . getBArguments ( ) - > empty ( ) ? false : B_ARG ( 0 ) ;
2019-11-26 18:29:09 +01:00
const bool checkIndices = block . getBArguments ( ) - > size ( ) < = 1 ? false : B_ARG ( 1 ) ;
2019-06-06 14:21:15 +02:00
const int inRank = input - > rankOf ( ) ;
const int indRank = indices - > rankOf ( ) ;
const int updRank = updates - > rankOf ( ) ;
2019-08-30 15:32:01 +02:00
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( inRank > 0 , 0 , " SCATTER_DIV OP: input should not be scalar ! " ) ;
2019-08-30 15:32:01 +02:00
2019-06-06 14:21:15 +02:00
if ( inRank = = 1 ) {
REQUIRE_TRUE ( indices - > isSameShape ( updates ) , 0 , " SCATTER_DIV OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly ! " , ShapeUtils : : shapeAsString ( indices ) . c_str ( ) , ShapeUtils : : shapeAsString ( updates ) . c_str ( ) ) ;
}
else if ( inRank = = updRank & & indices - > isVector ( ) ) {
std : : vector < Nd4jLong > updShape = updates - > getShapeAsVector ( ) ;
std : : vector < Nd4jLong > inShape = input - > getShapeAsVector ( ) ;
2019-08-30 15:32:01 +02:00
std : : vector < Nd4jLong > expectedUpdShape = { indices - > lengthOf ( ) } ;
2019-06-06 14:21:15 +02:00
expectedUpdShape . insert ( expectedUpdShape . end ( ) , inShape . begin ( ) + 1 , inShape . end ( ) ) ;
2019-08-30 15:32:01 +02:00
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( expectedUpdShape = = updShape , 0 , " SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedUpdShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( updShape ) . c_str ( ) ) ;
}
else {
2019-08-30 15:32:01 +02:00
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( updRank = = indRank + inRank - 1 , 0 , " SCATTER_DIV OP: wrong rank of updates array, expected is %i, but got %i instead ! " , indRank + inRank - 1 , updRank ) ;
2019-08-30 15:32:01 +02:00
2019-06-06 14:21:15 +02:00
std : : vector < Nd4jLong > updShape = updates - > getShapeAsVector ( ) ;
std : : vector < Nd4jLong > inShape = input - > getShapeAsVector ( ) ;
2019-08-30 15:32:01 +02:00
std : : vector < Nd4jLong > expectedUpdShape = indices - > getShapeAsVector ( ) ;
2019-06-06 14:21:15 +02:00
expectedUpdShape . insert ( expectedUpdShape . end ( ) , inShape . begin ( ) + 1 , inShape . end ( ) ) ;
2019-08-30 15:32:01 +02:00
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( expectedUpdShape = = updShape , 0 , " SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedUpdShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( updShape ) . c_str ( ) ) ;
}
2019-11-26 18:29:09 +01:00
if ( ! indices - > isEmpty ( ) ) {
if ( checkIndices ) {
const Nd4jLong numOfBadIndx = helpers : : checkIndices ( block . launchContext ( ) , * indices , * output , 0 ) ;
REQUIRE_TRUE ( numOfBadIndx = = 0 , 0 , " SCATTER_DIV OP: please check elements of indices-array, total number of wrong elements is %lld! " , numOfBadIndx ) ;
}
2019-08-30 15:32:01 +02:00
helpers : : scatter ( block . launchContext ( ) , pairwise : : Divide , * indices , * updates , * output , lock ) ;
2019-11-26 18:29:09 +01:00
}
2019-06-06 14:21:15 +02:00
return Status : : OK ( ) ;
}
DECLARE_SYN ( ScatterDiv , scatter_div ) ;
DECLARE_TYPES ( scatter_div ) {
getOpDescriptor ( )
- > setAllowedInputTypes ( 0 , { ALL_INTS , ALL_FLOATS } )
- > setAllowedInputTypes ( 1 , { ALL_INTS } )
- > setAllowedInputTypes ( 2 , { ALL_INTS , ALL_FLOATS } )
- > setAllowedOutputTypes ( { ALL_INTS , ALL_FLOATS } ) ;
}
}
}
# endif