From 1e9ff114aa7d51617478a21a36d66074a74426d1 Mon Sep 17 00:00:00 2001 From: shugeo Date: Mon, 2 Dec 2019 20:40:54 +0200 Subject: [PATCH] Shugeo atomic tests (#97) * Added atomic tests for atomicAdd, atomicSub and atomicDiv. * Fixed atomicAdd for 16bit ints. * Fixed atomicMul for 16 floats. * Eliminated waste prints. * Fixed problems with double type on matrix inverse helepers. * Eliminated commented wrong code. * Refactored atomicMul for 16bit types. * few more minor tweaks Signed-off-by: raver119 * Fixed fake_quant_with_min_max_vars_per_channel args processing. --- ...ke_quant_with_min_max_vars_per_channel.cpp | 15 +- .../ops/declarable/helpers/cuda/lup.cu | 32 +++- libnd4j/include/templatemath.h | 81 ++++++--- libnd4j/tests_cpu/layers_tests/AtomicTests.cu | 157 +++++++++++++++++- .../layers_tests/DeclarableOpsTests10.cpp | 94 +++++++++++ 5 files changed, 338 insertions(+), 41 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp index 5874d2f81..8f379911b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp @@ -25,13 +25,12 @@ #include namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) { + CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, 0) { auto x = INPUT_VARIABLE(0); auto min = INPUT_VARIABLE(1); auto max = INPUT_VARIABLE(2); - REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs"); auto depth = x->sizeAt(-1); REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length"); @@ -49,13 +48,13 @@ namespace nd4j { numBits = INT_ARG(0); bool narrowed = false; //INT_ARG(1); - if (block.getIArguments()->size() == 2) { - numBits = INT_ARG(0); - narrowed = INT_ARG(1); - REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits" - " for quatization should be in between 2 and 16, but %i " - "was given.", numBits); + if (block.getBArguments() && block.getBArguments()->size()) { + narrowed = B_ARG(0); } + + REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits" + " for quatization should be in between 2 and 16, but %i " + "was given.", numBits); helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, numBits, narrowed, output); return ND4J_STATUS_OK; } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 42acf9c09..568b9a9bc 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -110,12 +110,21 @@ namespace helpers { template static __global__ void invertLowKernel(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) { + T *inverted = reinterpret_cast(invertedBuf); T *input = reinterpret_cast(inputBuf); + if (threadIdx.x == 0) { + inverted = reinterpret_cast(invertedBuf); + input = reinterpret_cast(inputBuf); + } + __syncthreads(); - for (int i = blockIdx.x + 2; i < n; i += gridDim.x) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (int i = tid + 2; i < n; i += step) { for (int j = i - 2; j >= 0; --j) - for (int k = threadIdx.x; k < i; k += blockDim.x) { + for (int k = 0; k < i; k++) { Nd4jLong posZ[] = {i, j}; Nd4jLong posY[] = {k, j}; Nd4jLong posX[] = {i, k}; @@ -144,10 +153,12 @@ namespace helpers { input = reinterpret_cast(inputBuf); } __syncthreads(); + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; - for (int i = (int)n - blockIdx.x - 2; i >= 0; i -= gridDim.x) { + for (int i = (int)n - tid - 2; i >= 0; i -= step) { for (int j = i + 2; j < (int)n; j++) - for (int k = i + threadIdx.x; k < (int)n; k += blockDim.x) { + for (int k = i; k < (int)n; k++) { Nd4jLong posZ[] = {i, j}; Nd4jLong posY[] = {k, j}; Nd4jLong posX[] = {i, k}; @@ -498,8 +509,6 @@ namespace helpers { fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // else // fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); - -// if (matrix.dataType() == input->dataType()) lup_(context, &matrix, nullptr, nullptr); // else // lup_(context, &matrix, nullptr, nullptr); @@ -627,9 +636,14 @@ namespace helpers { for (auto i = 0LL; i < packX.numberOfTads(); i++) { fillMatrix<<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); matrix.tickWriteDevice(); - compound.assign(matrix); - lup_(context, &compound, nullptr, nullptr); - fillLowerUpperKernel<<>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); + //compound.assign(matrix); +// if (matrix.dataType() == input->dataType()) + lup_(context, &matrix, nullptr, nullptr); + fillLowerUpperKernel<<>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n); + lower.tickWriteDevice(); + upper.tickWriteDevice(); +// lower.printIndexedBuffer("LOWER"); +// upper.printIndexedBuffer("UPPER"); matrix.assign(0); invertUpperMatrix(context, &upper, &matrix); // U^{-1} matrix.tickWriteDevice(); diff --git a/libnd4j/include/templatemath.h b/libnd4j/include/templatemath.h index 7aa4dbbe6..b412befd8 100644 --- a/libnd4j/include/templatemath.h +++ b/libnd4j/include/templatemath.h @@ -1305,15 +1305,65 @@ inline __device__ bfloat16 nd4j_atomicAdd(bfloat16* address, bfloat16 else return old.B.L; } +template +static inline __device__ T internal_16bit_atomicAdd(T* address, T val) { + size_t shift = ((size_t)address & 2); + int *base_address = (int *)((char*)address - shift); + + union I16PAIR { + struct { + T H; + T L; + } B; + int W; + + __host__ __device__ + I16PAIR() {}; + + __host__ __device__ + ~I16PAIR() {}; + }; + + I16PAIR pairNew, pairOld, pairAssumed; + + if (reinterpret_cast(address) == base_address) { + pairOld.B.L = val; + do { + + pairNew.B.L = pairOld.B.L; + pairNew.B.H = pairOld.B.H + val; + pairAssumed.W = pairOld.W; + + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + } while (pairAssumed.W != pairOld.W); + + return (T) pairOld.B.H; + } else { + pairOld.B.H = val; + do { + + pairNew.B.H = pairOld.B.H; + pairNew.B.L = pairOld.B.L + val; + pairAssumed.W = pairOld.W; + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + + } while (pairAssumed.W != pairOld.W); + + return (T) pairOld.B.L; + } + +} + template <> inline __device__ int16_t nd4j_atomicAdd(int16_t* address, int16_t val) { - return nd4j_atomicAdd((bfloat16*)address, (bfloat16)val); + return internal_16bit_atomicAdd(address, val); } template <> inline __device__ uint16_t nd4j_atomicAdd(uint16_t* address, uint16_t val) { - return nd4j_atomicAdd((bfloat16*)address, (bfloat16)val); + return internal_16bit_atomicAdd(address, val); } + template <> inline __device__ int8_t nd4j_atomicAdd(int8_t* address, int8_t val) { int res = *address; @@ -1447,7 +1497,7 @@ inline __device__ unsigned char nd4j_atomicMul(unsigned char* add } template -static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) { +static inline __device__ T internal_16bit_atomicMul(T* address, T val) { size_t shift = ((size_t)address & 2); int *base_address = (int *)((char*)address - shift); @@ -1467,10 +1517,9 @@ static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) { I16PAIR pairNew, pairOld, pairAssumed; - pairOld.W = (int) val; if (reinterpret_cast(address) == base_address) { + pairOld.B.L = val; do { - pairNew.B.L = pairOld.B.L; pairNew.B.H = pairOld.B.H * val; pairAssumed.W = pairOld.W; @@ -1480,8 +1529,8 @@ static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) { return (T) pairOld.B.H; } else { + pairOld.B.H = val; do { - pairNew.B.H = pairOld.B.H; pairNew.B.L = pairOld.B.L * val; pairAssumed.W = pairOld.W; @@ -1491,10 +1540,8 @@ static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) { return (T) pairOld.B.L; } - } - template <> inline __device__ int16_t nd4j_atomicMul(int16_t* address, int16_t val) { return internal_16bit_atomicMul(address, val); @@ -1549,17 +1596,6 @@ inline __device__ uint64_t nd4j_atomicMul(uint64_t* address, uint64_t return (uint64_t)old; } -//template <> -//inline __device__ unsigned long long nd4j_atomicMul(unsigned long long* address, unsigned long long val) { -// unsigned long long int* res_address = address; -// unsigned long long int old = *res_address, assumed; -// do { -// assumed = old; -// old = atomicCAS(res_address, assumed, val * assumed); -// } while (assumed != old); -// return old; -//} - #if !defined(_WIN32) && !defined(_WIN64) template <> inline __device__ Nd4jLong nd4j_atomicMul(Nd4jLong* address, Nd4jLong val) { @@ -1585,22 +1621,21 @@ inline __device__ float16 nd4j_atomicMul(float16* address, float16 val) template <> inline __device__ float nd4j_atomicDiv(float* address, float val) { - return nd4j_atomicMul(address, (float) 1.f / val); + return nd4j_atomicMul(address, 1.f / val); } template <> inline __device__ float16 nd4j_atomicDiv(float16* address, float16 val) { - return nd4j_atomicMul(address, (float16) 1.f / val); + return internal_16bit_atomicMul(address, (float16) 1.f / val); } template <> inline __device__ bfloat16 nd4j_atomicDiv(bfloat16* address, bfloat16 val) { - return nd4j_atomicMul(address, (bfloat16) 1.f / val); + return internal_16bit_atomicMul(address, (bfloat16) 1 / val); } } #endif } - } #ifdef _OPENMP diff --git a/libnd4j/tests_cpu/layers_tests/AtomicTests.cu b/libnd4j/tests_cpu/layers_tests/AtomicTests.cu index 0ede6398c..fdf543026 100644 --- a/libnd4j/tests_cpu/layers_tests/AtomicTests.cu +++ b/libnd4j/tests_cpu/layers_tests/AtomicTests.cu @@ -60,12 +60,93 @@ static void multiplyLauncher(void *vbuffer, uint64_t length, void *vresult) { nd4j::cuda_exception::build("multiply failed", err); } +template +static _CUDA_G void sumKernel(void *vbuffer, uint64_t length, void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; + + nd4j::math::atomics::nd4j_atomicAdd(&result[i], buffer[e]); + } +} + +template +static void sumLauncher(void *vbuffer, uint64_t length, void *vresult) { + sumKernel<<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); + auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) + nd4j::cuda_exception::build("sum failed", err); +} + +template +static _CUDA_G void subKernel(void *vbuffer, uint64_t length, void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; + + nd4j::math::atomics::nd4j_atomicSub(&result[i], buffer[e]); + } +} + +template +static void subLauncher(void *vbuffer, uint64_t length, void *vresult) { + subKernel<<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); + auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) + nd4j::cuda_exception::build("sub failed", err); +} + +template +static _CUDA_G void divKernel(void *vbuffer, uint64_t length, void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; + + nd4j::math::atomics::nd4j_atomicDiv(&result[i], buffer[e]); + } +} + +template +static void divLauncher(void *vbuffer, uint64_t length, void *vresult) { + divKernel<<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); + auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) + nd4j::cuda_exception::build("div failed", err); +} + static void multiplyHost(NDArray &input, NDArray &output) { BUILD_SINGLE_SELECTOR(input.dataType(), multiplyLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); } +static void sumHost(NDArray &input, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), sumLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); +} + +static void subHost(NDArray &input, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), subLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES); +} + +static void divHost(NDArray &input, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), divLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES); +} + TEST_F(AtomicTests, test_multiply) { - std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::INT16}; + std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::INT16, nd4j::DataType::HALF}; for (auto t:dtypes) { nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); @@ -80,7 +161,81 @@ TEST_F(AtomicTests, test_multiply) { multiplyHost(input, output); ASSERT_EQ(exp, output); } +} +TEST_F(AtomicTests, test_multiply_2) { + std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::HALF, nd4j::DataType::BFLOAT16}; + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + input.assign(1.5); + output.assign(2); + exp.assign(10.125); + + multiplyHost(input, output); +// output.printBuffer("multiply 2"); + ASSERT_EQ(exp, output); + } +} + +TEST_F(AtomicTests, test_sum) { + std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16, nd4j::DataType::HALF, nd4j::DataType::INT16}; + + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(1); + output.assign(1); + exp.assign(5); + + sumHost(input, output); +// output.printIndexedBuffer("Sum"); + ASSERT_EQ(exp, output); + } +} + +TEST_F(AtomicTests, test_sub) { + std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::HALF}; + + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(1); + output.assign(5); + exp.assign(1); + + subHost(input, output); +// output.printBuffer("Sub"); + + ASSERT_EQ(exp, output); + } +} + +TEST_F(AtomicTests, test_div) { + std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16, nd4j::DataType::HALF}; + + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(2); + output.assign(32); + exp.assign(2); + + divHost(input, output); +// output.printBuffer("Div"); + ASSERT_EQ(exp, output); + } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 7bea1e820..6375d935c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2785,6 +2785,100 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) { delete results; } +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.777002f, 0.596913f, 0.72314f, 0.231040f, 0.509824f, + 0.179308f, 0.505282f, 0.86846f, 0.349958f, 0.509824f, + 0.087355f, 0.596913f, 0.65740f, 0.349958f, 0.159745f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); +// result->printIndexedBuffer("Quantized03"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.780061f, 0.596635f, 0.725987f, 0.231950f, 0.508419f, + 0.180014f, 0.504643f, 0.868406f, 0.351335f, 0.508419f, + 0.087699f, 0.596635f, 0.659988f, 0.351335f, 0.160374f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {8}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); +// result->printIndexedBuffer("Quantized03_1"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.775297f, 0.592226f, 0.725763f, 0.237561f, 0.503245f, + 0.189097f, 0.506084f, 0.868069f, 0.349355f, 0.503245f, + 0.094548f, 0.592226f, 0.654610f, 0.349355f, 0.153769f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); + result->printIndexedBuffer("Quantized03_2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.781600f, 0.593422f, 0.728248f, 0.233790f, 0.509014f, 0.186095f, 0.508648f, 0.868295f, 0.343809f, + 0.509014f, 0.093048f, 0.593422f, 0.658224f, 0.343809f, 0.165086f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {6}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); + result->printIndexedBuffer("Quantized03_3"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {