diff --git a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp index b09587cf7..21163d44c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp @@ -26,7 +26,7 @@ namespace ops { namespace helpers { template - static void Nudge(T min, T max, int quant_min, int quant_max, T* scale, T* nudged_min, T* nudged_max) { + static void nudge(T min, T max, int quant_min, int quant_max, T* scale, T* nudged_min, T* nudged_max) { T quant_max_float = static_cast(quant_max); T quant_min_float = static_cast(quant_min); *scale = (max - min) / (quant_max_float - quant_min_float); @@ -53,7 +53,7 @@ namespace helpers { PRAGMA_OMP_PARALLEL_FOR for (auto i = 0; i < channels; i++) { T scale, nudged_min, nudged_max; - Nudge(min->t(i), max->t(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); + nudge(min->t(i), max->t(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); for (auto e = 0; e < input->lengthOf(); e += channels) { T val = input->t(e + i); @@ -67,37 +67,26 @@ namespace helpers { } } - template - static void WiseMinMax(NDArray* input, T min, T max, NDArray* output) { - auto wiseMinMax = LAMBDA_T(x, min, max) { - if (x < min) { - return min; - } - else if (x > max) - return max; - return x; - }; - - input->applyLambda(wiseMinMax, output); - } - template void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { int lowIntBound = narrowed ? 1 : 0; int upperIntBound = (1 << numBits) - 1; - const float quant_min_float = static_cast(lowIntBound); - const float quant_max_float = static_cast(upperIntBound); - T nudged_min, nudged_max, scale; + T nudgedMin, nudgedMax, scale; - Nudge(min->t(0), max->t(0), quant_min_float, quant_max_float, &scale, &nudged_min, &nudged_max); - WiseMinMax(input, nudged_min, nudged_max, output); - *output -= nudged_min; - (*output) /= scale; - (*output) += T(0.5f); - output->applyTransform(transform::Floor, nullptr, nullptr); - (*output) *= scale; - (*output) += nudged_min; + nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); + + auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { + T val = x; + if (val < nudgedMin) { + val = nudgedMin; + } + else if (val > nudgedMax) + val = nudgedMax; + return (nd4j::math::nd4j_floor((val - nudgedMin)/scale + T(0.5)) * scale + nudgedMin); + }; + + input->applyLambda(fakeQuantizationWithMinMax, output); } void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu index 893e016ed..70eaac67b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu @@ -34,44 +34,45 @@ namespace helpers { // output - output tensor // template - static __host__ __device__ void Nudge(T min, T max, int quant_min, int quant_max, T* scale, T* nudged_min, T* nudged_max) { - T quant_max_float = static_cast(quant_max); - T quant_min_float = static_cast(quant_min); - *scale = (max - min) / (quant_max_float - quant_min_float); - auto zero_point_from_min = quant_min_float - min / *scale; - uint16_t const nudged_zero_point = [zero_point_from_min, quant_min, quant_max, quant_max_float, quant_min_float] { - if (zero_point_from_min < quant_min_float) { - return static_cast(quant_min); + static __host__ __device__ void + nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) { + T quantMaxF = static_cast(quantMax); + T quantMinF = static_cast(quantMin); + *scale = (max - min) / (quantMaxF - quantMinF); + auto zeroPointFromMin = quantMinF - min / *scale; + uint16_t const nudgedZeroPoint = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] { + if (zeroPointFromMin < quantMinF) { + return static_cast(quantMin); } - if (zero_point_from_min > quant_max_float) { - return static_cast(quant_max); + if (zeroPointFromMin > quantMaxF) { + return static_cast(quantMax); } - return nd4j::math::nd4j_round(zero_point_from_min); + return nd4j::math::nd4j_round(zeroPointFromMin); }(); - *nudged_min = (quant_min_float - nudged_zero_point) * (*scale); - *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); + *nudgedMin = (quantMinF - nudgedZeroPoint) * (*scale); + *nudgedMax = (quantMaxF - nudgedZeroPoint) * (*scale); } template void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { int lowIntBound = narrowed?1:0; int upperIntBound = (1 << numBits) - 1; - min->syncToHost(); + min->syncToHost(); // these are scalars, so nothing much happened max->syncToHost(); - T scale, nudged_min, nudged_max; - Nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); + T scale, nudgedMin, nudgedMax; + nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); - auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudged_min, nudged_max, scale) { + auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { T val = x; - if (x < nudged_min) { - val = nudged_min; + if (x < nudgedMin) { + val = nudgedMin; } - else if (x > nudged_max) { - val = nudged_max; + else if (x > nudgedMax) { + val = nudgedMax; } else val = x; - return (math::nd4j_floor((val - nudged_min) / scale + T(0.5)) * scale + nudged_min); + return (math::nd4j_floor((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin); }; input->applyLambda(wiseMinMaxAndSoOn, output); @@ -88,20 +89,20 @@ namespace helpers { __syncthreads(); for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) { - T scale, nudged_min, nudged_max; - Nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); - //auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudged_min, nudged_max, scale) { + T scale, nudgedMin, nudgedMax; + nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); + for (auto e = threadIdx.x; e < block; e += blockDim.x) { T val = input[shape::getIndexOffset(e * channels + i, inputShape)]; - if (val < nudged_min) { - val = nudged_min; - } else if (val > nudged_max) { - val = nudged_max; + if (val < nudgedMin) { + val = nudgedMin; + } else if (val > nudgedMax) { + val = nudgedMax; } - output[shape::getIndexOffset(e* channels + i, outputShape)] = (math::nd4j_floor((val - nudged_min) / scale + T(0.5)) * scale + nudged_min); + output[shape::getIndexOffset(e* channels + i, outputShape)] = + (math::nd4j_floor((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin); }; } - } template