parent
7fa01288bb
commit
1a0a8b1497
|
@ -130,26 +130,6 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
BUILD_DOUBLE_TEMPLATE(template void padCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int mode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const void* vPadVal), LIBND4J_TYPES, INTEGER_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template void padCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int mode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const void* vPadVal), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) {
|
|
||||||
|
|
||||||
PointersManager manager(context, "pad");
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&output}, {&input, &paddings, &padValue});
|
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
|
||||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
|
||||||
const int sharedMem = 8 * threadsPerBlock * output.rankOf() + 128;
|
|
||||||
|
|
||||||
const auto xType = input.dataType();
|
|
||||||
const auto yType = paddings.dataType();
|
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), mode, input.getSpecialBuffer(), input.getSpecialShapeInfo(), paddings.getSpecialBuffer(), paddings.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), padValue.getSpecialBuffer()), LIBND4J_TYPES, INTEGER_TYPES);
|
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&output}, {&input, &paddings, &padValue});
|
|
||||||
manager.synchronize();
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) {
|
void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) {
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue