diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp new file mode 100644 index 000000000..e5873d9dd --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp @@ -0,0 +1,71 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George Shulinok , created on 08.10.2019 +// + +#include +#if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) + +#include +#include +namespace nd4j { + namespace ops { + CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) { + + auto x = INPUT_VARIABLE(0); + auto min = INPUT_VARIABLE(1); + auto max = INPUT_VARIABLE(2); + + REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs"); + auto depth = x->sizeAt(-1); + REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0, + "fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length"); + REQUIRE_TRUE(depth == min->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Min length should be" + " %lld, but %lld occurs.", depth, min->lengthOf()); + + REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be" + "%lld, but %lld occurs.", depth, max->lengthOf()); + + auto output = OUTPUT_VARIABLE(0); + int numBits = 8; + if (block.getIArguments() && block.getIArguments()->size()) + numBits = INT_ARG(0); + bool narrowed = false; + //INT_ARG(1); + if (block.getIArguments()->size() == 2) { + numBits = INT_ARG(0); + narrowed = INT_ARG(1); + REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits" + " for quatization should be in between 2 and 16, but %i " + "was given.", numBits); + } + helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, numBits, narrowed, output); + return ND4J_STATUS_OK; + } + + DECLARE_TYPES(fake_quant_with_min_max_vars_per_channel) { + getOpDescriptor() + -> setAllowedOutputTypes({ALL_FLOATS}) + -> setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); + } + + DECLARE_SYN(fake_quant_with_min_max_args_per_channel, fake_quant_with_min_max_vars_per_channel); + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 8f6849ef7..cbc7e56da 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1748,6 +1748,25 @@ namespace nd4j { DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2); #endif +/** + * fake_quant_with_min_max_vals_per_channel - tf.quantization.fake_quant_with_min_max_vars_per_channel + * + * input params: + * 0 - NDArray (input) - at least 2D. + * 1 - 1D Tensor - min values (min length equals to last dim of input) + * 2 - 1D Tensor - max value (length equals to min) + * + * int params (optional): + * 0 - num_bits (allowed interval [2, 16], default 8) + * 1 - narrow_range (default False) + * + * output: + * 0 - NDArray with the same shape as input + */ + #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) + DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, -2); + #endif + /** * compare_and_bitpack - compare with greater and pack result with uint8 * diff --git a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp index df162474f..6ea2992b9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp @@ -25,74 +25,89 @@ namespace nd4j { namespace ops { namespace helpers { + // + // nudge - nudged min max over scale + // scale = (Max - Min) / (quantMax - quantMin) + // quantMin = 0 or 1, quantMax = 2^b - 1 == (1 << b) - 1 + // + template + static void nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) { + // floating point instead integers + T quantMaxF = static_cast(quantMax); + T quantMinF = static_cast(quantMin); + // compute scale + *scale = (max - min) / (quantMaxF - quantMinF); + // compute left bound point + auto zeroPointFromMin = quantMinF - min / *scale; + // bound zero point to conform with range [0 or 1, 2^b - 1] + uint16_t const nudged_zero_point = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] { + if (zeroPointFromMin < quantMinF) { + return static_cast(quantMin); + } + if (zeroPointFromMin > quantMaxF) { + return static_cast(quantMax); + } + return nd4j::math::nd4j_round(zeroPointFromMin); + }(); + // compute nudged min and max with computed nudged zero point + *nudgedMin = (quantMinF - nudged_zero_point) * (*scale); + *nudgedMax = (quantMaxF - nudged_zero_point) * (*scale); + } + + template + void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { + int lowIntBound = narrowed ? 1 : 0; // 0 or 1 + int upperIntBound = (1 << numBits) - 1; // 2^b - 1 + auto channels = input->sizeAt(-1); // last dimension + + PRAGMA_OMP_PARALLEL_FOR + for (auto i = 0; i < channels; i++) { + T scale, nudged_min, nudged_max; + // nudge min and max first, with scale computing + nudge(min->t(i), max->t(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); + // slide using last dimension and process all for given channel + for (auto e = 0; e < input->lengthOf(); e += channels) { + T val = input->t(e + i); + if ( val <= nudged_min) + val = nudged_min; + else if (val >= nudged_max) + val = nudged_max; + // quantization itself + output->t(e + i) = math::nd4j_floor((val - nudged_min)/scale + T(0.5)) * scale + nudged_min; + } + } + } + template void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { int lowIntBound = narrowed ? 1 : 0; - int upperIntBound = 1 << numBits - 1; + int upperIntBound = (1 << numBits) - 1; - const float quant_min_float = static_cast(lowIntBound); - const float quant_max_float = static_cast(upperIntBound); - T scale = (max->t(0) - min->t(0)) / (quant_max_float - quant_min_float); - const T zero_point_from_min = quant_min_float - min->e(0) / scale; - const uint16_t nudged_zero_point = [zero_point_from_min, lowIntBound, - quant_min_float, upperIntBound, - quant_max_float] { - if (zero_point_from_min < quant_min_float) { - return static_cast(lowIntBound); - } - if (zero_point_from_min > quant_max_float) { - return static_cast(upperIntBound); - } - return static_cast(roundf(zero_point_from_min)); - }(); - - auto nudged_min = (quant_min_float - nudged_zero_point) * (scale); - auto nudged_max = (quant_max_float - nudged_zero_point) * (scale); - //input->applyScalar(scalar::CompareAndSet, nudged_max, clamped, nullptr); //.cwiseMin(nudged_max).cwiseMax(nudged_min); - //input->applyScalar(scalar::CompareAndSet, nudged_min, clamped, nullptr); //.cwiseMin(nudged_max).cwiseMax(nudged_min); - auto wiseMax = LAMBDA_T(x, nudged_min) { - if (x < nudged_min) { - return nudged_min; + T nudgedMin, nudgedMax, scale; + // nudge with given min and max and compute scale and nudged min and max + nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); + // quantization as one + auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { + T val = x; // boundign value between nudged min and max + if (val < nudgedMin) { + val = nudgedMin; } - return x; - + else if (val > nudgedMax) + val = nudgedMax; + // converse value with scale and shifted with nudged min + return (nd4j::math::nd4j_floor((val - nudgedMin)/scale + T(0.5)) * scale + nudgedMin); }; - auto wiseMin = LAMBDA_T(x, nudged_max) { - if (x > nudged_max) { - return nudged_max; - } - return x; - }; - auto scaleTensor(*input); // = NDArrayFactory::create(input->ordering(), input->getShapeAsVector(), input->getWorkspace()); - auto clamped(*input); // = NDArrayFactory::create(input->ordering(), input->getShapeAsVector(), input->getWorkspace()); - scaleTensor.assign(scale); - input->applyLambda(wiseMin, &clamped); -// const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); - clamped.applyLambda(wiseMax, output); -// const auto clamped_shifted = clamped - nudged_min; - *output -= nudged_min; - // auto nudgedScale = scale; - (*output) /= scaleTensor; - (*output) += T(0.5f); - output->applyTransform(transform::Floor, nullptr, nullptr); - (*output) *= scaleTensor; - (*output) += nudged_min; - //output->printIndexedBuffer("FAKE QUANTED"); - /* - const auto nudged_scale_repl = inputs.constant(nudged_scale); - - const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); - const auto clamped_shifted = clamped - nudged_min; - *output = (clamped_shifted / nudged_scale_repl + 0.5f).floor() * - nudged_scale_repl + - nudged_min; -*/ + input->applyLambda(fakeQuantizationWithMinMax, output); } void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); } + void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); + } + BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu index 9bb331685..292b7e1c6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu @@ -33,65 +33,105 @@ namespace helpers { // narrowed - shrink is true // output - output tensor // + template + static __host__ __device__ void + nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) { + T quantMaxF = static_cast(quantMax); + T quantMinF = static_cast(quantMin); + *scale = (max - min) / (quantMaxF - quantMinF); + auto zeroPointFromMin = quantMinF - min / *scale; + uint16_t const nudgedZeroPoint = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] { + if (zeroPointFromMin < quantMinF) { + return static_cast(quantMin); + } + if (zeroPointFromMin > quantMaxF) { + return static_cast(quantMax); + } + return nd4j::math::nd4j_round(zeroPointFromMin); + }(); + *nudgedMin = (quantMinF - nudgedZeroPoint) * (*scale); + *nudgedMax = (quantMaxF - nudgedZeroPoint) * (*scale); + } + template void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { int lowIntBound = narrowed?1:0; - int upperIntBound = 1 << numBits - 1; - min->syncToHost(); + int upperIntBound = (1 << numBits) - 1; + min->syncToHost(); // these are scalars, so nothing much happened max->syncToHost(); - const float quant_min_float = static_cast(lowIntBound); - const float quant_max_float = static_cast(upperIntBound); - T scale = (max->t(0) - min->t(0)) / (quant_max_float - quant_min_float); - const T zero_point_from_min = quant_min_float - min->t(0) / scale; + T scale, nudgedMin, nudgedMax; + nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); - const uint16_t nudged_zero_point = [zero_point_from_min, lowIntBound, - quant_min_float, upperIntBound, - quant_max_float] { - if (zero_point_from_min < quant_min_float) { - return static_cast(lowIntBound); + auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { + T val = x; + if (x < nudgedMin) { + val = nudgedMin; } - if (zero_point_from_min > quant_max_float) { - return static_cast(upperIntBound); + else if (x > nudgedMax) { + val = nudgedMax; } - return static_cast(roundf(zero_point_from_min)); - }(); - - auto nudged_min = (quant_min_float - nudged_zero_point) * (scale); - auto nudged_max = (quant_max_float - nudged_zero_point) * (scale); - - auto wiseMax = LAMBDA_T(x, nudged_min) { - if (x < nudged_min) { - return nudged_min; - } - return x; + else + val = x; + return (math::nd4j_floor((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin); }; - auto wiseMin = LAMBDA_T(x, nudged_max) { - if (x > nudged_max) { - return nudged_max; - } - return x; - }; + input->applyLambda(wiseMinMaxAndSoOn, output); + } - auto scaleTensor(*input); - auto clamped(*input); - scaleTensor.assign(scale); - input->applyLambda(wiseMin, &clamped); + template + static __global__ void fakeQuantWithMinMaxKernel(T* input, Nd4jLong* inputShape, T* min, T* max, + int lowIntBound, int upperIntBound, Nd4jLong channels, + T* output, Nd4jLong* outputShape, Nd4jLong length) { + __shared__ int block; + if (threadIdx.x == 0) { + block = length / channels; // to loop with last dimension as block + } + __syncthreads(); - clamped.applyLambda(wiseMax, output); - *output -= nudged_min; + for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) { + T scale, nudgedMin, nudgedMax; + nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); + // loop over blocks to quantization between nudged min and max + for (auto b = threadIdx.x; b < block; b += blockDim.x) { + T val = input[shape::getIndexOffset(b * channels + i, inputShape)]; + if (val < nudgedMin) { + val = nudgedMin; + } else if (val > nudgedMax) { + val = nudgedMax; + } + output[shape::getIndexOffset(b * channels + i, outputShape)] = + (math::nd4j_floor((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin); + }; + } + } + + template + void fakeQuantWithMinMaxVarsPerChannel_(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { + int lowIntBound = narrowed?1:0; + int upperIntBound = (1 << numBits) - 1; + auto channels = min->lengthOf(); + auto length = input->lengthOf(); + NDArray::prepareSpecialUse({output}, {min, max, input}); + auto stream = context->getCudaStream(); + T* inputBuf = input->dataBuffer()->specialAsT(); + T* outputBuf = output->dataBuffer()->specialAsT(); + T* minBuf = min->dataBuffer()->specialAsT(); + T* maxBuf = max->dataBuffer()->specialAsT(); + fakeQuantWithMinMaxKernel<<<128, 256, 256, *stream>>>(inputBuf, input->specialShapeInfo(), + minBuf, maxBuf, lowIntBound, upperIntBound, channels, outputBuf, output->specialShapeInfo(), length); + NDArray::registerSpecialUse({output}, {min, max, input}); - (*output) /= scaleTensor; - (*output) += T(0.5f); - output->applyTransform(transform::Floor, nullptr, nullptr); - (*output) *= scaleTensor; - (*output) += nudged_min; } void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); } + void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (context, input, min, max, numBits, narrowed, output), FLOAT_TYPES); + } + BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES); + BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVarsPerChannel_, (LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/fake_quantization.h b/libnd4j/include/ops/declarable/helpers/fake_quantization.h index aa0941db4..cadd8be7c 100644 --- a/libnd4j/include/ops/declarable/helpers/fake_quantization.h +++ b/libnd4j/include/ops/declarable/helpers/fake_quantization.h @@ -27,6 +27,7 @@ namespace ops { namespace helpers { void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output); + void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output); } } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 446763096..0652a398e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2117,7 +2117,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32); - NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.251953f, 0.0f, 0.0f}, nd4j::DataType::FLOAT32); + NDArray exp('c', {2,3}, {-63.75, -63.75, -63.75, -63.5, 0., 0.}, nd4j::DataType::FLOAT32); NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32); NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32); @@ -2127,7 +2127,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); - // result->printIndexedBuffer("Quantized"); +// result->printBuffer("Quantized"); +// exp.printBuffer("Expected"); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); @@ -2137,7 +2138,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) { NDArray x = NDArrayFactory::create('c', {2,3}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); - NDArray exp = NDArrayFactory::create('c', {2,3}, {-63.75, -63.75, -63.251953, -63.251953, 0.0, 0.0}); + NDArray exp = NDArrayFactory::create('c', {2,3}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. }); NDArray min = NDArrayFactory::create(-63.65); NDArray max = NDArrayFactory::create(0.1); @@ -2154,6 +2155,119 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) { delete results; } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) { + + NDArray x = NDArrayFactory::create('c', {1,2,3,1}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); + NDArray exp = NDArrayFactory::create('c', {1,2,3,1}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. }); + NDArray min = NDArrayFactory::create('c', {1},{-63.65}); + NDArray max = NDArrayFactory::create('c', {1}, {0.1}); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); + // result->printIndexedBuffer("Quantized2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) { + + NDArray x = NDArrayFactory::create('c', {2,4,5,3}); + NDArray exp = NDArrayFactory::create('c', {2,4,5,3},{ + 1.0588236, 1.9607843, 3.019608, 4.0588236, 5.098039, 6.039216, 7.0588236, 8.039216, 9.058824, + 10.058824, 10.980392, 12.078432, 13.058824, 13.921569, 15.09804, 16.058825, 17.058825, 18.117647, + 19.058825, 20., 21.137257, 22.058825, 22.941177, 23.882355, 25.058825, 26.078432, 26.901962, + 28.058825, 29.019608, 29.92157, 31.058825, 31.960785, 32.941177, 34.058823, 35.09804, 35.960785, + 37.058823, 38.039215, 38.980392, 40.058823, 40.980392, 42.000004, 43.058826, 43.92157, 45.01961, + 45., 47.058823, 48.03922, 45., 50., 51.058826, 45., 50., 54.078434, + 45., 50., 57.09804, 45., 50., 60.11765, 45., 50., 62.862747, + 45., 50., 65.882355, 45., 50., 68.90196, 45., 50., 70., + 45., 50., 70., 45., 50., 70., 45., 50., 70., + 45., 50., 70., 45., 50., 70., 45., 50., 70., + 45., 50., 70., 45., 50., 70., 45., 50., 70., + 45., 50., 70., 45., 50., 70., 45., 50., 70., + 45., 50., 70., 45., 50., 70., 45., 50., 70., + 45., 50., 70.}); + NDArray min = NDArrayFactory::create({20., 20., 20.}); + NDArray max = NDArrayFactory::create({65., 70., 90.}); + x.linspace(1.); + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); +// result->printBuffer("Quantized per channels 4"); +// exp.printBuffer("Quantized per channest E"); +// auto diff = *result - exp; +// diff.printIndexedBuffer("Difference"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { + NDArray x = NDArrayFactory::create('c', {2, 3, 5, 4}); + NDArray exp = NDArrayFactory::create('c', {2, 3, 5, 4},{ + -19.92157 , -18.980392 , -18.039217 , -16.941177 , + -19.92157 , -18.980392 , -18.039217 , -16.941177 , + -19.92157 , -18.980392 , -18.039217 , -16.941177 , + -19.92157 , -18.980392 , -18.039217 , -16.941177 , + -19.92157 , -18.980392 , -18.039217 , -16.941177 , + -19.92157 , -18.980392 , -18.039217 , -16.941177 , + -19.92157 , -18.980392 , -18.039217 , -16.941177 , + -19.92157 , -18.980392 , -18.039217 , -16.941177 , + -19.92157 , -18.980392 , -18.039217 , -16.941177 , + -19.92157 , -18.980392 , -18.039217 , -16.941177 , + -19.92157 , -18.980392 , -18.039217 , -16.941177 , + -16. , -15.058824 , -13.960785 , -13.0196085 , + -11.92157 , -10.980392 , -10.039217 , -8.941177 , + -8.000001 , -7.0588236 , -5.960785 , -5.0196085 , + -3.9215698 , -2.9803925 , -2.039217 , -0.94117737, + 0. , 0.94117737, 2.039215 , 2.9803925 , + 4.07843 , 5.0196075 , 5.960783 , 7.0588226 , + 8. , 8.941177 , 10.039215 , 10.980392 , + 12.07843 , 13.019608 , 13.960783 , 15.058823 , + 16. , 16.941177 , 18.039217 , 18.980392 , + 20.07843 , 21.019608 , 21.960783 , 23.058823 , + 20.07843 , 21.019608 , 21.960783 , 23.058823 , + 20.07843 , 21.019608 , 21.960783 , 23.058823 , + 20.07843 , 21.019608 , 21.960783 , 23.058823 , + 20.07843 , 21.019608 , 21.960783 , 23.058823 , + 20.07843 , 21.019608 , 21.960783 , 23.058823 , + 20.07843 , 21.019608 , 21.960783 , 23.058823 , + 20.07843 , 21.019608 , 21.960783 , 23.058823 , + 20.07843 , 21.019608 , 21.960783 , 23.058823 , + 20.07843 , 21.019608 , 21.960783 , 23.058823 + }); + NDArray min = NDArrayFactory::create({-20., -19., -18., -17}); + NDArray max = NDArrayFactory::create({20., 21., 22., 23}); + x.linspace(-60.); + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); +// result->printBuffer("Quantized per channels 5"); +// exp.printBuffer("Quantized per channest E"); +// auto diff = *result - exp; +// diff.printIndexedBuffer("Difference"); + + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {