cavis/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu

110 lines
4.6 KiB
Plaintext

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include<ops/declarable/helpers/addBias.h>
#include <PointersManager.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo,
const void* vy, const Nd4jLong* yShapeInfo,
void* vz, const Nd4jLong* zShapeInfo,
const bool isNCHW) {
// bias [oC]
// if(input_rank == 4)
// input and output have same shapes: [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
// if(input_rank == 5)
// input and output have same shapes: [bS, oD, oH, oW, oC] (NHWC) or [bS, oD, oC, oH, oW] (NCHW)
const X* x = reinterpret_cast<const X*>(vx);
const Y* y = reinterpret_cast<const Y*>(vy);
X* z = reinterpret_cast<X*>(vz);
__shared__ int rank, channelPosition;
__shared__ Nd4jLong *sharedMem, len;
__shared__ bool xzSameOffsets, xzAreSame;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
rank = shape::rank(xShapeInfo); // xRank == zRank
xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
len = shape::length(xShapeInfo);
channelPosition = isNCHW ? 1 : rank - 1; // second or last
xzAreSame = x == z;
}
__syncthreads();
auto coords = sharedMem + threadIdx.x * rank;
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += blockDim.x * gridDim.x) {
shape::index2coords(i, xShapeInfo, coords);
const auto xOffsets = shape::getOffset(xShapeInfo, coords);
const auto zOffsets = xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords);
const auto yOffsets = shape::getOffset(yShapeInfo, coords + channelPosition);
if(xzAreSame)
z[zOffsets] += static_cast<X>(y[yOffsets]);
else
z[zOffsets] = x[xOffsets] + static_cast<X>(y[yOffsets]);
}
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
static void addBiasCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo,
const void* vy, const Nd4jLong* yShapeInfo,
void* vz, const Nd4jLong* zShapeInfo,
const bool isNCHW) {
addBiasCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, isNCHW);
}
//////////////////////////////////////////////////////////////////////////
void addBias(nd4j::graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) {
PointersManager manager(block.launchContext(), "addBias");
const int threadsPerBlock = MAX_NUM_THREADS;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
NDArray::prepareSpecialUse({&output}, {&input, &bias});
BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBiasCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), bias.getSpecialBuffer(), bias.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), isNCHW), FLOAT_TYPES, FLOAT_TYPES);
NDArray::registerSpecialUse({&output}, {&input, &bias});
manager.synchronize();
}
}
}
}