From 1a0a8b1497631690f7a53ea3b8d79a9fc97efb4c Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 8 Aug 2019 20:28:14 +0300 Subject: [PATCH] duplicate pad impl Signed-off-by: raver119 --- .../ops/declarable/helpers/cuda/pad.cu | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/pad.cu b/libnd4j/include/ops/declarable/helpers/cuda/pad.cu index f6b9d27fa..b268e6366 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/pad.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/pad.cu @@ -130,26 +130,6 @@ namespace nd4j { } BUILD_DOUBLE_TEMPLATE(template void padCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int mode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const void* vPadVal), LIBND4J_TYPES, INTEGER_TYPES); -/////////////////////////////////////////////////////////////////// - void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) { - - PointersManager manager(context, "pad"); - - NDArray::prepareSpecialUse({&output}, {&input, &paddings, &padValue}); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = 8 * threadsPerBlock * output.rankOf() + 128; - - const auto xType = input.dataType(); - const auto yType = paddings.dataType(); - - BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), mode, input.getSpecialBuffer(), input.getSpecialShapeInfo(), paddings.getSpecialBuffer(), paddings.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), padValue.getSpecialBuffer()), LIBND4J_TYPES, INTEGER_TYPES); - - NDArray::registerSpecialUse({&output}, {&input, &paddings, &padValue}); - manager.synchronize(); - } - /////////////////////////////////////////////////////////////////// void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) {