Shyrma softmax (#209)

* - provide new cuda kernel for softmax

Signed-off-by: Yurii <yurii@skymind.io>

* - further work on cuda kernel for softmax

Signed-off-by: Yurii <yurii@skymind.io>

* - correction cuda kernel for softmax

Signed-off-by: Yurii <yurii@skymind.io>
master
Yurii Shyrma 2019-08-30 20:31:05 +03:00 committed by GitHub
parent bdc3eacafd
commit 00fd50cee2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 65 additions and 19 deletions

View File

@ -24,6 +24,7 @@
#include <ShapeUtils.h> #include <ShapeUtils.h>
#include <numeric> #include <numeric>
#include <PointersManager.h> #include <PointersManager.h>
#include <helpers/ConstantTadHelper.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -196,7 +197,7 @@ void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray&
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
__global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { __device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
// logic of this kernel is based on assumption gridDim = 1 // logic of this kernel is based on assumption gridDim = 1
@ -210,7 +211,7 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
extern __shared__ char shared[]; extern __shared__ char shared[];
shmem = reinterpret_cast<T*>(shared); shmem = reinterpret_cast<T*>(shared);
len = shape::length(xzShapeInfo); len = shape::length(xShapeInfo);
numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x)
} }
__syncthreads(); __syncthreads();
@ -222,8 +223,8 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx < len) { if(elemIdx < len) {
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len);
shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : nd4j::math::nd4j_max<T>(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp shmem[threadIdx.x] = (threadIdx.x != 0) ? x[xOffset] : nd4j::math::nd4j_max<T>(x[xOffset], temp); // take into account max element evaluated on previous iteration and stored in temp
} }
else else
shmem[threadIdx.x] = -DataTypeUtils::max<T>(); // FIXME: what if T is unsigned ?? shmem[threadIdx.x] = -DataTypeUtils::max<T>(); // FIXME: what if T is unsigned ??
@ -248,9 +249,10 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx < len) { if(elemIdx < len) {
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len);
z[offset] = nd4j::math::nd4j_exp<T, T>(x[offset] - max); const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len);
shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp z[zOffset] = nd4j::math::nd4j_exp<T, T>(x[xOffset] - max);
shmem[threadIdx.x] = (threadIdx.x != 0) ? z[zOffset] : (z[zOffset] + temp); // take into account sum element evaluated on previous iteration and stored in temp
} }
else else
shmem[threadIdx.x] = 0; shmem[threadIdx.x] = 0;
@ -270,43 +272,87 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
for (int i = 0; i < numOfIters; ++i) { for (int i = 0; i < numOfIters; ++i) {
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx >= len) continue; if(elemIdx >= len) continue;
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len);
z[offset] /= shmem[0]; z[zOffset] /= shmem[0];
} }
} }
template<typename T>
__global__ void softMaxForVectorCudaGlobal(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
softMaxForVectorCuda<T>(vx, xShapeInfo, vz, zShapeInfo);
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
softMaxForVectorCuda<T><<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>(vx, xzShapeInfo, vz); softMaxForVectorCudaGlobal<T><<<1, MAX_NUM_THREADS / 4 , (MAX_NUM_THREADS / 4) * sizeof(T) + 512, *stream>>>(vx, xShapeInfo, vz, zShapeInfo);
} }
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ static void softMaxCuda(const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) {
const auto x = reinterpret_cast<const T*>(vx);
auto z = reinterpret_cast<T*>(vz);
const auto* xTad = x + xOffsets[blockIdx.x];
auto* zTad = z + zOffsets[blockIdx.x];
softMaxForVectorCuda<T>(xTad, xTadShapeInfo, zTad, zTadShapeInfo);
}
///////////////////////////////////////////////////////////////////
template<typename T>
static void softMaxCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) {
softMaxCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xTadShapeInfo, xOffsets, vz, zTadShapeInfo, zOffsets);
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void softmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { void softmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) {
if(!input.isActualOnDeviceSide()) input.syncToDevice(); if(!input.isActualOnDeviceSide()) input.syncToDevice();
const int rank = input.rankOf(); const int rank = input.rankOf();
PointersManager manager(context, "helpers::softmax");
if(input.isVector()) { if(input.isVector()) {
if(rank == 1 || input.sizeAt(dimension) != 1) { if(rank == 1 || input.sizeAt(dimension) != 1) {
BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer()), FLOAT_TYPES); NDArray::prepareSpecialUse({&output}, {&input});
input.tickReadDevice(); BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), FLOAT_TYPES);
NDArray::registerSpecialUse({&output}, {&input});
} }
else else
output = 1.; output = 1.;
} }
else { else {
auto maxAlongDim = const_cast<NDArray&>(input).reduceAlongDims(reduce::Max, {dimension}, true); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), {dimension});
(input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), {dimension});
auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true);
output /= sumAlongDim; const int threadsPerBlock = MAX_NUM_THREADS / 4;
input.tickReadDevice(); const int blocksPerGrid = packZ.numberOfTads();
const int sharedMem = input.sizeOfT() * threadsPerBlock + 512;
NDArray::prepareSpecialUse({&output}, {&input});
BUILD_SINGLE_SELECTOR(input.dataType(), softMaxCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), output.specialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets()), FLOAT_TYPES);
NDArray::registerSpecialUse({&output}, {&input});
// auto maxAlongDim = const_cast<NDArray&>(input).reduceAlongDims(reduce::Max, {dimension}, true);
// (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily
// auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true);
// output /= sumAlongDim;
// input.tickReadDevice();
} }
PointersManager manager(context, "helpers::softmax");
manager.synchronize(); manager.synchronize();
output.tickWriteDevice(); output.tickWriteDevice();