From d5b352273dc17aa28bceeaf9707201274c6d8518 Mon Sep 17 00:00:00 2001 From: shugeo Date: Thu, 10 Oct 2019 16:51:29 +0300 Subject: [PATCH] Implementation of cuda kernel for fake_quant_with_min_max_vars_per_channels op. Final revision. --- .../declarable/helpers/cuda/fake_quantization.cu | 13 +++++++++---- .../tests_cpu/layers_tests/DeclarableOpsTests10.cpp | 4 ++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu index 75a81f75a..893e016ed 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu @@ -81,19 +81,24 @@ namespace helpers { static __global__ void fakeQuantWithMinMaxKernel(T* input, Nd4jLong* inputShape, T* min, T* max, int lowIntBound, int upperIntBound, Nd4jLong channels, T* output, Nd4jLong* outputShape, Nd4jLong length) { + __shared__ int block; + if (threadIdx.x == 0) { + block = length / channels; + } + __syncthreads(); for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) { T scale, nudged_min, nudged_max; Nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); //auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudged_min, nudged_max, scale) { - for (auto e = threadIdx.x; e < (int)length; e += (int)channels) { - T val = input[shape::getIndexOffset(e + i, inputShape)]; + for (auto e = threadIdx.x; e < block; e += blockDim.x) { + T val = input[shape::getIndexOffset(e * channels + i, inputShape)]; if (val < nudged_min) { val = nudged_min; } else if (val > nudged_max) { val = nudged_max; } - output[shape::getIndexOffset(e + i, outputShape)] = (math::nd4j_floor((val - nudged_min) / scale + T(0.5)) * scale + nudged_min); + output[shape::getIndexOffset(e* channels + i, outputShape)] = (math::nd4j_floor((val - nudged_min) / scale + T(0.5)) * scale + nudged_min); }; } @@ -111,7 +116,7 @@ namespace helpers { T* outputBuf = output->dataBuffer()->specialAsT(); T* minBuf = min->dataBuffer()->specialAsT(); T* maxBuf = max->dataBuffer()->specialAsT(); - fakeQuantWithMinMaxKernel<<<1, 1, 256, *stream>>>(inputBuf, input->specialShapeInfo(), + fakeQuantWithMinMaxKernel<<<128, 256, 256, *stream>>>(inputBuf, input->specialShapeInfo(), minBuf, maxBuf, lowIntBound, upperIntBound, channels, outputBuf, output->specialShapeInfo(), length); NDArray::registerSpecialUse({output}, {min, max, input}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 6ae982cf8..0652a398e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2127,8 +2127,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); - result->printBuffer("Quantized"); - exp.printBuffer("Expected"); +// result->printBuffer("Quantized"); +// exp.printBuffer("Expected"); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result));