/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ #pragma once #include #include #include #include #include #include #include #include #include #include #ifdef __CUDACC__ #include #include #include #endif namespace functions { namespace broadcast { template class Broadcast; } namespace transform { template class TransformStrict; } namespace scalar { } namespace reduce { template class ReduceFloatFunction; template class ReduceSameFunction; } } namespace simdOps { template class Pooling2D { public: static const bool requiresSpecial = true; #ifdef __CUDACC__ inline __host__ __device__ #elif defined(__GNUC__) #endif static int outSize(int size, int k, int s, int p, bool coverAll) { if (coverAll) return (size + p * 2 - k + s - 1) / s + 1; else return (size + p * 2 - k) / s + 1; } #ifdef __CUDACC__ /** * Based on: https://github.com/pjreddie/darknet/blob/master/src/im2col_kernels.cu */ static inline __device__ void execSpecialCuda( T *dx, Nd4jLong *xShapeBuffer, Z *result, Nd4jLong *zShapeBuffer, Z *extraParams, int *allocationPointer, Z *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { __shared__ int kH; __shared__ int kW; __shared__ int sH; __shared__ int sW; __shared__ int pH; __shared__ int pW; __shared__ int dH; __shared__ int dW; __shared__ int poolingMode; __shared__ Z extraParam0; __shared__ int batchSize; __shared__ int inChannels; __shared__ int outH; __shared__ int outW; __shared__ int inH; __shared__ int inW; //__shared__ int *strideIn; //__shared__ int *strideOut; __shared__ int strideB; __shared__ int strideC; __shared__ int strideY; __shared__ int strideX; __shared__ int strideOB; __shared__ int strideOC; __shared__ int strideOY; __shared__ int strideOX; __shared__ int length; __shared__ int kHEff; __shared__ int kWEff; __shared__ bool fOrder; if (threadIdx.x == 0) { kH = (int)extraParams[0]; kW = (int)extraParams[1]; sH = (int)extraParams[2]; sW = (int)extraParams[3]; pH = (int)extraParams[4]; pW = (int)extraParams[5]; dH = (int)extraParams[6]; //Dilation, height dimension dW = (int)extraParams[7]; //Dilation, width dimension poolingMode = (int)extraParams[9]; extraParam0 = extraParams[10]; batchSize = shape::sizeAt(xShapeBuffer, 0); inChannels = shape::sizeAt(xShapeBuffer, 1); outH = shape::sizeAt(zShapeBuffer, 2); outW = shape::sizeAt(zShapeBuffer, 3); inH = shape::sizeAt(xShapeBuffer, 2); inW = shape::sizeAt(xShapeBuffer, 3); strideB = shape::stride(xShapeBuffer)[0]; strideC = shape::stride(xShapeBuffer)[1]; strideY = shape::stride(xShapeBuffer)[2]; strideX = shape::stride(xShapeBuffer)[3]; strideOB = shape::stride(zShapeBuffer)[0]; strideOC = shape::stride(zShapeBuffer)[1]; strideOY = shape::stride(zShapeBuffer)[2]; strideOX = shape::stride(zShapeBuffer)[3]; length = shape::length(zShapeBuffer); //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon kHEff = kH + (kH-1)*(dH-1); kWEff = kW + (kW-1)*(dW-1); fOrder = shape::order(zShapeBuffer) == 'f'; /* if (blockIdx.x == 0) { printf("kH: %i; kW: %i; sH: %i; sW: %i; pH: %i; pW: %i; dH: %i; dW: %i; poolingMode: %i; extraParam0: %f;\n", kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, (float) extraParam0); printf("batchSize: %i; inChannels: %i; outH: %i; outW: %i; inH: %i; inW: %i; strideB: %i; strideC: %i; strideY: %i; strideX: %i;\n", batchSize, inChannels, outH, outW, inH, inW, strideB, strideC, strideY, strideX); } */ } __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; for (int index = tid; index < length; index += blockDim.x * gridDim.x) { const int pw = index % outW; const int ph = (index / outW) % outH; const int c = (index / outW / outH) % inChannels; const int n = index / outW / outH / inChannels; int hstart = sH * ph - pH; int wstart = sW * pw - pW; int hend = hstart + kHEff; int wend = wstart + kWEff; // const int hSO = hstart; // const int hEO = hend; if(hstart < 0){ int f = nd4j::math::nd4j_ceil((Z) -hstart / (Z)dH); hstart += f * dH; } if(wstart < 0){ int f = nd4j::math::nd4j_ceil((Z) -wstart / (Z) dW); wstart += f * dW; } if(hend > inH){ int f = nd4j::math::nd4j_ceil((Z) (hend-inH) / (Z) dH); hend -= f * dH; } if(wend > inW){ int f = nd4j::math::nd4j_ceil((Z) (wend-inW) / (Z) dW); wend -= f * dW; } //Accounts for dilation int pool_size = nd4j::math::nd4j_ceil((double) (hend-hstart) / (double) dH) * nd4j::math::nd4j_ceil((double) (wend-wstart) / (double) dW); Z sum = poolingMode == 0 ? -nd4j::DataTypeUtils::max() : static_cast(0.f); T *input_slice = dx + (n * strideB + c * strideC); if (poolingMode == 0) { for (int h = hstart; h < hend; h += dH) { for (int w = wstart; w < wend; w += dW) { Z v = static_cast(input_slice[h * strideY + w * strideX]); if (v > sum) sum = v; } } } else if (poolingMode == 1) { for (int h = hstart; h < hend; h += dH) { for (int w = wstart; w < wend; w += dW) { sum += static_cast(input_slice[h * strideY + w * strideX]); } } } else if (poolingMode == 2) { for (int h = hstart; h < hend; h += dH) { for (int w = wstart; w < wend; w += dW) { sum += nd4j::math::nd4j_pow(static_cast(nd4j::math::nd4j_abs(input_slice[h * strideY + w * strideX])), extraParam0); } } } Z res; if (poolingMode == 0) { res = sum; } else if (poolingMode == 1) { int divide_factor = pool_size; //Case 0: exclude padding if ((int) extraParam0 == 1) //Case 1: include padding divide_factor = kH * kW; res = sum / static_cast(divide_factor); } else if (poolingMode == 2) { res = nd4j::math::nd4j_pow(sum, (Z) 1.0f / extraParam0); } if (!fOrder) { result[index] = res; } else { result[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = res; } /* if (index >= 0 && index < 400000) { printf("index: %i; hstart: %i; hend: %i; wstart: %i; wend: %i; ph: %i; pw: %i; hstart_orig: %i; hend_orig: %i;\n", index, hstart, hend, wstart, wend, ph, pw, hSO, hEO); } */ } __syncthreads(); } #endif static void execSpecial(T *in, Nd4jLong *inShapeBuffer, Z *out, Nd4jLong *outShapeBuffer, Z *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { // input is [bS, iC, iH, iW] // output is [bS, iC, oH, oW] const Nd4jLong kH = (int)extraParams[0]; const Nd4jLong kW = (int)extraParams[1]; const Nd4jLong sH = (int)extraParams[2]; const Nd4jLong sW = (int)extraParams[3]; const Nd4jLong pH = (int)extraParams[4]; const Nd4jLong pW = (int)extraParams[5]; const Nd4jLong dH = (int)extraParams[6]; const Nd4jLong dW = (int)extraParams[7]; Nd4jLong poolingMode = (int)extraParams[9]; T extraParam0 = extraParams[10]; if(dH == 0 || dW == 0) { printf("Special_ops pooling2d:: dilation must not be zero, but got instead {%lld, %lld} \n", dH, dW); throw ""; } const Nd4jLong kHEff = kH + (kH-1)*(dH-1); const Nd4jLong kWEff = kW + (kW-1)*(dW-1); const int bS = shape::sizeAt(inShapeBuffer, 0); const int iC = shape::sizeAt(inShapeBuffer, 1); const int iH = shape::sizeAt(inShapeBuffer, 2); const int iW = shape::sizeAt(inShapeBuffer, 3); const int oH = shape::sizeAt(outShapeBuffer, 2); const int oW = shape::sizeAt(outShapeBuffer, 3); const Nd4jLong iStride0 = shape::stride(inShapeBuffer)[0]; const Nd4jLong iStride1 = shape::stride(inShapeBuffer)[1]; const Nd4jLong iStride2 = shape::stride(inShapeBuffer)[2]; const Nd4jLong iStride3 = shape::stride(inShapeBuffer)[3]; const Nd4jLong oStride0 = shape::stride(outShapeBuffer)[0]; const Nd4jLong oStride1 = shape::stride(outShapeBuffer)[1]; const Nd4jLong oStride2 = shape::stride(outShapeBuffer)[2]; const Nd4jLong oStride3 = shape::stride(outShapeBuffer)[3]; const Nd4jLong iStep2 = dH*iStride2; const Nd4jLong iStep3 = dW*iStride3; const int kProd = kH*kW; const T iStep2Inv = 1./iStep2; const T iStep3Inv = 1./iStep3; Nd4jLong hstart, wstart, hend, wend; T sum, *pIn; if(poolingMode == 0) { // max PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, hstart, wstart, hend, wend) collapse(2)) for(int b = 0; b < bS; ++b) { for(int c = 0; c < iC; ++c) { for(int oh = 0; oh < oH; ++oh) { for(int ow = 0; ow < oW; ++ow) { pIn = in + b * iStride0 + c * iStride1; hstart = oh * sH - pH; wstart = ow * sW - pW; hend = hstart + kHEff; wend = wstart + kWEff; if(hstart < 0) hstart += dH * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); if(wstart < 0) wstart += dW * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); if(hend > iH) hend -= dH * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); if(wend > iW) wend -= dW * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); hstart *= iStride2; hend *= iStride2; wstart *= iStride3; wend *= iStride3; sum = -nd4j::DataTypeUtils::max(); for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { T val = pIn[kh + kw]; if (val > sum) sum = val; } out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; } } } } } /*************************************************************************/ else if(poolingMode == 1) { // avg PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, hstart, wstart, hend, wend) collapse(2)) for(int b = 0; b < bS; ++b) { for(int c = 0; c < iC; ++c) { for(int oh = 0; oh < oH; ++oh) { for(int ow = 0; ow < oW; ++ow) { pIn = in + b * iStride0 + c * iStride1; hstart = oh * sH - pH; wstart = ow * sW - pW; hend = hstart + kHEff; wend = wstart + kWEff; if(hstart < 0) hstart += dH * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); if(wstart < 0) wstart += dW * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); if(hend > iH) hend -= dH * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); if(wend > iW) wend -= dW * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); hstart *= iStride2; hend *= iStride2; wstart *= iStride3; wend *= iStride3; sum = static_cast(0.); for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) sum += pIn[kh + kw]; if ((int) extraParam0 == 0) //Exclude padding sum /= static_cast(nd4j::math::nd4j_ceil(static_cast(hend-hstart) / static_cast(iStep2))) * static_cast(nd4j::math::nd4j_ceil(static_cast(wend-wstart) / static_cast(iStep3))); //Accounts for dilation else if ((int) extraParam0 == 1) //Include padding sum /= kProd; out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; } } } } } /*************************************************************************/ else if(poolingMode == 2) { // pnorm PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, hstart, wstart, hend, wend) collapse(2)) for(int b = 0; b < bS; ++b) { for(int c = 0; c < iC; ++c) { for(int oh = 0; oh < oH; ++oh) { for(int ow = 0; ow < oW; ++ow) { pIn = in + b * iStride0 + c * iStride1; hstart = oh * sH - pH; wstart = ow * sW - pW; hend = hstart + kHEff; wend = wstart + kWEff; if(hstart < 0) hstart += dH * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); if(wstart < 0) wstart += dW * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); if(hend > iH) hend -= dH * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); if(wend > iW) wend -= dW * (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); hstart *= iStride2; hend *= iStride2; wstart *= iStride3; wend *= iStride3; sum = static_cast(0.); for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) sum += nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kh + kw]), extraParam0); sum = nd4j::math::nd4j_pow(sum, (T) 1. / extraParam0); out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; } } } } } else { nd4j_printf("Special_ops::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); throw ""; } } op_def static T op(T d1, Z *params) { return d1; } /** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc * normally negative indices are bad, OK here because of other checks on input indices * Uses unrolled loop specifically for length 4 */ static _CUDA_HD int getOffsetUnsafe4(int baseOffset, int *shape, int *stride, int *indices) { int offset = baseOffset; if (shape[0] != 1) offset += indices[0] * stride[0]; if (shape[1] != 1) offset += indices[1] * stride[1]; if (shape[2] != 1) offset += indices[2] * stride[2]; if (shape[3] != 1) offset += indices[3] * stride[3]; return offset; } /** * A version of Shape.getOffset without checking on input for negative indices etc * normally negative indices are bad, OK here because of other checks on input indices * Uses unrolled loop specifically for length 6, where indices[2] and indices[3] are zero (always are here) */ static _CUDA_HD int getOffsetUnsafe6(int baseOffset, int *shape, int *stride, int *indices) { int offset = baseOffset; if (shape[0] != 1) offset += indices[0] * stride[0]; if (shape[1] != 1) offset += indices[1] * stride[1]; if (shape[4] != 1) offset += indices[4] * stride[4]; if (shape[5] != 1) offset += indices[5] * stride[5]; return offset; } }; FORCEINLINE bool is_a_ge_zero_and_a_lt_b(int a, int b) { return static_cast(a) < static_cast(b); } template class Im2col { public: static const bool requiresSpecial = true; static _CUDA_HD int outSize(int size, int k, int s, int p, bool coverAll) { if (coverAll) return (size + p * 2 - k + s - 1) / s + 1; else return (size + p * 2 - k) / s + 1; } #ifdef __CUDACC__ /** * Based on: https://github.com/pjreddie/darknet/blob/master/src/im2col_kernels.cu */ static inline __device__ void execSpecialCuda( T *dx, Nd4jLong *xShapeBuffer, T *result, Nd4jLong *zShapeBuffer, T *extraParams, int *allocationPointer, T *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { /*kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], 0, false*/ __shared__ int kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, dY, dX, kSize, samples, depth, height, width, strideex, stridech, strideh, stridew, height_col, width_col, n; __shared__ T zeroPadVal; __shared__ Nd4jLong *outShape, *outStride, *inShape, *inStride; __shared__ char resultOrder; if (threadIdx.x == 0) { kernelHeight = (int) extraParams[0]; kernelWidth = (int) extraParams[1]; strideY = (int) extraParams[2]; strideX = (int) extraParams[3]; padHeight = (int) extraParams[4]; padWidth = (int) extraParams[5]; dY = (int) extraParams[6]; //Dilation, height/y dimension dX = (int) extraParams[7]; //Dilation, width/x dimension kSize = kernelWidth * kernelHeight; zeroPadVal = (T) extraParams[9]; //Value to use when value is padding. Usually 0 but not always outShape = shape::shapeOf(zShapeBuffer); resultOrder = shape::order(zShapeBuffer); outStride = shape::stride(zShapeBuffer); inShape = shape::shapeOf(xShapeBuffer); inStride = shape::stride(xShapeBuffer); samples = (int) inShape[0]; depth = (int) inShape[1]; height = (int) inShape[2]; width = (int) inShape[3]; strideex = (int) inStride[0]; stridech = (int) inStride[1]; strideh = (int) inStride[2]; stridew = (int) inStride[3]; // (height + 2 * padHeight - kernelHeight) / strideX + 1; // // (width + 2 * padWidth - kernelWidth) / strideY + 1; // height_col = (int) outShape[4]; width_col = (int) outShape[5]; n = samples * depth * height_col * width_col; } __syncthreads(); int index = blockIdx.x * blockDim.x + threadIdx.x; for (; index < n; index += blockDim.x*gridDim.x) { int h_index = index / width_col; int h_col = h_index % height_col; int w_col = index % width_col; int c_im = h_index / height_col; int c_col = c_im * kSize; int depth_im = c_im % depth; int num_im = c_im / depth; int h_offset = h_col * strideY - padHeight; int w_offset = w_col * strideX - padWidth; T* data_col_ptr = result; int i_c = (c_col * height_col + h_col) * width_col + w_col; data_col_ptr += (c_col * height_col + h_col) * width_col + w_col; T* data_im_ptr = dx; data_im_ptr += num_im * strideex + depth_im * stridech + h_offset * strideh + w_offset*stridew; for (int i = 0; i < kernelHeight; ++i) { for (int j = 0; j < kernelWidth; ++j) { int h_im = h_offset + i * dY; int w_im = w_offset + j * dX; int i_f = 0; int i_c_temp = i_c; for (int dim = 5; dim >= 0; dim--) { i_f += (i_c_temp % outShape[dim]) * outStride[dim]; i_c_temp = i_c_temp / outShape[dim]; } if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width){ result[i_f] = data_im_ptr[i * dY * strideh + j * dX * stridew]; } else result[i_f] = zeroPadVal; //result[i_f] = (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ? data_im_ptr[i * strideh + j*stridew] : 0; data_col_ptr += height_col * width_col; i_c += height_col * width_col; } } } } #endif static void execSpecial( T *imBuff, Nd4jLong *imShapeBuffer, T *colBuff, Nd4jLong *colShapeBuffer, T *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { /*kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], 0, false*/ // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] int kH = (int)extraParams[0]; int kW = (int)extraParams[1]; int sH = (int)extraParams[2]; int sW = (int)extraParams[3]; int pH = (int)extraParams[4]; int pW = (int)extraParams[5]; int dH = (int)extraParams[6]; //Dilation, height/y dimension int dW = (int)extraParams[7]; //Dilation, width/x dimension T zeroPadVal = extraParams[9]; auto colShape = shape::shapeOf(colShapeBuffer); auto colStride = shape::stride(colShapeBuffer); auto imShape = shape::shapeOf(imShapeBuffer); auto imStride = shape::stride(imShapeBuffer); const int bS = imShape[0]; const int iC = imShape[1]; const int iH = imShape[2]; const int iW = imShape[3]; const int oH = colShape[4]; const int oW = colShape[5]; const Nd4jLong colStride0 = colStride[0]; const Nd4jLong colStride1 = colStride[1]; const Nd4jLong colStride2 = colStride[2]; const Nd4jLong colStride3 = colStride[3]; const Nd4jLong colStride4 = colStride[4]; const Nd4jLong colStride5 = colStride[5]; const Nd4jLong imStride0 = imStride[0]; const Nd4jLong imStride1 = imStride[1]; const Nd4jLong imStride2 = imStride[2]; const Nd4jLong imStride3 = imStride[3]; T *col, *im; int imRow, imCol; if (shape::order(imShapeBuffer) == 'c' && shape::order(colShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(imShapeBuffer) && shape::strideDescendingCAscendingF(colShapeBuffer)) { PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, im, imRow, imCol) collapse(2)) for (int b = 0; b < bS; b++) { for (int c = 0; c < iC; ++c) { for (int kRow = 0; kRow < kH; ++kRow) { for (int kCol = 0; kCol < kW; ++kCol) { for (int colH = 0; colH < oH; ++colH) { for (int colW = 0; colW < oW; ++colW) { imRow = (-pH + kRow * dH) + colH*sH; imCol = (-pW + kCol * dW) + colW*sW; col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5; im = imBuff + b*imStride0 + c*imStride1 + imRow*imStride2 + imCol*imStride3; if (static_cast(imRow) >= static_cast(iH) || static_cast(imCol) >= static_cast(iW)) *col = zeroPadVal; else *col = *im; } } } } } } } else { PRAGMA_OMP_PARALLEL_FOR_ARGS(private(im, col, imRow, imCol) collapse(2)) for (int b = 0; b < bS; b++) { for (int colH = 0; colH < oH; ++colH) { for (int colW = 0; colW < oW; ++colW) { for (int c = 0; c < iC; ++c) { for (int kRow = 0; kRow < kH; ++kRow) { for (int kCol = 0; kCol < kW; ++kCol) { imRow = (-pH + kRow * dH) + colH*sH; imCol = (-pW + kCol * dW) + colW*sW; col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5; im = imBuff + b*imStride0 + c*imStride1 + imRow*imStride2 + imCol*imStride3; if (static_cast(imRow) >= static_cast(iH) || static_cast(imCol) >= static_cast(iW)) *col = zeroPadVal; else *col = *im; } } } } } } } } op_def static T op(T d1, T *params) { return d1; } /** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc * normally negative indices are bad, OK here because of other checks on input indices * Uses unrolled loop specifically for length 4 */ static _CUDA_HD int getOffsetUnsafe4(int baseOffset, int *shape, int *stride, int *indices) { int offset = baseOffset; if (shape[0] != 1) offset += indices[0] * stride[0]; if (shape[1] != 1) offset += indices[1] * stride[1]; if (shape[2] != 1) offset += indices[2] * stride[2]; if (shape[3] != 1) offset += indices[3] * stride[3]; return offset; } /** * A version of Shape.getOffset without checking on input for negative indices etc * normally negative indices are bad, OK here because of other checks on input indices * Uses unrolled loop specifically for length 6, where indices[2] and indices[3] are zero (always are here) */ static _CUDA_HD int getOffsetUnsafe6(int baseOffset, int *shape, int *stride, int *indices) { int offset = baseOffset; if (shape[0] != 1) offset += indices[0] * stride[0]; if (shape[1] != 1) offset += indices[1] * stride[1]; if (shape[4] != 1) offset += indices[4] * stride[4]; if (shape[5] != 1) offset += indices[5] * stride[5]; return offset; } }; template class Histogram { public: static const bool requiresSpecial = true; #ifdef __CUDACC__ static inline __device__ void execSpecialCuda( T *dx, Nd4jLong *xShapeBuffer, Z *result, Nd4jLong *zShapeBuffer, Z *extraParams, int *allocationPointer, Z *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { int numBins = (int) extraParams[0]; Z min_val = extraParams[1]; Z max_val = extraParams[2]; int tid = blockIdx.x * blockDim.x + threadIdx.x; __shared__ Z *bins; __shared__ int length; __shared__ Z *reductor; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; bins = (Z *) shmem; reductor = ((Z *) allocationPointer) + (numBins * blockIdx.x); length = shape::length(xShapeBuffer); } __syncthreads(); Z binSize = (max_val - min_val) / (numBins); for (int e = threadIdx.x; e < numBins; e += blockDim.x) { bins[e] = (Z) 0.0f; } __syncthreads(); for (int e = tid; e < length; e+= blockDim.x * gridDim.x) { int idx = (int) ((dx[e] - min_val) / binSize); if (idx < 0) idx = 0; else if (idx >= numBins) idx = numBins - 1; nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f); } __syncthreads(); // transfer shared memory to reduction memory if (gridDim.x > 1) { unsigned int *tc = (unsigned int *)reductionPointer; __shared__ bool amLast; for (int e = threadIdx.x; e < numBins; e += blockDim.x) { reductor[e] = bins[e]; } __threadfence(); __syncthreads(); if (threadIdx.x == 0) { unsigned int ticket = atomicInc(&tc[16384], gridDim.x); amLast = (ticket == gridDim.x - 1); } __syncthreads(); if (amLast) { tc[16384] = 0; // nullify shared memory for future accumulation for (int e = threadIdx.x; e < numBins; e += blockDim.x) { bins[e] = (Z) 0.0f; } // accumulate reduced bins for (int r = 0; r < gridDim.x; r++) { Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins); for (int e = threadIdx.x; e < numBins; e += blockDim.x) { bins[e] += ptrBuf[e]; } } __syncthreads(); // write them out to Z for (int e = threadIdx.x; e < numBins; e += blockDim.x) { result[e] = bins[e]; } } } else { // if there's only 1 block - just write away data for (int e = threadIdx.x; e < numBins; e += blockDim.x) { result[e] = bins[e]; } } }; #endif static void execSpecial( T *dx, Nd4jLong *xShapeBuffer, Z *result, Nd4jLong *zShapeBuffer, Z *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { int length = shape::length(xShapeBuffer); int _threads = 2; int numBins = (int) extraParams[0]; int span = (length / _threads) + 8; // get min over input T min_val = extraParams[1]; T max_val = extraParams[2]; T binSize = (max_val - min_val) / (numBins); PRAGMA_OMP_PARALLEL_THREADS(_threads) { int tid, start, end; int *bins = new int[numBins]; std::memset(bins, 0, sizeof(int) * numBins); tid = omp_get_thread_num(); start = span * tid; end = span * (tid + 1); if (end > length) end = length; PRAGMA_OMP_SIMD for (int x = start; x < end; x++) { int idx = (int) ((dx[x] - min_val) / binSize); if (idx < 0) idx = 0; else if (idx >= numBins) idx = numBins - 1; bins[idx]++; } PRAGMA_OMP_CRITICAL { PRAGMA_OMP_SIMD for (int x = 0; x < numBins; x++) { result[x] += bins[x]; } } delete[] bins; } } op_def static T op(T d1, Z *params) { return d1; } }; template class Col2Im { public: static const bool requiresSpecial = true; #ifdef __CUDACC__ /** * https://github.com/pjreddie/darknet/blob/master/src/col2im_kernels.cu */ static inline __device__ void execSpecialCuda( X *dx, Nd4jLong *xShapeBuffer, X *result, Nd4jLong *zShapeBuffer, X *extraParams, int *allocationPointer, X *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { __shared__ int strideex, stridech, stridekrow, stridekcol, striderow, stridecol, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, imgHeight, imgWidth, dY, dX, samples, depth, imgH, imgW, height_col, width_col, n, kEffectiveW, kEffectiveH; __shared__ Nd4jLong *inShape, *inStride, *outShape, *outStride; __shared__ char resultOrder; if (threadIdx.x == 0) { inShape = shape::shapeOf(xShapeBuffer); inStride = shape::stride(xShapeBuffer); strideex = (int) inStride[0]; stridech = (int) inStride[1]; stridekrow = (int) inStride[2]; stridekcol = (int) inStride[3]; striderow = (int) inStride[4]; stridecol = (int) inStride[5]; kernelHeight = (int) inShape[2]; kernelWidth = (int) inShape[3]; strideY = (int) extraParams[0]; strideX = (int) extraParams[1]; padHeight = (int) extraParams[2]; padWidth = (int) extraParams[3]; imgHeight = (int) extraParams[4]; imgWidth = (int) extraParams[5]; dY = (int) extraParams[6]; //Dilation in height/y dimension dX = (int) extraParams[7]; //Dilation in width/x dimension outShape = shape::shapeOf(zShapeBuffer); resultOrder = shape::order(zShapeBuffer); outStride = shape::stride(zShapeBuffer); samples = (int) outShape[0]; depth = (int) outShape[1]; imgH = (int) outShape[2]; imgW = (int) outShape[3]; height_col = inShape[4];//(imgHeight + 2 * padHeight - kernelHeight) / strideX + 1; width_col = inShape[5];//(imgWidth + 2 * padWidth - kernelWidth) / strideY + 1; n = samples * depth * imgHeight * imgWidth; //Effective kernel size, accounting for dilation kEffectiveW = kernelWidth + (kernelWidth - 1) * (dX - 1); kEffectiveH = kernelHeight + (kernelHeight - 1) * (dY - 1); } __syncthreads(); for (int i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { X val = 0; int w_im = i % imgWidth + padWidth; int h_im = (i / imgWidth) % imgHeight + padHeight; int c_im = i / (imgWidth * imgHeight); int num_im = c_im / depth; int depth_im = c_im % depth; // compute the start and end of the output // These are the indexes for dimensions ??? in the 6d col matrix int w_col_start = (w_im < kEffectiveW) ? 0 : (w_im - kEffectiveW) / strideX + 1; int w_col_end = nd4j::math::nd4j_min(w_im / strideX + 1, width_col); int h_col_start = (h_im < kEffectiveH) ? 0 : (h_im - kEffectiveH) / strideY + 1; int h_col_end = nd4j::math::nd4j_min(h_im / strideY + 1, height_col); //Iterate over col entries in the 6d array... these are added up for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) { for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) { int h_k = (h_im - h_col * strideY); int w_k = (w_im - w_col * strideX); if(h_k % dY == 0 && w_k % dX == 0){ h_k /= dY; w_k /= dX; int data_col_index = num_im * strideex + depth_im * stridech + h_k * stridekrow + w_k * stridekcol + h_col * striderow + w_col * stridecol; val += dx[data_col_index]; } } } int i_f = 0; int i_c = i; for (int dim = 3; dim >= 0; dim--) { i_f += (i_c % outShape[dim]) * outStride[dim]; i_c = i_c / outShape[dim]; } result[i_f] = val; } } #endif static void execSpecial( X *colBuff, Nd4jLong *colShapeBuffer, X *imBuff, Nd4jLong *imShapeBuffer, X *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] auto colShape = shape::shapeOf(colShapeBuffer); auto colStride = shape::stride(colShapeBuffer); auto imShape = shape::shapeOf(imShapeBuffer); auto imStride = shape::stride(imShapeBuffer); const int sH = (int)extraParams[0]; const int sW = (int)extraParams[1]; const int pH = (int)extraParams[2]; const int pW = (int)extraParams[3]; const int iH = (int)extraParams[4]; const int iW = (int)extraParams[5]; const int dH = (int)extraParams[6]; const int dW = (int)extraParams[7]; const int bS = imShape[0]; const int iC = imShape[1]; const int kH = colShape[2]; const int kW = colShape[3]; const int oH = colShape[4]; const int oW = colShape[5]; const Nd4jLong colStride0 = colStride[0]; const Nd4jLong colStride1 = colStride[1]; const Nd4jLong colStride2 = colStride[2]; const Nd4jLong colStride3 = colStride[3]; const Nd4jLong colStride4 = colStride[4]; const Nd4jLong colStride5 = colStride[5]; const Nd4jLong imStride0 = imStride[0]; const Nd4jLong imStride1 = imStride[1]; const Nd4jLong imStride2 = imStride[2]; const Nd4jLong imStride3 = imStride[3]; auto zLength = shape::length(imShapeBuffer); // initial zeroing of image content memset(imBuff, 0, zLength * sizeof(X)); X *col, *im; int imRow, imCol; if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) { PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, im, imRow, imCol) collapse(2)) for (int b = 0; b < bS; b++) { for (int c = 0; c < iC; ++c) { for (int kRow = 0; kRow < kH; ++kRow) { for (int kCol = 0; kCol < kW; ++kCol) { for (int colH = 0; colH < oH; ++colH) { for (int colW = 0; colW < oW; ++colW) { imRow = (-pH + kRow * dH) + colH*sH; imCol = (-pW + kCol * dW) + colW*sW; col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5; im = imBuff + b*imStride0 + c*imStride1 + imRow*imStride2 + imCol*imStride3; if (static_cast(imRow) < static_cast(iH) && static_cast(imCol) < static_cast(iW)) *im += *col; } } } } } } } else { PRAGMA_OMP_PARALLEL_FOR_ARGS(private(im, col, imRow, imCol)) for (int b = 0; b < bS; b++) { for (int colH = 0; colH < oH; ++colH) { for (int colW = 0; colW < oW; ++colW) { for (int c = 0; c < iC; ++c) { for (int kRow = 0; kRow < kH; ++kRow) { for (int kCol = 0; kCol < kW; ++kCol) { imRow = (-pH + kRow * dH) + colH*sH; imCol = (-pW + kCol * dW) + colW*sW; col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5; im = imBuff + b*imStride0 + c*imStride1 + imRow*imStride2 + imCol*imStride3; if (static_cast(imRow) < static_cast(iH) && static_cast(imCol) < static_cast(iW)) *im += *col; } } } } } } } } op_def static X op(X d1, X *params) { return d1; } /** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc * normally negative indices are bad, OK here because of other checks on input indices * Uses unrolled loop specifically for length 4 */ static _CUDA_HD int getOffsetUnsafe4(int baseOffset, int *shape, int *stride, int *indices) { int offset = baseOffset; if (shape[0] != 1) offset += indices[0] * stride[0]; if (shape[1] != 1) offset += indices[1] * stride[1]; if (shape[2] != 1) offset += indices[2] * stride[2]; if (shape[3] != 1) offset += indices[3] * stride[3]; return offset; } /** A version of Shape.getOffset without checking on input for negative indices etc * normally negative indices are bad, OK here because of other checks on input indices * Uses unrolled loop specifically for length 6, where indices[2] and indices[3] are zero (always are here) */ static _CUDA_HD int getOffsetUnsafe6(int baseOffset, int *shape, int *stride, int *indices) { int offset = baseOffset; if (shape[0] != 1) offset += indices[0] * stride[0]; if (shape[1] != 1) offset += indices[1] * stride[1]; if (shape[4] != 1) offset += indices[4] * stride[4]; if (shape[5] != 1) offset += indices[5] * stride[5]; return offset; } }; template class Reverse { public: static const bool requiresSpecial = true; #ifdef __CUDACC__ static inline __device__ void execSpecialCuda(X *dx, Nd4jLong *xShapeBuffer, X *result, Nd4jLong *zShapeBuffer, X *extraParams, int *allocationPointer, X *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { __shared__ Nd4jLong xLength; __shared__ int xEWS; __shared__ char xOrder; __shared__ Nd4jLong sLength; __shared__ X *shmem; int tid = threadIdx.x + blockIdx.x * blockDim.x; if (threadIdx.x == 0) { xLength = shape::length(xShapeBuffer); xEWS = shape::elementWiseStride(xShapeBuffer); xOrder = shape::order(xShapeBuffer); sLength = xLength - 1; extern __shared__ unsigned char shrd[]; shmem = (X *) shrd; } __syncthreads(); if (dx == result) { if (xEWS == 1) { for (int e = tid; e < xLength / 2; e += blockDim.x * gridDim.x) { Nd4jLong idx = sLength - e; X tmp = dx[e]; dx[e] = dx[idx]; dx[idx] = tmp; } } else if (xEWS >= 1) { for (int e = tid; e < xLength / 2; e += blockDim.x * gridDim.x) { Nd4jLong idx1 = (sLength - e) * xEWS; Nd4jLong idx2 = e * xEWS; X tmp = dx[idx2]; dx[idx2] = dx[idx1]; dx[idx1] = tmp; } } else { for (int e = tid; e < xLength / 2; e += blockDim.x * gridDim.x) { auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength); auto zOffset = shape::getIndexOffset(sLength - e, xShapeBuffer, xLength); result[zOffset] = dx[xOffset]; } } } else { __shared__ int zEWS; __shared__ char zOrder; if (threadIdx.x == 0) { zEWS = shape::elementWiseStride(zShapeBuffer); zOrder = shape::order(zShapeBuffer); } __syncthreads(); if (xEWS == 1 && zEWS == 1 && xOrder == zOrder) { // loop for whole array for (int e = tid; e < xLength; e += blockDim.x * gridDim.x) { result[sLength - e] = dx[e]; } } else if (xEWS >= 1 && zEWS >= 1 && xOrder == zOrder) { for (int e = tid; e < xLength; e += blockDim.x * gridDim.x) { result[(sLength - e) * zEWS] = dx[e * xEWS]; } } else { for (int e = tid; e < xLength; e += blockDim.x * gridDim.x) { auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength); auto zOffset = shape::getIndexOffset(sLength - e, xShapeBuffer, xLength); result[zOffset] = dx[xOffset]; } } } } #endif static void execSpecial(X *dx, Nd4jLong *xShapeBuffer, X *result, Nd4jLong *zShapeBuffer, X *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong xLength = shape::length(xShapeBuffer); int xEWS = shape::elementWiseStride(xShapeBuffer); char xOrder = shape::order(xShapeBuffer); Nd4jLong sLength = xLength - 1; // two step phase here if (dx == result) { if (xEWS == 1) { PRAGMA_OMP_PARALLEL_FOR_SIMD for (Nd4jLong e = 0; e < xLength / 2; e++) { Nd4jLong idx = sLength - e; auto tmp = dx[e]; dx[e] = dx[idx]; dx[idx] = tmp; } } else if (xEWS > 1) { PRAGMA_OMP_PARALLEL_FOR_SIMD for (Nd4jLong e = 0; e < xLength / 2; e++) { Nd4jLong idx1 = (sLength - e) * xEWS; Nd4jLong idx2 = e * xEWS; auto tmp = dx[idx2]; dx[idx2] = dx[idx1]; dx[idx1] = tmp; } } else { PRAGMA_OMP_PARALLEL_FOR_SIMD for (Nd4jLong e = 0; e < xLength / 2; e++) { auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength); auto zOffset = shape::getIndexOffset(sLength - e, xShapeBuffer, xLength); result[zOffset] = dx[xOffset]; } } } else { // single step phase here auto zEWS = shape::elementWiseStride(zShapeBuffer); auto zOrder = shape::order(zShapeBuffer); if (xEWS == 1 && zEWS == 1 && xOrder == zOrder) { PRAGMA_OMP_PARALLEL_FOR_SIMD for (Nd4jLong e = 0; e < xLength; e++) { result[sLength - e] = dx[e]; } } else if (xEWS >= 1 && zEWS >= 1 && xOrder == zOrder) { PRAGMA_OMP_PARALLEL_FOR_SIMD for (Nd4jLong e = 0; e < xLength; e++) { result[(sLength - e) * zEWS] = dx[e * xEWS]; } } else { PRAGMA_OMP_PARALLEL_FOR_SIMD for (Nd4jLong e = 0; e < xLength; e++) { auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength); auto zOffset = shape::getIndexOffset(sLength - e, zShapeBuffer, xLength); result[zOffset] = dx[xOffset]; } } } } op_def static X op(X d1, X *params) { return d1; } }; template class SoftMax { public: static const bool requiresSpecial = true; #ifdef __CUDACC__ /** * */ static inline __device__ void execSpecialCuda( void *vx, Nd4jLong *xShapeBuffer, void *vresult, Nd4jLong *zShapeBuffer, void *vextraParams, int *allocationPointer, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto dx = reinterpret_cast(vx); auto result = reinterpret_cast(vresult); auto extraParams = reinterpret_cast(vextraParams); auto shape = shape::shapeOf(xShapeBuffer); __shared__ X maxResult; __shared__ Nd4jLong *maxResultShapeBuffer; auto length = shape::length(xShapeBuffer); auto stride = shape::stride(xShapeBuffer); //compute the row wise maxes __shared__ Nd4jLong maxShape[2]; // it's always 2d here __shared__ Nd4jLong tempBuffer[8]; if (threadIdx.x == 0) { maxResult = (X) 0.0; maxShape[0] = shape[0]; maxShape[1] = 1; maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT(), maxShape, tempBuffer); } __syncthreads(); functions::reduce::ReduceSameInplace::execScalarCudaLegacy(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr); __syncthreads(); //subtract max of each row functions::scalar::ScalarInplace::transformCudaLegacy(nd4j::scalar::Subtract, &maxResult, dx, xShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer); __syncthreads(); //after subtracting the row wise maxes take the exp functions::transform::TransformStrictInplace::transformCudaLegacy(nd4j::transform::Exp, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); __syncthreads(); //take the sum for the exponential functions::reduce::ReduceSameInplace::execScalarCudaLegacy(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr); __syncthreads(); //divide by the sum functions::scalar::ScalarInplace::transformCudaLegacy(nd4j::scalar::Divide, &maxResult, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer); } #endif static void execSpecial( void *vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto x = reinterpret_cast(vx); auto z = reinterpret_cast(vz); auto extraParams = reinterpret_cast(vextraParams); if (shape::isMatrix(xShapeInfo)) { if(shape::equalsStrict(xShapeInfo, zShapeInfo)) { if (tadShapeInfo == nullptr) { auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, 1); tadShapeInfo = tadPack.primaryShapeInfo(); tadOffsets = tadPack.primaryOffsets(); } const uint tadLen = shape::length(tadShapeInfo); const uint numOfTads = shape::length(xShapeInfo) / tadLen; if(shape::elementWiseStride(tadShapeInfo) == 1) { PRAGMA_OMP_PARALLEL_FOR_SIMD for (uint i = 0; i < numOfTads; ++i) { X* inBuff = x + tadOffsets[i]; X* outBuff = z + tadOffsets[i]; X max = -nd4j::DataTypeUtils::max(); X sum = 0; for(uint j = 0; j < tadLen; ++j) max = nd4j::math::nd4j_max(max, inBuff[j]); for (uint j = 0; j < tadLen; ++j) { X temp = nd4j::math::nd4j_exp(inBuff[j] - max); outBuff[j] = temp; sum += temp; } for (uint j = 0; j < tadLen; ++j) outBuff[j] /= sum; } } else { uint xShapeInfoCast[MAX_RANK]; bool canCast = nd4j::DataTypeUtils::castShapeInfo(tadShapeInfo, xShapeInfoCast); auto offsets = new Nd4jLong[tadLen]; shape::calcOffsets(tadShapeInfo, offsets); PRAGMA_OMP_PARALLEL_FOR_SIMD for (uint i = 0; i < numOfTads; ++i) { X* inBuff = x + tadOffsets[i]; X* outBuff = z + tadOffsets[i]; X max = -nd4j::DataTypeUtils::max(); X sum = 0.f; for(uint j = 0; j < tadLen; ++j) max = nd4j::math::nd4j_max(max, inBuff[offsets[j]]); for (uint j = 0; j < tadLen; ++j) { X temp = nd4j::math::nd4j_exp(inBuff[offsets[j]] - max); outBuff[offsets[j]] = temp; sum += temp; } for (uint j = 0; j < tadLen; ++j) outBuff[offsets[j]] /= sum; } delete []offsets; } } else { auto shape = shape::shapeOf(xShapeInfo); //iterate along rows int dimension[1] = { 0 }; int maxDimension[1] = { 1 }; //compute the row wise maxes auto maxResult = new X[shape[0]]; for (int i = 0; i < shape[0]; i++) maxResult[i] = 0.0; Nd4jLong maxShape[2] = { shape[0], 1 }; auto maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT(), maxShape); functions::reduce::ReduceSameFunction::exec(nd4j::reduce::Max, x, xShapeInfo, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr); //subtract max of each row functions::broadcast::Broadcast::exec(nd4j::broadcast::Subtract, x, xShapeInfo, maxResult, maxResultShapeBuffer, z, zShapeInfo, dimension, 1, nullptr, nullptr, nullptr, nullptr); //after subtracting the row wise maxes take the exp functions::transform::TransformStrict::exec(nd4j::transform::Exp, z, zShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets); //take the sum for the exponential functions::reduce::ReduceSameFunction::exec(nd4j::reduce::Sum, z, zShapeInfo, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr); //divide by the sum functions::broadcast::Broadcast::exec(nd4j::broadcast::Divide, z, zShapeInfo, maxResult, maxResultShapeBuffer, z, zShapeInfo, dimension, 1, nullptr, nullptr, nullptr, nullptr); delete[] maxResultShapeBuffer; delete[] maxResult; } } else if (shape::isVector(xShapeInfo)) { auto max = -nd4j::DataTypeUtils::max(); X sum = 0; int elementWiseStride = shape::elementWiseStride(xShapeInfo); int resultElementWiseStride = shape::elementWiseStride(zShapeInfo); int length = shape::length(xShapeInfo); if (elementWiseStride >= 1 && resultElementWiseStride >= 1) { if (elementWiseStride == 1 && resultElementWiseStride == 1) { for (int i = 0; i < length; i++) { max = nd4j::math::nd4j_max(max, x[i]); } for (int i = 0; i < length; i++) { z[i] = nd4j::math::nd4j_exp(x[i] - max); sum += z[i]; } PRAGMA_OMP_SIMD for (int i = 0; i < length; i++) { z[i] /= sum; } } else { for (int i = 0; i < length; i++) { max = nd4j::math::nd4j_max(max, x[i * elementWiseStride]); } for (int i = 0; i < length; i++) { auto r = nd4j::math::nd4j_exp(x[i * elementWiseStride] - max); z[i * resultElementWiseStride] = r; sum += r; } for (int i = 0; i < length; i++) { z[i * resultElementWiseStride] /= sum; } } } } } op_def static X op(X d1, X *params) { return d1; } }; template class LogSoftMax { public: static const bool requiresSpecial = true; #ifdef __CUDACC__ /** * */ static inline __device__ void execSpecialCuda( void *vx, Nd4jLong *xShapeBuffer, void *vresult, Nd4jLong *zShapeBuffer, void *vextraParams, int *allocationPointer, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto shape = shape::shapeOf(xShapeBuffer); auto stride = shape::stride(xShapeBuffer); //iterate along rows auto dx = reinterpret_cast(vx); auto result = reinterpret_cast(vresult); auto extraParams = reinterpret_cast(vextraParams); __shared__ X maxResult; __shared__ Nd4jLong *maxResultShapeBuffer; if (threadIdx.x == 0) { maxResult = (X) 0.0; } __syncthreads(); //compute the row wise maxes Nd4jLong maxShape[2] = { shape[0], 1 }; __shared__ Nd4jLong tempBuffer[8]; if (threadIdx.x == 0) maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT(), maxShape, tempBuffer); __syncthreads(); functions::reduce::ReduceSameInplace::execScalarCudaLegacy(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr); __syncthreads(); //subtract max of each row functions::scalar::ScalarInplace::transformCudaLegacy(nd4j::scalar::Subtract, &maxResult, dx, xShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer); __syncthreads(); //after subtracting the row wise maxes take the exp functions::transform::TransformStrictInplace::transformCudaLegacy(nd4j::transform::Exp, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); __syncthreads(); //take the sum for the exponential functions::reduce::ReduceSameInplace::execScalarCudaLegacy(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr); __syncthreads(); //divide by the sum functions::scalar::ScalarInplace::transformCudaLegacy(nd4j::scalar::Divide, &maxResult, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer); __syncthreads(); functions::transform::TransformStrictInplace::transformCudaLegacy(nd4j::transform::Log, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); } #endif static void execSpecial( void *vx, Nd4jLong *xShapeBuffer, void *vresult, Nd4jLong *zShapeBuffer, void *vextraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto dx = reinterpret_cast(vx); auto result = reinterpret_cast(vresult); auto extraParams = reinterpret_cast(vextraParams); if (shape::isMatrix(xShapeBuffer, 2)) { auto shape = shape::shapeOf(xShapeBuffer); //iterate along rows int dimension[1] = { 0 }; int maxDimension[1] = { 1 }; //compute the row wise maxes auto maxResult = new X[shape[0]]; PRAGMA_OMP_SIMD for (int i = 0; i < shape[0]; i++) maxResult[i] = 0.0; Nd4jLong maxShape[2] = { shape[0], 1 }; auto maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT(), maxShape); functions::reduce::ReduceSameFunction::exec(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr); //subtract max of each row functions::broadcast::Broadcast::exec(nd4j::broadcast::Subtract, dx, xShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr); //after subtracting the row wise maxes take the exp functions::transform::TransformStrict::exec(nd4j::transform::Exp, result, zShapeBuffer, result, zShapeBuffer, extraParams, tadShapeInfo, tadOffsets); //take the sum for the exponential functions::reduce::ReduceSameFunction::exec(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr); //divide by the sum functions::broadcast::Broadcast::exec(nd4j::broadcast::Divide, result, zShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr); functions::transform::TransformStrict::exec(nd4j::transform::Log, result, zShapeBuffer, result, zShapeBuffer, extraParams, tadShapeInfo, tadOffsets); delete[] maxResultShapeBuffer; } else if (shape::isVector(xShapeBuffer, 2)) { auto max = -FLOAT_MAX_VALUE; X sum = 0; auto elementWiseStride = shape::elementWiseStride(xShapeBuffer); auto length = shape::length(xShapeBuffer); if (elementWiseStride == 1) { for (int i = 0; i < length; i++) { max = nd4j::math::nd4j_max(max, result[i]); } for (int i = 0; i < length; i++) { result[i] = nd4j::math::nd4j_exp(dx[i] - max); sum += result[i]; } PRAGMA_OMP_SIMD for (int i = 0; i < length; i++) { result[i] /= sum; result[i] = nd4j::math::nd4j_log(result[i]); } } else if (elementWiseStride > 1) { for (int i = 0; i < length; i++) { max = nd4j::math::nd4j_max(max, result[i * elementWiseStride]); } for (int i = 0; i < length; i++) { result[i * elementWiseStride] = nd4j::math::nd4j_exp(dx[i * elementWiseStride] - max); sum += result[i * elementWiseStride]; } for (int i = 0; i < length; i++) { result[i * elementWiseStride] /= sum; result[i * elementWiseStride] = nd4j::math::nd4j_log(result[i * elementWiseStride]); } } } } op_def static X op(X d1, X *params) { return d1; } }; /** * softmax(x) */ template class SoftMaxDerivative { public: static const bool requiresSpecial = true; #ifdef __CUDACC__ /** * */ static inline __device__ void execSpecialCuda( void *vx, Nd4jLong *xShapeBuffer, void *vresult, Nd4jLong *zShapeBuffer, void *vextraParams, int *allocationPointer, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto dx = reinterpret_cast(vx); auto result = reinterpret_cast(vresult); auto extraParams = reinterpret_cast(vextraParams); auto shape = shape::shapeOf(xShapeBuffer); __shared__ X maxResult; __shared__ Nd4jLong *maxResultShapeBuffer; __shared__ Nd4jLong resultEWS; auto length = shape::length(xShapeBuffer); if (threadIdx.x == 0) { resultEWS = shape::elementWiseStride(zShapeBuffer); maxResult = (X) 0.0; } __syncthreads(); auto tride = shape::stride(xShapeBuffer); Nd4jLong maxShape[2] = { shape[0], 1 }; __shared__ Nd4jLong tempBuffer[8]; if (threadIdx.x == 0) maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT(), maxShape, tempBuffer); __syncthreads(); functions::reduce::ReduceSameInplace::execScalarCudaLegacy(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr); __syncthreads(); //subtract max of each row functions::scalar::ScalarInplace::transformCudaLegacy(nd4j::scalar::Subtract, &maxResult, dx, xShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer); __syncthreads(); //after subtracting the row wise maxes take the exp functions::transform::TransformStrictInplace::transformCudaLegacy(nd4j::transform::Exp, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); __syncthreads(); //take the sum for the exponential functions::reduce::ReduceSameInplace::execScalarCudaLegacy(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr); __syncthreads(); //divide by the sum functions::scalar::ScalarInplace::transformCudaLegacy(nd4j::scalar::Divide, &maxResult, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer); __syncthreads(); if (resultEWS >= 1) { for (int i = threadIdx.x; i < length; i += blockDim.x) { result[i * resultEWS] = result[i * resultEWS] * ((X) 1.0 - result[i * resultEWS]); } } else { printf("Non element wise stride not supported right now\n"); } } #endif static void execSpecial( void *vx, Nd4jLong *xShapeBuffer, void *vresult, Nd4jLong *zShapeBuffer, void *vextraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto dx = reinterpret_cast(vx); auto result = reinterpret_cast(vresult); auto extraParams = reinterpret_cast(vextraParams); if (shape::isMatrix(xShapeBuffer, 2)) { auto shape = shape::shapeOf(xShapeBuffer); auto resultEleStide = shape::elementWiseStride(zShapeBuffer); //iterate along rows int dimension[1] = { 0 }; int maxDimension[1] = { 1 }; auto len = shape::length(xShapeBuffer); //compute the row wise maxes auto maxResult = new X[shape[0]]; PRAGMA_OMP_SIMD for (int i = 0; i < shape[0]; i++) maxResult[i] = 0.0f; Nd4jLong maxShape[2] = { shape[0], 1 }; auto maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT(), maxShape); functions::reduce::ReduceSameFunction::exec(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr); //subtract max of each row functions::broadcast::Broadcast::exec(nd4j::broadcast::Subtract, result, zShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr); //after subtracting the row wise maxes take the exp functions::transform::TransformStrict::exec(nd4j::transform::Exp, result, zShapeBuffer, result, zShapeBuffer, extraParams, tadShapeInfo, tadOffsets); //take the sum for the exponential functions::reduce::ReduceSameFunction::exec(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr); //divide by the sum functions::broadcast::Broadcast::exec(nd4j::broadcast::Divide, result, zShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr); if (resultEleStide >= 1) { if (resultEleStide == 1) { PRAGMA_OMP_SIMD for (int i = 0; i < len; i++) { result[i] = result[i] * (static_cast(1.0f) - result[i]); } } else { PRAGMA_OMP_SIMD for (int i = 0; i < len; i++) { result[i * resultEleStide] = result[i * resultEleStide] * (static_cast(1.0f) - result[i * resultEleStide]); } } } else { for (int i = 0; i < len; i++) { Nd4jLong zOffset = shape::getIndexOffset(i, zShapeBuffer, len); result[zOffset] = result[zOffset] * ((X) 1.0f - result[zOffset]); } } delete[] maxResultShapeBuffer; delete[] maxResult; } else if (shape::isVector(xShapeBuffer, 2)) { auto max = -nd4j::DataTypeUtils::max(); X sum = 0; auto elementWiseStride = shape::elementWiseStride(xShapeBuffer); auto length = shape::length(xShapeBuffer); if (elementWiseStride == 1) { for (int i = 0; i < length; i++) { max = nd4j::math::nd4j_max(max, result[i]); } for (int i = 0; i < length; i++) { result[i] -= max; result[i] = nd4j::math::nd4j_exp(result[i]); sum += result[i]; } for (int i = 0; i < length; i++) { result[i] /= sum; } for (int i = 0; i < length; i++) { result[i] = result[i] * ((X) 1.0f - result[i]); } } else if (elementWiseStride >= 1) { for (int i = 0; i < length; i++) { max = nd4j::math::nd4j_max(max, result[i * elementWiseStride]); } for (int i = 0; i < length; i++) { result[i * elementWiseStride] -= max; result[i * elementWiseStride] = nd4j::math::nd4j_exp(result[i * elementWiseStride]); sum += result[i * elementWiseStride]; } PRAGMA_OMP_SIMD for (int i = 0; i < length; i++) { result[i * elementWiseStride] /= sum; } PRAGMA_OMP_SIMD for (int i = 0; i < length; i++) { result[i * elementWiseStride] = result[i * elementWiseStride] * ((X) 1.0f - result[i * elementWiseStride]); } } else { printf("non-ews access on row not implemented yet"); } } } op_def static X op(X d1, X *params) { return d1; } }; template class IsMax { public: static const bool requiresSpecial = true; #ifdef __CUDACC__ static inline __device__ void doAllCuda( void *vx, Nd4jLong *xShapeBuffer, void *vresult, Nd4jLong *zShapeBuffer, void *vextraParams, int *allocationPointer, void *reductionPointer) { auto dx = reinterpret_cast(vx); auto result = reinterpret_cast(vresult); auto extraParams = reinterpret_cast(vextraParams); // this code is safe to delete, it's never used /* __shared__ int maxIdx; __shared__ int length; if (threadIdx.x == 0) { length = shape::length(zShapeBuffer); } __syncthreads(); functions::indexreduce::IndexReduce::template transform>( dx, xShapeBuffer, extraParams, result, zShapeBuffer, nullptr, 1, 1, allocationPointer, reductionPointer, nullptr, nullptr); __syncthreads(); if (threadIdx.x == 0) maxIdx = (int)result[0]; __syncthreads(); for (int i = threadIdx.x; i < length; i += blockDim.x) result[i] = 0; __syncthreads(); if (threadIdx.x == 0) { result[maxIdx] = 1.0; } */ } #endif #ifdef __CUDACC__ inline __host__ #elif defined(__GNUC__) #endif static void doAll( void *vx, Nd4jLong *xShapeBuffer, void *vresult, Nd4jLong *zShapeBuffer, void *vextraParams) { auto dx = reinterpret_cast(vx); auto result = reinterpret_cast(vresult); auto extraParams = reinterpret_cast(vextraParams); auto length = shape::length(xShapeBuffer); auto eleStride = shape::elementWiseStride(xShapeBuffer); auto resultEleStride = shape::elementWiseStride(zShapeBuffer); auto xOrder = shape::order(xShapeBuffer); auto resultOrder = shape::order(zShapeBuffer); if (xOrder == resultOrder && xOrder == 'c') { if (eleStride == 1 && resultEleStride == 1) { if (length < ELEMENT_THRESHOLD) { int maxIdx = 0; auto currMax = dx[0]; for (int i = 0; i < length; i++) { if (currMax < dx[i]) { currMax = dx[i]; maxIdx = i; } result[i] = static_cast(0); } result[maxIdx] = static_cast(1); } else { int maxIdx = 0; auto currMax = dx[0]; { int maxIdxLocal = maxIdx; auto currMaxLocal = currMax; for (int i = 0; i < length; i++) { if (currMaxLocal < dx[i]) { currMaxLocal = dx[i]; maxIdxLocal = i; } result[i] = static_cast(0); } PRAGMA_OMP_CRITICAL { if (currMax < currMaxLocal) { currMax = currMaxLocal; maxIdx = maxIdxLocal; } } } result[maxIdx] = static_cast(1); } } else { if (length < ELEMENT_THRESHOLD) { int maxIdx = 0; auto currMax = dx[0]; for (int i = 0; i < length; i++) { result[i * resultEleStride] = static_cast(0); if (currMax < dx[i * eleStride]) { currMax = dx[i * eleStride]; maxIdx = i; } } result[maxIdx * resultEleStride] = static_cast(1); } else { int maxIdx = 0; auto currMax = dx[0]; { int maxIdxLocal = maxIdx; auto currMaxLocal = currMax; for (int i = 0; i < length; i++) { result[i * resultEleStride] = static_cast(0); if (currMaxLocal < dx[i * eleStride]) { currMaxLocal = dx[i * eleStride]; maxIdxLocal = i; } } PRAGMA_OMP_CRITICAL { if (currMax < currMaxLocal) { currMax = currMaxLocal; maxIdx = maxIdxLocal; } } } result[maxIdx * resultEleStride] = static_cast(1); } } } else { Nd4jLong shapeIter[MAX_RANK]; Nd4jLong coord[MAX_RANK]; int dim; Nd4jLong xStridesIter[MAX_RANK]; Nd4jLong resultStridesIter[MAX_RANK]; auto xShape = shape::shapeOf(xShapeBuffer); auto xStride = shape::stride(xShapeBuffer); auto resultStride = shape::stride(zShapeBuffer); auto rank = shape::rank(xShapeBuffer); auto originalResult = result; if (PrepareTwoRawArrayIter(rank, xShape, dx, xStride, result, resultStride, &rank, shapeIter, &dx, xStridesIter, &result, resultStridesIter) >= 0) { auto value = dx[0]; int idx = 0; int maxIdx = 0; ND4J_RAW_ITER_START(dim, rank, coord, shapeIter); { if (dx[0] > value) { value = dx[0]; maxIdx = idx; } idx++; result[0] = static_cast(0); } ND4J_RAW_ITER_TWO_NEXT( dim, rank, coord, shapeIter, dx, xStridesIter, result, resultStridesIter); //pointer to where max value would be if (shape::order(zShapeBuffer) == 'c' || (shape::order(zShapeBuffer) == 'f' && maxIdx * shape::stride(zShapeBuffer)[shape::rank(zShapeBuffer) - 1] >= shape::length(zShapeBuffer))) originalResult[maxIdx] = static_cast(1); else originalResult[maxIdx * shape::stride(zShapeBuffer)[shape::rank(zShapeBuffer) - 1]] = static_cast(1); } } } public: #ifdef __CUDACC__ /** * */ static inline __device__ void execSpecialCuda( void *vx, Nd4jLong *xShapeBuffer, void *vresult, Nd4jLong *zShapeBuffer, void *vextraParams, int *allocationPointer, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto dx = reinterpret_cast(vx); auto result = reinterpret_cast(vresult); auto extraParams = reinterpret_cast(vextraParams); // FIXME: MAX_DIMENSION is lower then FP16 frame if (extraParams == nullptr || (int) extraParams[0] == MAX_DIMENSION) { doAllCuda(dx, xShapeBuffer, result, zShapeBuffer, extraParams, allocationPointer, reductionPointer); } } #endif static void execSpecial( void *vx, Nd4jLong *xShapeBuffer, void *vresult, Nd4jLong *zShapeBuffer, void *vextraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto dx = reinterpret_cast(vx); auto result = reinterpret_cast(vresult); auto extraParams = reinterpret_cast(vextraParams); //FIXME: this op should be moved to CustomOps if (extraParams == nullptr || (int)extraParams[0] == 0 || ((int)extraParams[0] == 1 && (int)extraParams[1] == MAX_DIMENSION)) { doAll(dx, xShapeBuffer, result, zShapeBuffer, extraParams); } else if (shape::isVector(xShapeBuffer)) { auto dimensionLength = (int)extraParams[0]; auto dimension = new int[dimensionLength]; auto length = shape::length(xShapeBuffer); for (int i = 0; i < dimensionLength; i++) { dimension[i] = (int)extraParams[i + 1]; } if (shape::shapeOf(xShapeBuffer)[dimension[0]] == 1) { for (int i = 0; i < length; i++) { result[i] = static_cast(1); } } else { auto eleStride = shape::elementWiseStride(xShapeBuffer); if (eleStride == 1) { int maxIdx = 0; auto currMax = dx[0]; if (length < ELEMENT_THRESHOLD) { for (int i = 0; i < length; i++) { if (currMax < dx[i]) { currMax = dx[i]; maxIdx = i; } result[i] = static_cast(0); } } else { PRAGMA_OMP_PARALLEL { int maxIdxLocal = maxIdx; auto currMaxLocal = currMax; for (int i = 0; i < length; i++) { if (currMaxLocal < dx[i]) { currMaxLocal = dx[i]; maxIdxLocal = i; } result[i] = static_cast(0); } PRAGMA_OMP_CRITICAL { if (currMax < currMaxLocal) { currMax = currMaxLocal; maxIdx = maxIdxLocal; } } } } result[maxIdx] = static_cast(1); } else { int maxIdx = 0; auto currMax = dx[0]; if (length < ELEMENT_THRESHOLD) { for (int i = 0; i < length; i++) { if (currMax < dx[i * eleStride]) { currMax = dx[i * eleStride]; maxIdx = i; } result[i] = static_cast(0); } } else { { int maxIdxLocal = maxIdx; auto currMaxLocal = currMax; for (int i = 0; i < length; i++) { if (currMaxLocal < dx[i * eleStride]) { currMaxLocal = dx[i * eleStride]; maxIdxLocal = i; } result[i] = static_cast(0); } PRAGMA_OMP_CRITICAL { if (currMax < currMaxLocal) { currMax = currMaxLocal; maxIdx = maxIdxLocal; } } } } result[maxIdx] = static_cast(1); } } } else { auto dimensionLength = (int) extraParams[0]; auto dimension = new int[dimensionLength]; PRAGMA_OMP_SIMD for (int i = 0; i < dimensionLength; i++) { dimension[i] = (int) extraParams[i + 1]; } //decompose in to several sub tads after //moving all dimensions (in sorted order) //to the back. //permuted version of the x shape info for setting up the tad problem auto tadShapeShapeInfo = tadShapeInfo; if(tadShapeInfo==nullptr) { auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeBuffer, dimension, dimensionLength); tadShapeShapeInfo = tadPack.primaryShapeInfo(); tadOffsets = tadPack.primaryOffsets(); tadShapeInfo = tadShapeShapeInfo; } auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeBuffer, dimension, dimensionLength); auto tads = shape::length(xShapeBuffer) / tadLength; int tadsPerThread = tads / TAD_THRESHOLD; int num_threads = nd4j::math::nd4j_max(1, tadsPerThread); num_threads = nd4j::math::nd4j_min(num_threads, omp_get_max_threads()); auto tadEWS = shape::elementWiseStride(tadShapeShapeInfo); auto zEWS = tadEWS; int span = (tads / num_threads) + 8; PRAGMA_OMP_PARALLEL_THREADS(num_threads) { int tid = omp_get_thread_num(); int start = span * tid; int end = span * (tid + 1); if (end > tads) end = tads; for (int r = start; r < end; r++) { if (tadEWS > 0 && zEWS > 0 && dimensionLength == 1) { auto rX = dx + tadOffsets[r]; auto rZ = result + tadOffsets[r]; auto maxValue = rX[0]; int maxIdx = 0; if (tadEWS == 1 && zEWS == 1) { for (int i = 0; i < tadLength; i++) { if (rX[i] > maxValue) { maxIdx = i; maxValue = rX[i]; } } for (int i = 0; i < tadLength; i++) { rZ[i] = static_cast(maxIdx == i); } } else { for (int i = 0; i < tadLength; i++) { if (rX[i * tadEWS] > maxValue) { maxIdx = i; maxValue = rX[i * tadEWS]; } } for (int i = 0; i < tadLength; i++) { rZ[i * zEWS] = static_cast(maxIdx == i); } } } else { int tadsPerThread = tads / TAD_THRESHOLD; int num_threads = nd4j::math::nd4j_max(1, tadsPerThread); num_threads = nd4j::math::nd4j_min(num_threads, omp_get_max_threads()); auto offset = tadOffsets[r]; Nd4jLong shapeIter[MAX_RANK]; Nd4jLong coord[MAX_RANK]; int dim; Nd4jLong xStridesIter[MAX_RANK]; Nd4jLong resultStridesIter[MAX_RANK]; auto xShape = shape::shapeOf(tadShapeShapeInfo); auto xStride = shape::stride(tadShapeShapeInfo); auto resultStride = shape::stride(tadShapeShapeInfo); int rank = shape::rank(tadShapeShapeInfo); auto xPointer = dx + offset; auto resultPointer = result + offset; auto maxValue = xPointer[0]; auto maxCursor = resultPointer; Nd4jPointer maxCursorLong = reinterpret_cast(maxCursor); if (PrepareTwoRawArrayIter(rank, xShape, xPointer, xStride, resultPointer, resultStride, &rank, shapeIter, &xPointer, xStridesIter, &resultPointer, resultStridesIter) >= 0) { ND4J_RAW_ITER_START(dim, rank, coord, shapeIter); { if (maxValue < xPointer[0]) { maxCursor = resultPointer; maxCursorLong = reinterpret_cast(resultPointer); maxValue = xPointer[0]; } resultPointer[0] = static_cast(0); } ND4J_RAW_ITER_TWO_NEXT(dim, rank, coord, shapeIter, xPointer, xStridesIter, resultPointer, resultStridesIter); maxCursor = reinterpret_cast(maxCursorLong); maxCursor[0] = static_cast(1);; } } } } delete[] dimension; } } op_def static Z op(X d1, X *params) { return nd4j::math::softplus(d1); } }; }