/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by raver on 4/9/2018. // #include #include "../indexreduce.h" #include #include #include #include "../legacy_ops.h" using namespace simdOps; template static __global__ void simpleIndexReduceGeneric(const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { functions::indexreduce::IndexReduce::transform(op,dx,xShapeInfo,extraParams,result,zShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); } namespace functions { namespace indexreduce { template _CUDA_H void IndexReduce::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { simpleIndexReduceGeneric<<>>(opNum, dx, xShapeInfo, xRank, extraParams, result, zShapeInfo, 0, nullptr, 0, 1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets); } template _CUDA_H void IndexReduce::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { simpleIndexReduceGeneric<<>>( opNum, dx, xShapeInfo, xRank, extraParams, result, zShapeInfo, zRank, dimension, dimensionLength, 1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets); } // This is the un-specialized struct. Note that we prevent instantiation of this // struct by putting an undefined symbol in the function body so it won't compile. template struct SharedIndexValue { // Ensure that we won't compile any un-specialized types __device__ T * getPointer() { extern __device__ void error(void); error(); return 0; } }; // Following are the specializations for the following types. // int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, and double // One could also specialize it for user-defined types. template<> struct SharedIndexValue { __device__ IndexValue * getPointer() { extern __shared__ IndexValue s_int2[]; return s_int2; } }; // Following are the specializations for the following types. // int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, and double // One could also specialize it for user-defined types. template<> struct SharedIndexValue { __device__ IndexValue * getPointer() { extern __shared__ IndexValue s_int6[]; return s_int6; } }; template template __device__ void IndexReduce::aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *vextraParams) { // start the shared memory loop on the next power of 2 less // than the block size. If block size is not a power of 2, // accumulate the intermediate sums in the remainder range. auto extraParams = static_cast(vextraParams); IndexValue *sPartials = *sPartialsRef; Nd4jLong floorPow2 = blockDim.x; if (floorPow2 & (floorPow2 - 1)) { while ( floorPow2 & (floorPow2 - 1) ) { floorPow2 &= floorPow2 - 1; } if (tid >= floorPow2) { IndexValue prev = sPartials[tid - floorPow2]; IndexValue curr = sPartials[tid]; sPartials[tid - floorPow2] = OpType::update(prev,curr,extraParams); } __syncthreads(); } for (int activeThreads = floorPow2 >> 1;activeThreads; activeThreads >>= 1) { if (tid < activeThreads && tid + activeThreads < numElements) { IndexValue curr = sPartials[tid]; IndexValue next = sPartials[tid + activeThreads]; sPartials[tid] = OpType::update(curr,next,extraParams); } __syncthreads(); } } template __device__ void IndexReduce::transform( const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, result, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); } template template __device__ void IndexReduce::transform(void *vdx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *vreductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets){ /**int * Gpu information for the problem */ auto dx = reinterpret_cast(vdx); auto z = reinterpret_cast(vz); auto extraParams = static_cast(vextraParams); auto reductionBuffer = static_cast(vreductionBuffer); auto order = shape::order(xShapeInfo); int tid = blockIdx.x * blockDim.x + threadIdx.x; __shared__ volatile int resultScalar; //shared memory space for storing intermediate results __shared__ IndexValue* sPartials; if(threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; sPartials = reinterpret_cast*>(shmem); } __syncthreads(); sPartials[threadIdx.x] = OpType::startingIndexValue(dx); //length for the tad __shared__ volatile Nd4jLong xLength; __shared__ volatile Nd4jLong zLen; //only compute the tad indexes once IndexValue reduction = OpType::startingIndexValue(dx); if (threadIdx.x == 0) { if (zShapeInfo != nullptr) zLen = shape::length(zShapeInfo); else zLen = 1; if (dimensionLength == 1) { if (zLen == 1 && (dimension == nullptr || dimension[0] == MAX_DIMENSION)) resultScalar = 1; else resultScalar = 0; } else resultScalar = 0; if (zLen == 1) resultScalar = 1; xLength = shape::length(xShapeInfo); } __syncthreads(); if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; for (uint i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) z[i] = (Z) reduction.index; return; } if (!resultScalar) { __shared__ Nd4jLong tadLength; __shared__ int tadEWS; __shared__ int numTads; if (threadIdx.x == 0) { tadLength = shape::length(tadOnlyShapeInfo); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); numTads = shape::length(xShapeInfo) / tadLength; } __syncthreads(); if (dimensionLength > 1 || tadEWS < 1) { for (int r = blockIdx.x; r < numTads; r += gridDim.x) { auto tadOffsetForBlock = tadOffsets[r]; sPartials[threadIdx.x] = OpType::startingIndexValue(dx); for(int i = threadIdx.x;i < tadLength; i += blockDim.x) { auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); IndexValue comp {dx[xOffset], i}; sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams); } __syncthreads(); aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength),extraParams); __syncthreads(); if (threadIdx.x == 0) { z[r] = (Z) sPartials[threadIdx.x].index; } __syncthreads(); } } else { for(int i = blockIdx.x; i < numTads; i+= gridDim.x) { Nd4jLong tadOffsetForBlock = tadOffsets[i]; sPartials[threadIdx.x] = OpType::startingIndexValue(dx); for (int x = threadIdx.x; x < tadLength; x+= blockDim.x) { IndexValue comp {dx[tadOffsetForBlock + x * tadEWS], x}; sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams); } __syncthreads(); aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength),extraParams); __syncthreads(); if (threadIdx.x == 0) { z[i] = (Z) sPartials[threadIdx.x].index; //postProcess(sPartials[0],tadLength ,extraParams); } __syncthreads(); } } } else { auto n = shape::length(xShapeInfo); auto xElementWiseStride = shape::elementWiseStride(xShapeInfo); if(xElementWiseStride >= 1 && order == 'c') { for(Nd4jLong i = tid;i < n; i += (blockDim.x * gridDim.x)) { IndexValue indexVal = {dx[i * xElementWiseStride], i}; reduction = OpType::update(reduction, indexVal, extraParams); } } else { for(Nd4jLong i = tid;i < n; i += blockDim.x * gridDim.x) { auto offset = shape::getIndexOffset(i, xShapeInfo); IndexValue indexVal = {dx[offset], i}; reduction = OpType::update(reduction, indexVal, extraParams); } } sPartials[threadIdx.x] = reduction; __syncthreads(); aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, (int) n),extraParams); __syncthreads(); if (gridDim.x > 1) { __shared__ bool amLast; unsigned int *tc = (unsigned int *) reductionBuffer; tid = threadIdx.x; if (threadIdx.x == 0) { auto pBuffer = reinterpret_cast *>(reductionBuffer); pBuffer[blockIdx.x] = {sPartials[0].value, sPartials[0].index}; } __threadfence(); __syncthreads(); if (tid==0) { unsigned int ticket = atomicInc(&tc[16384], gridDim.x); amLast = (ticket == gridDim.x-1); } __syncthreads(); if (amLast) { tc[16384] = 0; IndexValue *pBuffer = (IndexValue *) reductionBuffer; sPartials[threadIdx.x] = OpType::startingIndexValue(dx); for (Nd4jLong i = threadIdx.x; i < gridDim.x; i += blockDim.x) { sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], pBuffer[i], extraParams); } __syncthreads(); aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x),extraParams); __syncthreads(); if (tid == 0) { z[0] = (Z) sPartials[0].index; } } } else { if (tid == 0) { auto tc = reinterpret_cast(reductionBuffer); tc[16384] = 0; z[0] = (Z) sPartials[0].index; } } } } BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES); } }