2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								/*******************************************************************************
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  Copyright  ( c )  2015 - 2018  Skymind ,  Inc . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  This  program  and  the  accompanying  materials  are  made  available  under  the 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  terms  of  the  Apache  License ,  Version  2.0  which  is  available  at 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  https : //www.apache.org/licenses/LICENSE-2.0.
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  Unless  required  by  applicable  law  or  agreed  to  in  writing ,  software 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  distributed  under  the  License  is  distributed  on  an  " AS IS "  BASIS ,  WITHOUT 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  WARRANTIES  OR  CONDITIONS  OF  ANY  KIND ,  either  express  or  implied .  See  the 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  License  for  the  specific  language  governing  permissions  and  limitations 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  under  the  License . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  SPDX - License - Identifier :  Apache - 2.0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//
 
							 
						 
					
						
							
								
									
										
										
										
											2019-08-02 20:01:03 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								// @author raver119@gmail.com, created on 24.11.17.
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-20 08:58:44 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								// @author Yurii Shyrma (iuriish@yahoo.com)
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								# include  <system/op_boilerplate.h> 
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								# if NOT_EXCLUDED(OP_scatter_add) 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								# include  <ops/declarable/CustomOperations.h> 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								# include  <ops/declarable/generic/helpers/ScatterHelper.h> 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								namespace  sd  { 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-20 08:58:44 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								namespace  ops  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								OP_IMPL ( scatter_add ,  3 ,  1 ,  true )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  input  =  INPUT_VARIABLE ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  indices  =  INPUT_VARIABLE ( 1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  updates  =  INPUT_VARIABLE ( 2 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  output  =  OUTPUT_VARIABLE ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-08-30 16:32:01 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  ( ! block . isInplace ( ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        output - > assign ( input ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-07-20 08:58:44 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    const  bool  lock  =  block . getBArguments ( ) - > empty ( )  ?  false  :  B_ARG ( 0 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-26 19:29:09 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    const  bool  checkIndices  =  block . getBArguments ( ) - > size ( )  < =  1  ?  false  :  B_ARG ( 1 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-20 08:58:44 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  inRank   =  input - > rankOf ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  indRank  =  indices - > rankOf ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  updRank  =  updates - > rankOf ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  indLen  =  indices - > lengthOf ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( inRank  >  0 ,  0 ,  " SCATTER_ADD OP: input should not be scalar ! " ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( inRank  = =  1 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( indices - > isSameShape ( updates ) ,  0 ,  " SCATTER_ADD OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly ! " ,  ShapeUtils : : shapeAsString ( indices ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( updates ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-20 08:58:44 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    else  if  ( inRank  = =  updRank  & &  indices - > isVector ( ) )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        std : : vector < Nd4jLong >  updShape  =  updates - > getShapeAsVector ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        std : : vector < Nd4jLong >  inShape   =  input - > getShapeAsVector ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        std : : vector < Nd4jLong >  expectedUpdShape  =  { indices - > lengthOf ( ) } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        expectedUpdShape . insert ( expectedUpdShape . end ( ) ,  inShape . begin ( ) + 1 ,  inShape . end ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( expectedUpdShape  = =  updShape ,  0 ,  " SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( expectedUpdShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( updShape ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( updRank  = =  indRank  +  inRank  -  1 ,  0 ,  " SCATTER_ADD OP: wrong rank of updates array, expected is %i, but got %i instead ! " ,  indRank  +  inRank  -  1  ,  updRank ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-07-20 08:58:44 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        std : : vector < Nd4jLong >  updShape  =  updates - > getShapeAsVector ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        std : : vector < Nd4jLong >  inShape   =  input - > getShapeAsVector ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        std : : vector < Nd4jLong >  expectedUpdShape  =  indices - > getShapeAsVector ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        expectedUpdShape . insert ( expectedUpdShape . end ( ) ,  inShape . begin ( )  +  Nd4jLong ( 1L ) ,  inShape . end ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( expectedUpdShape  = =  updShape ,  0 ,  " SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( expectedUpdShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( updShape ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-20 08:58:44 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-11-26 19:29:09 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  ( ! indices - > isEmpty ( ) )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        if ( checkIndices )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            const  Nd4jLong  numOfBadIndx  =  helpers : : checkIndices ( block . launchContext ( ) ,  * indices ,  * output ,  0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            REQUIRE_TRUE ( numOfBadIndx  = =  0 ,  0 ,  " SCATTER_ADD OP: please check elements of indices-array, total number of wrong elements is %lld! " ,  numOfBadIndx ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-08-30 16:32:01 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        helpers : : scatter ( block . launchContext ( ) ,  pairwise : : Add ,  * indices ,  * updates ,  * output ,  lock ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-26 19:29:09 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-20 08:58:44 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  Status : : OK ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								DECLARE_SYN ( ScatterAdd ,  scatter_add ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								DECLARE_TYPES ( scatter_add )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    getOpDescriptor ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        - > setAllowedInputTypes ( 0 ,  { ALL_INTS ,  ALL_FLOATS } ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        - > setAllowedInputTypes ( 1 ,  { ALL_INTS } ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        - > setAllowedInputTypes ( 2 ,  { ALL_INTS ,  ALL_FLOATS } ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        - > setAllowedOutputTypes ( { ALL_INTS ,  ALL_FLOATS } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								# endif