diff --git a/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu b/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu index 3fc5d42c9..dad5a5b06 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu @@ -91,19 +91,55 @@ static void addBiasCudaLauncher(const int blocksPerGrid, const int threadsPerBlo addBiasCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, isNCHW); } +template +__global__ static void addBias2DCuda( const void* vx, + const void* vy, + void* vz, + uint32_t blocks, uint32_t length) { + + auto y = reinterpret_cast(vy); + + for (uint32_t b = blockIdx.x; b < blocks; b += gridDim.x) { + auto x = reinterpret_cast(vx) + length * b; + auto z = reinterpret_cast(vz) + length * b; + + for (uint32_t e = threadIdx.x; e < length; e += blockDim.x) { + z[e] = x[e] + y[e]; + } + } +} + +template +static void addBias2DCudaLauncher(const cudaStream_t *stream, const void* vx, + const void* vy, + void* vz, + uint32_t blocks, uint32_t length) { + + addBias2DCuda<<<256, 1024, 128, *stream>>>(vx, vy, vz, blocks, length); +} + ////////////////////////////////////////////////////////////////////////// void addBias(sd::graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) { PointersManager manager(block.launchContext(), "addBias"); - - const int threadsPerBlock = MAX_NUM_THREADS/2; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - NDArray::prepareSpecialUse({&output}, {&input, &bias}); - BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBiasCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), bias.getSpecialBuffer(), bias.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), isNCHW), FLOAT_TYPES, FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input, &bias}); + if (input.rankOf() == 2 && bias.rankOf() == 1 && input.ordering() == 'c' && output.ordering() == 'c' && input.ews() == 1 && bias.ews() == 1 && input.sizeAt(1) == bias.sizeAt(0)) { + BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias2DCudaLauncher, + (block.launchContext()->getCudaStream(), input.getSpecialBuffer(), bias.getSpecialBuffer(), output.specialBuffer(), input.sizeAt(0), bias.sizeAt(0)), + FLOAT_TYPES, FLOAT_TYPES); + } else { + // default case + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + + BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBiasCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), bias.getSpecialBuffer(), bias.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), isNCHW), + FLOAT_TYPES, FLOAT_TYPES); + } + NDArray::registerSpecialUse({&output}, {&input, &bias}); manager.synchronize(); }