/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #ifndef LIBND4J_SPECIAL_ACCUMULATION_OPS_H #define LIBND4J_SPECIAL_ACCUMULATION_OPS_H #include #include #include //#include //#include namespace simdOps { template class LogSumExp { public: static const bool requiresSpecialAccumulation = true; constexpr static functions::ReduceType reduceType = functions::ReduceType::SUM; op_def static T startingValue(const T *input) { return (T) 0.0f; } op_def static Z merge(T old, T opOutput, Z *extraParams) { return opOutput + old; } op_def static T update(T old, T opOutput, Z *extraParams) { return opOutput + old; } op_def static Z op(T d1, T d2) { return nd4j::math::nd4j_exp(d1 - d2); } op_def static Z op(T d1, Z* extraParams) { return nd4j::math::nd4j_exp(static_cast(d1) - extraParams[0]); } op_def static Z postProcess(T reduction, Nd4jLong n, Z *extraParams) { return extraParams[0] + nd4j::math::nd4j_log(reduction); } #ifdef __CUDACC__ __device__ static inline void aggregatePartials(Z *sPartials, int tid, int numItems, Z *extraParams) { // start the shared memory loop on the next power of 2 less // than the block size. If block size is not a power of 2, // accumulate the intermediate sums in the remainder range. int floorPow2 = numItems; if (floorPow2 & (floorPow2 - 1)) { while (floorPow2 & (floorPow2 - 1)) { floorPow2 &= floorPow2 - 1; } if (tid >= floorPow2) { sPartials[tid - floorPow2] = update(sPartials[tid - floorPow2], sPartials[tid], extraParams); } __syncthreads(); } for (int activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { if (tid < activeThreads && tid + activeThreads < numItems) { sPartials[tid] = update(sPartials[tid], sPartials[tid + activeThreads], extraParams); } __syncthreads(); } } static inline __device__ void execSpecialCuda( T *dx, Nd4jLong *xShapeInfo, Z *extraParams, Z *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Z *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { // we assume that RESULT already holds max values //shared memory space for storing intermediate results __shared__ Z *sPartials; // __shared__ shape::TAD *tad; __shared__ Nd4jLong tadLength; __shared__ Nd4jLong numTads; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; sPartials = (Z *) shmem; tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); numTads = shape::length(xShapeInfo) / tadLength; } __syncthreads(); for (int r = blockIdx.x; r < numTads; r += gridDim.x) { auto tadOffsetForBlock = tadOffsets[r]; sPartials[threadIdx.x] = startingValue(dx + tadOffsetForBlock); for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); sPartials[threadIdx.x] = update(sPartials[threadIdx.x], op(dx[xOffset], result[r]), extraParams); } __syncthreads(); // aggregate. do NOT reduce for elements > tadLength aggregatePartials(sPartials, threadIdx.x, nd4j::math::nd4j_min(blockDim.x, tadLength), &result[r]); __syncthreads(); if (threadIdx.x == 0) result[r] = postProcess(sPartials[threadIdx.x], tadLength, &result[r]); } } #endif static void execSpecial(T *x, Nd4jLong *xShapeInfo, Z *extraParams, Z *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { Nd4jLong resultLength = shape::length(resultShapeInfoBuffer); auto tadOnlyShapeInfo = tadShapeInfo; auto tadOffsets = tadOffset; if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { if (dimensionLength < 1) return; auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); tadOnlyShapeInfo = tadPack.primaryShapeInfo(); tadOffsets = tadPack.primaryOffsets(); } const Nd4jLong tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); auto numTads = shape::length(xShapeInfo) / tadLength; auto tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); int tadsPerThread = resultLength / TAD_THRESHOLD; int num_threads = nd4j::math::nd4j_max(1, tadsPerThread); num_threads = nd4j::math::nd4j_min(num_threads, omp_get_max_threads()); if (tadEWS > 0 && (numTads == 1 || shape::isVector(tadOnlyShapeInfo) || shape::isScalar(tadOnlyShapeInfo))) { PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) for (int i = 0; i < resultLength; i++) { T *iter = x + tadOffsets[i]; T start = startingValue(iter); if (tadEWS == 1) { for (int j = 0; j < tadLength; j++) { start = update(start, op(iter[j], result[i]), extraParams); } } else { for (int j = 0; j < tadLength; j++) { start = update(start, op(iter[j * tadEWS], result[i]), extraParams); } } result[i] = postProcess(start, tadLength, &result[i]); } } else { PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) for (int i = 0; i < resultLength; i++) { auto offset = tadOffsets[i]; T start = startingValue(x + offset); for (int j = 0; j < tadLength; j++) { auto xOffset = offset + shape::getIndexOffset(j, tadOnlyShapeInfo, tadLength); start = update(start, op(x[xOffset], result[i]), extraParams); } result[i] = postProcess(start, tadLength, &result[i]);; } } } }; } #endif //LIBND4J_SPECIAL_ACCUMULATION_OPS_H