2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
# pragma once
# include <ops/ops.h>
# include <loops/reduce_float.h>
# include <loops/reduce_same.h>
# include <loops/scalar.h>
# include <loops/indexreduce.h>
# include <loops/broadcasting.h>
# include <loops/transform_float.h>
# include <op_enums.h>
# include <loops/transform_strict.h>
# include <helpers/ConstantTadHelper.h>
# ifdef __CUDACC__
# include <loops/cuda/inplace_loops/reduce_same_inplace.h>
# include <loops/cuda/inplace_loops/transform_strict_inplace.h>
# include <loops/cuda/inplace_loops/scalar_inplace.h>
# endif
namespace functions {
namespace broadcast {
template < typename X , typename Y , typename Z >
class Broadcast ;
}
namespace transform {
template < typename X >
class TransformStrict ;
}
namespace scalar {
}
namespace reduce {
template < typename X , typename Z >
class ReduceFloatFunction ;
template < typename X >
class ReduceSameFunction ;
}
}
namespace simdOps {
template < typename T , typename Z >
class Pooling2D {
public :
static const bool requiresSpecial = true ;
# ifdef __CUDACC__
inline __host__ __device__
# elif defined(__GNUC__)
# endif
static int outSize ( int size , int k , int s , int p , bool coverAll ) {
if ( coverAll )
return ( size + p * 2 - k + s - 1 ) / s + 1 ;
else
return ( size + p * 2 - k ) / s + 1 ;
}
# ifdef __CUDACC__
/**
* Based on : https : //github.com/pjreddie/darknet/blob/master/src/im2col_kernels.cu
*/
static inline __device__ void execSpecialCuda (
T * dx , Nd4jLong * xShapeBuffer ,
Z * result , Nd4jLong * zShapeBuffer ,
2019-09-11 19:12:09 +02:00
Z * extraParams ,
int * allocationPointer , Z * reductionPointer ,
2019-06-06 14:21:15 +02:00
Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
__shared__ int kH ;
__shared__ int kW ;
__shared__ int sH ;
__shared__ int sW ;
__shared__ int pH ;
__shared__ int pW ;
__shared__ int dH ;
__shared__ int dW ;
__shared__ int poolingMode ;
__shared__ Z extraParam0 ;
__shared__ int batchSize ;
__shared__ int inChannels ;
__shared__ int outH ;
__shared__ int outW ;
__shared__ int inH ;
__shared__ int inW ;
//__shared__ int *strideIn;
//__shared__ int *strideOut;
__shared__ int strideB ;
__shared__ int strideC ;
__shared__ int strideY ;
__shared__ int strideX ;
__shared__ int strideOB ;
__shared__ int strideOC ;
__shared__ int strideOY ;
__shared__ int strideOX ;
__shared__ int length ;
__shared__ int kHEff ;
__shared__ int kWEff ;
__shared__ bool fOrder ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
if ( threadIdx . x = = 0 ) {
kH = ( int ) extraParams [ 0 ] ;
kW = ( int ) extraParams [ 1 ] ;
sH = ( int ) extraParams [ 2 ] ;
sW = ( int ) extraParams [ 3 ] ;
pH = ( int ) extraParams [ 4 ] ;
pW = ( int ) extraParams [ 5 ] ;
dH = ( int ) extraParams [ 6 ] ; //Dilation, height dimension
dW = ( int ) extraParams [ 7 ] ; //Dilation, width dimension
poolingMode = ( int ) extraParams [ 9 ] ;
extraParam0 = extraParams [ 10 ] ;
batchSize = shape : : sizeAt ( xShapeBuffer , 0 ) ;
inChannels = shape : : sizeAt ( xShapeBuffer , 1 ) ;
outH = shape : : sizeAt ( zShapeBuffer , 2 ) ;
outW = shape : : sizeAt ( zShapeBuffer , 3 ) ;
inH = shape : : sizeAt ( xShapeBuffer , 2 ) ;
inW = shape : : sizeAt ( xShapeBuffer , 3 ) ;
strideB = shape : : stride ( xShapeBuffer ) [ 0 ] ;
strideC = shape : : stride ( xShapeBuffer ) [ 1 ] ;
strideY = shape : : stride ( xShapeBuffer ) [ 2 ] ;
strideX = shape : : stride ( xShapeBuffer ) [ 3 ] ;
strideOB = shape : : stride ( zShapeBuffer ) [ 0 ] ;
strideOC = shape : : stride ( zShapeBuffer ) [ 1 ] ;
strideOY = shape : : stride ( zShapeBuffer ) [ 2 ] ;
strideOX = shape : : stride ( zShapeBuffer ) [ 3 ] ;
length = shape : : length ( zShapeBuffer ) ;
//Replace kernel H/W with *effective* kernel H/W accounting for dilatyon
kHEff = kH + ( kH - 1 ) * ( dH - 1 ) ;
kWEff = kW + ( kW - 1 ) * ( dW - 1 ) ;
fOrder = shape : : order ( zShapeBuffer ) = = ' f ' ;
/*
if ( blockIdx . x = = 0 ) {
printf ( " kH: %i; kW: %i; sH: %i; sW: %i; pH: %i; pW: %i; dH: %i; dW: %i; poolingMode: %i; extraParam0: %f; \n " , kH , kW , sH , sW , pH , pW , dH , dW , poolingMode , ( float ) extraParam0 ) ;
printf ( " batchSize: %i; inChannels: %i; outH: %i; outW: %i; inH: %i; inW: %i; strideB: %i; strideC: %i; strideY: %i; strideX: %i; \n " , batchSize , inChannels , outH , outW , inH , inW , strideB , strideC , strideY , strideX ) ;
}
*/
}
__syncthreads ( ) ;
int tid = blockIdx . x * blockDim . x + threadIdx . x ;
for ( int index = tid ; index < length ; index + = blockDim . x * gridDim . x ) {
const int pw = index % outW ;
const int ph = ( index / outW ) % outH ;
const int c = ( index / outW / outH ) % inChannels ;
const int n = index / outW / outH / inChannels ;
int hstart = sH * ph - pH ;
int wstart = sW * pw - pW ;
int hend = hstart + kHEff ;
int wend = wstart + kWEff ;
// const int hSO = hstart;
// const int hEO = hend;
if ( hstart < 0 ) {
int f = nd4j : : math : : nd4j_ceil < Z , int > ( ( Z ) - hstart / ( Z ) dH ) ;
hstart + = f * dH ;
}
if ( wstart < 0 ) {
int f = nd4j : : math : : nd4j_ceil < Z , int > ( ( Z ) - wstart / ( Z ) dW ) ;
wstart + = f * dW ;
}
if ( hend > inH ) {
int f = nd4j : : math : : nd4j_ceil < Z , int > ( ( Z ) ( hend - inH ) / ( Z ) dH ) ;
hend - = f * dH ;
}
if ( wend > inW ) {
int f = nd4j : : math : : nd4j_ceil < Z , int > ( ( Z ) ( wend - inW ) / ( Z ) dW ) ;
wend - = f * dW ;
}
//Accounts for dilation
int pool_size = nd4j : : math : : nd4j_ceil < double , int > ( ( double ) ( hend - hstart ) / ( double ) dH ) * nd4j : : math : : nd4j_ceil < double , int > ( ( double ) ( wend - wstart ) / ( double ) dW ) ;
Z sum = poolingMode = = 0 ? - nd4j : : DataTypeUtils : : max < Z > ( ) : static_cast < Z > ( 0.f ) ;
T * input_slice = dx + ( n * strideB + c * strideC ) ;
if ( poolingMode = = 0 ) {
for ( int h = hstart ; h < hend ; h + = dH ) {
for ( int w = wstart ; w < wend ; w + = dW ) {
Z v = static_cast < Z > ( input_slice [ h * strideY + w * strideX ] ) ;
if ( v > sum )
sum = v ;
}
}
} else if ( poolingMode = = 1 ) {
for ( int h = hstart ; h < hend ; h + = dH ) {
for ( int w = wstart ; w < wend ; w + = dW ) {
sum + = static_cast < Z > ( input_slice [ h * strideY + w * strideX ] ) ;
}
}
} else if ( poolingMode = = 2 ) {
for ( int h = hstart ; h < hend ; h + = dH ) {
for ( int w = wstart ; w < wend ; w + = dW ) {
sum + = nd4j : : math : : nd4j_pow < Z , Z , Z > ( static_cast < Z > ( nd4j : : math : : nd4j_abs < T > ( input_slice [ h * strideY + w * strideX ] ) ) , extraParam0 ) ;
}
}
}
Z res ;
if ( poolingMode = = 0 ) {
res = sum ;
} else if ( poolingMode = = 1 ) {
int divide_factor = pool_size ; //Case 0: exclude padding
if ( ( int ) extraParam0 = = 1 ) //Case 1: include padding
divide_factor = kH * kW ;
res = sum / static_cast < Z > ( divide_factor ) ;
} else if ( poolingMode = = 2 ) {
res = nd4j : : math : : nd4j_pow < Z , Z , Z > ( sum , ( Z ) 1.0f / extraParam0 ) ;
}
if ( ! fOrder ) {
result [ index ] = res ;
} else {
result [ n * strideOB + c * strideOC + pw * strideOX + ph * strideOY ] = res ;
}
/*
if ( index > = 0 & & index < 400000 ) {
printf ( " index: %i; hstart: %i; hend: %i; wstart: %i; wend: %i; ph: %i; pw: %i; hstart_orig: %i; hend_orig: %i; \n " , index , hstart , hend , wstart , wend , ph , pw , hSO , hEO ) ;
}
*/
}
__syncthreads ( ) ;
}
# endif
static void execSpecial ( T * in , Nd4jLong * inShapeBuffer , Z * out , Nd4jLong * outShapeBuffer , Z * extraParams , Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
// input is [bS, iC, iH, iW]
// output is [bS, iC, oH, oW]
const Nd4jLong kH = ( int ) extraParams [ 0 ] ;
const Nd4jLong kW = ( int ) extraParams [ 1 ] ;
const Nd4jLong sH = ( int ) extraParams [ 2 ] ;
const Nd4jLong sW = ( int ) extraParams [ 3 ] ;
const Nd4jLong pH = ( int ) extraParams [ 4 ] ;
2019-09-11 19:12:09 +02:00
const Nd4jLong pW = ( int ) extraParams [ 5 ] ;
2019-06-06 14:21:15 +02:00
const Nd4jLong dH = ( int ) extraParams [ 6 ] ;
const Nd4jLong dW = ( int ) extraParams [ 7 ] ;
Nd4jLong poolingMode = ( int ) extraParams [ 9 ] ;
T extraParam0 = extraParams [ 10 ] ;
if ( dH = = 0 | | dW = = 0 ) {
printf ( " Special_ops pooling2d:: dilation must not be zero, but got instead {%lld, %lld} \n " , dH , dW ) ;
throw " " ;
}
const Nd4jLong kHEff = kH + ( kH - 1 ) * ( dH - 1 ) ;
const Nd4jLong kWEff = kW + ( kW - 1 ) * ( dW - 1 ) ;
const int bS = shape : : sizeAt ( inShapeBuffer , 0 ) ;
const int iC = shape : : sizeAt ( inShapeBuffer , 1 ) ;
const int iH = shape : : sizeAt ( inShapeBuffer , 2 ) ;
const int iW = shape : : sizeAt ( inShapeBuffer , 3 ) ;
const int oH = shape : : sizeAt ( outShapeBuffer , 2 ) ;
2019-09-11 19:12:09 +02:00
const int oW = shape : : sizeAt ( outShapeBuffer , 3 ) ;
2019-06-06 14:21:15 +02:00
const Nd4jLong iStride0 = shape : : stride ( inShapeBuffer ) [ 0 ] ;
const Nd4jLong iStride1 = shape : : stride ( inShapeBuffer ) [ 1 ] ;
const Nd4jLong iStride2 = shape : : stride ( inShapeBuffer ) [ 2 ] ;
const Nd4jLong iStride3 = shape : : stride ( inShapeBuffer ) [ 3 ] ;
const Nd4jLong oStride0 = shape : : stride ( outShapeBuffer ) [ 0 ] ;
const Nd4jLong oStride1 = shape : : stride ( outShapeBuffer ) [ 1 ] ;
const Nd4jLong oStride2 = shape : : stride ( outShapeBuffer ) [ 2 ] ;
const Nd4jLong oStride3 = shape : : stride ( outShapeBuffer ) [ 3 ] ;
const Nd4jLong iStep2 = dH * iStride2 ;
2019-09-11 19:12:09 +02:00
const Nd4jLong iStep3 = dW * iStride3 ;
2019-06-06 14:21:15 +02:00
const int kProd = kH * kW ;
2019-09-11 19:12:09 +02:00
const T iStep2Inv = 1. / iStep2 ;
2019-06-06 14:21:15 +02:00
const T iStep3Inv = 1. / iStep3 ;
Nd4jLong hstart , wstart , hend , wend ;
T sum , * pIn ;
2019-09-11 19:12:09 +02:00
if ( poolingMode = = 0 ) { // max
2019-06-06 14:21:15 +02:00
PRAGMA_OMP_PARALLEL_FOR_ARGS ( private ( pIn , sum , hstart , wstart , hend , wend ) collapse ( 2 ) )
for ( int b = 0 ; b < bS ; + + b ) {
2019-09-11 19:12:09 +02:00
for ( int c = 0 ; c < iC ; + + c ) {
2019-06-06 14:21:15 +02:00
for ( int oh = 0 ; oh < oH ; + + oh ) {
for ( int ow = 0 ; ow < oW ; + + ow ) {
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
pIn = in + b * iStride0 + c * iStride1 ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
hstart = oh * sH - pH ;
2019-09-11 19:12:09 +02:00
wstart = ow * sW - pW ;
2019-06-06 14:21:15 +02:00
hend = hstart + kHEff ;
wend = wstart + kWEff ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
if ( hstart < 0 )
hstart + = dH * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( - hstart ) / static_cast < T > ( dH ) ) ;
if ( wstart < 0 )
wstart + = dW * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( - wstart ) / static_cast < T > ( dW ) ) ;
if ( hend > iH )
hend - = dH * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( hend - iH ) / static_cast < T > ( dH ) ) ;
if ( wend > iW )
wend - = dW * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( wend - iW ) / static_cast < T > ( dW ) ) ;
hstart * = iStride2 ;
hend * = iStride2 ;
wstart * = iStride3 ;
wend * = iStride3 ;
sum = - nd4j : : DataTypeUtils : : max < Z > ( ) ;
2019-09-11 19:12:09 +02:00
for ( Nd4jLong kh = hstart ; kh < hend ; kh + = iStep2 )
2019-06-06 14:21:15 +02:00
for ( Nd4jLong kw = wstart ; kw < wend ; kw + = iStep3 ) {
T val = pIn [ kh + kw ] ;
if ( val > sum )
sum = val ;
}
out [ b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3 ] = sum ;
}
}
}
2019-09-11 19:12:09 +02:00
}
2019-06-06 14:21:15 +02:00
}
2019-09-11 19:12:09 +02:00
/*************************************************************************/
2019-06-06 14:21:15 +02:00
else if ( poolingMode = = 1 ) { // avg
PRAGMA_OMP_PARALLEL_FOR_ARGS ( private ( pIn , sum , hstart , wstart , hend , wend ) collapse ( 2 ) )
for ( int b = 0 ; b < bS ; + + b ) {
2019-09-11 19:12:09 +02:00
for ( int c = 0 ; c < iC ; + + c ) {
2019-06-06 14:21:15 +02:00
for ( int oh = 0 ; oh < oH ; + + oh ) {
for ( int ow = 0 ; ow < oW ; + + ow ) {
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
pIn = in + b * iStride0 + c * iStride1 ;
hstart = oh * sH - pH ;
wstart = ow * sW - pW ;
hend = hstart + kHEff ;
wend = wstart + kWEff ;
if ( hstart < 0 )
hstart + = dH * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( - hstart ) / static_cast < T > ( dH ) ) ;
if ( wstart < 0 )
wstart + = dW * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( - wstart ) / static_cast < T > ( dW ) ) ;
if ( hend > iH )
hend - = dH * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( hend - iH ) / static_cast < T > ( dH ) ) ;
if ( wend > iW )
wend - = dW * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( wend - iW ) / static_cast < T > ( dW ) ) ;
hstart * = iStride2 ;
hend * = iStride2 ;
wstart * = iStride3 ;
wend * = iStride3 ;
sum = static_cast < Z > ( 0. ) ;
2019-09-11 19:12:09 +02:00
for ( Nd4jLong kh = hstart ; kh < hend ; kh + = iStep2 )
2019-06-06 14:21:15 +02:00
for ( Nd4jLong kw = wstart ; kw < wend ; kw + = iStep3 )
sum + = pIn [ kh + kw ] ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
if ( ( int ) extraParam0 = = 0 ) //Exclude padding
sum / = static_cast < T > ( nd4j : : math : : nd4j_ceil < double , T > ( static_cast < double > ( hend - hstart ) / static_cast < double > ( iStep2 ) ) ) * static_cast < T > ( nd4j : : math : : nd4j_ceil < double , T > ( static_cast < double > ( wend - wstart ) / static_cast < double > ( iStep3 ) ) ) ; //Accounts for dilation
else if ( ( int ) extraParam0 = = 1 ) //Include padding
sum / = kProd ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
out [ b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3 ] = sum ;
}
}
}
}
2019-09-11 19:12:09 +02:00
}
/*************************************************************************/
2019-06-06 14:21:15 +02:00
else if ( poolingMode = = 2 ) { // pnorm
PRAGMA_OMP_PARALLEL_FOR_ARGS ( private ( pIn , sum , hstart , wstart , hend , wend ) collapse ( 2 ) )
for ( int b = 0 ; b < bS ; + + b ) {
2019-09-11 19:12:09 +02:00
for ( int c = 0 ; c < iC ; + + c ) {
2019-06-06 14:21:15 +02:00
for ( int oh = 0 ; oh < oH ; + + oh ) {
for ( int ow = 0 ; ow < oW ; + + ow ) {
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
pIn = in + b * iStride0 + c * iStride1 ;
hstart = oh * sH - pH ;
wstart = ow * sW - pW ;
hend = hstart + kHEff ;
wend = wstart + kWEff ;
if ( hstart < 0 )
hstart + = dH * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( - hstart ) / static_cast < T > ( dH ) ) ;
if ( wstart < 0 )
wstart + = dW * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( - wstart ) / static_cast < T > ( dW ) ) ;
if ( hend > iH )
hend - = dH * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( hend - iH ) / static_cast < T > ( dH ) ) ;
if ( wend > iW )
wend - = dW * ( Nd4jLong ) nd4j : : math : : nd4j_ceil < T , Nd4jLong > ( static_cast < T > ( wend - iW ) / static_cast < T > ( dW ) ) ;
hstart * = iStride2 ;
hend * = iStride2 ;
wstart * = iStride3 ;
wend * = iStride3 ;
sum = static_cast < T > ( 0. ) ;
2019-09-11 19:12:09 +02:00
for ( Nd4jLong kh = hstart ; kh < hend ; kh + = iStep2 )
2019-06-06 14:21:15 +02:00
for ( Nd4jLong kw = wstart ; kw < wend ; kw + = iStep3 )
sum + = nd4j : : math : : nd4j_pow < T , T , T > ( nd4j : : math : : nd4j_abs < T > ( pIn [ kh + kw ] ) , extraParam0 ) ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
sum = nd4j : : math : : nd4j_pow < T , T , T > ( sum , ( T ) 1. / extraParam0 ) ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
out [ b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3 ] = sum ;
}
}
}
}
}
else {
nd4j_printf ( " Special_ops::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead ! \n " , poolingMode ) ;
throw " " ;
}
}
op_def static T op ( T d1 , Z * params ) {
return d1 ;
}
/** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc
* normally negative indices are bad , OK here because of other checks on input indices
* Uses unrolled loop specifically for length 4
*/
static _CUDA_HD int getOffsetUnsafe4 ( int baseOffset , int * shape , int * stride , int * indices ) {
int offset = baseOffset ;
if ( shape [ 0 ] ! = 1 ) offset + = indices [ 0 ] * stride [ 0 ] ;
if ( shape [ 1 ] ! = 1 ) offset + = indices [ 1 ] * stride [ 1 ] ;
if ( shape [ 2 ] ! = 1 ) offset + = indices [ 2 ] * stride [ 2 ] ;
if ( shape [ 3 ] ! = 1 ) offset + = indices [ 3 ] * stride [ 3 ] ;
return offset ;
}
/**
* A version of Shape . getOffset without checking on input for negative indices etc
* normally negative indices are bad , OK here because of other checks on input indices
* Uses unrolled loop specifically for length 6 , where indices [ 2 ] and indices [ 3 ] are zero ( always are here )
*/
static _CUDA_HD int getOffsetUnsafe6 ( int baseOffset , int * shape , int * stride , int * indices ) {
int offset = baseOffset ;
if ( shape [ 0 ] ! = 1 ) offset + = indices [ 0 ] * stride [ 0 ] ;
if ( shape [ 1 ] ! = 1 ) offset + = indices [ 1 ] * stride [ 1 ] ;
if ( shape [ 4 ] ! = 1 ) offset + = indices [ 4 ] * stride [ 4 ] ;
if ( shape [ 5 ] ! = 1 ) offset + = indices [ 5 ] * stride [ 5 ] ;
return offset ;
}
} ;
FORCEINLINE bool is_a_ge_zero_and_a_lt_b ( int a , int b ) {
return static_cast < unsigned > ( a ) < static_cast < unsigned > ( b ) ;
}
template < typename T >
2019-09-11 19:12:09 +02:00
class
2019-06-06 14:21:15 +02:00
Im2col {
public :
static const bool requiresSpecial = true ;
static _CUDA_HD int outSize ( int size , int k , int s , int p , bool coverAll ) {
if ( coverAll )
return ( size + p * 2 - k + s - 1 ) / s + 1 ;
else
return ( size + p * 2 - k ) / s + 1 ;
}
# ifdef __CUDACC__
/**
* Based on : https : //github.com/pjreddie/darknet/blob/master/src/im2col_kernels.cu
*/
static inline __device__ void execSpecialCuda (
T * dx , Nd4jLong * xShapeBuffer ,
T * result , Nd4jLong * zShapeBuffer ,
2019-09-11 19:12:09 +02:00
T * extraParams ,
int * allocationPointer , T * reductionPointer ,
2019-06-06 14:21:15 +02:00
Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
/*kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], 0, false*/
__shared__ int kernelHeight , kernelWidth , strideY , strideX , padHeight , padWidth , dY , dX , kSize , samples , depth , height , width , strideex , stridech , strideh , stridew , height_col , width_col , n ;
__shared__ T zeroPadVal ;
__shared__ Nd4jLong * outShape , * outStride , * inShape , * inStride ;
__shared__ char resultOrder ;
if ( threadIdx . x = = 0 ) {
kernelHeight = ( int ) extraParams [ 0 ] ;
kernelWidth = ( int ) extraParams [ 1 ] ;
strideY = ( int ) extraParams [ 2 ] ;
strideX = ( int ) extraParams [ 3 ] ;
padHeight = ( int ) extraParams [ 4 ] ;
padWidth = ( int ) extraParams [ 5 ] ;
dY = ( int ) extraParams [ 6 ] ; //Dilation, height/y dimension
dX = ( int ) extraParams [ 7 ] ; //Dilation, width/x dimension
kSize = kernelWidth * kernelHeight ;
zeroPadVal = ( T ) extraParams [ 9 ] ; //Value to use when value is padding. Usually 0 but not always
outShape = shape : : shapeOf ( zShapeBuffer ) ;
resultOrder = shape : : order ( zShapeBuffer ) ;
outStride = shape : : stride ( zShapeBuffer ) ;
inShape = shape : : shapeOf ( xShapeBuffer ) ;
inStride = shape : : stride ( xShapeBuffer ) ;
samples = ( int ) inShape [ 0 ] ;
depth = ( int ) inShape [ 1 ] ;
height = ( int ) inShape [ 2 ] ;
width = ( int ) inShape [ 3 ] ;
strideex = ( int ) inStride [ 0 ] ;
stridech = ( int ) inStride [ 1 ] ;
strideh = ( int ) inStride [ 2 ] ;
stridew = ( int ) inStride [ 3 ] ;
// (height + 2 * padHeight - kernelHeight) / strideX + 1; //
// (width + 2 * padWidth - kernelWidth) / strideY + 1; //
height_col = ( int ) outShape [ 4 ] ;
width_col = ( int ) outShape [ 5 ] ;
n = samples * depth * height_col * width_col ;
}
__syncthreads ( ) ;
int index = blockIdx . x * blockDim . x + threadIdx . x ;
for ( ; index < n ; index + = blockDim . x * gridDim . x ) {
int h_index = index / width_col ;
int h_col = h_index % height_col ;
int w_col = index % width_col ;
int c_im = h_index / height_col ;
int c_col = c_im * kSize ;
int depth_im = c_im % depth ;
int num_im = c_im / depth ;
int h_offset = h_col * strideY - padHeight ;
int w_offset = w_col * strideX - padWidth ;
T * data_col_ptr = result ;
int i_c = ( c_col * height_col + h_col ) * width_col + w_col ;
data_col_ptr + = ( c_col * height_col + h_col ) * width_col + w_col ;
T * data_im_ptr = dx ;
data_im_ptr + = num_im * strideex + depth_im * stridech + h_offset * strideh + w_offset * stridew ;
for ( int i = 0 ; i < kernelHeight ; + + i ) {
for ( int j = 0 ; j < kernelWidth ; + + j ) {
int h_im = h_offset + i * dY ;
int w_im = w_offset + j * dX ;
int i_f = 0 ;
int i_c_temp = i_c ;
for ( int dim = 5 ; dim > = 0 ; dim - - ) {
i_f + = ( i_c_temp % outShape [ dim ] ) * outStride [ dim ] ;
i_c_temp = i_c_temp / outShape [ dim ] ;
}
if ( h_im > = 0 & & w_im > = 0 & & h_im < height & & w_im < width ) {
result [ i_f ] = data_im_ptr [ i * dY * strideh + j * dX * stridew ] ;
} else result [ i_f ] = zeroPadVal ;
//result[i_f] = (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ? data_im_ptr[i * strideh + j*stridew] : 0;
data_col_ptr + = height_col * width_col ;
i_c + = height_col * width_col ;
}
}
}
}
# endif
static void execSpecial (
T * imBuff ,
Nd4jLong * imShapeBuffer ,
T * colBuff ,
Nd4jLong * colShapeBuffer ,
T * extraParams , Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
/*kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], 0, false*/
2019-09-11 19:12:09 +02:00
// [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
2019-06-06 14:21:15 +02:00
int kH = ( int ) extraParams [ 0 ] ;
int kW = ( int ) extraParams [ 1 ] ;
int sH = ( int ) extraParams [ 2 ] ;
int sW = ( int ) extraParams [ 3 ] ;
int pH = ( int ) extraParams [ 4 ] ;
int pW = ( int ) extraParams [ 5 ] ;
int dH = ( int ) extraParams [ 6 ] ; //Dilation, height/y dimension
2019-09-11 19:12:09 +02:00
int dW = ( int ) extraParams [ 7 ] ; //Dilation, width/x dimension
2019-06-06 14:21:15 +02:00
T zeroPadVal = extraParams [ 9 ] ;
auto colShape = shape : : shapeOf ( colShapeBuffer ) ;
auto colStride = shape : : stride ( colShapeBuffer ) ;
auto imShape = shape : : shapeOf ( imShapeBuffer ) ;
auto imStride = shape : : stride ( imShapeBuffer ) ;
const int bS = imShape [ 0 ] ;
const int iC = imShape [ 1 ] ;
const int iH = imShape [ 2 ] ;
const int iW = imShape [ 3 ] ;
const int oH = colShape [ 4 ] ;
const int oW = colShape [ 5 ] ;
const Nd4jLong colStride0 = colStride [ 0 ] ;
const Nd4jLong colStride1 = colStride [ 1 ] ;
const Nd4jLong colStride2 = colStride [ 2 ] ;
const Nd4jLong colStride3 = colStride [ 3 ] ;
const Nd4jLong colStride4 = colStride [ 4 ] ;
const Nd4jLong colStride5 = colStride [ 5 ] ;
const Nd4jLong imStride0 = imStride [ 0 ] ;
const Nd4jLong imStride1 = imStride [ 1 ] ;
const Nd4jLong imStride2 = imStride [ 2 ] ;
const Nd4jLong imStride3 = imStride [ 3 ] ;
T * col , * im ;
int imRow , imCol ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
if ( shape : : order ( imShapeBuffer ) = = ' c ' & & shape : : order ( colShapeBuffer ) = = ' c ' & & shape : : strideDescendingCAscendingF ( imShapeBuffer ) & & shape : : strideDescendingCAscendingF ( colShapeBuffer ) ) {
PRAGMA_OMP_PARALLEL_FOR_ARGS ( private ( col , im , imRow , imCol ) collapse ( 2 ) )
for ( int b = 0 ; b < bS ; b + + ) {
2019-09-11 19:12:09 +02:00
for ( int c = 0 ; c < iC ; + + c ) {
for ( int kRow = 0 ; kRow < kH ; + + kRow ) {
for ( int kCol = 0 ; kCol < kW ; + + kCol ) {
2019-06-06 14:21:15 +02:00
for ( int colH = 0 ; colH < oH ; + + colH ) {
2019-09-11 19:12:09 +02:00
for ( int colW = 0 ; colW < oW ; + + colW ) {
2019-06-06 14:21:15 +02:00
imRow = ( - pH + kRow * dH ) + colH * sH ;
imCol = ( - pW + kCol * dW ) + colW * sW ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
col = colBuff + b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5 ;
2019-09-11 19:12:09 +02:00
im = imBuff + b * imStride0 + c * imStride1 + imRow * imStride2 + imCol * imStride3 ;
2019-06-06 14:21:15 +02:00
if ( static_cast < unsigned > ( imRow ) > = static_cast < unsigned > ( iH ) | | static_cast < unsigned > ( imCol ) > = static_cast < unsigned > ( iW ) )
* col = zeroPadVal ;
2019-09-11 19:12:09 +02:00
else
2019-06-06 14:21:15 +02:00
* col = * im ;
}
}
}
}
}
2019-09-11 19:12:09 +02:00
}
2019-06-06 14:21:15 +02:00
}
else {
PRAGMA_OMP_PARALLEL_FOR_ARGS ( private ( im , col , imRow , imCol ) collapse ( 2 ) )
for ( int b = 0 ; b < bS ; b + + ) {
for ( int colH = 0 ; colH < oH ; + + colH ) {
for ( int colW = 0 ; colW < oW ; + + colW ) {
for ( int c = 0 ; c < iC ; + + c ) {
2019-09-11 19:12:09 +02:00
for ( int kRow = 0 ; kRow < kH ; + + kRow ) {
for ( int kCol = 0 ; kCol < kW ; + + kCol ) {
2019-06-06 14:21:15 +02:00
imRow = ( - pH + kRow * dH ) + colH * sH ;
imCol = ( - pW + kCol * dW ) + colW * sW ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
col = colBuff + b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5 ;
im = imBuff + b * imStride0 + c * imStride1 + imRow * imStride2 + imCol * imStride3 ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
if ( static_cast < unsigned > ( imRow ) > = static_cast < unsigned > ( iH ) | | static_cast < unsigned > ( imCol ) > = static_cast < unsigned > ( iW ) )
* col = zeroPadVal ;
2019-09-11 19:12:09 +02:00
else
2019-06-06 14:21:15 +02:00
* col = * im ;
}
}
}
}
}
}
}
}
op_def static T op ( T d1 , T * params ) {
return d1 ;
}
/** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc
* normally negative indices are bad , OK here because of other checks on input indices
* Uses unrolled loop specifically for length 4
*/
static _CUDA_HD int getOffsetUnsafe4 ( int baseOffset , int * shape , int * stride , int * indices ) {
int offset = baseOffset ;
if ( shape [ 0 ] ! = 1 ) offset + = indices [ 0 ] * stride [ 0 ] ;
if ( shape [ 1 ] ! = 1 ) offset + = indices [ 1 ] * stride [ 1 ] ;
if ( shape [ 2 ] ! = 1 ) offset + = indices [ 2 ] * stride [ 2 ] ;
if ( shape [ 3 ] ! = 1 ) offset + = indices [ 3 ] * stride [ 3 ] ;
return offset ;
}
/**
* A version of Shape . getOffset without checking on input for negative indices etc
* normally negative indices are bad , OK here because of other checks on input indices
* Uses unrolled loop specifically for length 6 , where indices [ 2 ] and indices [ 3 ] are zero ( always are here )
*/
static _CUDA_HD int getOffsetUnsafe6 ( int baseOffset , int * shape , int * stride , int * indices ) {
int offset = baseOffset ;
if ( shape [ 0 ] ! = 1 ) offset + = indices [ 0 ] * stride [ 0 ] ;
if ( shape [ 1 ] ! = 1 ) offset + = indices [ 1 ] * stride [ 1 ] ;
if ( shape [ 4 ] ! = 1 ) offset + = indices [ 4 ] * stride [ 4 ] ;
if ( shape [ 5 ] ! = 1 ) offset + = indices [ 5 ] * stride [ 5 ] ;
return offset ;
}
} ;
template < typename T , typename Z >
class Histogram {
public :
static const bool requiresSpecial = true ;
# ifdef __CUDACC__
static inline __device__ void execSpecialCuda (
T * dx , Nd4jLong * xShapeBuffer ,
Z * result , Nd4jLong * zShapeBuffer ,
2019-09-11 19:12:09 +02:00
Z * extraParams ,
int * allocationPointer , Z * reductionPointer ,
2019-06-06 14:21:15 +02:00
Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
} ;
# endif
static void execSpecial (
T * dx ,
Nd4jLong * xShapeBuffer ,
Z * result ,
Nd4jLong * zShapeBuffer ,
Z * extraParams , Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
}
op_def static T op ( T d1 , Z * params ) {
return d1 ;
}
} ;
template < typename X >
class Col2Im {
public :
static const bool requiresSpecial = true ;
# ifdef __CUDACC__
/**
* https : //github.com/pjreddie/darknet/blob/master/src/col2im_kernels.cu
*/
static inline __device__ void execSpecialCuda (
X * dx , Nd4jLong * xShapeBuffer ,
X * result , Nd4jLong * zShapeBuffer ,
2019-09-11 19:12:09 +02:00
X * extraParams , int * allocationPointer ,
X * reductionPointer ,
2019-06-06 14:21:15 +02:00
Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
__shared__ int strideex , stridech , stridekrow , stridekcol , striderow , stridecol , kernelHeight , kernelWidth , strideY , strideX , padHeight , padWidth , imgHeight , imgWidth , dY , dX , samples , depth , imgH , imgW , height_col , width_col , n , kEffectiveW , kEffectiveH ;
__shared__ Nd4jLong * inShape , * inStride , * outShape , * outStride ;
__shared__ char resultOrder ;
if ( threadIdx . x = = 0 ) {
inShape = shape : : shapeOf ( xShapeBuffer ) ;
inStride = shape : : stride ( xShapeBuffer ) ;
strideex = ( int ) inStride [ 0 ] ;
stridech = ( int ) inStride [ 1 ] ;
stridekrow = ( int ) inStride [ 2 ] ;
stridekcol = ( int ) inStride [ 3 ] ;
striderow = ( int ) inStride [ 4 ] ;
stridecol = ( int ) inStride [ 5 ] ;
kernelHeight = ( int ) inShape [ 2 ] ;
kernelWidth = ( int ) inShape [ 3 ] ;
strideY = ( int ) extraParams [ 0 ] ;
strideX = ( int ) extraParams [ 1 ] ;
padHeight = ( int ) extraParams [ 2 ] ;
padWidth = ( int ) extraParams [ 3 ] ;
imgHeight = ( int ) extraParams [ 4 ] ;
imgWidth = ( int ) extraParams [ 5 ] ;
dY = ( int ) extraParams [ 6 ] ; //Dilation in height/y dimension
dX = ( int ) extraParams [ 7 ] ; //Dilation in width/x dimension
outShape = shape : : shapeOf ( zShapeBuffer ) ;
resultOrder = shape : : order ( zShapeBuffer ) ;
outStride = shape : : stride ( zShapeBuffer ) ;
samples = ( int ) outShape [ 0 ] ;
depth = ( int ) outShape [ 1 ] ;
imgH = ( int ) outShape [ 2 ] ;
imgW = ( int ) outShape [ 3 ] ;
height_col = inShape [ 4 ] ; //(imgHeight + 2 * padHeight - kernelHeight) / strideX + 1;
width_col = inShape [ 5 ] ; //(imgWidth + 2 * padWidth - kernelWidth) / strideY + 1;
n = samples * depth * imgHeight * imgWidth ;
//Effective kernel size, accounting for dilation
kEffectiveW = kernelWidth + ( kernelWidth - 1 ) * ( dX - 1 ) ;
kEffectiveH = kernelHeight + ( kernelHeight - 1 ) * ( dY - 1 ) ;
}
__syncthreads ( ) ;
for ( int i = ( blockDim . x * blockIdx . x ) + threadIdx . x ; i < n ; i + = blockDim . x * gridDim . x ) {
X val = 0 ;
int w_im = i % imgWidth + padWidth ;
int h_im = ( i / imgWidth ) % imgHeight + padHeight ;
int c_im = i / ( imgWidth * imgHeight ) ;
int num_im = c_im / depth ;
int depth_im = c_im % depth ;
// compute the start and end of the output
// These are the indexes for dimensions ??? in the 6d col matrix
int w_col_start = ( w_im < kEffectiveW ) ? 0 : ( w_im - kEffectiveW ) / strideX + 1 ;
int w_col_end = nd4j : : math : : nd4j_min < int > ( w_im / strideX + 1 , width_col ) ;
int h_col_start = ( h_im < kEffectiveH ) ? 0 : ( h_im - kEffectiveH ) / strideY + 1 ;
int h_col_end = nd4j : : math : : nd4j_min < int > ( h_im / strideY + 1 , height_col ) ;
//Iterate over col entries in the 6d array... these are added up
for ( int h_col = h_col_start ; h_col < h_col_end ; h_col + = 1 ) {
for ( int w_col = w_col_start ; w_col < w_col_end ; w_col + = 1 ) {
int h_k = ( h_im - h_col * strideY ) ;
int w_k = ( w_im - w_col * strideX ) ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
if ( h_k % dY = = 0 & & w_k % dX = = 0 ) {
h_k / = dY ;
w_k / = dX ;
int data_col_index = num_im * strideex + depth_im * stridech + h_k * stridekrow + w_k * stridekcol + h_col * striderow + w_col * stridecol ;
val + = dx [ data_col_index ] ;
}
}
}
int i_f = 0 ;
int i_c = i ;
for ( int dim = 3 ; dim > = 0 ; dim - - )
{
i_f + = ( i_c % outShape [ dim ] ) * outStride [ dim ] ;
i_c = i_c / outShape [ dim ] ;
}
result [ i_f ] = val ;
}
}
# endif
static void execSpecial (
X * colBuff ,
Nd4jLong * colShapeBuffer ,
X * imBuff ,
Nd4jLong * imShapeBuffer ,
X * extraParams ,
Nd4jLong * tadShapeInfo ,
Nd4jLong * tadOffsets ) {
// [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
auto colShape = shape : : shapeOf ( colShapeBuffer ) ;
auto colStride = shape : : stride ( colShapeBuffer ) ;
auto imShape = shape : : shapeOf ( imShapeBuffer ) ;
2019-09-11 19:12:09 +02:00
auto imStride = shape : : stride ( imShapeBuffer ) ;
2019-06-06 14:21:15 +02:00
const int sH = ( int ) extraParams [ 0 ] ;
const int sW = ( int ) extraParams [ 1 ] ;
const int pH = ( int ) extraParams [ 2 ] ;
const int pW = ( int ) extraParams [ 3 ] ;
const int iH = ( int ) extraParams [ 4 ] ;
const int iW = ( int ) extraParams [ 5 ] ;
2019-09-11 19:12:09 +02:00
const int dH = ( int ) extraParams [ 6 ] ;
const int dW = ( int ) extraParams [ 7 ] ;
2019-06-06 14:21:15 +02:00
const int bS = imShape [ 0 ] ;
const int iC = imShape [ 1 ] ;
const int kH = colShape [ 2 ] ;
2019-09-11 19:12:09 +02:00
const int kW = colShape [ 3 ] ;
2019-06-06 14:21:15 +02:00
const int oH = colShape [ 4 ] ;
const int oW = colShape [ 5 ] ;
const Nd4jLong colStride0 = colStride [ 0 ] ;
const Nd4jLong colStride1 = colStride [ 1 ] ;
const Nd4jLong colStride2 = colStride [ 2 ] ;
const Nd4jLong colStride3 = colStride [ 3 ] ;
const Nd4jLong colStride4 = colStride [ 4 ] ;
const Nd4jLong colStride5 = colStride [ 5 ] ;
const Nd4jLong imStride0 = imStride [ 0 ] ;
const Nd4jLong imStride1 = imStride [ 1 ] ;
const Nd4jLong imStride2 = imStride [ 2 ] ;
const Nd4jLong imStride3 = imStride [ 3 ] ;
auto zLength = shape : : length ( imShapeBuffer ) ;
// initial zeroing of image content
memset ( imBuff , 0 , zLength * sizeof ( X ) ) ;
X * col , * im ;
int imRow , imCol ;
if ( shape : : order ( colShapeBuffer ) = = ' c ' & & shape : : order ( imShapeBuffer ) = = ' c ' & & shape : : strideDescendingCAscendingF ( colShapeBuffer ) & & shape : : strideDescendingCAscendingF ( imShapeBuffer ) ) {
PRAGMA_OMP_PARALLEL_FOR_ARGS ( private ( col , im , imRow , imCol ) collapse ( 2 ) )
2019-09-11 19:12:09 +02:00
for ( int b = 0 ; b < bS ; b + + ) {
for ( int c = 0 ; c < iC ; + + c ) {
for ( int kRow = 0 ; kRow < kH ; + + kRow ) {
for ( int kCol = 0 ; kCol < kW ; + + kCol ) {
2019-06-06 14:21:15 +02:00
for ( int colH = 0 ; colH < oH ; + + colH ) {
2019-09-11 19:12:09 +02:00
for ( int colW = 0 ; colW < oW ; + + colW ) {
2019-06-06 14:21:15 +02:00
imRow = ( - pH + kRow * dH ) + colH * sH ;
imCol = ( - pW + kCol * dW ) + colW * sW ;
col = colBuff + b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5 ;
im = imBuff + b * imStride0 + c * imStride1 + imRow * imStride2 + imCol * imStride3 ;
if ( static_cast < unsigned > ( imRow ) < static_cast < unsigned > ( iH ) & & static_cast < unsigned > ( imCol ) < static_cast < unsigned > ( iW ) )
* im + = * col ;
}
}
}
}
}
2019-09-11 19:12:09 +02:00
}
2019-06-06 14:21:15 +02:00
}
else {
PRAGMA_OMP_PARALLEL_FOR_ARGS ( private ( im , col , imRow , imCol ) )
2019-09-11 19:12:09 +02:00
for ( int b = 0 ; b < bS ; b + + ) {
2019-06-06 14:21:15 +02:00
for ( int colH = 0 ; colH < oH ; + + colH ) {
for ( int colW = 0 ; colW < oW ; + + colW ) {
2019-09-11 19:12:09 +02:00
for ( int c = 0 ; c < iC ; + + c ) {
for ( int kRow = 0 ; kRow < kH ; + + kRow ) {
for ( int kCol = 0 ; kCol < kW ; + + kCol ) {
2019-06-06 14:21:15 +02:00
imRow = ( - pH + kRow * dH ) + colH * sH ;
imCol = ( - pW + kCol * dW ) + colW * sW ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
col = colBuff + b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5 ;
im = imBuff + b * imStride0 + c * imStride1 + imRow * imStride2 + imCol * imStride3 ;
if ( static_cast < unsigned > ( imRow ) < static_cast < unsigned > ( iH ) & & static_cast < unsigned > ( imCol ) < static_cast < unsigned > ( iW ) )
* im + = * col ;
}
}
}
2019-09-11 19:12:09 +02:00
}
2019-06-06 14:21:15 +02:00
}
2019-09-11 19:12:09 +02:00
}
2019-06-06 14:21:15 +02:00
}
}
op_def static X op ( X d1 , X * params ) {
return d1 ;
}
/** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc
* normally negative indices are bad , OK here because of other checks on input indices
* Uses unrolled loop specifically for length 4
*/
static _CUDA_HD int getOffsetUnsafe4 ( int baseOffset , int * shape , int * stride , int * indices ) {
int offset = baseOffset ;
if ( shape [ 0 ] ! = 1 ) offset + = indices [ 0 ] * stride [ 0 ] ;
if ( shape [ 1 ] ! = 1 ) offset + = indices [ 1 ] * stride [ 1 ] ;
if ( shape [ 2 ] ! = 1 ) offset + = indices [ 2 ] * stride [ 2 ] ;
if ( shape [ 3 ] ! = 1 ) offset + = indices [ 3 ] * stride [ 3 ] ;
return offset ;
}
/** A version of Shape.getOffset without checking on input for negative indices etc
* normally negative indices are bad , OK here because of other checks on input indices
* Uses unrolled loop specifically for length 6 , where indices [ 2 ] and indices [ 3 ] are zero ( always are here )
*/
static _CUDA_HD int getOffsetUnsafe6 ( int baseOffset , int * shape , int * stride , int * indices ) {
int offset = baseOffset ;
if ( shape [ 0 ] ! = 1 ) offset + = indices [ 0 ] * stride [ 0 ] ;
if ( shape [ 1 ] ! = 1 ) offset + = indices [ 1 ] * stride [ 1 ] ;
if ( shape [ 4 ] ! = 1 ) offset + = indices [ 4 ] * stride [ 4 ] ;
if ( shape [ 5 ] ! = 1 ) offset + = indices [ 5 ] * stride [ 5 ] ;
return offset ;
}
} ;
template < typename X >
class Reverse {
public :
static const bool requiresSpecial = true ;
# ifdef __CUDACC__
2019-09-11 19:12:09 +02:00
static inline __device__ void execSpecialCuda ( X * dx , Nd4jLong * xShapeBuffer ,
X * result , Nd4jLong * zShapeBuffer ,
X * extraParams , int * allocationPointer ,
X * reductionPointer ,
2019-06-06 14:21:15 +02:00
Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
__shared__ Nd4jLong xLength ;
__shared__ int xEWS ;
__shared__ char xOrder ;
__shared__ Nd4jLong sLength ;
__shared__ X * shmem ;
int tid = threadIdx . x + blockIdx . x * blockDim . x ;
if ( threadIdx . x = = 0 ) {
xLength = shape : : length ( xShapeBuffer ) ;
xEWS = shape : : elementWiseStride ( xShapeBuffer ) ;
xOrder = shape : : order ( xShapeBuffer ) ;
sLength = xLength - 1 ;
extern __shared__ unsigned char shrd [ ] ;
shmem = ( X * ) shrd ;
}
__syncthreads ( ) ;
if ( dx = = result ) {
if ( xEWS = = 1 ) {
for ( int e = tid ; e < xLength / 2 ; e + = blockDim . x * gridDim . x ) {
Nd4jLong idx = sLength - e ;
X tmp = dx [ e ] ;
dx [ e ] = dx [ idx ] ;
dx [ idx ] = tmp ;
}
} else if ( xEWS > = 1 ) {
for ( int e = tid ; e < xLength / 2 ; e + = blockDim . x * gridDim . x ) {
Nd4jLong idx1 = ( sLength - e ) * xEWS ;
Nd4jLong idx2 = e * xEWS ;
X tmp = dx [ idx2 ] ;
dx [ idx2 ] = dx [ idx1 ] ;
dx [ idx1 ] = tmp ;
}
2019-09-11 19:12:09 +02:00
}
else {
2019-06-06 14:21:15 +02:00
for ( int e = tid ; e < xLength / 2 ; e + = blockDim . x * gridDim . x ) {
2019-09-11 19:12:09 +02:00
auto xOffset = shape : : getIndexOffset ( e , xShapeBuffer ) ;
auto zOffset = shape : : getIndexOffset ( sLength - e , xShapeBuffer ) ;
2019-06-06 14:21:15 +02:00
result [ zOffset ] = dx [ xOffset ] ;
}
}
} else {
__shared__ int zEWS ;
__shared__ char zOrder ;
if ( threadIdx . x = = 0 ) {
zEWS = shape : : elementWiseStride ( zShapeBuffer ) ;
zOrder = shape : : order ( zShapeBuffer ) ;
}
__syncthreads ( ) ;
if ( xEWS = = 1 & & zEWS = = 1 & & xOrder = = zOrder ) {
// loop for whole array
for ( int e = tid ; e < xLength ; e + = blockDim . x * gridDim . x ) {
result [ sLength - e ] = dx [ e ] ;
}
} else if ( xEWS > = 1 & & zEWS > = 1 & & xOrder = = zOrder ) {
for ( int e = tid ; e < xLength ; e + = blockDim . x * gridDim . x ) {
result [ ( sLength - e ) * zEWS ] = dx [ e * xEWS ] ;
}
2019-09-11 19:12:09 +02:00
}
else {
2019-06-06 14:21:15 +02:00
for ( int e = tid ; e < xLength ; e + = blockDim . x * gridDim . x ) {
2019-09-11 19:12:09 +02:00
auto xOffset = shape : : getIndexOffset ( e , xShapeBuffer ) ;
auto zOffset = shape : : getIndexOffset ( sLength - e , xShapeBuffer ) ;
2019-06-06 14:21:15 +02:00
result [ zOffset ] = dx [ xOffset ] ;
}
}
}
}
# endif
static void execSpecial ( X * dx , Nd4jLong * xShapeBuffer , X * result , Nd4jLong * zShapeBuffer , X * extraParams , Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
Nd4jLong xLength = shape : : length ( xShapeBuffer ) ;
int xEWS = shape : : elementWiseStride ( xShapeBuffer ) ;
char xOrder = shape : : order ( xShapeBuffer ) ;
Nd4jLong sLength = xLength - 1 ;
// two step phase here
if ( dx = = result ) {
if ( xEWS = = 1 ) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for ( Nd4jLong e = 0 ; e < xLength / 2 ; e + + ) {
Nd4jLong idx = sLength - e ;
auto tmp = dx [ e ] ;
dx [ e ] = dx [ idx ] ;
dx [ idx ] = tmp ;
}
} else if ( xEWS > 1 ) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for ( Nd4jLong e = 0 ; e < xLength / 2 ; e + + ) {
Nd4jLong idx1 = ( sLength - e ) * xEWS ;
Nd4jLong idx2 = e * xEWS ;
auto tmp = dx [ idx2 ] ;
dx [ idx2 ] = dx [ idx1 ] ;
dx [ idx1 ] = tmp ;
}
2019-09-11 19:12:09 +02:00
}
2019-06-06 14:21:15 +02:00
else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
2019-09-11 19:12:09 +02:00
for ( Nd4jLong e = 0 ; e < xLength / 2 ; e + + ) {
auto xOffset = shape : : getIndexOffset ( e , xShapeBuffer ) ;
auto zOffset = shape : : getIndexOffset ( sLength - e , xShapeBuffer ) ;
2019-06-06 14:21:15 +02:00
result [ zOffset ] = dx [ xOffset ] ;
}
}
} else {
// single step phase here
auto zEWS = shape : : elementWiseStride ( zShapeBuffer ) ;
auto zOrder = shape : : order ( zShapeBuffer ) ;
if ( xEWS = = 1 & & zEWS = = 1 & & xOrder = = zOrder ) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for ( Nd4jLong e = 0 ; e < xLength ; e + + ) {
result [ sLength - e ] = dx [ e ] ;
}
} else if ( xEWS > = 1 & & zEWS > = 1 & & xOrder = = zOrder ) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for ( Nd4jLong e = 0 ; e < xLength ; e + + ) {
result [ ( sLength - e ) * zEWS ] = dx [ e * xEWS ] ;
}
2019-09-11 19:12:09 +02:00
}
2019-06-06 14:21:15 +02:00
else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for ( Nd4jLong e = 0 ; e < xLength ; e + + ) {
2019-09-11 19:12:09 +02:00
auto xOffset = shape : : getIndexOffset ( e , xShapeBuffer ) ;
auto zOffset = shape : : getIndexOffset ( sLength - e , zShapeBuffer ) ;
2019-06-06 14:21:15 +02:00
result [ zOffset ] = dx [ xOffset ] ;
}
}
}
}
op_def static X op ( X d1 , X * params ) {
return d1 ;
}
} ;
template < typename X >
class SoftMax {
public :
static const bool requiresSpecial = true ;
# ifdef __CUDACC__
/**
*
*/
static inline __device__ void execSpecialCuda (
void * vx , Nd4jLong * xShapeBuffer ,
void * vresult , Nd4jLong * zShapeBuffer ,
void * vextraParams ,
2019-09-11 19:12:09 +02:00
int * allocationPointer , void * reductionPointer ,
2019-06-06 14:21:15 +02:00
Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
auto dx = reinterpret_cast < X * > ( vx ) ;
auto result = reinterpret_cast < X * > ( vresult ) ;
auto extraParams = reinterpret_cast < X * > ( vextraParams ) ;
auto shape = shape : : shapeOf ( xShapeBuffer ) ;
__shared__ X maxResult ;
__shared__ Nd4jLong * maxResultShapeBuffer ;
auto length = shape : : length ( xShapeBuffer ) ;
auto stride = shape : : stride ( xShapeBuffer ) ;
//compute the row wise maxes
__shared__ Nd4jLong maxShape [ 2 ] ;
// it's always 2d here
__shared__ Nd4jLong tempBuffer [ 8 ] ;
if ( threadIdx . x = = 0 ) {
maxResult = ( X ) 0.0 ;
maxShape [ 0 ] = shape [ 0 ] ;
maxShape [ 1 ] = 1 ;
maxResultShapeBuffer = shape : : shapeBuffer ( 2 , nd4j : : DataTypeUtils : : fromT < X > ( ) , maxShape , tempBuffer ) ;
}
__syncthreads ( ) ;
functions : : reduce : : ReduceSameInplace < X > : : execScalarCudaLegacy ( nd4j : : reduce : : Max , dx , xShapeBuffer , extraParams , & maxResult , maxResultShapeBuffer , reductionPointer , nullptr ) ;
__syncthreads ( ) ;
//subtract max of each row
functions : : scalar : : ScalarInplace < X , X , X > : : transformCudaLegacy ( nd4j : : scalar : : Subtract , & maxResult , dx , xShapeBuffer , extraParams , result , zShapeBuffer , allocationPointer ) ;
__syncthreads ( ) ;
//after subtracting the row wise maxes take the exp
functions : : transform : : TransformStrictInplace < X > : : transformCudaLegacy ( nd4j : : transform : : Exp , result , zShapeBuffer , extraParams , result , zShapeBuffer , allocationPointer , reductionPointer , tadShapeInfo , tadOffsets ) ;
__syncthreads ( ) ;
//take the sum for the exponential
functions : : reduce : : ReduceSameInplace < X > : : execScalarCudaLegacy ( nd4j : : reduce : : Sum , result , zShapeBuffer , extraParams , & maxResult , maxResultShapeBuffer , reductionPointer , nullptr ) ;
__syncthreads ( ) ;
//divide by the sum
functions : : scalar : : ScalarInplace < X , X , X > : : transformCudaLegacy ( nd4j : : scalar : : Divide , & maxResult , result , zShapeBuffer , extraParams , result , zShapeBuffer , allocationPointer ) ;
}
# endif
static void execSpecial (
void * vx ,
Nd4jLong * xShapeInfo ,
void * vz ,
Nd4jLong * zShapeInfo ,
void * vextraParams ,
Nd4jLong * tadShapeInfo ,
Nd4jLong * tadOffsets ) {
auto x = reinterpret_cast < X * > ( vx ) ;
auto z = reinterpret_cast < X * > ( vz ) ;
auto extraParams = reinterpret_cast < X * > ( vextraParams ) ;
if ( shape : : isMatrix ( xShapeInfo ) ) {
if ( shape : : equalsStrict ( xShapeInfo , zShapeInfo ) ) {
if ( tadShapeInfo = = nullptr ) {
auto tadPack = nd4j : : ConstantTadHelper : : getInstance ( ) - > tadForDimensions ( xShapeInfo , 1 ) ;
tadShapeInfo = tadPack . primaryShapeInfo ( ) ;
tadOffsets = tadPack . primaryOffsets ( ) ;
}
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
const uint tadLen = shape : : length ( tadShapeInfo ) ;
const uint numOfTads = shape : : length ( xShapeInfo ) / tadLen ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
if ( shape : : elementWiseStride ( tadShapeInfo ) = = 1 ) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for ( uint i = 0 ; i < numOfTads ; + + i ) {
X * inBuff = x + tadOffsets [ i ] ;
X * outBuff = z + tadOffsets [ i ] ;
X max = - nd4j : : DataTypeUtils : : max < X > ( ) ;
X sum = 0 ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
for ( uint j = 0 ; j < tadLen ; + + j )
2019-09-11 19:12:09 +02:00
max = nd4j : : math : : nd4j_max < X > ( max , inBuff [ j ] ) ;
2019-06-06 14:21:15 +02:00
for ( uint j = 0 ; j < tadLen ; + + j ) {
X temp = nd4j : : math : : nd4j_exp < X , X > ( inBuff [ j ] - max ) ;
outBuff [ j ] = temp ;
sum + = temp ;
}
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
for ( uint j = 0 ; j < tadLen ; + + j )
2019-09-11 19:12:09 +02:00
outBuff [ j ] / = sum ;
2019-06-06 14:21:15 +02:00
}
}
else {
uint xShapeInfoCast [ MAX_RANK ] ;
bool canCast = nd4j : : DataTypeUtils : : castShapeInfo ( tadShapeInfo , xShapeInfoCast ) ;
auto offsets = new Nd4jLong [ tadLen ] ;
shape : : calcOffsets ( tadShapeInfo , offsets ) ;
PRAGMA_OMP_PARALLEL_FOR_SIMD
2019-09-11 19:12:09 +02:00
for ( uint i = 0 ; i < numOfTads ; + + i ) {
2019-06-06 14:21:15 +02:00
X * inBuff = x + tadOffsets [ i ] ;
X * outBuff = z + tadOffsets [ i ] ;
X max = - nd4j : : DataTypeUtils : : max < X > ( ) ;
2019-09-11 19:12:09 +02:00
X sum = 0.f ;
for ( uint j = 0 ; j < tadLen ; + + j )
max = nd4j : : math : : nd4j_max < X > ( max , inBuff [ offsets [ j ] ] ) ;
2019-06-06 14:21:15 +02:00
for ( uint j = 0 ; j < tadLen ; + + j ) {
X temp = nd4j : : math : : nd4j_exp < X , X > ( inBuff [ offsets [ j ] ] - max ) ;
outBuff [ offsets [ j ] ] = temp ;
sum + = temp ;
}
for ( uint j = 0 ; j < tadLen ; + + j )
outBuff [ offsets [ j ] ] / = sum ;
}
delete [ ] offsets ;
}
}
else {
auto shape = shape : : shapeOf ( xShapeInfo ) ;
//iterate along rows
int dimension [ 1 ] = { 0 } ;
int maxDimension [ 1 ] = { 1 } ;
//compute the row wise maxes
auto maxResult = new X [ shape [ 0 ] ] ;
for ( int i = 0 ; i < shape [ 0 ] ; i + + )
maxResult [ i ] = 0.0 ;
Nd4jLong maxShape [ 2 ] = { shape [ 0 ] , 1 } ;
auto maxResultShapeBuffer = shape : : shapeBuffer ( 2 , nd4j : : DataTypeUtils : : fromT < X > ( ) , maxShape ) ;
functions : : reduce : : ReduceSameFunction < X > : : exec ( nd4j : : reduce : : Max , x , xShapeInfo , extraParams , maxResult , maxResultShapeBuffer , maxDimension , 1 , nullptr , nullptr ) ;
//subtract max of each row
functions : : broadcast : : Broadcast < X , X , X > : : exec ( nd4j : : broadcast : : Subtract , x , xShapeInfo , maxResult , maxResultShapeBuffer , z , zShapeInfo , dimension , 1 , nullptr , nullptr , nullptr , nullptr ) ;
//after subtracting the row wise maxes take the exp
functions : : transform : : TransformStrict < X > : : exec ( nd4j : : transform : : Exp , z , zShapeInfo , z , zShapeInfo , extraParams , tadShapeInfo , tadOffsets ) ;
//take the sum for the exponential
functions : : reduce : : ReduceSameFunction < X > : : exec ( nd4j : : reduce : : Sum , z , zShapeInfo , extraParams , maxResult , maxResultShapeBuffer , maxDimension , 1 , nullptr , nullptr ) ;
//divide by the sum
functions : : broadcast : : Broadcast < X , X , X > : : exec ( nd4j : : broadcast : : Divide , z , zShapeInfo , maxResult , maxResultShapeBuffer , z , zShapeInfo , dimension , 1 , nullptr , nullptr , nullptr , nullptr ) ;
delete [ ] maxResultShapeBuffer ;
delete [ ] maxResult ;
2019-09-11 19:12:09 +02:00
}
2019-06-06 14:21:15 +02:00
}
else if ( shape : : isVector ( xShapeInfo ) ) {
auto max = - nd4j : : DataTypeUtils : : max < X > ( ) ;
X sum = 0 ;
int elementWiseStride = shape : : elementWiseStride ( xShapeInfo ) ;
int resultElementWiseStride = shape : : elementWiseStride ( zShapeInfo ) ;
int length = shape : : length ( xShapeInfo ) ;
if ( elementWiseStride > = 1 & & resultElementWiseStride > = 1 ) {
if ( elementWiseStride = = 1 & & resultElementWiseStride = = 1 ) {
for ( int i = 0 ; i < length ; i + + ) {
max = nd4j : : math : : nd4j_max < X > ( max , x [ i ] ) ;
}
for ( int i = 0 ; i < length ; i + + ) {
z [ i ] = nd4j : : math : : nd4j_exp < X , X > ( x [ i ] - max ) ;
sum + = z [ i ] ;
}
PRAGMA_OMP_SIMD
for ( int i = 0 ; i < length ; i + + ) {
z [ i ] / = sum ;
}
}
else {
for ( int i = 0 ; i < length ; i + + ) {
max = nd4j : : math : : nd4j_max < X > ( max , x [ i * elementWiseStride ] ) ;
}
for ( int i = 0 ; i < length ; i + + ) {
auto r = nd4j : : math : : nd4j_exp < X , X > ( x [ i * elementWiseStride ] - max ) ;
z [ i * resultElementWiseStride ] = r ;
sum + = r ;
}
for ( int i = 0 ; i < length ; i + + ) {
z [ i * resultElementWiseStride ] / = sum ;
}
}
}
}
}
op_def static X op ( X d1 , X * params ) {
return d1 ;
}
} ;
template < typename X >
class LogSoftMax {
public :
static const bool requiresSpecial = true ;
# ifdef __CUDACC__
/**
*
*/
static inline __device__ void execSpecialCuda (
void * vx , Nd4jLong * xShapeBuffer ,
void * vresult , Nd4jLong * zShapeBuffer ,
void * vextraParams ,
2019-09-11 19:12:09 +02:00
int * allocationPointer , void * reductionPointer ,
2019-06-06 14:21:15 +02:00
Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
auto shape = shape : : shapeOf ( xShapeBuffer ) ;
auto stride = shape : : stride ( xShapeBuffer ) ;
//iterate along rows
auto dx = reinterpret_cast < X * > ( vx ) ;
auto result = reinterpret_cast < X * > ( vresult ) ;
auto extraParams = reinterpret_cast < X * > ( vextraParams ) ;
__shared__ X maxResult ;
__shared__ Nd4jLong * maxResultShapeBuffer ;
if ( threadIdx . x = = 0 ) {
maxResult = ( X ) 0.0 ;
}
__syncthreads ( ) ;
//compute the row wise maxes
Nd4jLong maxShape [ 2 ] = { shape [ 0 ] , 1 } ;
__shared__ Nd4jLong tempBuffer [ 8 ] ;
if ( threadIdx . x = = 0 )
maxResultShapeBuffer = shape : : shapeBuffer ( 2 , nd4j : : DataTypeUtils : : fromT < X > ( ) , maxShape , tempBuffer ) ;
__syncthreads ( ) ;
functions : : reduce : : ReduceSameInplace < X > : : execScalarCudaLegacy ( nd4j : : reduce : : Max , dx , xShapeBuffer , extraParams , & maxResult , maxResultShapeBuffer , reductionPointer , nullptr ) ;
__syncthreads ( ) ;
//subtract max of each row
functions : : scalar : : ScalarInplace < X , X , X > : : transformCudaLegacy ( nd4j : : scalar : : Subtract , & maxResult , dx , xShapeBuffer , extraParams , result , zShapeBuffer , allocationPointer ) ;
__syncthreads ( ) ;
//after subtracting the row wise maxes take the exp
functions : : transform : : TransformStrictInplace < X > : : transformCudaLegacy ( nd4j : : transform : : Exp , result , zShapeBuffer , extraParams , result , zShapeBuffer , allocationPointer , reductionPointer , tadShapeInfo , tadOffsets ) ;
__syncthreads ( ) ;
//take the sum for the exponential
functions : : reduce : : ReduceSameInplace < X > : : execScalarCudaLegacy ( nd4j : : reduce : : Sum , result , zShapeBuffer , extraParams , & maxResult , maxResultShapeBuffer , reductionPointer , nullptr ) ;
__syncthreads ( ) ;
//divide by the sum
functions : : scalar : : ScalarInplace < X , X , X > : : transformCudaLegacy ( nd4j : : scalar : : Divide , & maxResult , result , zShapeBuffer , extraParams , result , zShapeBuffer , allocationPointer ) ;
__syncthreads ( ) ;
functions : : transform : : TransformStrictInplace < X > : : transformCudaLegacy ( nd4j : : transform : : Log , result , zShapeBuffer , extraParams , result , zShapeBuffer , allocationPointer , reductionPointer , tadShapeInfo , tadOffsets ) ;
}
# endif
static void execSpecial (
void * vx ,
Nd4jLong * xShapeBuffer ,
void * vresult ,
Nd4jLong * zShapeBuffer ,
void * vextraParams ,
Nd4jLong * tadShapeInfo ,
Nd4jLong * tadOffsets ) {
auto dx = reinterpret_cast < X * > ( vx ) ;
auto result = reinterpret_cast < X * > ( vresult ) ;
auto extraParams = reinterpret_cast < X * > ( vextraParams ) ;
if ( shape : : isMatrix ( xShapeBuffer , 2 ) ) {
auto shape = shape : : shapeOf ( xShapeBuffer ) ;
//iterate along rows
int dimension [ 1 ] = { 0 } ;
int maxDimension [ 1 ] = { 1 } ;
//compute the row wise maxes
auto maxResult = new X [ shape [ 0 ] ] ;
PRAGMA_OMP_SIMD
for ( int i = 0 ; i < shape [ 0 ] ; i + + )
maxResult [ i ] = 0.0 ;
Nd4jLong maxShape [ 2 ] = { shape [ 0 ] , 1 } ;
auto maxResultShapeBuffer = shape : : shapeBuffer ( 2 , nd4j : : DataTypeUtils : : fromT < X > ( ) , maxShape ) ;
functions : : reduce : : ReduceSameFunction < X > : : exec ( nd4j : : reduce : : Max , dx , xShapeBuffer , extraParams , maxResult , maxResultShapeBuffer , maxDimension , 1 , nullptr , nullptr ) ;
//subtract max of each row
functions : : broadcast : : Broadcast < X , X , X > : : exec ( nd4j : : broadcast : : Subtract , dx , xShapeBuffer , maxResult , maxResultShapeBuffer , result , zShapeBuffer , dimension , 1 , nullptr , nullptr , nullptr , nullptr ) ;
//after subtracting the row wise maxes take the exp
functions : : transform : : TransformStrict < X > : : exec ( nd4j : : transform : : Exp , result , zShapeBuffer , result , zShapeBuffer , extraParams , tadShapeInfo , tadOffsets ) ;
//take the sum for the exponential
functions : : reduce : : ReduceSameFunction < X > : : exec ( nd4j : : reduce : : Sum , result , zShapeBuffer , extraParams , maxResult , maxResultShapeBuffer , maxDimension , 1 , nullptr , nullptr ) ;
//divide by the sum
functions : : broadcast : : Broadcast < X , X , X > : : exec ( nd4j : : broadcast : : Divide , result , zShapeBuffer , maxResult , maxResultShapeBuffer , result , zShapeBuffer , dimension , 1 , nullptr , nullptr , nullptr , nullptr ) ;
functions : : transform : : TransformStrict < X > : : exec ( nd4j : : transform : : Log , result , zShapeBuffer , result , zShapeBuffer , extraParams , tadShapeInfo , tadOffsets ) ;
delete [ ] maxResultShapeBuffer ;
}
else if ( shape : : isVector ( xShapeBuffer , 2 ) ) {
auto max = - FLOAT_MAX_VALUE ;
X sum = 0 ;
auto elementWiseStride = shape : : elementWiseStride ( xShapeBuffer ) ;
auto length = shape : : length ( xShapeBuffer ) ;
if ( elementWiseStride = = 1 ) {
for ( int i = 0 ; i < length ; i + + ) {
max = nd4j : : math : : nd4j_max < X > ( max , result [ i ] ) ;
}
for ( int i = 0 ; i < length ; i + + ) {
result [ i ] = nd4j : : math : : nd4j_exp < X , X > ( dx [ i ] - max ) ;
sum + = result [ i ] ;
}
PRAGMA_OMP_SIMD
for ( int i = 0 ; i < length ; i + + ) {
result [ i ] / = sum ;
result [ i ] = nd4j : : math : : nd4j_log < X , X > ( result [ i ] ) ;
}
}
else if ( elementWiseStride > 1 ) {
for ( int i = 0 ; i < length ; i + + ) {
max = nd4j : : math : : nd4j_max < X > ( max , result [ i * elementWiseStride ] ) ;
}
for ( int i = 0 ; i < length ; i + + ) {
result [ i * elementWiseStride ] = nd4j : : math : : nd4j_exp < X , X > ( dx [ i * elementWiseStride ] - max ) ;
sum + = result [ i * elementWiseStride ] ;
}
for ( int i = 0 ; i < length ; i + + ) {
result [ i * elementWiseStride ] / = sum ;
result [ i * elementWiseStride ] = nd4j : : math : : nd4j_log < X , X > ( result [ i * elementWiseStride ] ) ;
}
}
}
}
op_def static X op ( X d1 , X * params ) {
return d1 ;
}
} ;
/**
* softmax ( x )
*/
template < typename X >
class SoftMaxDerivative {
public :
static const bool requiresSpecial = true ;
# ifdef __CUDACC__
/**
*
*/
static inline __device__ void execSpecialCuda (
void * vx , Nd4jLong * xShapeBuffer ,
void * vresult , Nd4jLong * zShapeBuffer ,
void * vextraParams ,
2019-09-11 19:12:09 +02:00
int * allocationPointer , void * reductionPointer ,
2019-06-06 14:21:15 +02:00
Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
auto dx = reinterpret_cast < X * > ( vx ) ;
auto result = reinterpret_cast < X * > ( vresult ) ;
auto extraParams = reinterpret_cast < X * > ( vextraParams ) ;
auto shape = shape : : shapeOf ( xShapeBuffer ) ;
__shared__ X maxResult ;
__shared__ Nd4jLong * maxResultShapeBuffer ;
__shared__ Nd4jLong resultEWS ;
auto length = shape : : length ( xShapeBuffer ) ;
if ( threadIdx . x = = 0 ) {
resultEWS = shape : : elementWiseStride ( zShapeBuffer ) ;
maxResult = ( X ) 0.0 ;
}
__syncthreads ( ) ;
auto tride = shape : : stride ( xShapeBuffer ) ;
Nd4jLong maxShape [ 2 ] = { shape [ 0 ] , 1 } ;
__shared__ Nd4jLong tempBuffer [ 8 ] ;
if ( threadIdx . x = = 0 )
maxResultShapeBuffer = shape : : shapeBuffer ( 2 , nd4j : : DataTypeUtils : : fromT < X > ( ) , maxShape , tempBuffer ) ;
__syncthreads ( ) ;
functions : : reduce : : ReduceSameInplace < X > : : execScalarCudaLegacy ( nd4j : : reduce : : Max , dx , xShapeBuffer , extraParams , & maxResult , maxResultShapeBuffer , reductionPointer , nullptr ) ;
__syncthreads ( ) ;
//subtract max of each row
functions : : scalar : : ScalarInplace < X , X , X > : : transformCudaLegacy ( nd4j : : scalar : : Subtract , & maxResult , dx , xShapeBuffer , extraParams , result , zShapeBuffer , allocationPointer ) ;
__syncthreads ( ) ;
//after subtracting the row wise maxes take the exp
functions : : transform : : TransformStrictInplace < X > : : transformCudaLegacy ( nd4j : : transform : : Exp , result , zShapeBuffer , extraParams , result , zShapeBuffer , allocationPointer , reductionPointer , tadShapeInfo , tadOffsets ) ;
__syncthreads ( ) ;
//take the sum for the exponential
functions : : reduce : : ReduceSameInplace < X > : : execScalarCudaLegacy ( nd4j : : reduce : : Sum , result , zShapeBuffer , extraParams , & maxResult , maxResultShapeBuffer , reductionPointer , nullptr ) ;
__syncthreads ( ) ;
//divide by the sum
functions : : scalar : : ScalarInplace < X , X , X > : : transformCudaLegacy ( nd4j : : scalar : : Divide , & maxResult , result , zShapeBuffer , extraParams , result , zShapeBuffer , allocationPointer ) ;
__syncthreads ( ) ;
if ( resultEWS > = 1 ) {
for ( int i = threadIdx . x ; i < length ; i + = blockDim . x ) {
result [ i * resultEWS ] = result [ i * resultEWS ] * ( ( X ) 1.0 - result [ i * resultEWS ] ) ;
}
}
else {
printf ( " Non element wise stride not supported right now \n " ) ;
}
}
# endif
static void execSpecial (
void * vx ,
Nd4jLong * xShapeBuffer ,
void * vresult ,
Nd4jLong * zShapeBuffer ,
void * vextraParams , Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
auto dx = reinterpret_cast < X * > ( vx ) ;
auto result = reinterpret_cast < X * > ( vresult ) ;
auto extraParams = reinterpret_cast < X * > ( vextraParams ) ;
2019-09-11 19:12:09 +02:00
2019-06-06 14:21:15 +02:00
if ( shape : : isMatrix ( xShapeBuffer , 2 ) ) {
auto shape = shape : : shapeOf ( xShapeBuffer ) ;
auto resultEleStide = shape : : elementWiseStride ( zShapeBuffer ) ;
//iterate along rows
int dimension [ 1 ] = { 0 } ;
int maxDimension [ 1 ] = { 1 } ;
auto len = shape : : length ( xShapeBuffer ) ;
//compute the row wise maxes
auto maxResult = new X [ shape [ 0 ] ] ;
PRAGMA_OMP_SIMD
for ( int i = 0 ; i < shape [ 0 ] ; i + + )
maxResult [ i ] = 0.0f ;
Nd4jLong maxShape [ 2 ] = { shape [ 0 ] , 1 } ;
auto maxResultShapeBuffer = shape : : shapeBuffer ( 2 , nd4j : : DataTypeUtils : : fromT < X > ( ) , maxShape ) ;
functions : : reduce : : ReduceSameFunction < X > : : exec ( nd4j : : reduce : : Max , dx , xShapeBuffer , extraParams , maxResult , maxResultShapeBuffer , maxDimension , 1 , nullptr , nullptr ) ;
//subtract max of each row
functions : : broadcast : : Broadcast < X , X , X > : : exec ( nd4j : : broadcast : : Subtract , result , zShapeBuffer , maxResult , maxResultShapeBuffer , result , zShapeBuffer , dimension , 1 , nullptr , nullptr , nullptr , nullptr ) ;
//after subtracting the row wise maxes take the exp
functions : : transform : : TransformStrict < X > : : exec ( nd4j : : transform : : Exp , result , zShapeBuffer , result , zShapeBuffer , extraParams , tadShapeInfo , tadOffsets ) ;
//take the sum for the exponential
functions : : reduce : : ReduceSameFunction < X > : : exec ( nd4j : : reduce : : Sum , result , zShapeBuffer , extraParams , maxResult , maxResultShapeBuffer , maxDimension , 1 , nullptr , nullptr ) ;
//divide by the sum
functions : : broadcast : : Broadcast < X , X , X > : : exec ( nd4j : : broadcast : : Divide , result , zShapeBuffer , maxResult , maxResultShapeBuffer , result , zShapeBuffer , dimension , 1 , nullptr , nullptr , nullptr , nullptr ) ;
if ( resultEleStide > = 1 ) {
if ( resultEleStide = = 1 ) {
PRAGMA_OMP_SIMD
for ( int i = 0 ; i < len ; i + + ) {
result [ i ] = result [ i ] * ( static_cast < X > ( 1.0f ) - result [ i ] ) ;
}
}
else {
PRAGMA_OMP_SIMD
for ( int i = 0 ; i < len ; i + + ) {
result [ i * resultEleStide ] = result [ i * resultEleStide ] * ( static_cast < X > ( 1.0f ) - result [ i * resultEleStide ] ) ;
}
}
}
else {
2019-09-11 19:12:09 +02:00
for ( int i = 0 ; i < len ; i + + ) {
Nd4jLong zOffset = shape : : getIndexOffset ( i , zShapeBuffer ) ;
2019-06-06 14:21:15 +02:00
result [ zOffset ] = result [ zOffset ] * ( ( X ) 1.0f - result [ zOffset ] ) ;
}
}
delete [ ] maxResultShapeBuffer ;
delete [ ] maxResult ;
}
else if ( shape : : isVector ( xShapeBuffer , 2 ) ) {
auto max = - nd4j : : DataTypeUtils : : max < X > ( ) ;
X sum = 0 ;
auto elementWiseStride = shape : : elementWiseStride ( xShapeBuffer ) ;
auto length = shape : : length ( xShapeBuffer ) ;
if ( elementWiseStride = = 1 ) {
for ( int i = 0 ; i < length ; i + + ) {
max = nd4j : : math : : nd4j_max < X > ( max , result [ i ] ) ;
}
for ( int i = 0 ; i < length ; i + + ) {
result [ i ] - = max ;
result [ i ] = nd4j : : math : : nd4j_exp < X , X > ( result [ i ] ) ;
sum + = result [ i ] ;
}
for ( int i = 0 ; i < length ; i + + ) {
result [ i ] / = sum ;
}
for ( int i = 0 ; i < length ; i + + ) {
result [ i ] = result [ i ] * ( ( X ) 1.0f - result [ i ] ) ;
}
} else if ( elementWiseStride > = 1 ) {
for ( int i = 0 ; i < length ; i + + ) {
max = nd4j : : math : : nd4j_max < X > ( max , result [ i * elementWiseStride ] ) ;
}
for ( int i = 0 ; i < length ; i + + ) {
result [ i * elementWiseStride ] - = max ;
result [ i * elementWiseStride ] = nd4j : : math : : nd4j_exp < X , X > ( result [ i * elementWiseStride ] ) ;
sum + = result [ i * elementWiseStride ] ;
}
PRAGMA_OMP_SIMD
for ( int i = 0 ; i < length ; i + + ) {
result [ i * elementWiseStride ] / = sum ;
}
PRAGMA_OMP_SIMD
for ( int i = 0 ; i < length ; i + + ) {
result [ i * elementWiseStride ] = result [ i * elementWiseStride ] * ( ( X ) 1.0f - result [ i * elementWiseStride ] ) ;
}
} else {
printf ( " non-ews access on row not implemented yet " ) ;
}
}
}
op_def static X op ( X d1 , X * params ) {
return d1 ;
}
} ;
template < typename X , typename Z >
class IsMax {
public :
static const bool requiresSpecial = true ;
# ifdef __CUDACC__
static inline __device__ void doAllCuda (
void * vx ,
Nd4jLong * xShapeBuffer ,
void * vresult ,
Nd4jLong * zShapeBuffer ,
void * vextraParams ,
int * allocationPointer , void * reductionPointer ) {
auto dx = reinterpret_cast < X * > ( vx ) ;
auto result = reinterpret_cast < Z * > ( vresult ) ;
auto extraParams = reinterpret_cast < X * > ( vextraParams ) ;
// this code is safe to delete, it's never used
/*
__shared__ int maxIdx ;
__shared__ int length ;
if ( threadIdx . x = = 0 ) {
length = shape : : length ( zShapeBuffer ) ;
}
__syncthreads ( ) ;
functions : : indexreduce : : IndexReduce < T > : : template transform < simdOps : : IndexMax < T > > (
dx ,
xShapeBuffer ,
extraParams ,
result ,
zShapeBuffer ,
nullptr ,
1 ,
1 , allocationPointer , reductionPointer , nullptr , nullptr ) ;
__syncthreads ( ) ;
if ( threadIdx . x = = 0 )
maxIdx = ( int ) result [ 0 ] ;
__syncthreads ( ) ;
for ( int i = threadIdx . x ; i < length ; i + = blockDim . x )
result [ i ] = 0 ;
__syncthreads ( ) ;
if ( threadIdx . x = = 0 ) {
result [ maxIdx ] = 1.0 ;
}
*/
}
# endif
# ifdef __CUDACC__
inline __host__
# elif defined(__GNUC__)
# endif
static void doAll (
void * vx ,
Nd4jLong * xShapeBuffer ,
void * vresult ,
Nd4jLong * zShapeBuffer ,
void * vextraParams ) {
auto dx = reinterpret_cast < X * > ( vx ) ;
auto result = reinterpret_cast < Z * > ( vresult ) ;
auto extraParams = reinterpret_cast < X * > ( vextraParams ) ;
auto length = shape : : length ( xShapeBuffer ) ;
auto eleStride = shape : : elementWiseStride ( xShapeBuffer ) ;
auto resultEleStride = shape : : elementWiseStride ( zShapeBuffer ) ;
auto xOrder = shape : : order ( xShapeBuffer ) ;
auto resultOrder = shape : : order ( zShapeBuffer ) ;
if ( xOrder = = resultOrder & & xOrder = = ' c ' ) {
if ( eleStride = = 1 & & resultEleStride = = 1 ) {
if ( length < ELEMENT_THRESHOLD ) {
int maxIdx = 0 ;
auto currMax = dx [ 0 ] ;
for ( int i = 0 ; i < length ; i + + ) {
if ( currMax < dx [ i ] ) {
currMax = dx [ i ] ;
maxIdx = i ;
}
result [ i ] = static_cast < Z > ( 0 ) ;
}
result [ maxIdx ] = static_cast < Z > ( 1 ) ;
}
else {
int maxIdx = 0 ;
auto currMax = dx [ 0 ] ;
{
int maxIdxLocal = maxIdx ;
auto currMaxLocal = currMax ;
for ( int i = 0 ; i < length ; i + + ) {
if ( currMaxLocal < dx [ i ] ) {
currMaxLocal = dx [ i ] ;
maxIdxLocal = i ;
}
result [ i ] = static_cast < Z > ( 0 ) ;
}
PRAGMA_OMP_CRITICAL
{
if ( currMax < currMaxLocal ) {
currMax = currMaxLocal ;
maxIdx = maxIdxLocal ;
}
}
}
result [ maxIdx ] = static_cast < Z > ( 1 ) ;
}
}
else {
if ( length < ELEMENT_THRESHOLD ) {
int maxIdx = 0 ;
auto currMax = dx [ 0 ] ;
for ( int i = 0 ; i < length ; i + + ) {
result [ i * resultEleStride ] = static_cast < Z > ( 0 ) ;
if ( currMax < dx [ i * eleStride ] ) {
currMax = dx [ i * eleStride ] ;
maxIdx = i ;
}
}
result [ maxIdx * resultEleStride ] = static_cast < Z > ( 1 ) ;
}
else {
int maxIdx = 0 ;
auto currMax = dx [ 0 ] ;
{
int maxIdxLocal = maxIdx ;
auto currMaxLocal = currMax ;
for ( int i = 0 ; i < length ; i + + ) {
result [ i * resultEleStride ] = static_cast < Z > ( 0 ) ;
if ( currMaxLocal < dx [ i * eleStride ] ) {
currMaxLocal = dx [ i * eleStride ] ;
maxIdxLocal = i ;
}
}
PRAGMA_OMP_CRITICAL
{
if ( currMax < currMaxLocal ) {
currMax = currMaxLocal ;
maxIdx = maxIdxLocal ;
}
}
}
result [ maxIdx * resultEleStride ] = static_cast < Z > ( 1 ) ;
}
}
}
else {
Nd4jLong shapeIter [ MAX_RANK ] ;
Nd4jLong coord [ MAX_RANK ] ;
int dim ;
Nd4jLong xStridesIter [ MAX_RANK ] ;
Nd4jLong resultStridesIter [ MAX_RANK ] ;
auto xShape = shape : : shapeOf ( xShapeBuffer ) ;
auto xStride = shape : : stride ( xShapeBuffer ) ;
auto resultStride = shape : : stride ( zShapeBuffer ) ;
auto rank = shape : : rank ( xShapeBuffer ) ;
auto originalResult = result ;
if ( PrepareTwoRawArrayIter < X , Z > ( rank ,
xShape ,
dx ,
xStride ,
result ,
resultStride ,
& rank ,
shapeIter ,
& dx ,
xStridesIter ,
& result ,
resultStridesIter ) > = 0 ) {
auto value = dx [ 0 ] ;
int idx = 0 ;
int maxIdx = 0 ;
ND4J_RAW_ITER_START ( dim , rank , coord , shapeIter ) ; {
if ( dx [ 0 ] > value ) {
value = dx [ 0 ] ;
maxIdx = idx ;
}
idx + + ;
result [ 0 ] = static_cast < Z > ( 0 ) ;
}
ND4J_RAW_ITER_TWO_NEXT (
dim ,
rank ,
coord ,
shapeIter ,
dx ,
xStridesIter ,
result ,
resultStridesIter ) ;
//pointer to where max value would be
if ( shape : : order ( zShapeBuffer ) = = ' c ' | | ( shape : : order ( zShapeBuffer ) = = ' f ' & &
maxIdx * shape : : stride ( zShapeBuffer ) [ shape : : rank ( zShapeBuffer ) - 1 ] > =
shape : : length ( zShapeBuffer ) ) )
originalResult [ maxIdx ] = static_cast < Z > ( 1 ) ;
else
originalResult [ maxIdx * shape : : stride ( zShapeBuffer ) [ shape : : rank ( zShapeBuffer ) - 1 ] ] = static_cast < Z > ( 1 ) ;
}
}
}
public :
# ifdef __CUDACC__
/**
*
*/
static inline __device__ void execSpecialCuda (
void * vx , Nd4jLong * xShapeBuffer ,
void * vresult , Nd4jLong * zShapeBuffer ,
2019-09-11 19:12:09 +02:00
void * vextraParams , int * allocationPointer ,
void * reductionPointer ,
2019-06-06 14:21:15 +02:00
Nd4jLong * tadShapeInfo , Nd4jLong * tadOffsets ) {
auto dx = reinterpret_cast < X * > ( vx ) ;
auto result = reinterpret_cast < Z * > ( vresult ) ;
auto extraParams = reinterpret_cast < X * > ( vextraParams ) ;
// FIXME: MAX_DIMENSION is lower then FP16 frame
if ( extraParams = = nullptr | | ( int ) extraParams [ 0 ] = = MAX_DIMENSION ) {
doAllCuda ( dx , xShapeBuffer , result , zShapeBuffer , extraParams , allocationPointer , reductionPointer ) ;
}
}
# endif
static void execSpecial (
void * vx ,
Nd4jLong * xShapeBuffer ,
void * vresult ,
Nd4jLong * zShapeBuffer ,
void * vextraParams ,
Nd4jLong * tadShapeInfo ,
Nd4jLong * tadOffsets ) {
auto dx = reinterpret_cast < X * > ( vx ) ;
auto result = reinterpret_cast < Z * > ( vresult ) ;
auto extraParams = reinterpret_cast < X * > ( vextraParams ) ;
//FIXME: this op should be moved to CustomOps
if ( extraParams = = nullptr | | ( int ) extraParams [ 0 ] = = 0 | |
( ( int ) extraParams [ 0 ] = = 1 & & ( int ) extraParams [ 1 ] = = MAX_DIMENSION ) ) {
doAll ( dx , xShapeBuffer , result , zShapeBuffer , extraParams ) ;
}
else if ( shape : : isVector ( xShapeBuffer ) ) {
auto dimensionLength = ( int ) extraParams [ 0 ] ;
auto dimension = new int [ dimensionLength ] ;
auto length = shape : : length ( xShapeBuffer ) ;
for ( int i = 0 ; i < dimensionLength ; i + + ) {
dimension [ i ] = ( int ) extraParams [ i + 1 ] ;
}
if ( shape : : shapeOf ( xShapeBuffer ) [ dimension [ 0 ] ] = = 1 ) {
for ( int i = 0 ; i < length ; i + + ) {
result [ i ] = static_cast < Z > ( 1 ) ;
}
}
else {
auto eleStride = shape : : elementWiseStride ( xShapeBuffer ) ;
if ( eleStride = = 1 ) {
int maxIdx = 0 ;
auto currMax = dx [ 0 ] ;
if ( length < ELEMENT_THRESHOLD ) {
for ( int i = 0 ; i < length ; i + + ) {
if ( currMax < dx [ i ] ) {
currMax = dx [ i ] ;
maxIdx = i ;
}
result [ i ] = static_cast < Z > ( 0 ) ;
}
}
else {
PRAGMA_OMP_PARALLEL
{
int maxIdxLocal = maxIdx ;
auto currMaxLocal = currMax ;
for ( int i = 0 ; i < length ; i + + ) {
if ( currMaxLocal < dx [ i ] ) {
currMaxLocal = dx [ i ] ;
maxIdxLocal = i ;
}
result [ i ] = static_cast < Z > ( 0 ) ;
}
PRAGMA_OMP_CRITICAL
{
if ( currMax < currMaxLocal ) {
currMax = currMaxLocal ;
maxIdx = maxIdxLocal ;
}
}
}
}
result [ maxIdx ] = static_cast < Z > ( 1 ) ;
}
else {
int maxIdx = 0 ;
auto currMax = dx [ 0 ] ;
if ( length < ELEMENT_THRESHOLD ) {
for ( int i = 0 ; i < length ; i + + ) {
if ( currMax < dx [ i * eleStride ] ) {
currMax = dx [ i * eleStride ] ;
maxIdx = i ;
}
result [ i ] = static_cast < Z > ( 0 ) ;
}
}
else {
{
int maxIdxLocal = maxIdx ;
auto currMaxLocal = currMax ;
for ( int i = 0 ; i < length ; i + + ) {
if ( currMaxLocal < dx [ i * eleStride ] ) {
currMaxLocal = dx [ i * eleStride ] ;
maxIdxLocal = i ;
}
result [ i ] = static_cast < Z > ( 0 ) ;
}
PRAGMA_OMP_CRITICAL
{
if ( currMax < currMaxLocal ) {
currMax = currMaxLocal ;
maxIdx = maxIdxLocal ;
}
}
}
}
result [ maxIdx ] = static_cast < Z > ( 1 ) ;
}
}
}
else {
auto dimensionLength = ( int ) extraParams [ 0 ] ;
auto dimension = new int [ dimensionLength ] ;
PRAGMA_OMP_SIMD
for ( int i = 0 ; i < dimensionLength ; i + + ) {
dimension [ i ] = ( int ) extraParams [ i + 1 ] ;
}
//decompose in to several sub tads after
//moving all dimensions (in sorted order)
//to the back.
2019-09-11 19:12:09 +02:00
//permuted version of the x shape info for setting up the tad problem
2019-06-06 14:21:15 +02:00
auto tadShapeShapeInfo = tadShapeInfo ;
if ( tadShapeInfo = = nullptr ) {
auto tadPack = nd4j : : ConstantTadHelper : : getInstance ( ) - > tadForDimensions ( xShapeBuffer , dimension , dimensionLength ) ;
tadShapeShapeInfo = tadPack . primaryShapeInfo ( ) ;
tadOffsets = tadPack . primaryOffsets ( ) ;
tadShapeInfo = tadShapeShapeInfo ;
2019-09-11 19:12:09 +02:00
}
2019-06-06 14:21:15 +02:00
auto tadLength = shape : : length ( tadShapeInfo ) ; //shape::tadLength(xShapeBuffer, dimension, dimensionLength);
auto tads = shape : : length ( xShapeBuffer ) / tadLength ;
int tadsPerThread = tads / TAD_THRESHOLD ;
int num_threads = nd4j : : math : : nd4j_max < int > ( 1 , tadsPerThread ) ;
num_threads = nd4j : : math : : nd4j_min < int > ( num_threads , omp_get_max_threads ( ) ) ;
auto tadEWS = shape : : elementWiseStride ( tadShapeShapeInfo ) ;
auto zEWS = tadEWS ;
int span = ( tads / num_threads ) + 8 ;
PRAGMA_OMP_PARALLEL_THREADS ( num_threads )
{
int tid = omp_get_thread_num ( ) ;
int start = span * tid ;
int end = span * ( tid + 1 ) ;
if ( end > tads ) end = tads ;
for ( int r = start ; r < end ; r + + ) {
if ( tadEWS > 0 & & zEWS > 0 & & dimensionLength = = 1 ) {
auto rX = dx + tadOffsets [ r ] ;
auto rZ = result + tadOffsets [ r ] ;
auto maxValue = rX [ 0 ] ;
int maxIdx = 0 ;
if ( tadEWS = = 1 & & zEWS = = 1 ) {
for ( int i = 0 ; i < tadLength ; i + + ) {
if ( rX [ i ] > maxValue ) {
maxIdx = i ;
maxValue = rX [ i ] ;
}
}
for ( int i = 0 ; i < tadLength ; i + + ) {
rZ [ i ] = static_cast < Z > ( maxIdx = = i ) ;
}
} else {
for ( int i = 0 ; i < tadLength ; i + + ) {
if ( rX [ i * tadEWS ] > maxValue ) {
maxIdx = i ;
maxValue = rX [ i * tadEWS ] ;
}
}
for ( int i = 0 ; i < tadLength ; i + + ) {
rZ [ i * zEWS ] = static_cast < Z > ( maxIdx = = i ) ;
}
}
} else {
int tadsPerThread = tads / TAD_THRESHOLD ;
int num_threads = nd4j : : math : : nd4j_max < int > ( 1 , tadsPerThread ) ;
num_threads = nd4j : : math : : nd4j_min < int > ( num_threads , omp_get_max_threads ( ) ) ;
auto offset = tadOffsets [ r ] ;
Nd4jLong shapeIter [ MAX_RANK ] ;
Nd4jLong coord [ MAX_RANK ] ;
int dim ;
Nd4jLong xStridesIter [ MAX_RANK ] ;
Nd4jLong resultStridesIter [ MAX_RANK ] ;
auto xShape = shape : : shapeOf ( tadShapeShapeInfo ) ;
auto xStride = shape : : stride ( tadShapeShapeInfo ) ;
auto resultStride = shape : : stride ( tadShapeShapeInfo ) ;
int rank = shape : : rank ( tadShapeShapeInfo ) ;
auto xPointer = dx + offset ;
auto resultPointer = result + offset ;
auto maxValue = xPointer [ 0 ] ;
auto maxCursor = resultPointer ;
Nd4jPointer maxCursorLong = reinterpret_cast < Nd4jPointer > ( maxCursor ) ;
if ( PrepareTwoRawArrayIter < X , Z > ( rank ,
xShape ,
xPointer ,
xStride ,
resultPointer ,
resultStride ,
& rank ,
shapeIter ,
& xPointer ,
xStridesIter ,
& resultPointer ,
resultStridesIter ) > = 0 ) {
ND4J_RAW_ITER_START ( dim , rank , coord , shapeIter ) ; {
if ( maxValue < xPointer [ 0 ] ) {
maxCursor = resultPointer ;
maxCursorLong = reinterpret_cast < Nd4jPointer > ( resultPointer ) ;
maxValue = xPointer [ 0 ] ;
}
resultPointer [ 0 ] = static_cast < Z > ( 0 ) ;
}
ND4J_RAW_ITER_TWO_NEXT ( dim ,
rank ,
coord ,
shapeIter ,
xPointer ,
xStridesIter ,
resultPointer ,
resultStridesIter ) ;
maxCursor = reinterpret_cast < Z * > ( maxCursorLong ) ;
maxCursor [ 0 ] = static_cast < Z > ( 1 ) ; ;
}
}
}
}
delete [ ] dimension ;
}
}
op_def static Z op ( X d1 , X * params ) {
return nd4j : : math : : softplus < X , Z > ( d1 ) ;
}
} ;
}