parent
11d148a5eb
commit
f990b2486d
|
@ -91,19 +91,55 @@ static void addBiasCudaLauncher(const int blocksPerGrid, const int threadsPerBlo
|
||||||
addBiasCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, isNCHW);
|
addBiasCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, isNCHW);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Y>
|
||||||
|
__global__ static void addBias2DCuda( const void* vx,
|
||||||
|
const void* vy,
|
||||||
|
void* vz,
|
||||||
|
uint32_t blocks, uint32_t length) {
|
||||||
|
|
||||||
|
auto y = reinterpret_cast<const Y*>(vy);
|
||||||
|
|
||||||
|
for (uint32_t b = blockIdx.x; b < blocks; b += gridDim.x) {
|
||||||
|
auto x = reinterpret_cast<const X*>(vx) + length * b;
|
||||||
|
auto z = reinterpret_cast<X*>(vz) + length * b;
|
||||||
|
|
||||||
|
for (uint32_t e = threadIdx.x; e < length; e += blockDim.x) {
|
||||||
|
z[e] = x[e] + y[e];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Y>
|
||||||
|
static void addBias2DCudaLauncher(const cudaStream_t *stream, const void* vx,
|
||||||
|
const void* vy,
|
||||||
|
void* vz,
|
||||||
|
uint32_t blocks, uint32_t length) {
|
||||||
|
|
||||||
|
addBias2DCuda<X,Y><<<256, 1024, 128, *stream>>>(vx, vy, vz, blocks, length);
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void addBias(sd::graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) {
|
void addBias(sd::graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) {
|
||||||
|
|
||||||
PointersManager manager(block.launchContext(), "addBias");
|
PointersManager manager(block.launchContext(), "addBias");
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input, &bias});
|
||||||
|
|
||||||
|
if (input.rankOf() == 2 && bias.rankOf() == 1 && input.ordering() == 'c' && output.ordering() == 'c' && input.ews() == 1 && bias.ews() == 1 && input.sizeAt(1) == bias.sizeAt(0)) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias2DCudaLauncher,
|
||||||
|
(block.launchContext()->getCudaStream(), input.getSpecialBuffer(), bias.getSpecialBuffer(), output.specialBuffer(), input.sizeAt(0), bias.sizeAt(0)),
|
||||||
|
FLOAT_TYPES, FLOAT_TYPES);
|
||||||
|
} else {
|
||||||
|
// default case
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&output}, {&input, &bias});
|
|
||||||
BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBiasCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), bias.getSpecialBuffer(), bias.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), isNCHW), FLOAT_TYPES, FLOAT_TYPES);
|
|
||||||
NDArray::registerSpecialUse({&output}, {&input, &bias});
|
|
||||||
|
|
||||||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBiasCudaLauncher,
|
||||||
|
(blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), bias.getSpecialBuffer(), bias.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), isNCHW),
|
||||||
|
FLOAT_TYPES, FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input, &bias});
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue