2294 lines
84 KiB
C
2294 lines
84 KiB
C
|
/*******************************************************************************
|
||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
******************************************************************************/
|
||
|
|
||
|
#pragma once
|
||
|
#include <ops/ops.h>
|
||
|
#include <loops/reduce_float.h>
|
||
|
#include <loops/reduce_same.h>
|
||
|
#include <loops/scalar.h>
|
||
|
#include <loops/indexreduce.h>
|
||
|
#include <loops/broadcasting.h>
|
||
|
#include <loops/transform_float.h>
|
||
|
#include <op_enums.h>
|
||
|
#include <loops/transform_strict.h>
|
||
|
#include <helpers/ConstantTadHelper.h>
|
||
|
|
||
|
#ifdef __CUDACC__
|
||
|
#include <loops/cuda/inplace_loops/reduce_same_inplace.h>
|
||
|
#include <loops/cuda/inplace_loops/transform_strict_inplace.h>
|
||
|
#include <loops/cuda/inplace_loops/scalar_inplace.h>
|
||
|
#endif
|
||
|
|
||
|
namespace functions {
|
||
|
namespace broadcast {
|
||
|
template <typename X, typename Y, typename Z>
|
||
|
class Broadcast;
|
||
|
}
|
||
|
|
||
|
namespace transform {
|
||
|
template <typename X>
|
||
|
class TransformStrict;
|
||
|
}
|
||
|
|
||
|
namespace scalar {
|
||
|
}
|
||
|
|
||
|
namespace reduce {
|
||
|
template <typename X, typename Z>
|
||
|
class ReduceFloatFunction;
|
||
|
|
||
|
template <typename X>
|
||
|
class ReduceSameFunction;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
namespace simdOps {
|
||
|
|
||
|
template<typename T, typename Z>
|
||
|
class Pooling2D {
|
||
|
public:
|
||
|
static const bool requiresSpecial = true;
|
||
|
#ifdef __CUDACC__
|
||
|
inline __host__ __device__
|
||
|
#elif defined(__GNUC__)
|
||
|
|
||
|
#endif
|
||
|
static int outSize(int size, int k, int s, int p, bool coverAll) {
|
||
|
if (coverAll)
|
||
|
return (size + p * 2 - k + s - 1) / s + 1;
|
||
|
else
|
||
|
return (size + p * 2 - k) / s + 1;
|
||
|
}
|
||
|
|
||
|
#ifdef __CUDACC__
|
||
|
/**
|
||
|
* Based on: https://github.com/pjreddie/darknet/blob/master/src/im2col_kernels.cu
|
||
|
*/
|
||
|
|
||
|
static inline __device__ void execSpecialCuda(
|
||
|
T *dx, Nd4jLong *xShapeBuffer,
|
||
|
Z *result, Nd4jLong *zShapeBuffer,
|
||
|
Z *extraParams,
|
||
|
int *allocationPointer, Z *reductionPointer,
|
||
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
|
||
|
__shared__ int kH;
|
||
|
__shared__ int kW;
|
||
|
__shared__ int sH;
|
||
|
__shared__ int sW;
|
||
|
__shared__ int pH;
|
||
|
__shared__ int pW;
|
||
|
__shared__ int dH;
|
||
|
__shared__ int dW;
|
||
|
__shared__ int poolingMode;
|
||
|
__shared__ Z extraParam0;
|
||
|
|
||
|
__shared__ int batchSize;
|
||
|
__shared__ int inChannels;
|
||
|
__shared__ int outH;
|
||
|
__shared__ int outW;
|
||
|
__shared__ int inH;
|
||
|
__shared__ int inW;
|
||
|
|
||
|
//__shared__ int *strideIn;
|
||
|
//__shared__ int *strideOut;
|
||
|
__shared__ int strideB;
|
||
|
__shared__ int strideC;
|
||
|
__shared__ int strideY;
|
||
|
__shared__ int strideX;
|
||
|
|
||
|
__shared__ int strideOB;
|
||
|
__shared__ int strideOC;
|
||
|
__shared__ int strideOY;
|
||
|
__shared__ int strideOX;
|
||
|
|
||
|
__shared__ int length;
|
||
|
__shared__ int kHEff;
|
||
|
__shared__ int kWEff;
|
||
|
__shared__ bool fOrder;
|
||
|
|
||
|
|
||
|
if (threadIdx.x == 0) {
|
||
|
kH = (int)extraParams[0];
|
||
|
kW = (int)extraParams[1];
|
||
|
sH = (int)extraParams[2];
|
||
|
sW = (int)extraParams[3];
|
||
|
pH = (int)extraParams[4];
|
||
|
pW = (int)extraParams[5];
|
||
|
dH = (int)extraParams[6]; //Dilation, height dimension
|
||
|
dW = (int)extraParams[7]; //Dilation, width dimension
|
||
|
poolingMode = (int)extraParams[9];
|
||
|
extraParam0 = extraParams[10];
|
||
|
|
||
|
batchSize = shape::sizeAt(xShapeBuffer, 0);
|
||
|
inChannels = shape::sizeAt(xShapeBuffer, 1);
|
||
|
outH = shape::sizeAt(zShapeBuffer, 2);
|
||
|
outW = shape::sizeAt(zShapeBuffer, 3);
|
||
|
inH = shape::sizeAt(xShapeBuffer, 2);
|
||
|
inW = shape::sizeAt(xShapeBuffer, 3);
|
||
|
|
||
|
strideB = shape::stride(xShapeBuffer)[0];
|
||
|
strideC = shape::stride(xShapeBuffer)[1];
|
||
|
strideY = shape::stride(xShapeBuffer)[2];
|
||
|
strideX = shape::stride(xShapeBuffer)[3];
|
||
|
|
||
|
strideOB = shape::stride(zShapeBuffer)[0];
|
||
|
strideOC = shape::stride(zShapeBuffer)[1];
|
||
|
strideOY = shape::stride(zShapeBuffer)[2];
|
||
|
strideOX = shape::stride(zShapeBuffer)[3];
|
||
|
|
||
|
length = shape::length(zShapeBuffer);
|
||
|
|
||
|
//Replace kernel H/W with *effective* kernel H/W accounting for dilatyon
|
||
|
kHEff = kH + (kH-1)*(dH-1);
|
||
|
kWEff = kW + (kW-1)*(dW-1);
|
||
|
|
||
|
fOrder = shape::order(zShapeBuffer) == 'f';
|
||
|
/*
|
||
|
if (blockIdx.x == 0) {
|
||
|
printf("kH: %i; kW: %i; sH: %i; sW: %i; pH: %i; pW: %i; dH: %i; dW: %i; poolingMode: %i; extraParam0: %f;\n", kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, (float) extraParam0);
|
||
|
printf("batchSize: %i; inChannels: %i; outH: %i; outW: %i; inH: %i; inW: %i; strideB: %i; strideC: %i; strideY: %i; strideX: %i;\n", batchSize, inChannels, outH, outW, inH, inW, strideB, strideC, strideY, strideX);
|
||
|
}
|
||
|
*/
|
||
|
}
|
||
|
__syncthreads();
|
||
|
|
||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||
|
|
||
|
for (int index = tid; index < length; index += blockDim.x * gridDim.x) {
|
||
|
const int pw = index % outW;
|
||
|
const int ph = (index / outW) % outH;
|
||
|
const int c = (index / outW / outH) % inChannels;
|
||
|
const int n = index / outW / outH / inChannels;
|
||
|
int hstart = sH * ph - pH;
|
||
|
int wstart = sW * pw - pW;
|
||
|
int hend = hstart + kHEff;
|
||
|
int wend = wstart + kWEff;
|
||
|
|
||
|
// const int hSO = hstart;
|
||
|
// const int hEO = hend;
|
||
|
|
||
|
if(hstart < 0){
|
||
|
int f = nd4j::math::nd4j_ceil<Z,int>((Z) -hstart / (Z)dH);
|
||
|
hstart += f * dH;
|
||
|
}
|
||
|
if(wstart < 0){
|
||
|
int f = nd4j::math::nd4j_ceil<Z,int>((Z) -wstart / (Z) dW);
|
||
|
wstart += f * dW;
|
||
|
}
|
||
|
if(hend > inH){
|
||
|
int f = nd4j::math::nd4j_ceil<Z,int>((Z) (hend-inH) / (Z) dH);
|
||
|
hend -= f * dH;
|
||
|
}
|
||
|
if(wend > inW){
|
||
|
int f = nd4j::math::nd4j_ceil<Z,int>((Z) (wend-inW) / (Z) dW);
|
||
|
wend -= f * dW;
|
||
|
}
|
||
|
//Accounts for dilation
|
||
|
int pool_size = nd4j::math::nd4j_ceil<double,int>((double) (hend-hstart) / (double) dH) * nd4j::math::nd4j_ceil<double,int>((double) (wend-wstart) / (double) dW);
|
||
|
|
||
|
Z sum = poolingMode == 0 ? -nd4j::DataTypeUtils::max<Z>() : static_cast<Z>(0.f);
|
||
|
|
||
|
T *input_slice = dx + (n * strideB + c * strideC);
|
||
|
if (poolingMode == 0) {
|
||
|
for (int h = hstart; h < hend; h += dH) {
|
||
|
for (int w = wstart; w < wend; w += dW) {
|
||
|
Z v = static_cast<Z>(input_slice[h * strideY + w * strideX]);
|
||
|
if (v > sum)
|
||
|
sum = v;
|
||
|
}
|
||
|
}
|
||
|
} else if (poolingMode == 1) {
|
||
|
for (int h = hstart; h < hend; h += dH) {
|
||
|
for (int w = wstart; w < wend; w += dW) {
|
||
|
sum += static_cast<Z>(input_slice[h * strideY + w * strideX]);
|
||
|
}
|
||
|
}
|
||
|
} else if (poolingMode == 2) {
|
||
|
for (int h = hstart; h < hend; h += dH) {
|
||
|
for (int w = wstart; w < wend; w += dW) {
|
||
|
sum += nd4j::math::nd4j_pow<Z,Z,Z>(static_cast<Z>(nd4j::math::nd4j_abs<T>(input_slice[h * strideY + w * strideX])), extraParam0);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
Z res;
|
||
|
|
||
|
if (poolingMode == 0) {
|
||
|
res = sum;
|
||
|
} else if (poolingMode == 1) {
|
||
|
int divide_factor = pool_size; //Case 0: exclude padding
|
||
|
if ((int) extraParam0 == 1) //Case 1: include padding
|
||
|
divide_factor = kH * kW;
|
||
|
|
||
|
res = sum / static_cast<Z>(divide_factor);
|
||
|
} else if (poolingMode == 2) {
|
||
|
res = nd4j::math::nd4j_pow<Z,Z,Z>(sum, (Z) 1.0f / extraParam0);
|
||
|
}
|
||
|
|
||
|
|
||
|
if (!fOrder) {
|
||
|
result[index] = res;
|
||
|
} else {
|
||
|
result[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = res;
|
||
|
}
|
||
|
/*
|
||
|
if (index >= 0 && index < 400000) {
|
||
|
printf("index: %i; hstart: %i; hend: %i; wstart: %i; wend: %i; ph: %i; pw: %i; hstart_orig: %i; hend_orig: %i;\n", index, hstart, hend, wstart, wend, ph, pw, hSO, hEO);
|
||
|
}
|
||
|
*/
|
||
|
}
|
||
|
|
||
|
__syncthreads();
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
|
||
|
static void execSpecial(T *in, Nd4jLong *inShapeBuffer, Z *out, Nd4jLong *outShapeBuffer, Z *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
// input is [bS, iC, iH, iW]
|
||
|
// output is [bS, iC, oH, oW]
|
||
|
|
||
|
const Nd4jLong kH = (int)extraParams[0];
|
||
|
const Nd4jLong kW = (int)extraParams[1];
|
||
|
const Nd4jLong sH = (int)extraParams[2];
|
||
|
const Nd4jLong sW = (int)extraParams[3];
|
||
|
const Nd4jLong pH = (int)extraParams[4];
|
||
|
const Nd4jLong pW = (int)extraParams[5];
|
||
|
const Nd4jLong dH = (int)extraParams[6];
|
||
|
const Nd4jLong dW = (int)extraParams[7];
|
||
|
Nd4jLong poolingMode = (int)extraParams[9];
|
||
|
T extraParam0 = extraParams[10];
|
||
|
|
||
|
if(dH == 0 || dW == 0) {
|
||
|
printf("Special_ops pooling2d:: dilation must not be zero, but got instead {%lld, %lld} \n", dH, dW);
|
||
|
throw "";
|
||
|
}
|
||
|
|
||
|
const Nd4jLong kHEff = kH + (kH-1)*(dH-1);
|
||
|
const Nd4jLong kWEff = kW + (kW-1)*(dW-1);
|
||
|
|
||
|
const int bS = shape::sizeAt(inShapeBuffer, 0);
|
||
|
const int iC = shape::sizeAt(inShapeBuffer, 1);
|
||
|
const int iH = shape::sizeAt(inShapeBuffer, 2);
|
||
|
const int iW = shape::sizeAt(inShapeBuffer, 3);
|
||
|
const int oH = shape::sizeAt(outShapeBuffer, 2);
|
||
|
const int oW = shape::sizeAt(outShapeBuffer, 3);
|
||
|
const Nd4jLong iStride0 = shape::stride(inShapeBuffer)[0];
|
||
|
const Nd4jLong iStride1 = shape::stride(inShapeBuffer)[1];
|
||
|
const Nd4jLong iStride2 = shape::stride(inShapeBuffer)[2];
|
||
|
const Nd4jLong iStride3 = shape::stride(inShapeBuffer)[3];
|
||
|
const Nd4jLong oStride0 = shape::stride(outShapeBuffer)[0];
|
||
|
const Nd4jLong oStride1 = shape::stride(outShapeBuffer)[1];
|
||
|
const Nd4jLong oStride2 = shape::stride(outShapeBuffer)[2];
|
||
|
const Nd4jLong oStride3 = shape::stride(outShapeBuffer)[3];
|
||
|
|
||
|
const Nd4jLong iStep2 = dH*iStride2;
|
||
|
const Nd4jLong iStep3 = dW*iStride3;
|
||
|
const int kProd = kH*kW;
|
||
|
const T iStep2Inv = 1./iStep2;
|
||
|
const T iStep3Inv = 1./iStep3;
|
||
|
|
||
|
Nd4jLong hstart, wstart, hend, wend;
|
||
|
T sum, *pIn;
|
||
|
|
||
|
if(poolingMode == 0) { // max
|
||
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, hstart, wstart, hend, wend) collapse(2))
|
||
|
for(int b = 0; b < bS; ++b) {
|
||
|
for(int c = 0; c < iC; ++c) {
|
||
|
for(int oh = 0; oh < oH; ++oh) {
|
||
|
for(int ow = 0; ow < oW; ++ow) {
|
||
|
|
||
|
pIn = in + b * iStride0 + c * iStride1;
|
||
|
|
||
|
hstart = oh * sH - pH;
|
||
|
wstart = ow * sW - pW;
|
||
|
hend = hstart + kHEff;
|
||
|
wend = wstart + kWEff;
|
||
|
|
||
|
if(hstart < 0)
|
||
|
hstart += dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-hstart) / static_cast<T>(dH));
|
||
|
if(wstart < 0)
|
||
|
wstart += dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-wstart) / static_cast<T>(dW));
|
||
|
if(hend > iH)
|
||
|
hend -= dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(hend-iH) / static_cast<T>(dH));
|
||
|
if(wend > iW)
|
||
|
wend -= dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(wend-iW) / static_cast<T>(dW));
|
||
|
|
||
|
hstart *= iStride2;
|
||
|
hend *= iStride2;
|
||
|
wstart *= iStride3;
|
||
|
wend *= iStride3;
|
||
|
|
||
|
sum = -nd4j::DataTypeUtils::max<Z>();
|
||
|
|
||
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep2)
|
||
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) {
|
||
|
T val = pIn[kh + kw];
|
||
|
if (val > sum)
|
||
|
sum = val;
|
||
|
}
|
||
|
out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
/*************************************************************************/
|
||
|
else if(poolingMode == 1) { // avg
|
||
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, hstart, wstart, hend, wend) collapse(2))
|
||
|
for(int b = 0; b < bS; ++b) {
|
||
|
for(int c = 0; c < iC; ++c) {
|
||
|
for(int oh = 0; oh < oH; ++oh) {
|
||
|
for(int ow = 0; ow < oW; ++ow) {
|
||
|
|
||
|
pIn = in + b * iStride0 + c * iStride1;
|
||
|
|
||
|
hstart = oh * sH - pH;
|
||
|
wstart = ow * sW - pW;
|
||
|
hend = hstart + kHEff;
|
||
|
wend = wstart + kWEff;
|
||
|
|
||
|
if(hstart < 0)
|
||
|
hstart += dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-hstart) / static_cast<T>(dH));
|
||
|
if(wstart < 0)
|
||
|
wstart += dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-wstart) / static_cast<T>(dW));
|
||
|
if(hend > iH)
|
||
|
hend -= dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(hend-iH) / static_cast<T>(dH));
|
||
|
if(wend > iW)
|
||
|
wend -= dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(wend-iW) / static_cast<T>(dW));
|
||
|
|
||
|
hstart *= iStride2;
|
||
|
hend *= iStride2;
|
||
|
wstart *= iStride3;
|
||
|
wend *= iStride3;
|
||
|
|
||
|
sum = static_cast<Z>(0.);
|
||
|
|
||
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep2)
|
||
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
|
||
|
sum += pIn[kh + kw];
|
||
|
|
||
|
if ((int) extraParam0 == 0) //Exclude padding
|
||
|
sum /= static_cast<T>(nd4j::math::nd4j_ceil<double,T>(static_cast<double>(hend-hstart) / static_cast<double>(iStep2))) * static_cast<T>(nd4j::math::nd4j_ceil<double,T>(static_cast<double>(wend-wstart) / static_cast<double>(iStep3))); //Accounts for dilation
|
||
|
else if ((int) extraParam0 == 1) //Include padding
|
||
|
sum /= kProd;
|
||
|
|
||
|
out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
/*************************************************************************/
|
||
|
else if(poolingMode == 2) { // pnorm
|
||
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, hstart, wstart, hend, wend) collapse(2))
|
||
|
for(int b = 0; b < bS; ++b) {
|
||
|
for(int c = 0; c < iC; ++c) {
|
||
|
for(int oh = 0; oh < oH; ++oh) {
|
||
|
for(int ow = 0; ow < oW; ++ow) {
|
||
|
|
||
|
pIn = in + b * iStride0 + c * iStride1;
|
||
|
|
||
|
hstart = oh * sH - pH;
|
||
|
wstart = ow * sW - pW;
|
||
|
hend = hstart + kHEff;
|
||
|
wend = wstart + kWEff;
|
||
|
|
||
|
if(hstart < 0)
|
||
|
hstart += dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-hstart) / static_cast<T>(dH));
|
||
|
if(wstart < 0)
|
||
|
wstart += dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-wstart) / static_cast<T>(dW));
|
||
|
if(hend > iH)
|
||
|
hend -= dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(hend-iH) / static_cast<T>(dH));
|
||
|
if(wend > iW)
|
||
|
wend -= dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(wend-iW) / static_cast<T>(dW));
|
||
|
|
||
|
hstart *= iStride2;
|
||
|
hend *= iStride2;
|
||
|
wstart *= iStride3;
|
||
|
wend *= iStride3;
|
||
|
|
||
|
sum = static_cast<T>(0.);
|
||
|
|
||
|
for (Nd4jLong kh = hstart; kh < hend; kh += iStep2)
|
||
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
|
||
|
sum += nd4j::math::nd4j_pow<T, T, T>(nd4j::math::nd4j_abs<T>(pIn[kh + kw]), extraParam0);
|
||
|
|
||
|
sum = nd4j::math::nd4j_pow<T,T,T>(sum, (T) 1. / extraParam0);
|
||
|
|
||
|
out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
nd4j_printf("Special_ops::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
||
|
throw "";
|
||
|
}
|
||
|
}
|
||
|
|
||
|
op_def static T op(T d1, Z *params) {
|
||
|
return d1;
|
||
|
}
|
||
|
|
||
|
|
||
|
/** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc
|
||
|
* normally negative indices are bad, OK here because of other checks on input indices
|
||
|
* Uses unrolled loop specifically for length 4
|
||
|
*/
|
||
|
static _CUDA_HD int getOffsetUnsafe4(int baseOffset, int *shape, int *stride, int *indices) {
|
||
|
int offset = baseOffset;
|
||
|
if (shape[0] != 1) offset += indices[0] * stride[0];
|
||
|
if (shape[1] != 1) offset += indices[1] * stride[1];
|
||
|
if (shape[2] != 1) offset += indices[2] * stride[2];
|
||
|
if (shape[3] != 1) offset += indices[3] * stride[3];
|
||
|
return offset;
|
||
|
}
|
||
|
|
||
|
|
||
|
/**
|
||
|
* A version of Shape.getOffset without checking on input for negative indices etc
|
||
|
* normally negative indices are bad, OK here because of other checks on input indices
|
||
|
* Uses unrolled loop specifically for length 6, where indices[2] and indices[3] are zero (always are here)
|
||
|
*/
|
||
|
static _CUDA_HD int getOffsetUnsafe6(int baseOffset, int *shape, int *stride, int *indices) {
|
||
|
int offset = baseOffset;
|
||
|
if (shape[0] != 1) offset += indices[0] * stride[0];
|
||
|
if (shape[1] != 1) offset += indices[1] * stride[1];
|
||
|
if (shape[4] != 1) offset += indices[4] * stride[4];
|
||
|
if (shape[5] != 1) offset += indices[5] * stride[5];
|
||
|
return offset;
|
||
|
}
|
||
|
|
||
|
};
|
||
|
|
||
|
|
||
|
FORCEINLINE bool is_a_ge_zero_and_a_lt_b(int a, int b) {
|
||
|
return static_cast<unsigned>(a) < static_cast<unsigned>(b);
|
||
|
}
|
||
|
|
||
|
template<typename T>
|
||
|
class
|
||
|
Im2col {
|
||
|
public:
|
||
|
static const bool requiresSpecial = true;
|
||
|
|
||
|
static _CUDA_HD int outSize(int size, int k, int s, int p, bool coverAll) {
|
||
|
if (coverAll)
|
||
|
return (size + p * 2 - k + s - 1) / s + 1;
|
||
|
else
|
||
|
return (size + p * 2 - k) / s + 1;
|
||
|
}
|
||
|
|
||
|
#ifdef __CUDACC__
|
||
|
/**
|
||
|
* Based on: https://github.com/pjreddie/darknet/blob/master/src/im2col_kernels.cu
|
||
|
*/
|
||
|
|
||
|
static inline __device__ void execSpecialCuda(
|
||
|
T *dx, Nd4jLong *xShapeBuffer,
|
||
|
T *result, Nd4jLong *zShapeBuffer,
|
||
|
T *extraParams,
|
||
|
int *allocationPointer, T *reductionPointer,
|
||
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
|
||
|
/*kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], 0, false*/
|
||
|
__shared__ int kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, dY, dX, kSize, samples, depth, height, width, strideex, stridech, strideh, stridew, height_col, width_col, n;
|
||
|
__shared__ T zeroPadVal;
|
||
|
__shared__ Nd4jLong *outShape, *outStride, *inShape, *inStride;
|
||
|
__shared__ char resultOrder;
|
||
|
|
||
|
if (threadIdx.x == 0) {
|
||
|
kernelHeight = (int) extraParams[0];
|
||
|
kernelWidth = (int) extraParams[1];
|
||
|
strideY = (int) extraParams[2];
|
||
|
strideX = (int) extraParams[3];
|
||
|
padHeight = (int) extraParams[4];
|
||
|
padWidth = (int) extraParams[5];
|
||
|
dY = (int) extraParams[6]; //Dilation, height/y dimension
|
||
|
dX = (int) extraParams[7]; //Dilation, width/x dimension
|
||
|
kSize = kernelWidth * kernelHeight;
|
||
|
zeroPadVal = (T) extraParams[9]; //Value to use when value is padding. Usually 0 but not always
|
||
|
|
||
|
outShape = shape::shapeOf(zShapeBuffer);
|
||
|
resultOrder = shape::order(zShapeBuffer);
|
||
|
outStride = shape::stride(zShapeBuffer);
|
||
|
|
||
|
inShape = shape::shapeOf(xShapeBuffer);
|
||
|
inStride = shape::stride(xShapeBuffer);
|
||
|
|
||
|
samples = (int) inShape[0];
|
||
|
depth = (int) inShape[1];
|
||
|
height = (int) inShape[2];
|
||
|
width = (int) inShape[3];
|
||
|
|
||
|
|
||
|
strideex = (int) inStride[0];
|
||
|
stridech = (int) inStride[1];
|
||
|
strideh = (int) inStride[2];
|
||
|
stridew = (int) inStride[3];
|
||
|
|
||
|
// (height + 2 * padHeight - kernelHeight) / strideX + 1; //
|
||
|
// (width + 2 * padWidth - kernelWidth) / strideY + 1; //
|
||
|
height_col = (int) outShape[4];
|
||
|
width_col = (int) outShape[5];
|
||
|
|
||
|
n = samples * depth * height_col * width_col;
|
||
|
}
|
||
|
__syncthreads();
|
||
|
|
||
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||
|
for (; index < n; index += blockDim.x*gridDim.x) {
|
||
|
int h_index = index / width_col;
|
||
|
int h_col = h_index % height_col;
|
||
|
int w_col = index % width_col;
|
||
|
|
||
|
int c_im = h_index / height_col;
|
||
|
int c_col = c_im * kSize;
|
||
|
|
||
|
int depth_im = c_im % depth;
|
||
|
int num_im = c_im / depth;
|
||
|
int h_offset = h_col * strideY - padHeight;
|
||
|
int w_offset = w_col * strideX - padWidth;
|
||
|
|
||
|
T* data_col_ptr = result;
|
||
|
|
||
|
int i_c = (c_col * height_col + h_col) * width_col + w_col;
|
||
|
data_col_ptr += (c_col * height_col + h_col) * width_col + w_col;
|
||
|
|
||
|
T* data_im_ptr = dx;
|
||
|
|
||
|
data_im_ptr += num_im * strideex + depth_im * stridech + h_offset * strideh + w_offset*stridew;
|
||
|
|
||
|
for (int i = 0; i < kernelHeight; ++i) {
|
||
|
for (int j = 0; j < kernelWidth; ++j) {
|
||
|
int h_im = h_offset + i * dY;
|
||
|
int w_im = w_offset + j * dX;
|
||
|
int i_f = 0;
|
||
|
int i_c_temp = i_c;
|
||
|
for (int dim = 5; dim >= 0; dim--) {
|
||
|
i_f += (i_c_temp % outShape[dim]) * outStride[dim];
|
||
|
i_c_temp = i_c_temp / outShape[dim];
|
||
|
}
|
||
|
if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width){
|
||
|
result[i_f] = data_im_ptr[i * dY * strideh + j * dX * stridew];
|
||
|
} else result[i_f] = zeroPadVal;
|
||
|
|
||
|
//result[i_f] = (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ? data_im_ptr[i * strideh + j*stridew] : 0;
|
||
|
data_col_ptr += height_col * width_col;
|
||
|
i_c += height_col * width_col;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
|
||
|
static void execSpecial(
|
||
|
T *imBuff,
|
||
|
Nd4jLong *imShapeBuffer,
|
||
|
T *colBuff,
|
||
|
Nd4jLong *colShapeBuffer,
|
||
|
T *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
/*kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], 0, false*/
|
||
|
|
||
|
// [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||
|
|
||
|
int kH = (int)extraParams[0];
|
||
|
int kW = (int)extraParams[1];
|
||
|
int sH = (int)extraParams[2];
|
||
|
int sW = (int)extraParams[3];
|
||
|
int pH = (int)extraParams[4];
|
||
|
int pW = (int)extraParams[5];
|
||
|
int dH = (int)extraParams[6]; //Dilation, height/y dimension
|
||
|
int dW = (int)extraParams[7]; //Dilation, width/x dimension
|
||
|
T zeroPadVal = extraParams[9];
|
||
|
|
||
|
auto colShape = shape::shapeOf(colShapeBuffer);
|
||
|
auto colStride = shape::stride(colShapeBuffer);
|
||
|
auto imShape = shape::shapeOf(imShapeBuffer);
|
||
|
auto imStride = shape::stride(imShapeBuffer);
|
||
|
|
||
|
const int bS = imShape[0];
|
||
|
const int iC = imShape[1];
|
||
|
const int iH = imShape[2];
|
||
|
const int iW = imShape[3];
|
||
|
const int oH = colShape[4];
|
||
|
const int oW = colShape[5];
|
||
|
const Nd4jLong colStride0 = colStride[0];
|
||
|
const Nd4jLong colStride1 = colStride[1];
|
||
|
const Nd4jLong colStride2 = colStride[2];
|
||
|
const Nd4jLong colStride3 = colStride[3];
|
||
|
const Nd4jLong colStride4 = colStride[4];
|
||
|
const Nd4jLong colStride5 = colStride[5];
|
||
|
const Nd4jLong imStride0 = imStride[0];
|
||
|
const Nd4jLong imStride1 = imStride[1];
|
||
|
const Nd4jLong imStride2 = imStride[2];
|
||
|
const Nd4jLong imStride3 = imStride[3];
|
||
|
|
||
|
T *col, *im;
|
||
|
int imRow, imCol;
|
||
|
|
||
|
if (shape::order(imShapeBuffer) == 'c' && shape::order(colShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(imShapeBuffer) && shape::strideDescendingCAscendingF(colShapeBuffer)) {
|
||
|
|
||
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, im, imRow, imCol) collapse(2))
|
||
|
for (int b = 0; b < bS; b++) {
|
||
|
for (int c = 0; c < iC; ++c) {
|
||
|
for (int kRow = 0; kRow < kH; ++kRow) {
|
||
|
for (int kCol = 0; kCol < kW; ++kCol) {
|
||
|
for (int colH = 0; colH < oH; ++colH) {
|
||
|
for (int colW = 0; colW < oW; ++colW) {
|
||
|
|
||
|
imRow = (-pH + kRow * dH) + colH*sH;
|
||
|
imCol = (-pW + kCol * dW) + colW*sW;
|
||
|
|
||
|
col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5;
|
||
|
im = imBuff + b*imStride0 + c*imStride1 + imRow*imStride2 + imCol*imStride3;
|
||
|
|
||
|
if (static_cast<unsigned>(imRow) >= static_cast<unsigned>(iH) || static_cast<unsigned>(imCol) >= static_cast<unsigned>(iW))
|
||
|
*col = zeroPadVal;
|
||
|
else
|
||
|
*col = *im;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(im, col, imRow, imCol) collapse(2))
|
||
|
for (int b = 0; b < bS; b++) {
|
||
|
for (int colH = 0; colH < oH; ++colH) {
|
||
|
for (int colW = 0; colW < oW; ++colW) {
|
||
|
for (int c = 0; c < iC; ++c) {
|
||
|
for (int kRow = 0; kRow < kH; ++kRow) {
|
||
|
for (int kCol = 0; kCol < kW; ++kCol) {
|
||
|
|
||
|
imRow = (-pH + kRow * dH) + colH*sH;
|
||
|
imCol = (-pW + kCol * dW) + colW*sW;
|
||
|
|
||
|
col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5;
|
||
|
im = imBuff + b*imStride0 + c*imStride1 + imRow*imStride2 + imCol*imStride3;
|
||
|
|
||
|
if (static_cast<unsigned>(imRow) >= static_cast<unsigned>(iH) || static_cast<unsigned>(imCol) >= static_cast<unsigned>(iW))
|
||
|
*col = zeroPadVal;
|
||
|
else
|
||
|
*col = *im;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
op_def static T op(T d1, T *params) {
|
||
|
return d1;
|
||
|
}
|
||
|
|
||
|
|
||
|
/** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc
|
||
|
* normally negative indices are bad, OK here because of other checks on input indices
|
||
|
* Uses unrolled loop specifically for length 4
|
||
|
*/
|
||
|
static _CUDA_HD int getOffsetUnsafe4(int baseOffset, int *shape, int *stride, int *indices) {
|
||
|
int offset = baseOffset;
|
||
|
if (shape[0] != 1) offset += indices[0] * stride[0];
|
||
|
if (shape[1] != 1) offset += indices[1] * stride[1];
|
||
|
if (shape[2] != 1) offset += indices[2] * stride[2];
|
||
|
if (shape[3] != 1) offset += indices[3] * stride[3];
|
||
|
return offset;
|
||
|
}
|
||
|
|
||
|
|
||
|
/**
|
||
|
* A version of Shape.getOffset without checking on input for negative indices etc
|
||
|
* normally negative indices are bad, OK here because of other checks on input indices
|
||
|
* Uses unrolled loop specifically for length 6, where indices[2] and indices[3] are zero (always are here)
|
||
|
*/
|
||
|
static _CUDA_HD int getOffsetUnsafe6(int baseOffset, int *shape, int *stride, int *indices) {
|
||
|
int offset = baseOffset;
|
||
|
if (shape[0] != 1) offset += indices[0] * stride[0];
|
||
|
if (shape[1] != 1) offset += indices[1] * stride[1];
|
||
|
if (shape[4] != 1) offset += indices[4] * stride[4];
|
||
|
if (shape[5] != 1) offset += indices[5] * stride[5];
|
||
|
return offset;
|
||
|
}
|
||
|
|
||
|
};
|
||
|
|
||
|
template<typename T, typename Z>
|
||
|
class Histogram {
|
||
|
public:
|
||
|
static const bool requiresSpecial = true;
|
||
|
|
||
|
#ifdef __CUDACC__
|
||
|
static inline __device__ void execSpecialCuda(
|
||
|
T *dx, Nd4jLong *xShapeBuffer,
|
||
|
Z *result, Nd4jLong *zShapeBuffer,
|
||
|
Z *extraParams,
|
||
|
int *allocationPointer, Z *reductionPointer,
|
||
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
|
||
|
|
||
|
|
||
|
};
|
||
|
#endif
|
||
|
|
||
|
static void execSpecial(
|
||
|
T *dx,
|
||
|
Nd4jLong *xShapeBuffer,
|
||
|
Z *result,
|
||
|
Nd4jLong *zShapeBuffer,
|
||
|
Z *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
|
||
|
|
||
|
|
||
|
}
|
||
|
|
||
|
|
||
|
op_def static T op(T d1, Z *params) {
|
||
|
return d1;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template<typename X>
|
||
|
class Col2Im {
|
||
|
|
||
|
public:
|
||
|
static const bool requiresSpecial = true;
|
||
|
#ifdef __CUDACC__
|
||
|
/**
|
||
|
* https://github.com/pjreddie/darknet/blob/master/src/col2im_kernels.cu
|
||
|
*/
|
||
|
|
||
|
static inline __device__ void execSpecialCuda(
|
||
|
X *dx, Nd4jLong *xShapeBuffer,
|
||
|
X *result, Nd4jLong *zShapeBuffer,
|
||
|
X *extraParams, int *allocationPointer,
|
||
|
X *reductionPointer,
|
||
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
|
||
|
__shared__ int strideex, stridech, stridekrow, stridekcol, striderow, stridecol, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, imgHeight, imgWidth, dY, dX, samples, depth, imgH, imgW, height_col, width_col, n, kEffectiveW, kEffectiveH;
|
||
|
__shared__ Nd4jLong *inShape, *inStride, *outShape, *outStride;
|
||
|
__shared__ char resultOrder;
|
||
|
|
||
|
if (threadIdx.x == 0) {
|
||
|
inShape = shape::shapeOf(xShapeBuffer);
|
||
|
inStride = shape::stride(xShapeBuffer);
|
||
|
|
||
|
strideex = (int) inStride[0];
|
||
|
stridech = (int) inStride[1];
|
||
|
stridekrow = (int) inStride[2];
|
||
|
stridekcol = (int) inStride[3];
|
||
|
striderow = (int) inStride[4];
|
||
|
stridecol = (int) inStride[5];
|
||
|
|
||
|
kernelHeight = (int) inShape[2];
|
||
|
kernelWidth = (int) inShape[3];
|
||
|
|
||
|
strideY = (int) extraParams[0];
|
||
|
strideX = (int) extraParams[1];
|
||
|
padHeight = (int) extraParams[2];
|
||
|
padWidth = (int) extraParams[3];
|
||
|
imgHeight = (int) extraParams[4];
|
||
|
imgWidth = (int) extraParams[5];
|
||
|
dY = (int) extraParams[6]; //Dilation in height/y dimension
|
||
|
dX = (int) extraParams[7]; //Dilation in width/x dimension
|
||
|
|
||
|
outShape = shape::shapeOf(zShapeBuffer);
|
||
|
resultOrder = shape::order(zShapeBuffer);
|
||
|
outStride = shape::stride(zShapeBuffer);
|
||
|
|
||
|
samples = (int) outShape[0];
|
||
|
depth = (int) outShape[1];
|
||
|
imgH = (int) outShape[2];
|
||
|
imgW = (int) outShape[3];
|
||
|
|
||
|
height_col = inShape[4];//(imgHeight + 2 * padHeight - kernelHeight) / strideX + 1;
|
||
|
width_col = inShape[5];//(imgWidth + 2 * padWidth - kernelWidth) / strideY + 1;
|
||
|
|
||
|
n = samples * depth * imgHeight * imgWidth;
|
||
|
|
||
|
//Effective kernel size, accounting for dilation
|
||
|
kEffectiveW = kernelWidth + (kernelWidth - 1) * (dX - 1);
|
||
|
kEffectiveH = kernelHeight + (kernelHeight - 1) * (dY - 1);
|
||
|
}
|
||
|
__syncthreads();
|
||
|
|
||
|
for (int i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
|
||
|
X val = 0;
|
||
|
int w_im = i % imgWidth + padWidth;
|
||
|
int h_im = (i / imgWidth) % imgHeight + padHeight;
|
||
|
int c_im = i / (imgWidth * imgHeight);
|
||
|
|
||
|
int num_im = c_im / depth;
|
||
|
int depth_im = c_im % depth;
|
||
|
|
||
|
// compute the start and end of the output
|
||
|
// These are the indexes for dimensions ??? in the 6d col matrix
|
||
|
int w_col_start = (w_im < kEffectiveW) ? 0 : (w_im - kEffectiveW) / strideX + 1;
|
||
|
int w_col_end = nd4j::math::nd4j_min<int>(w_im / strideX + 1, width_col);
|
||
|
|
||
|
int h_col_start = (h_im < kEffectiveH) ? 0 : (h_im - kEffectiveH) / strideY + 1;
|
||
|
int h_col_end = nd4j::math::nd4j_min<int>(h_im / strideY + 1, height_col);
|
||
|
|
||
|
|
||
|
//Iterate over col entries in the 6d array... these are added up
|
||
|
for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) {
|
||
|
for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) {
|
||
|
int h_k = (h_im - h_col * strideY);
|
||
|
int w_k = (w_im - w_col * strideX);
|
||
|
|
||
|
if(h_k % dY == 0 && w_k % dX == 0){
|
||
|
h_k /= dY;
|
||
|
w_k /= dX;
|
||
|
|
||
|
int data_col_index = num_im * strideex + depth_im * stridech + h_k * stridekrow + w_k * stridekcol + h_col * striderow + w_col * stridecol;
|
||
|
val += dx[data_col_index];
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
int i_f = 0;
|
||
|
int i_c = i;
|
||
|
for (int dim = 3; dim >= 0; dim--)
|
||
|
{
|
||
|
i_f += (i_c % outShape[dim]) * outStride[dim];
|
||
|
i_c = i_c / outShape[dim];
|
||
|
}
|
||
|
result[i_f] = val;
|
||
|
}
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
static void execSpecial(
|
||
|
X *colBuff,
|
||
|
Nd4jLong *colShapeBuffer,
|
||
|
X *imBuff,
|
||
|
Nd4jLong *imShapeBuffer,
|
||
|
X *extraParams,
|
||
|
Nd4jLong *tadShapeInfo,
|
||
|
Nd4jLong *tadOffsets) {
|
||
|
|
||
|
// [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
||
|
|
||
|
auto colShape = shape::shapeOf(colShapeBuffer);
|
||
|
auto colStride = shape::stride(colShapeBuffer);
|
||
|
auto imShape = shape::shapeOf(imShapeBuffer);
|
||
|
auto imStride = shape::stride(imShapeBuffer);
|
||
|
|
||
|
const int sH = (int)extraParams[0];
|
||
|
const int sW = (int)extraParams[1];
|
||
|
const int pH = (int)extraParams[2];
|
||
|
const int pW = (int)extraParams[3];
|
||
|
const int iH = (int)extraParams[4];
|
||
|
const int iW = (int)extraParams[5];
|
||
|
const int dH = (int)extraParams[6];
|
||
|
const int dW = (int)extraParams[7];
|
||
|
|
||
|
const int bS = imShape[0];
|
||
|
const int iC = imShape[1];
|
||
|
const int kH = colShape[2];
|
||
|
const int kW = colShape[3];
|
||
|
const int oH = colShape[4];
|
||
|
const int oW = colShape[5];
|
||
|
const Nd4jLong colStride0 = colStride[0];
|
||
|
const Nd4jLong colStride1 = colStride[1];
|
||
|
const Nd4jLong colStride2 = colStride[2];
|
||
|
const Nd4jLong colStride3 = colStride[3];
|
||
|
const Nd4jLong colStride4 = colStride[4];
|
||
|
const Nd4jLong colStride5 = colStride[5];
|
||
|
const Nd4jLong imStride0 = imStride[0];
|
||
|
const Nd4jLong imStride1 = imStride[1];
|
||
|
const Nd4jLong imStride2 = imStride[2];
|
||
|
const Nd4jLong imStride3 = imStride[3];
|
||
|
|
||
|
auto zLength = shape::length(imShapeBuffer);
|
||
|
|
||
|
// initial zeroing of image content
|
||
|
memset(imBuff, 0, zLength * sizeof(X));
|
||
|
|
||
|
|
||
|
X *col, *im;
|
||
|
int imRow, imCol;
|
||
|
|
||
|
if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) {
|
||
|
|
||
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, im, imRow, imCol) collapse(2))
|
||
|
for (int b = 0; b < bS; b++) {
|
||
|
for (int c = 0; c < iC; ++c) {
|
||
|
for (int kRow = 0; kRow < kH; ++kRow) {
|
||
|
for (int kCol = 0; kCol < kW; ++kCol) {
|
||
|
for (int colH = 0; colH < oH; ++colH) {
|
||
|
for (int colW = 0; colW < oW; ++colW) {
|
||
|
|
||
|
imRow = (-pH + kRow * dH) + colH*sH;
|
||
|
imCol = (-pW + kCol * dW) + colW*sW;
|
||
|
|
||
|
col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5;
|
||
|
im = imBuff + b*imStride0 + c*imStride1 + imRow*imStride2 + imCol*imStride3;
|
||
|
|
||
|
if (static_cast<unsigned>(imRow) < static_cast<unsigned>(iH) && static_cast<unsigned>(imCol) < static_cast<unsigned>(iW))
|
||
|
*im += *col;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(private(im, col, imRow, imCol))
|
||
|
for (int b = 0; b < bS; b++) {
|
||
|
for (int colH = 0; colH < oH; ++colH) {
|
||
|
for (int colW = 0; colW < oW; ++colW) {
|
||
|
for (int c = 0; c < iC; ++c) {
|
||
|
for (int kRow = 0; kRow < kH; ++kRow) {
|
||
|
for (int kCol = 0; kCol < kW; ++kCol) {
|
||
|
|
||
|
imRow = (-pH + kRow * dH) + colH*sH;
|
||
|
imCol = (-pW + kCol * dW) + colW*sW;
|
||
|
|
||
|
col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5;
|
||
|
im = imBuff + b*imStride0 + c*imStride1 + imRow*imStride2 + imCol*imStride3;
|
||
|
|
||
|
if (static_cast<unsigned>(imRow) < static_cast<unsigned>(iH) && static_cast<unsigned>(imCol) < static_cast<unsigned>(iW))
|
||
|
*im += *col;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
op_def static X op(X d1, X *params) {
|
||
|
return d1;
|
||
|
}
|
||
|
|
||
|
|
||
|
/** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc
|
||
|
* normally negative indices are bad, OK here because of other checks on input indices
|
||
|
* Uses unrolled loop specifically for length 4
|
||
|
*/
|
||
|
static _CUDA_HD int getOffsetUnsafe4(int baseOffset, int *shape, int *stride, int *indices) {
|
||
|
int offset = baseOffset;
|
||
|
if (shape[0] != 1) offset += indices[0] * stride[0];
|
||
|
if (shape[1] != 1) offset += indices[1] * stride[1];
|
||
|
if (shape[2] != 1) offset += indices[2] * stride[2];
|
||
|
if (shape[3] != 1) offset += indices[3] * stride[3];
|
||
|
return offset;
|
||
|
}
|
||
|
|
||
|
/** A version of Shape.getOffset without checking on input for negative indices etc
|
||
|
* normally negative indices are bad, OK here because of other checks on input indices
|
||
|
* Uses unrolled loop specifically for length 6, where indices[2] and indices[3] are zero (always are here)
|
||
|
*/
|
||
|
static _CUDA_HD int getOffsetUnsafe6(int baseOffset, int *shape, int *stride, int *indices) {
|
||
|
int offset = baseOffset;
|
||
|
if (shape[0] != 1) offset += indices[0] * stride[0];
|
||
|
if (shape[1] != 1) offset += indices[1] * stride[1];
|
||
|
if (shape[4] != 1) offset += indices[4] * stride[4];
|
||
|
if (shape[5] != 1) offset += indices[5] * stride[5];
|
||
|
return offset;
|
||
|
}
|
||
|
|
||
|
};
|
||
|
|
||
|
|
||
|
template<typename X>
|
||
|
class Reverse {
|
||
|
public:
|
||
|
static const bool requiresSpecial = true;
|
||
|
|
||
|
#ifdef __CUDACC__
|
||
|
static inline __device__ void execSpecialCuda(X *dx, Nd4jLong *xShapeBuffer,
|
||
|
X *result, Nd4jLong *zShapeBuffer,
|
||
|
X *extraParams, int *allocationPointer,
|
||
|
X *reductionPointer,
|
||
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
|
||
|
__shared__ Nd4jLong xLength;
|
||
|
__shared__ int xEWS;
|
||
|
__shared__ char xOrder;
|
||
|
__shared__ Nd4jLong sLength;
|
||
|
__shared__ X *shmem;
|
||
|
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||
|
|
||
|
if (threadIdx.x == 0) {
|
||
|
xLength = shape::length(xShapeBuffer);
|
||
|
xEWS = shape::elementWiseStride(xShapeBuffer);
|
||
|
xOrder = shape::order(xShapeBuffer);
|
||
|
sLength = xLength - 1;
|
||
|
|
||
|
extern __shared__ unsigned char shrd[];
|
||
|
shmem = (X *) shrd;
|
||
|
}
|
||
|
__syncthreads();
|
||
|
|
||
|
|
||
|
|
||
|
if (dx == result) {
|
||
|
|
||
|
if (xEWS == 1) {
|
||
|
for (int e = tid; e < xLength / 2; e += blockDim.x * gridDim.x) {
|
||
|
Nd4jLong idx = sLength - e;
|
||
|
X tmp = dx[e];
|
||
|
dx[e] = dx[idx];
|
||
|
dx[idx] = tmp;
|
||
|
}
|
||
|
} else if (xEWS >= 1) {
|
||
|
for (int e = tid; e < xLength / 2; e += blockDim.x * gridDim.x) {
|
||
|
Nd4jLong idx1 = (sLength - e) * xEWS;
|
||
|
Nd4jLong idx2 = e * xEWS;
|
||
|
X tmp = dx[idx2];
|
||
|
dx[idx2] = dx[idx1];
|
||
|
dx[idx1] = tmp;
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
for (int e = tid; e < xLength / 2; e += blockDim.x * gridDim.x) {
|
||
|
auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength);
|
||
|
auto zOffset = shape::getIndexOffset(sLength - e, xShapeBuffer, xLength);
|
||
|
result[zOffset] = dx[xOffset];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
} else {
|
||
|
__shared__ int zEWS;
|
||
|
__shared__ char zOrder;
|
||
|
|
||
|
if (threadIdx.x == 0) {
|
||
|
zEWS = shape::elementWiseStride(zShapeBuffer);
|
||
|
zOrder = shape::order(zShapeBuffer);
|
||
|
}
|
||
|
__syncthreads();
|
||
|
|
||
|
if (xEWS == 1 && zEWS == 1 && xOrder == zOrder) {
|
||
|
// loop for whole array
|
||
|
for (int e = tid; e < xLength; e += blockDim.x * gridDim.x) {
|
||
|
result[sLength - e] = dx[e];
|
||
|
}
|
||
|
} else if (xEWS >= 1 && zEWS >= 1 && xOrder == zOrder) {
|
||
|
|
||
|
for (int e = tid; e < xLength; e += blockDim.x * gridDim.x) {
|
||
|
result[(sLength - e) * zEWS] = dx[e * xEWS];
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
for (int e = tid; e < xLength; e += blockDim.x * gridDim.x) {
|
||
|
auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength);
|
||
|
auto zOffset = shape::getIndexOffset(sLength - e, xShapeBuffer, xLength);
|
||
|
result[zOffset] = dx[xOffset];
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#endif
|
||
|
|
||
|
|
||
|
static void execSpecial(X *dx, Nd4jLong *xShapeBuffer, X *result, Nd4jLong *zShapeBuffer, X *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
Nd4jLong xLength = shape::length(xShapeBuffer);
|
||
|
int xEWS = shape::elementWiseStride(xShapeBuffer);
|
||
|
char xOrder = shape::order(xShapeBuffer);
|
||
|
Nd4jLong sLength = xLength - 1;
|
||
|
|
||
|
// two step phase here
|
||
|
if (dx == result) {
|
||
|
if (xEWS == 1) {
|
||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||
|
for (Nd4jLong e = 0; e < xLength / 2; e++) {
|
||
|
Nd4jLong idx = sLength - e;
|
||
|
auto tmp = dx[e];
|
||
|
dx[e] = dx[idx];
|
||
|
dx[idx] = tmp;
|
||
|
}
|
||
|
} else if (xEWS > 1) {
|
||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||
|
for (Nd4jLong e = 0; e < xLength / 2; e++) {
|
||
|
Nd4jLong idx1 = (sLength - e) * xEWS;
|
||
|
Nd4jLong idx2 = e * xEWS;
|
||
|
auto tmp = dx[idx2];
|
||
|
dx[idx2] = dx[idx1];
|
||
|
dx[idx1] = tmp;
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||
|
for (Nd4jLong e = 0; e < xLength / 2; e++) {
|
||
|
auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength);
|
||
|
auto zOffset = shape::getIndexOffset(sLength - e, xShapeBuffer, xLength);
|
||
|
|
||
|
result[zOffset] = dx[xOffset];
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
// single step phase here
|
||
|
auto zEWS = shape::elementWiseStride(zShapeBuffer);
|
||
|
auto zOrder = shape::order(zShapeBuffer);
|
||
|
|
||
|
if (xEWS == 1 && zEWS == 1 && xOrder == zOrder) {
|
||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||
|
for (Nd4jLong e = 0; e < xLength; e++) {
|
||
|
result[sLength - e] = dx[e];
|
||
|
}
|
||
|
} else if (xEWS >= 1 && zEWS >= 1 && xOrder == zOrder) {
|
||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||
|
for (Nd4jLong e = 0; e < xLength; e++) {
|
||
|
result[(sLength - e) * zEWS] = dx[e * xEWS];
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||
|
for (Nd4jLong e = 0; e < xLength; e++) {
|
||
|
auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength);
|
||
|
auto zOffset = shape::getIndexOffset(sLength - e, zShapeBuffer, xLength);
|
||
|
result[zOffset] = dx[xOffset];
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
op_def static X op(X d1, X *params) {
|
||
|
return d1;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template<typename X>
|
||
|
class SoftMax {
|
||
|
public:
|
||
|
static const bool requiresSpecial = true;
|
||
|
|
||
|
#ifdef __CUDACC__
|
||
|
/**
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
static inline __device__ void execSpecialCuda(
|
||
|
void *vx, Nd4jLong *xShapeBuffer,
|
||
|
void *vresult, Nd4jLong *zShapeBuffer,
|
||
|
void *vextraParams,
|
||
|
int *allocationPointer, void *reductionPointer,
|
||
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
|
||
|
auto dx = reinterpret_cast<X *>(vx);
|
||
|
auto result = reinterpret_cast<X *>(vresult);
|
||
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||
|
|
||
|
auto shape = shape::shapeOf(xShapeBuffer);
|
||
|
__shared__ X maxResult;
|
||
|
__shared__ Nd4jLong *maxResultShapeBuffer;
|
||
|
|
||
|
auto length = shape::length(xShapeBuffer);
|
||
|
|
||
|
auto stride = shape::stride(xShapeBuffer);
|
||
|
//compute the row wise maxes
|
||
|
|
||
|
__shared__ Nd4jLong maxShape[2];
|
||
|
|
||
|
// it's always 2d here
|
||
|
__shared__ Nd4jLong tempBuffer[8];
|
||
|
|
||
|
if (threadIdx.x == 0) {
|
||
|
maxResult = (X) 0.0;
|
||
|
maxShape[0] = shape[0];
|
||
|
maxShape[1] = 1;
|
||
|
maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape, tempBuffer);
|
||
|
}
|
||
|
__syncthreads();
|
||
|
|
||
|
|
||
|
functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
|
||
|
__syncthreads();
|
||
|
|
||
|
//subtract max of each row
|
||
|
functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Subtract, &maxResult, dx, xShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
|
||
|
__syncthreads();
|
||
|
|
||
|
//after subtracting the row wise maxes take the exp
|
||
|
functions::transform::TransformStrictInplace<X>::transformCudaLegacy(nd4j::transform::Exp, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets);
|
||
|
__syncthreads();
|
||
|
|
||
|
//take the sum for the exponential
|
||
|
functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
|
||
|
__syncthreads();
|
||
|
|
||
|
//divide by the sum
|
||
|
functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Divide, &maxResult, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
static void execSpecial(
|
||
|
void *vx,
|
||
|
Nd4jLong *xShapeInfo,
|
||
|
void *vz,
|
||
|
Nd4jLong *zShapeInfo,
|
||
|
void *vextraParams,
|
||
|
Nd4jLong *tadShapeInfo,
|
||
|
Nd4jLong *tadOffsets) {
|
||
|
|
||
|
auto x = reinterpret_cast<X *>(vx);
|
||
|
auto z = reinterpret_cast<X *>(vz);
|
||
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||
|
|
||
|
if (shape::isMatrix(xShapeInfo)) {
|
||
|
|
||
|
if(shape::equalsStrict(xShapeInfo, zShapeInfo)) {
|
||
|
if (tadShapeInfo == nullptr) {
|
||
|
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, 1);
|
||
|
tadShapeInfo = tadPack.primaryShapeInfo();
|
||
|
tadOffsets = tadPack.primaryOffsets();
|
||
|
}
|
||
|
|
||
|
const uint tadLen = shape::length(tadShapeInfo);
|
||
|
const uint numOfTads = shape::length(xShapeInfo) / tadLen;
|
||
|
|
||
|
if(shape::elementWiseStride(tadShapeInfo) == 1) {
|
||
|
|
||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||
|
for (uint i = 0; i < numOfTads; ++i) {
|
||
|
|
||
|
X* inBuff = x + tadOffsets[i];
|
||
|
X* outBuff = z + tadOffsets[i];
|
||
|
|
||
|
X max = -nd4j::DataTypeUtils::max<X>();
|
||
|
X sum = 0;
|
||
|
|
||
|
for(uint j = 0; j < tadLen; ++j)
|
||
|
max = nd4j::math::nd4j_max<X>(max, inBuff[j]);
|
||
|
|
||
|
for (uint j = 0; j < tadLen; ++j) {
|
||
|
X temp = nd4j::math::nd4j_exp<X,X>(inBuff[j] - max);
|
||
|
outBuff[j] = temp;
|
||
|
sum += temp;
|
||
|
}
|
||
|
|
||
|
for (uint j = 0; j < tadLen; ++j)
|
||
|
outBuff[j] /= sum;
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
uint xShapeInfoCast[MAX_RANK];
|
||
|
bool canCast = nd4j::DataTypeUtils::castShapeInfo(tadShapeInfo, xShapeInfoCast);
|
||
|
|
||
|
auto offsets = new Nd4jLong[tadLen];
|
||
|
shape::calcOffsets(tadShapeInfo, offsets);
|
||
|
|
||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||
|
for (uint i = 0; i < numOfTads; ++i) {
|
||
|
|
||
|
X* inBuff = x + tadOffsets[i];
|
||
|
X* outBuff = z + tadOffsets[i];
|
||
|
|
||
|
X max = -nd4j::DataTypeUtils::max<X>();
|
||
|
X sum = 0.f;
|
||
|
|
||
|
for(uint j = 0; j < tadLen; ++j)
|
||
|
max = nd4j::math::nd4j_max<X>(max, inBuff[offsets[j]]);
|
||
|
|
||
|
for (uint j = 0; j < tadLen; ++j) {
|
||
|
X temp = nd4j::math::nd4j_exp<X,X>(inBuff[offsets[j]] - max);
|
||
|
outBuff[offsets[j]] = temp;
|
||
|
sum += temp;
|
||
|
}
|
||
|
|
||
|
for (uint j = 0; j < tadLen; ++j)
|
||
|
outBuff[offsets[j]] /= sum;
|
||
|
}
|
||
|
delete []offsets;
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
auto shape = shape::shapeOf(xShapeInfo);
|
||
|
//iterate along rows
|
||
|
int dimension[1] = { 0 };
|
||
|
int maxDimension[1] = { 1 };
|
||
|
//compute the row wise maxes
|
||
|
auto maxResult = new X[shape[0]];
|
||
|
for (int i = 0; i < shape[0]; i++)
|
||
|
maxResult[i] = 0.0;
|
||
|
Nd4jLong maxShape[2] = { shape[0], 1 };
|
||
|
auto maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape);
|
||
|
functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Max, x, xShapeInfo, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr);
|
||
|
|
||
|
//subtract max of each row
|
||
|
functions::broadcast::Broadcast<X, X, X>::exec(nd4j::broadcast::Subtract, x, xShapeInfo, maxResult, maxResultShapeBuffer, z, zShapeInfo, dimension, 1, nullptr, nullptr, nullptr, nullptr);
|
||
|
|
||
|
//after subtracting the row wise maxes take the exp
|
||
|
functions::transform::TransformStrict<X>::exec(nd4j::transform::Exp, z, zShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);
|
||
|
|
||
|
//take the sum for the exponential
|
||
|
functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Sum, z, zShapeInfo, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr);
|
||
|
|
||
|
//divide by the sum
|
||
|
functions::broadcast::Broadcast<X,X,X>::exec(nd4j::broadcast::Divide, z, zShapeInfo, maxResult, maxResultShapeBuffer, z, zShapeInfo, dimension, 1, nullptr, nullptr, nullptr, nullptr);
|
||
|
|
||
|
delete[] maxResultShapeBuffer;
|
||
|
delete[] maxResult;
|
||
|
}
|
||
|
}
|
||
|
else if (shape::isVector(xShapeInfo)) {
|
||
|
auto max = -nd4j::DataTypeUtils::max<X>();
|
||
|
X sum = 0;
|
||
|
int elementWiseStride = shape::elementWiseStride(xShapeInfo);
|
||
|
int resultElementWiseStride = shape::elementWiseStride(zShapeInfo);
|
||
|
int length = shape::length(xShapeInfo);
|
||
|
if (elementWiseStride >= 1 && resultElementWiseStride >= 1) {
|
||
|
if (elementWiseStride == 1 && resultElementWiseStride == 1) {
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
max = nd4j::math::nd4j_max<X>(max, x[i]);
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
z[i] = nd4j::math::nd4j_exp<X,X>(x[i] - max);
|
||
|
sum += z[i];
|
||
|
}
|
||
|
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
z[i] /= sum;
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
max = nd4j::math::nd4j_max<X>(max, x[i * elementWiseStride]);
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
auto r = nd4j::math::nd4j_exp<X, X>(x[i * elementWiseStride] - max);
|
||
|
z[i * resultElementWiseStride] = r;
|
||
|
sum += r;
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
z[i * resultElementWiseStride] /= sum;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
op_def static X op(X d1, X *params) {
|
||
|
return d1;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
|
||
|
|
||
|
template<typename X>
|
||
|
class LogSoftMax {
|
||
|
public:
|
||
|
static const bool requiresSpecial = true;
|
||
|
#ifdef __CUDACC__
|
||
|
/**
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
static inline __device__ void execSpecialCuda(
|
||
|
void *vx, Nd4jLong *xShapeBuffer,
|
||
|
void *vresult, Nd4jLong *zShapeBuffer,
|
||
|
void *vextraParams,
|
||
|
int *allocationPointer, void *reductionPointer,
|
||
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
|
||
|
auto shape = shape::shapeOf(xShapeBuffer);
|
||
|
auto stride = shape::stride(xShapeBuffer);
|
||
|
//iterate along rows
|
||
|
|
||
|
auto dx = reinterpret_cast<X *>(vx);
|
||
|
auto result = reinterpret_cast<X *>(vresult);
|
||
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||
|
|
||
|
__shared__ X maxResult;
|
||
|
__shared__ Nd4jLong *maxResultShapeBuffer;
|
||
|
if (threadIdx.x == 0) {
|
||
|
maxResult = (X) 0.0;
|
||
|
}
|
||
|
__syncthreads();
|
||
|
//compute the row wise maxes
|
||
|
|
||
|
Nd4jLong maxShape[2] = { shape[0], 1 };
|
||
|
__shared__ Nd4jLong tempBuffer[8];
|
||
|
|
||
|
if (threadIdx.x == 0)
|
||
|
maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape, tempBuffer);
|
||
|
__syncthreads();
|
||
|
|
||
|
functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
|
||
|
__syncthreads();
|
||
|
|
||
|
//subtract max of each row
|
||
|
functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Subtract, &maxResult, dx, xShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
|
||
|
__syncthreads();
|
||
|
|
||
|
//after subtracting the row wise maxes take the exp
|
||
|
functions::transform::TransformStrictInplace<X>::transformCudaLegacy(nd4j::transform::Exp, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets);
|
||
|
__syncthreads();
|
||
|
|
||
|
//take the sum for the exponential
|
||
|
functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
|
||
|
__syncthreads();
|
||
|
|
||
|
//divide by the sum
|
||
|
functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Divide, &maxResult, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
|
||
|
__syncthreads();
|
||
|
|
||
|
functions::transform::TransformStrictInplace<X>::transformCudaLegacy(nd4j::transform::Log, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets);
|
||
|
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
|
||
|
static void execSpecial(
|
||
|
void *vx,
|
||
|
Nd4jLong *xShapeBuffer,
|
||
|
void *vresult,
|
||
|
Nd4jLong *zShapeBuffer,
|
||
|
void *vextraParams,
|
||
|
Nd4jLong *tadShapeInfo,
|
||
|
Nd4jLong *tadOffsets) {
|
||
|
|
||
|
auto dx = reinterpret_cast<X *>(vx);
|
||
|
auto result = reinterpret_cast<X *>(vresult);
|
||
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||
|
|
||
|
if (shape::isMatrix(xShapeBuffer, 2)) {
|
||
|
auto shape = shape::shapeOf(xShapeBuffer);
|
||
|
//iterate along rows
|
||
|
int dimension[1] = { 0 };
|
||
|
int maxDimension[1] = { 1 };
|
||
|
//compute the row wise maxes
|
||
|
auto maxResult = new X[shape[0]];
|
||
|
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int i = 0; i < shape[0]; i++)
|
||
|
maxResult[i] = 0.0;
|
||
|
|
||
|
Nd4jLong maxShape[2] = { shape[0], 1 };
|
||
|
auto maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape);
|
||
|
functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr);
|
||
|
|
||
|
//subtract max of each row
|
||
|
functions::broadcast::Broadcast<X,X,X>::exec(nd4j::broadcast::Subtract, dx, xShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr);
|
||
|
|
||
|
//after subtracting the row wise maxes take the exp
|
||
|
functions::transform::TransformStrict<X>::exec(nd4j::transform::Exp, result, zShapeBuffer, result, zShapeBuffer, extraParams, tadShapeInfo, tadOffsets);
|
||
|
|
||
|
//take the sum for the exponential
|
||
|
functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr);
|
||
|
|
||
|
//divide by the sum
|
||
|
functions::broadcast::Broadcast<X,X,X>::exec(nd4j::broadcast::Divide, result, zShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr);
|
||
|
|
||
|
functions::transform::TransformStrict<X>::exec(nd4j::transform::Log, result, zShapeBuffer, result, zShapeBuffer, extraParams, tadShapeInfo, tadOffsets);
|
||
|
|
||
|
|
||
|
delete[] maxResultShapeBuffer;
|
||
|
}
|
||
|
else if (shape::isVector(xShapeBuffer, 2)) {
|
||
|
auto max = -FLOAT_MAX_VALUE;
|
||
|
X sum = 0;
|
||
|
|
||
|
auto elementWiseStride = shape::elementWiseStride(xShapeBuffer);
|
||
|
auto length = shape::length(xShapeBuffer);
|
||
|
if (elementWiseStride == 1) {
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
max = nd4j::math::nd4j_max<X>(max, result[i]);
|
||
|
}
|
||
|
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i] = nd4j::math::nd4j_exp<X, X>(dx[i] - max);
|
||
|
sum += result[i];
|
||
|
}
|
||
|
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i] /= sum;
|
||
|
result[i] = nd4j::math::nd4j_log<X, X>(result[i]);
|
||
|
}
|
||
|
}
|
||
|
else if (elementWiseStride > 1) {
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
max = nd4j::math::nd4j_max<X>(max, result[i * elementWiseStride]);
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i * elementWiseStride] = nd4j::math::nd4j_exp<X, X>(dx[i * elementWiseStride] - max);
|
||
|
sum += result[i * elementWiseStride];
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i * elementWiseStride] /= sum;
|
||
|
result[i * elementWiseStride] = nd4j::math::nd4j_log<X, X>(result[i * elementWiseStride]);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
op_def static X op(X d1, X *params) {
|
||
|
return d1;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
|
||
|
/**
|
||
|
* softmax(x)
|
||
|
*/
|
||
|
template<typename X>
|
||
|
class SoftMaxDerivative {
|
||
|
public:
|
||
|
static const bool requiresSpecial = true;
|
||
|
|
||
|
#ifdef __CUDACC__
|
||
|
/**
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
static inline __device__ void execSpecialCuda(
|
||
|
void *vx, Nd4jLong *xShapeBuffer,
|
||
|
void *vresult, Nd4jLong *zShapeBuffer,
|
||
|
void *vextraParams,
|
||
|
int *allocationPointer, void *reductionPointer,
|
||
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
|
||
|
auto dx = reinterpret_cast<X *>(vx);
|
||
|
auto result = reinterpret_cast<X *>(vresult);
|
||
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||
|
|
||
|
auto shape = shape::shapeOf(xShapeBuffer);
|
||
|
__shared__ X maxResult;
|
||
|
__shared__ Nd4jLong *maxResultShapeBuffer;
|
||
|
__shared__ Nd4jLong resultEWS;
|
||
|
|
||
|
auto length = shape::length(xShapeBuffer);
|
||
|
|
||
|
if (threadIdx.x == 0) {
|
||
|
resultEWS = shape::elementWiseStride(zShapeBuffer);
|
||
|
|
||
|
maxResult = (X) 0.0;
|
||
|
}
|
||
|
__syncthreads();
|
||
|
|
||
|
auto tride = shape::stride(xShapeBuffer);
|
||
|
Nd4jLong maxShape[2] = { shape[0], 1 };
|
||
|
|
||
|
__shared__ Nd4jLong tempBuffer[8];
|
||
|
|
||
|
if (threadIdx.x == 0)
|
||
|
maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape, tempBuffer);
|
||
|
__syncthreads();
|
||
|
|
||
|
functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
|
||
|
__syncthreads();
|
||
|
|
||
|
//subtract max of each row
|
||
|
functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Subtract, &maxResult, dx, xShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
|
||
|
__syncthreads();
|
||
|
|
||
|
//after subtracting the row wise maxes take the exp
|
||
|
functions::transform::TransformStrictInplace<X>::transformCudaLegacy(nd4j::transform::Exp, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets);
|
||
|
__syncthreads();
|
||
|
|
||
|
//take the sum for the exponential
|
||
|
functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
|
||
|
__syncthreads();
|
||
|
|
||
|
//divide by the sum
|
||
|
functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Divide, &maxResult, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
|
||
|
__syncthreads();
|
||
|
|
||
|
if (resultEWS >= 1) {
|
||
|
for (int i = threadIdx.x; i < length; i += blockDim.x) {
|
||
|
result[i * resultEWS] = result[i * resultEWS] * ((X) 1.0 - result[i * resultEWS]);
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
printf("Non element wise stride not supported right now\n");
|
||
|
}
|
||
|
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
|
||
|
static void execSpecial(
|
||
|
void *vx,
|
||
|
Nd4jLong *xShapeBuffer,
|
||
|
void *vresult,
|
||
|
Nd4jLong *zShapeBuffer,
|
||
|
void *vextraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
|
||
|
auto dx = reinterpret_cast<X *>(vx);
|
||
|
auto result = reinterpret_cast<X *>(vresult);
|
||
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||
|
|
||
|
if (shape::isMatrix(xShapeBuffer, 2)) {
|
||
|
auto shape = shape::shapeOf(xShapeBuffer);
|
||
|
|
||
|
auto resultEleStide = shape::elementWiseStride(zShapeBuffer);
|
||
|
|
||
|
//iterate along rows
|
||
|
int dimension[1] = { 0 };
|
||
|
int maxDimension[1] = { 1 };
|
||
|
auto len = shape::length(xShapeBuffer);
|
||
|
//compute the row wise maxes
|
||
|
auto maxResult = new X[shape[0]];
|
||
|
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int i = 0; i < shape[0]; i++)
|
||
|
maxResult[i] = 0.0f;
|
||
|
|
||
|
Nd4jLong maxShape[2] = { shape[0], 1 };
|
||
|
auto maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape);
|
||
|
functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr);
|
||
|
|
||
|
//subtract max of each row
|
||
|
functions::broadcast::Broadcast<X,X,X>::exec(nd4j::broadcast::Subtract, result, zShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr);
|
||
|
|
||
|
//after subtracting the row wise maxes take the exp
|
||
|
functions::transform::TransformStrict<X>::exec(nd4j::transform::Exp, result, zShapeBuffer, result, zShapeBuffer, extraParams, tadShapeInfo, tadOffsets);
|
||
|
|
||
|
//take the sum for the exponential
|
||
|
functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr);
|
||
|
|
||
|
//divide by the sum
|
||
|
functions::broadcast::Broadcast<X,X,X>::exec(nd4j::broadcast::Divide, result, zShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr);
|
||
|
|
||
|
if (resultEleStide >= 1) {
|
||
|
if (resultEleStide == 1) {
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int i = 0; i < len; i++) {
|
||
|
result[i] = result[i] * (static_cast<X>(1.0f) - result[i]);
|
||
|
}
|
||
|
|
||
|
}
|
||
|
else {
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int i = 0; i < len; i++) {
|
||
|
result[i * resultEleStide] = result[i * resultEleStide] * (static_cast<X>(1.0f) - result[i * resultEleStide]);
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
for (int i = 0; i < len; i++) {
|
||
|
Nd4jLong zOffset = shape::getIndexOffset(i, zShapeBuffer, len);
|
||
|
result[zOffset] = result[zOffset] * ((X) 1.0f - result[zOffset]);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
delete[] maxResultShapeBuffer;
|
||
|
delete[] maxResult;
|
||
|
}
|
||
|
else if (shape::isVector(xShapeBuffer, 2)) {
|
||
|
auto max = -nd4j::DataTypeUtils::max<X>();
|
||
|
X sum = 0;
|
||
|
|
||
|
auto elementWiseStride = shape::elementWiseStride(xShapeBuffer);
|
||
|
auto length = shape::length(xShapeBuffer);
|
||
|
if (elementWiseStride == 1) {
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
max = nd4j::math::nd4j_max<X>(max, result[i]);
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i] -= max;
|
||
|
result[i] = nd4j::math::nd4j_exp<X, X>(result[i]);
|
||
|
sum += result[i];
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i] /= sum;
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i] = result[i] * ((X) 1.0f - result[i]);
|
||
|
}
|
||
|
} else if (elementWiseStride >= 1) {
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
max = nd4j::math::nd4j_max<X>(max, result[i * elementWiseStride]);
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i * elementWiseStride] -= max;
|
||
|
result[i * elementWiseStride] = nd4j::math::nd4j_exp<X, X>(result[i * elementWiseStride]);
|
||
|
sum += result[i * elementWiseStride];
|
||
|
}
|
||
|
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i * elementWiseStride] /= sum;
|
||
|
}
|
||
|
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i * elementWiseStride] = result[i * elementWiseStride] * ((X) 1.0f - result[i * elementWiseStride]);
|
||
|
}
|
||
|
} else {
|
||
|
printf("non-ews access on row not implemented yet");
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
op_def static X op(X d1, X *params) {
|
||
|
return d1;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
|
||
|
template<typename X, typename Z>
|
||
|
class IsMax {
|
||
|
public:
|
||
|
static const bool requiresSpecial = true;
|
||
|
|
||
|
|
||
|
#ifdef __CUDACC__
|
||
|
|
||
|
static inline __device__ void doAllCuda(
|
||
|
void *vx,
|
||
|
Nd4jLong *xShapeBuffer,
|
||
|
void *vresult,
|
||
|
Nd4jLong *zShapeBuffer,
|
||
|
void *vextraParams,
|
||
|
int *allocationPointer, void *reductionPointer) {
|
||
|
|
||
|
auto dx = reinterpret_cast<X *>(vx);
|
||
|
auto result = reinterpret_cast<Z *>(vresult);
|
||
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||
|
|
||
|
// this code is safe to delete, it's never used
|
||
|
/*
|
||
|
__shared__ int maxIdx;
|
||
|
__shared__ int length;
|
||
|
if (threadIdx.x == 0) {
|
||
|
length = shape::length(zShapeBuffer);
|
||
|
}
|
||
|
__syncthreads();
|
||
|
|
||
|
functions::indexreduce::IndexReduce<T>::template transform<simdOps::IndexMax<T>>(
|
||
|
dx,
|
||
|
xShapeBuffer,
|
||
|
extraParams,
|
||
|
result,
|
||
|
zShapeBuffer,
|
||
|
nullptr,
|
||
|
1,
|
||
|
1, allocationPointer, reductionPointer, nullptr, nullptr);
|
||
|
|
||
|
__syncthreads();
|
||
|
if (threadIdx.x == 0)
|
||
|
maxIdx = (int)result[0];
|
||
|
__syncthreads();
|
||
|
|
||
|
for (int i = threadIdx.x; i < length; i += blockDim.x)
|
||
|
result[i] = 0;
|
||
|
__syncthreads();
|
||
|
|
||
|
if (threadIdx.x == 0) {
|
||
|
result[maxIdx] = 1.0;
|
||
|
}
|
||
|
*/
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
#ifdef __CUDACC__
|
||
|
inline __host__
|
||
|
|
||
|
#elif defined(__GNUC__)
|
||
|
|
||
|
|
||
|
#endif
|
||
|
static void doAll(
|
||
|
void *vx,
|
||
|
Nd4jLong *xShapeBuffer,
|
||
|
void *vresult,
|
||
|
Nd4jLong *zShapeBuffer,
|
||
|
void *vextraParams) {
|
||
|
|
||
|
auto dx = reinterpret_cast<X *>(vx);
|
||
|
auto result = reinterpret_cast<Z *>(vresult);
|
||
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||
|
|
||
|
auto length = shape::length(xShapeBuffer);
|
||
|
auto eleStride = shape::elementWiseStride(xShapeBuffer);
|
||
|
auto resultEleStride = shape::elementWiseStride(zShapeBuffer);
|
||
|
auto xOrder = shape::order(xShapeBuffer);
|
||
|
auto resultOrder = shape::order(zShapeBuffer);
|
||
|
|
||
|
if (xOrder == resultOrder && xOrder == 'c') {
|
||
|
if (eleStride == 1 && resultEleStride == 1) {
|
||
|
if (length < ELEMENT_THRESHOLD) {
|
||
|
int maxIdx = 0;
|
||
|
auto currMax = dx[0];
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
if (currMax < dx[i]) {
|
||
|
currMax = dx[i];
|
||
|
maxIdx = i;
|
||
|
}
|
||
|
|
||
|
result[i] = static_cast<Z>(0);
|
||
|
|
||
|
}
|
||
|
|
||
|
result[maxIdx] = static_cast<Z>(1);
|
||
|
|
||
|
}
|
||
|
else {
|
||
|
int maxIdx = 0;
|
||
|
auto currMax = dx[0];
|
||
|
|
||
|
|
||
|
{
|
||
|
int maxIdxLocal = maxIdx;
|
||
|
auto currMaxLocal = currMax;
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
if (currMaxLocal < dx[i]) {
|
||
|
currMaxLocal = dx[i];
|
||
|
maxIdxLocal = i;
|
||
|
}
|
||
|
result[i] = static_cast<Z>(0);
|
||
|
}
|
||
|
|
||
|
PRAGMA_OMP_CRITICAL
|
||
|
{
|
||
|
if (currMax < currMaxLocal) {
|
||
|
currMax = currMaxLocal;
|
||
|
maxIdx = maxIdxLocal;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
result[maxIdx] = static_cast<Z>(1);
|
||
|
}
|
||
|
|
||
|
}
|
||
|
else {
|
||
|
if (length < ELEMENT_THRESHOLD) {
|
||
|
int maxIdx = 0;
|
||
|
auto currMax = dx[0];
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i * resultEleStride] = static_cast<Z>(0);
|
||
|
if (currMax < dx[i * eleStride]) {
|
||
|
currMax = dx[i * eleStride];
|
||
|
maxIdx = i;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
result[maxIdx * resultEleStride] = static_cast<Z>(1);
|
||
|
|
||
|
}
|
||
|
else {
|
||
|
int maxIdx = 0;
|
||
|
auto currMax = dx[0];
|
||
|
|
||
|
|
||
|
{
|
||
|
int maxIdxLocal = maxIdx;
|
||
|
auto currMaxLocal = currMax;
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i * resultEleStride] = static_cast<Z>(0);
|
||
|
if (currMaxLocal < dx[i * eleStride]) {
|
||
|
currMaxLocal = dx[i * eleStride];
|
||
|
maxIdxLocal = i;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
PRAGMA_OMP_CRITICAL
|
||
|
{
|
||
|
if (currMax < currMaxLocal) {
|
||
|
currMax = currMaxLocal;
|
||
|
maxIdx = maxIdxLocal;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
result[maxIdx * resultEleStride] = static_cast<Z>(1);
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
else {
|
||
|
Nd4jLong shapeIter[MAX_RANK];
|
||
|
Nd4jLong coord[MAX_RANK];
|
||
|
int dim;
|
||
|
Nd4jLong xStridesIter[MAX_RANK];
|
||
|
Nd4jLong resultStridesIter[MAX_RANK];
|
||
|
auto xShape = shape::shapeOf(xShapeBuffer);
|
||
|
auto xStride = shape::stride(xShapeBuffer);
|
||
|
auto resultStride = shape::stride(zShapeBuffer);
|
||
|
auto rank = shape::rank(xShapeBuffer);
|
||
|
auto originalResult = result;
|
||
|
if (PrepareTwoRawArrayIter<X, Z>(rank,
|
||
|
xShape,
|
||
|
dx,
|
||
|
xStride,
|
||
|
result,
|
||
|
resultStride,
|
||
|
&rank,
|
||
|
shapeIter,
|
||
|
&dx,
|
||
|
xStridesIter,
|
||
|
&result,
|
||
|
resultStridesIter) >= 0) {
|
||
|
auto value = dx[0];
|
||
|
int idx = 0;
|
||
|
int maxIdx = 0;
|
||
|
ND4J_RAW_ITER_START(dim, rank, coord, shapeIter); {
|
||
|
if (dx[0] > value) {
|
||
|
value = dx[0];
|
||
|
maxIdx = idx;
|
||
|
}
|
||
|
|
||
|
idx++;
|
||
|
result[0] = static_cast<Z>(0);
|
||
|
|
||
|
}
|
||
|
ND4J_RAW_ITER_TWO_NEXT(
|
||
|
dim,
|
||
|
rank,
|
||
|
coord,
|
||
|
shapeIter,
|
||
|
dx,
|
||
|
xStridesIter,
|
||
|
result,
|
||
|
resultStridesIter);
|
||
|
|
||
|
//pointer to where max value would be
|
||
|
if (shape::order(zShapeBuffer) == 'c' || (shape::order(zShapeBuffer) == 'f' &&
|
||
|
maxIdx * shape::stride(zShapeBuffer)[shape::rank(zShapeBuffer) - 1] >=
|
||
|
shape::length(zShapeBuffer)))
|
||
|
originalResult[maxIdx] = static_cast<Z>(1);
|
||
|
else
|
||
|
originalResult[maxIdx * shape::stride(zShapeBuffer)[shape::rank(zShapeBuffer) - 1]] = static_cast<Z>(1);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
}
|
||
|
public:
|
||
|
|
||
|
|
||
|
#ifdef __CUDACC__
|
||
|
/**
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
static inline __device__ void execSpecialCuda(
|
||
|
void *vx, Nd4jLong *xShapeBuffer,
|
||
|
void *vresult, Nd4jLong *zShapeBuffer,
|
||
|
void *vextraParams, int *allocationPointer,
|
||
|
void *reductionPointer,
|
||
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||
|
|
||
|
auto dx = reinterpret_cast<X *>(vx);
|
||
|
auto result = reinterpret_cast<Z *>(vresult);
|
||
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||
|
|
||
|
// FIXME: MAX_DIMENSION is lower then FP16 frame
|
||
|
if (extraParams == nullptr || (int) extraParams[0] == MAX_DIMENSION) {
|
||
|
doAllCuda(dx, xShapeBuffer, result, zShapeBuffer, extraParams, allocationPointer, reductionPointer);
|
||
|
}
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
static void execSpecial(
|
||
|
void *vx,
|
||
|
Nd4jLong *xShapeBuffer,
|
||
|
void *vresult,
|
||
|
Nd4jLong *zShapeBuffer,
|
||
|
void *vextraParams,
|
||
|
Nd4jLong *tadShapeInfo,
|
||
|
Nd4jLong *tadOffsets) {
|
||
|
|
||
|
auto dx = reinterpret_cast<X *>(vx);
|
||
|
auto result = reinterpret_cast<Z *>(vresult);
|
||
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||
|
|
||
|
//FIXME: this op should be moved to CustomOps
|
||
|
if (extraParams == nullptr || (int)extraParams[0] == 0 ||
|
||
|
((int)extraParams[0] == 1 && (int)extraParams[1] == MAX_DIMENSION)) {
|
||
|
doAll(dx, xShapeBuffer, result, zShapeBuffer, extraParams);
|
||
|
}
|
||
|
else if (shape::isVector(xShapeBuffer)) {
|
||
|
auto dimensionLength = (int)extraParams[0];
|
||
|
auto dimension = new int[dimensionLength];
|
||
|
auto length = shape::length(xShapeBuffer);
|
||
|
for (int i = 0; i < dimensionLength; i++) {
|
||
|
dimension[i] = (int)extraParams[i + 1];
|
||
|
}
|
||
|
if (shape::shapeOf(xShapeBuffer)[dimension[0]] == 1) {
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
result[i] = static_cast<Z>(1);
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
auto eleStride = shape::elementWiseStride(xShapeBuffer);
|
||
|
if (eleStride == 1) {
|
||
|
int maxIdx = 0;
|
||
|
auto currMax = dx[0];
|
||
|
if (length < ELEMENT_THRESHOLD) {
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
if (currMax < dx[i]) {
|
||
|
currMax = dx[i];
|
||
|
maxIdx = i;
|
||
|
}
|
||
|
|
||
|
result[i] = static_cast<Z>(0);
|
||
|
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
PRAGMA_OMP_PARALLEL
|
||
|
{
|
||
|
int maxIdxLocal = maxIdx;
|
||
|
auto currMaxLocal = currMax;
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
if (currMaxLocal < dx[i]) {
|
||
|
currMaxLocal = dx[i];
|
||
|
maxIdxLocal = i;
|
||
|
}
|
||
|
|
||
|
result[i] = static_cast<Z>(0);
|
||
|
|
||
|
}
|
||
|
|
||
|
PRAGMA_OMP_CRITICAL
|
||
|
{
|
||
|
if (currMax < currMaxLocal) {
|
||
|
currMax = currMaxLocal;
|
||
|
maxIdx = maxIdxLocal;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
result[maxIdx] = static_cast<Z>(1);
|
||
|
|
||
|
}
|
||
|
|
||
|
|
||
|
else {
|
||
|
int maxIdx = 0;
|
||
|
auto currMax = dx[0];
|
||
|
if (length < ELEMENT_THRESHOLD) {
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
if (currMax < dx[i * eleStride]) {
|
||
|
currMax = dx[i * eleStride];
|
||
|
maxIdx = i;
|
||
|
}
|
||
|
|
||
|
result[i] = static_cast<Z>(0);
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
{
|
||
|
int maxIdxLocal = maxIdx;
|
||
|
auto currMaxLocal = currMax;
|
||
|
|
||
|
for (int i = 0; i < length; i++) {
|
||
|
if (currMaxLocal < dx[i * eleStride]) {
|
||
|
currMaxLocal = dx[i * eleStride];
|
||
|
maxIdxLocal = i;
|
||
|
}
|
||
|
|
||
|
result[i] = static_cast<Z>(0);
|
||
|
}
|
||
|
|
||
|
PRAGMA_OMP_CRITICAL
|
||
|
{
|
||
|
if (currMax < currMaxLocal) {
|
||
|
currMax = currMaxLocal;
|
||
|
maxIdx = maxIdxLocal;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
result[maxIdx] = static_cast<Z>(1);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
}
|
||
|
else {
|
||
|
auto dimensionLength = (int) extraParams[0];
|
||
|
auto dimension = new int[dimensionLength];
|
||
|
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int i = 0; i < dimensionLength; i++) {
|
||
|
dimension[i] = (int) extraParams[i + 1];
|
||
|
}
|
||
|
//decompose in to several sub tads after
|
||
|
//moving all dimensions (in sorted order)
|
||
|
//to the back.
|
||
|
//permuted version of the x shape info for setting up the tad problem
|
||
|
auto tadShapeShapeInfo = tadShapeInfo;
|
||
|
if(tadShapeInfo==nullptr) {
|
||
|
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeBuffer, dimension, dimensionLength);
|
||
|
|
||
|
tadShapeShapeInfo = tadPack.primaryShapeInfo();
|
||
|
tadOffsets = tadPack.primaryOffsets();
|
||
|
tadShapeInfo = tadShapeShapeInfo;
|
||
|
}
|
||
|
|
||
|
auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeBuffer, dimension, dimensionLength);
|
||
|
auto tads = shape::length(xShapeBuffer) / tadLength;
|
||
|
|
||
|
int tadsPerThread = tads / TAD_THRESHOLD;
|
||
|
int num_threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
||
|
num_threads = nd4j::math::nd4j_min<int>(num_threads, omp_get_max_threads());
|
||
|
|
||
|
auto tadEWS = shape::elementWiseStride(tadShapeShapeInfo);
|
||
|
auto zEWS = tadEWS;
|
||
|
|
||
|
int span = (tads / num_threads) + 8;
|
||
|
|
||
|
PRAGMA_OMP_PARALLEL_THREADS(num_threads)
|
||
|
{
|
||
|
int tid = omp_get_thread_num();
|
||
|
int start = span * tid;
|
||
|
int end = span * (tid + 1);
|
||
|
if (end > tads) end = tads;
|
||
|
|
||
|
for (int r = start; r < end; r++) {
|
||
|
if (tadEWS > 0 && zEWS > 0 && dimensionLength == 1) {
|
||
|
auto rX = dx + tadOffsets[r];
|
||
|
auto rZ = result + tadOffsets[r];
|
||
|
|
||
|
auto maxValue = rX[0];
|
||
|
int maxIdx = 0;
|
||
|
if (tadEWS == 1 && zEWS == 1) {
|
||
|
|
||
|
for (int i = 0; i < tadLength; i++) {
|
||
|
if (rX[i] > maxValue) {
|
||
|
maxIdx = i;
|
||
|
maxValue = rX[i];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
for (int i = 0; i < tadLength; i++) {
|
||
|
rZ[i] = static_cast<Z>(maxIdx == i);
|
||
|
}
|
||
|
|
||
|
} else {
|
||
|
|
||
|
for (int i = 0; i < tadLength; i++) {
|
||
|
if (rX[i * tadEWS] > maxValue) {
|
||
|
maxIdx = i;
|
||
|
maxValue = rX[i * tadEWS];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < tadLength; i++) {
|
||
|
rZ[i * zEWS] = static_cast<Z>(maxIdx == i);
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
int tadsPerThread = tads / TAD_THRESHOLD;
|
||
|
int num_threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
||
|
num_threads = nd4j::math::nd4j_min<int>(num_threads, omp_get_max_threads());
|
||
|
|
||
|
auto offset = tadOffsets[r];
|
||
|
Nd4jLong shapeIter[MAX_RANK];
|
||
|
Nd4jLong coord[MAX_RANK];
|
||
|
int dim;
|
||
|
Nd4jLong xStridesIter[MAX_RANK];
|
||
|
Nd4jLong resultStridesIter[MAX_RANK];
|
||
|
auto xShape = shape::shapeOf(tadShapeShapeInfo);
|
||
|
auto xStride = shape::stride(tadShapeShapeInfo);
|
||
|
auto resultStride = shape::stride(tadShapeShapeInfo);
|
||
|
int rank = shape::rank(tadShapeShapeInfo);
|
||
|
auto xPointer = dx + offset;
|
||
|
auto resultPointer = result + offset;
|
||
|
auto maxValue = xPointer[0];
|
||
|
|
||
|
auto maxCursor = resultPointer;
|
||
|
Nd4jPointer maxCursorLong = reinterpret_cast<Nd4jPointer>(maxCursor);
|
||
|
if (PrepareTwoRawArrayIter<X,Z>(rank,
|
||
|
xShape,
|
||
|
xPointer,
|
||
|
xStride,
|
||
|
resultPointer,
|
||
|
resultStride,
|
||
|
&rank,
|
||
|
shapeIter,
|
||
|
&xPointer,
|
||
|
xStridesIter,
|
||
|
&resultPointer,
|
||
|
resultStridesIter) >= 0) {
|
||
|
ND4J_RAW_ITER_START(dim, rank, coord, shapeIter); {
|
||
|
if (maxValue < xPointer[0]) {
|
||
|
maxCursor = resultPointer;
|
||
|
maxCursorLong = reinterpret_cast<Nd4jPointer>(resultPointer);
|
||
|
maxValue = xPointer[0];
|
||
|
}
|
||
|
|
||
|
resultPointer[0] = static_cast<Z>(0);
|
||
|
}
|
||
|
ND4J_RAW_ITER_TWO_NEXT(dim,
|
||
|
rank,
|
||
|
coord,
|
||
|
shapeIter,
|
||
|
xPointer,
|
||
|
xStridesIter,
|
||
|
resultPointer,
|
||
|
resultStridesIter);
|
||
|
maxCursor = reinterpret_cast<Z *>(maxCursorLong);
|
||
|
maxCursor[0] = static_cast<Z>(1);;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
delete[] dimension;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
op_def static Z op(X d1, X *params) {
|
||
|
return nd4j::math::softplus<X,Z>(d1);
|
||
|
}
|
||
|
};
|
||
|
}
|