From 3504b0cda99fbb31fb2735a90c7b1d82fc4ae25c Mon Sep 17 00:00:00 2001 From: shugeo Date: Thu, 10 Oct 2019 15:44:50 +0300 Subject: [PATCH] Implemented fake_quant_with_min_max_vars_per_channel fop cuda helper. The first working revision. --- .../helpers/cuda/fake_quantization.cu | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu index 9def35152..d491d056a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu @@ -77,14 +77,45 @@ namespace helpers { input->applyLambda(wiseMinMaxAndSoOn, output); } + template + void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { + int lowIntBound = narrowed?1:0; + int upperIntBound = (1 << numBits) - 1; + min->syncToHost(); + max->syncToHost(); + T scale, nudged_min, nudged_max; + auto channels = min->lengthOf(); + input->syncToHost(); + input->syncToDevice(); + output->syncToHost(); + for (auto i = 0; i < channels; i++) { + Nudge(min->t(i), max->t(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); + + //auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudged_min, nudged_max, scale) { + for (auto e = 0; e < input->lengthOf(); e += channels) { + T val = input->t(e + i); + if (val < nudged_min) { + val = nudged_min; + } else if (val > nudged_max) { + val = nudged_max; + } + + output->t(e + i) = (math::nd4j_floor((val - nudged_min) / scale + T(0.5)) * scale + nudged_min); + }; + } + output->syncToDevice(); + output->tickWriteDevice(); + } + void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); } void fakeQuantWithMinMaxVarsPerChannel(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES); + BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVarsPerChannel_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES); } }