From 753565145c7c4d1cdbfe3380a26e003083c29821 Mon Sep 17 00:00:00 2001 From: shugeo Date: Thu, 10 Oct 2019 14:00:49 +0300 Subject: [PATCH] Refactored fake_quant_with_min_max_vars op cuda implementation. --- .../helpers/cuda/fake_quantization.cu | 71 ++++++++----------- 1 file changed, 30 insertions(+), 41 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu index 4e62aafa8..9def35152 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu @@ -34,58 +34,47 @@ namespace helpers { // output - output tensor // template - void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - int lowIntBound = narrowed?1:0; - int upperIntBound = 1 << numBits - 1; - min->syncToHost(); - max->syncToHost(); - const float quant_min_float = static_cast(lowIntBound); - const float quant_max_float = static_cast(upperIntBound); - T scale = (max->t(0) - min->t(0)) / (quant_max_float - quant_min_float); - const T zero_point_from_min = quant_min_float - min->t(0) / scale; - - const uint16_t nudged_zero_point = [zero_point_from_min, lowIntBound, - quant_min_float, upperIntBound, - quant_max_float] { + static void Nudge(T min, T max, int quant_min, int quant_max, T* scale, T* nudged_min, T* nudged_max) { + T quant_max_float = static_cast(quant_max); + T quant_min_float = static_cast(quant_min); + *scale = (max - min) / (quant_max_float - quant_min_float); + auto zero_point_from_min = quant_min_float - min / *scale; + uint16_t const nudged_zero_point = [zero_point_from_min, quant_min, quant_max, quant_max_float, quant_min_float] { if (zero_point_from_min < quant_min_float) { - return static_cast(lowIntBound); + return static_cast(quant_min); } if (zero_point_from_min > quant_max_float) { - return static_cast(upperIntBound); + return static_cast(quant_max); } - return static_cast(roundf(zero_point_from_min)); + return nd4j::math::nd4j_round(zero_point_from_min); }(); + *nudged_min = (quant_min_float - nudged_zero_point) * (*scale); + *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); + } - auto nudged_min = (quant_min_float - nudged_zero_point) * (scale); - auto nudged_max = (quant_max_float - nudged_zero_point) * (scale); + template + void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { + int lowIntBound = narrowed?1:0; + int upperIntBound = (1 << numBits) - 1; + min->syncToHost(); + max->syncToHost(); + T scale, nudged_min, nudged_max; + Nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); - auto wiseMax = LAMBDA_T(x, nudged_min) { + auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudged_min, nudged_max, scale) { + T val = x; if (x < nudged_min) { - return nudged_min; + val = nudged_min; } - return x; + else if (x > nudged_max) { + val = nudged_max; + } + else + val = x; + return (math::nd4j_floor((val - nudged_min) / scale + T(0.5)) * scale + nudged_min); }; - auto wiseMin = LAMBDA_T(x, nudged_max) { - if (x > nudged_max) { - return nudged_max; - } - return x; - }; - - auto scaleTensor(*input); - auto clamped(*input); - scaleTensor.assign(scale); - input->applyLambda(wiseMin, &clamped); - - clamped.applyLambda(wiseMax, output); - *output -= nudged_min; - - (*output) /= scaleTensor; - (*output) += T(0.5f); - output->applyTransform(transform::Floor, nullptr, nullptr); - (*output) *= scaleTensor; - (*output) += nudged_min; + input->applyLambda(wiseMinMaxAndSoOn, output); } void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {