diff --git a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp index a2d0c3c59..28437359e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp @@ -74,6 +74,20 @@ namespace helpers { } } + template + static void WiseMinMax(NDArray* input, T min, T max, NDArray* output) { + auto wiseMinMax = LAMBDA_T(x, min, max) { + if (x < min) { + return min; + } + else if (x > max) + return max; + return x; + }; + + input->applyLambda(wiseMinMax, output); + } + template void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { int lowIntBound = narrowed ? 1 : 0; @@ -81,62 +95,16 @@ namespace helpers { const float quant_min_float = static_cast(lowIntBound); const float quant_max_float = static_cast(upperIntBound); - T scale = (max->t(0) - min->t(0)) / (quant_max_float - quant_min_float); - const T zero_point_from_min = quant_min_float - min->e(0) / scale; - const uint16_t nudged_zero_point = [zero_point_from_min, lowIntBound, - quant_min_float, upperIntBound, - quant_max_float] { - if (zero_point_from_min < quant_min_float) { - return static_cast(lowIntBound); - } - if (zero_point_from_min > quant_max_float) { - return static_cast(upperIntBound); - } - return static_cast(roundf(zero_point_from_min)); - }(); + T nudged_min, nudged_max, scale; - auto nudged_min = (quant_min_float - nudged_zero_point) * (scale); - auto nudged_max = (quant_max_float - nudged_zero_point) * (scale); - //input->applyScalar(scalar::CompareAndSet, nudged_max, clamped, nullptr); //.cwiseMin(nudged_max).cwiseMax(nudged_min); - //input->applyScalar(scalar::CompareAndSet, nudged_min, clamped, nullptr); //.cwiseMin(nudged_max).cwiseMax(nudged_min); - auto wiseMax = LAMBDA_T(x, nudged_min) { - if (x < nudged_min) { - return nudged_min; - } - return x; - - }; - auto wiseMin = LAMBDA_T(x, nudged_max) { - if (x > nudged_max) { - return nudged_max; - } - return x; - }; - auto scaleTensor(*input); // = NDArrayFactory::create(input->ordering(), input->getShapeAsVector(), input->getWorkspace()); - auto clamped(*input); // = NDArrayFactory::create(input->ordering(), input->getShapeAsVector(), input->getWorkspace()); - scaleTensor.assign(scale); - input->applyLambda(wiseMin, &clamped); -// const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); - clamped.applyLambda(wiseMax, output); -// const auto clamped_shifted = clamped - nudged_min; + Nudge(min->t(0), max->t(0), quant_min_float, quant_max_float, &scale, &nudged_min, &nudged_max); + WiseMinMax(input, nudged_min, nudged_max, output); *output -= nudged_min; - // auto nudgedScale = scale; - (*output) /= scaleTensor; -// (*output) += T(0.5f); - output->applyTransform(transform::Round, nullptr, nullptr); - (*output) *= scaleTensor; + (*output) /= scale; + (*output) += T(0.5f); + output->applyTransform(transform::Floor, nullptr, nullptr); + (*output) *= scale; (*output) += nudged_min; - //output->printIndexedBuffer("FAKE QUANTED"); - /* - const auto nudged_scale_repl = inputs.constant(nudged_scale); - - const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); - const auto clamped_shifted = clamped - nudged_min; - *output = (clamped_shifted / nudged_scale_repl + 0.5f).floor() * - nudged_scale_repl + - nudged_min; -*/ - } void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {