diff --git a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp index 21163d44c..6ea2992b9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp @@ -25,43 +25,54 @@ namespace nd4j { namespace ops { namespace helpers { + // + // nudge - nudged min max over scale + // scale = (Max - Min) / (quantMax - quantMin) + // quantMin = 0 or 1, quantMax = 2^b - 1 == (1 << b) - 1 + // template - static void nudge(T min, T max, int quant_min, int quant_max, T* scale, T* nudged_min, T* nudged_max) { - T quant_max_float = static_cast(quant_max); - T quant_min_float = static_cast(quant_min); - *scale = (max - min) / (quant_max_float - quant_min_float); - auto zero_point_from_min = quant_min_float - min / *scale; - uint16_t const nudged_zero_point = [zero_point_from_min, quant_min, quant_max, quant_max_float, quant_min_float] { - if (zero_point_from_min < quant_min_float) { - return static_cast(quant_min); + static void nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) { + // floating point instead integers + T quantMaxF = static_cast(quantMax); + T quantMinF = static_cast(quantMin); + // compute scale + *scale = (max - min) / (quantMaxF - quantMinF); + // compute left bound point + auto zeroPointFromMin = quantMinF - min / *scale; + // bound zero point to conform with range [0 or 1, 2^b - 1] + uint16_t const nudged_zero_point = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] { + if (zeroPointFromMin < quantMinF) { + return static_cast(quantMin); } - if (zero_point_from_min > quant_max_float) { - return static_cast(quant_max); + if (zeroPointFromMin > quantMaxF) { + return static_cast(quantMax); } - return nd4j::math::nd4j_round(zero_point_from_min); - }(); - *nudged_min = (quant_min_float - nudged_zero_point) * (*scale); - *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); + return nd4j::math::nd4j_round(zeroPointFromMin); + }(); + // compute nudged min and max with computed nudged zero point + *nudgedMin = (quantMinF - nudged_zero_point) * (*scale); + *nudgedMax = (quantMaxF - nudged_zero_point) * (*scale); } template void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - int lowIntBound = narrowed ? 1 : 0; - int upperIntBound = (1 << numBits) - 1; - auto channels = input->sizeAt(-1); + int lowIntBound = narrowed ? 1 : 0; // 0 or 1 + int upperIntBound = (1 << numBits) - 1; // 2^b - 1 + auto channels = input->sizeAt(-1); // last dimension PRAGMA_OMP_PARALLEL_FOR for (auto i = 0; i < channels; i++) { T scale, nudged_min, nudged_max; + // nudge min and max first, with scale computing nudge(min->t(i), max->t(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); - + // slide using last dimension and process all for given channel for (auto e = 0; e < input->lengthOf(); e += channels) { T val = input->t(e + i); if ( val <= nudged_min) val = nudged_min; else if (val >= nudged_max) val = nudged_max; - + // quantization itself output->t(e + i) = math::nd4j_floor((val - nudged_min)/scale + T(0.5)) * scale + nudged_min; } } @@ -73,16 +84,17 @@ namespace helpers { int upperIntBound = (1 << numBits) - 1; T nudgedMin, nudgedMax, scale; - + // nudge with given min and max and compute scale and nudged min and max nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); - + // quantization as one auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { - T val = x; + T val = x; // boundign value between nudged min and max if (val < nudgedMin) { val = nudgedMin; } else if (val > nudgedMax) val = nudgedMax; + // converse value with scale and shifted with nudged min return (nd4j::math::nd4j_floor((val - nudgedMin)/scale + T(0.5)) * scale + nudgedMin); };