Shyrma softmax (#209)

* - provide new cuda kernel for softmax

Signed-off-by: Yurii <yurii@skymind.io>

* - further work on cuda kernel for softmax

Signed-off-by: Yurii <yurii@skymind.io>

* - correction cuda kernel for softmax

Signed-off-by: Yurii <yurii@skymind.io>
master
Yurii Shyrma 2019-08-30 20:31:05 +03:00 committed by GitHub
parent bdc3eacafd
commit 00fd50cee2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 65 additions and 19 deletions

View File

@ -24,6 +24,7 @@
#include <ShapeUtils.h>
#include <numeric>
#include <PointersManager.h>
#include <helpers/ConstantTadHelper.h>
namespace nd4j {
namespace ops {
@ -196,7 +197,7 @@ void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray&
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) {
__device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
// logic of this kernel is based on assumption gridDim = 1
@ -210,7 +211,7 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
if (threadIdx.x == 0) {
extern __shared__ char shared[];
shmem = reinterpret_cast<T*>(shared);
len = shape::length(xzShapeInfo);
len = shape::length(xShapeInfo);
numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x)
}
__syncthreads();
@ -222,8 +223,8 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx < len) {
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len);
shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : nd4j::math::nd4j_max<T>(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp
const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len);
shmem[threadIdx.x] = (threadIdx.x != 0) ? x[xOffset] : nd4j::math::nd4j_max<T>(x[xOffset], temp); // take into account max element evaluated on previous iteration and stored in temp
}
else
shmem[threadIdx.x] = -DataTypeUtils::max<T>(); // FIXME: what if T is unsigned ??
@ -248,9 +249,10 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx < len) {
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len);
z[offset] = nd4j::math::nd4j_exp<T, T>(x[offset] - max);
shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp
const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len);
const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len);
z[zOffset] = nd4j::math::nd4j_exp<T, T>(x[xOffset] - max);
shmem[threadIdx.x] = (threadIdx.x != 0) ? z[zOffset] : (z[zOffset] + temp); // take into account sum element evaluated on previous iteration and stored in temp
}
else
shmem[threadIdx.x] = 0;
@ -270,43 +272,87 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
for (int i = 0; i < numOfIters; ++i) {
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx >= len) continue;
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len);
z[offset] /= shmem[0];
const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len);
z[zOffset] /= shmem[0];
}
}
template<typename T>
__global__ void softMaxForVectorCudaGlobal(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
softMaxForVectorCuda<T>(vx, xShapeInfo, vz, zShapeInfo);
}
///////////////////////////////////////////////////////////////////
template <typename T>
linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz) {
linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
softMaxForVectorCuda<T><<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>(vx, xzShapeInfo, vz);
softMaxForVectorCudaGlobal<T><<<1, MAX_NUM_THREADS / 4 , (MAX_NUM_THREADS / 4) * sizeof(T) + 512, *stream>>>(vx, xShapeInfo, vz, zShapeInfo);
}
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ static void softMaxCuda(const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) {
const auto x = reinterpret_cast<const T*>(vx);
auto z = reinterpret_cast<T*>(vz);
const auto* xTad = x + xOffsets[blockIdx.x];
auto* zTad = z + zOffsets[blockIdx.x];
softMaxForVectorCuda<T>(xTad, xTadShapeInfo, zTad, zTadShapeInfo);
}
///////////////////////////////////////////////////////////////////
template<typename T>
static void softMaxCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) {
softMaxCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xTadShapeInfo, xOffsets, vz, zTadShapeInfo, zOffsets);
}
//////////////////////////////////////////////////////////////////////////
void softmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) {
if(!input.isActualOnDeviceSide()) input.syncToDevice();
const int rank = input.rankOf();
PointersManager manager(context, "helpers::softmax");
if(input.isVector()) {
if(rank == 1 || input.sizeAt(dimension) != 1) {
BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer()), FLOAT_TYPES);
input.tickReadDevice();
NDArray::prepareSpecialUse({&output}, {&input});
BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), FLOAT_TYPES);
NDArray::registerSpecialUse({&output}, {&input});
}
else
output = 1.;
}
else {
auto maxAlongDim = const_cast<NDArray&>(input).reduceAlongDims(reduce::Max, {dimension}, true);
(input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily
auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true);
output /= sumAlongDim;
input.tickReadDevice();
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), {dimension});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), {dimension});
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = packZ.numberOfTads();
const int sharedMem = input.sizeOfT() * threadsPerBlock + 512;
NDArray::prepareSpecialUse({&output}, {&input});
BUILD_SINGLE_SELECTOR(input.dataType(), softMaxCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), output.specialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets()), FLOAT_TYPES);
NDArray::registerSpecialUse({&output}, {&input});
// auto maxAlongDim = const_cast<NDArray&>(input).reduceAlongDims(reduce::Max, {dimension}, true);
// (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily
// auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true);
// output /= sumAlongDim;
// input.tickReadDevice();
}
PointersManager manager(context, "helpers::softmax");
manager.synchronize();
output.tickWriteDevice();