Shyrma softmax (#209)
* - provide new cuda kernel for softmax Signed-off-by: Yurii <yurii@skymind.io> * - further work on cuda kernel for softmax Signed-off-by: Yurii <yurii@skymind.io> * - correction cuda kernel for softmax Signed-off-by: Yurii <yurii@skymind.io>master
parent
bdc3eacafd
commit
00fd50cee2
|
@ -24,6 +24,7 @@
|
|||
#include <ShapeUtils.h>
|
||||
#include <numeric>
|
||||
#include <PointersManager.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -196,7 +197,7 @@ void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray&
|
|||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) {
|
||||
__device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
// logic of this kernel is based on assumption gridDim = 1
|
||||
|
||||
|
@ -210,7 +211,7 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
|
|||
if (threadIdx.x == 0) {
|
||||
extern __shared__ char shared[];
|
||||
shmem = reinterpret_cast<T*>(shared);
|
||||
len = shape::length(xzShapeInfo);
|
||||
len = shape::length(xShapeInfo);
|
||||
numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x)
|
||||
}
|
||||
__syncthreads();
|
||||
|
@ -222,8 +223,8 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
|
|||
|
||||
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
|
||||
if(elemIdx < len) {
|
||||
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len);
|
||||
shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : nd4j::math::nd4j_max<T>(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp
|
||||
const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len);
|
||||
shmem[threadIdx.x] = (threadIdx.x != 0) ? x[xOffset] : nd4j::math::nd4j_max<T>(x[xOffset], temp); // take into account max element evaluated on previous iteration and stored in temp
|
||||
}
|
||||
else
|
||||
shmem[threadIdx.x] = -DataTypeUtils::max<T>(); // FIXME: what if T is unsigned ??
|
||||
|
@ -248,9 +249,10 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
|
|||
|
||||
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
|
||||
if(elemIdx < len) {
|
||||
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len);
|
||||
z[offset] = nd4j::math::nd4j_exp<T, T>(x[offset] - max);
|
||||
shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp
|
||||
const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len);
|
||||
const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len);
|
||||
z[zOffset] = nd4j::math::nd4j_exp<T, T>(x[xOffset] - max);
|
||||
shmem[threadIdx.x] = (threadIdx.x != 0) ? z[zOffset] : (z[zOffset] + temp); // take into account sum element evaluated on previous iteration and stored in temp
|
||||
}
|
||||
else
|
||||
shmem[threadIdx.x] = 0;
|
||||
|
@ -270,43 +272,87 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo
|
|||
for (int i = 0; i < numOfIters; ++i) {
|
||||
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
|
||||
if(elemIdx >= len) continue;
|
||||
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len);
|
||||
z[offset] /= shmem[0];
|
||||
const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len);
|
||||
z[zOffset] /= shmem[0];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__global__ void softMaxForVectorCudaGlobal(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
softMaxForVectorCuda<T>(vx, xShapeInfo, vz, zShapeInfo);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz) {
|
||||
linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
softMaxForVectorCuda<T><<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>(vx, xzShapeInfo, vz);
|
||||
softMaxForVectorCudaGlobal<T><<<1, MAX_NUM_THREADS / 4 , (MAX_NUM_THREADS / 4) * sizeof(T) + 512, *stream>>>(vx, xShapeInfo, vz, zShapeInfo);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ static void softMaxCuda(const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
|
||||
void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) {
|
||||
|
||||
const auto x = reinterpret_cast<const T*>(vx);
|
||||
auto z = reinterpret_cast<T*>(vz);
|
||||
|
||||
const auto* xTad = x + xOffsets[blockIdx.x];
|
||||
auto* zTad = z + zOffsets[blockIdx.x];
|
||||
|
||||
softMaxForVectorCuda<T>(xTad, xTadShapeInfo, zTad, zTadShapeInfo);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
static void softMaxCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||
const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
|
||||
void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) {
|
||||
|
||||
softMaxCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xTadShapeInfo, xOffsets, vz, zTadShapeInfo, zOffsets);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void softmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) {
|
||||
|
||||
if(!input.isActualOnDeviceSide()) input.syncToDevice();
|
||||
const int rank = input.rankOf();
|
||||
|
||||
PointersManager manager(context, "helpers::softmax");
|
||||
|
||||
if(input.isVector()) {
|
||||
|
||||
if(rank == 1 || input.sizeAt(dimension) != 1) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer()), FLOAT_TYPES);
|
||||
input.tickReadDevice();
|
||||
NDArray::prepareSpecialUse({&output}, {&input});
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), FLOAT_TYPES);
|
||||
NDArray::registerSpecialUse({&output}, {&input});
|
||||
}
|
||||
else
|
||||
output = 1.;
|
||||
}
|
||||
else {
|
||||
|
||||
auto maxAlongDim = const_cast<NDArray&>(input).reduceAlongDims(reduce::Max, {dimension}, true);
|
||||
(input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily
|
||||
auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true);
|
||||
output /= sumAlongDim;
|
||||
input.tickReadDevice();
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), {dimension});
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), {dimension});
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||
const int blocksPerGrid = packZ.numberOfTads();
|
||||
const int sharedMem = input.sizeOfT() * threadsPerBlock + 512;
|
||||
|
||||
NDArray::prepareSpecialUse({&output}, {&input});
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), softMaxCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), output.specialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets()), FLOAT_TYPES);
|
||||
NDArray::registerSpecialUse({&output}, {&input});
|
||||
|
||||
// auto maxAlongDim = const_cast<NDArray&>(input).reduceAlongDims(reduce::Max, {dimension}, true);
|
||||
// (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily
|
||||
// auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true);
|
||||
// output /= sumAlongDim;
|
||||
// input.tickReadDevice();
|
||||
}
|
||||
|
||||
PointersManager manager(context, "helpers::softmax");
|
||||
|
||||
manager.synchronize();
|
||||
|
||||
output.tickWriteDevice();
|
||||
|
|
Loading…
Reference in New Issue