Shyrma softmax (#209)
* - provide new cuda kernel for softmax Signed-off-by: Yurii <yurii@skymind.io> * - further work on cuda kernel for softmax Signed-off-by: Yurii <yurii@skymind.io> * - correction cuda kernel for softmax Signed-off-by: Yurii <yurii@skymind.io>master
parent
bdc3eacafd
commit
00fd50cee2
|
@ -24,6 +24,7 @@
|
||||||
#include <ShapeUtils.h>
|
#include <ShapeUtils.h>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <PointersManager.h>
|
#include <PointersManager.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -196,7 +197,7 @@ void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray&
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) {
|
__device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
// logic of this kernel is based on assumption gridDim = 1
|
// logic of this kernel is based on assumption gridDim = 1
|
||||||
|
|
||||||
|
@ -210,7 +211,7 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ char shared[];
|
extern __shared__ char shared[];
|
||||||
shmem = reinterpret_cast<T*>(shared);
|
shmem = reinterpret_cast<T*>(shared);
|
||||||
len = shape::length(xzShapeInfo);
|
len = shape::length(xShapeInfo);
|
||||||
numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x)
|
numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x)
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
@ -222,8 +223,8 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
|
||||||
|
|
||||||
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
|
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
|
||||||
if(elemIdx < len) {
|
if(elemIdx < len) {
|
||||||
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len);
|
const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len);
|
||||||
shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : nd4j::math::nd4j_max<T>(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp
|
shmem[threadIdx.x] = (threadIdx.x != 0) ? x[xOffset] : nd4j::math::nd4j_max<T>(x[xOffset], temp); // take into account max element evaluated on previous iteration and stored in temp
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
shmem[threadIdx.x] = -DataTypeUtils::max<T>(); // FIXME: what if T is unsigned ??
|
shmem[threadIdx.x] = -DataTypeUtils::max<T>(); // FIXME: what if T is unsigned ??
|
||||||
|
@ -248,9 +249,10 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
|
||||||
|
|
||||||
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
|
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
|
||||||
if(elemIdx < len) {
|
if(elemIdx < len) {
|
||||||
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len);
|
const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len);
|
||||||
z[offset] = nd4j::math::nd4j_exp<T, T>(x[offset] - max);
|
const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len);
|
||||||
shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp
|
z[zOffset] = nd4j::math::nd4j_exp<T, T>(x[xOffset] - max);
|
||||||
|
shmem[threadIdx.x] = (threadIdx.x != 0) ? z[zOffset] : (z[zOffset] + temp); // take into account sum element evaluated on previous iteration and stored in temp
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
shmem[threadIdx.x] = 0;
|
shmem[threadIdx.x] = 0;
|
||||||
|
@ -270,43 +272,87 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
|
||||||
for (int i = 0; i < numOfIters; ++i) {
|
for (int i = 0; i < numOfIters; ++i) {
|
||||||
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
|
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
|
||||||
if(elemIdx >= len) continue;
|
if(elemIdx >= len) continue;
|
||||||
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len);
|
const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len);
|
||||||
z[offset] /= shmem[0];
|
z[zOffset] /= shmem[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__global__ void softMaxForVectorCudaGlobal(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
softMaxForVectorCuda<T>(vx, xShapeInfo, vz, zShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz) {
|
linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
softMaxForVectorCuda<T><<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>(vx, xzShapeInfo, vz);
|
softMaxForVectorCudaGlobal<T><<<1, MAX_NUM_THREADS / 4 , (MAX_NUM_THREADS / 4) * sizeof(T) + 512, *stream>>>(vx, xShapeInfo, vz, zShapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void softMaxCuda(const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
|
||||||
|
void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) {
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
|
auto z = reinterpret_cast<T*>(vz);
|
||||||
|
|
||||||
|
const auto* xTad = x + xOffsets[blockIdx.x];
|
||||||
|
auto* zTad = z + zOffsets[blockIdx.x];
|
||||||
|
|
||||||
|
softMaxForVectorCuda<T>(xTad, xTadShapeInfo, zTad, zTadShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void softMaxCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||||
|
const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
|
||||||
|
void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) {
|
||||||
|
|
||||||
|
softMaxCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xTadShapeInfo, xOffsets, vz, zTadShapeInfo, zOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void softmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) {
|
void softmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) {
|
||||||
|
|
||||||
if(!input.isActualOnDeviceSide()) input.syncToDevice();
|
if(!input.isActualOnDeviceSide()) input.syncToDevice();
|
||||||
const int rank = input.rankOf();
|
const int rank = input.rankOf();
|
||||||
|
|
||||||
|
PointersManager manager(context, "helpers::softmax");
|
||||||
|
|
||||||
if(input.isVector()) {
|
if(input.isVector()) {
|
||||||
|
|
||||||
if(rank == 1 || input.sizeAt(dimension) != 1) {
|
if(rank == 1 || input.sizeAt(dimension) != 1) {
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer()), FLOAT_TYPES);
|
NDArray::prepareSpecialUse({&output}, {&input});
|
||||||
input.tickReadDevice();
|
BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input});
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
output = 1.;
|
output = 1.;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
auto maxAlongDim = const_cast<NDArray&>(input).reduceAlongDims(reduce::Max, {dimension}, true);
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), {dimension});
|
||||||
(input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), {dimension});
|
||||||
auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true);
|
|
||||||
output /= sumAlongDim;
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
input.tickReadDevice();
|
const int blocksPerGrid = packZ.numberOfTads();
|
||||||
|
const int sharedMem = input.sizeOfT() * threadsPerBlock + 512;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input});
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), softMaxCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), output.specialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets()), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input});
|
||||||
|
|
||||||
|
// auto maxAlongDim = const_cast<NDArray&>(input).reduceAlongDims(reduce::Max, {dimension}, true);
|
||||||
|
// (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily
|
||||||
|
// auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true);
|
||||||
|
// output /= sumAlongDim;
|
||||||
|
// input.tickReadDevice();
|
||||||
}
|
}
|
||||||
|
|
||||||
PointersManager manager(context, "helpers::softmax");
|
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
|
|
||||||
output.tickWriteDevice();
|
output.tickWriteDevice();
|
||||||
|
|
Loading…
Reference in New Issue