From 3a89e518115b4123fb44ff014a14624b3fa84c71 Mon Sep 17 00:00:00 2001 From: shugeo Date: Wed, 9 Oct 2019 13:38:18 +0300 Subject: [PATCH] Added tests for fake_quant_with_min_max_vars_per_channel op. --- .../layers_tests/DeclarableOpsTests10.cpp | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 446763096..191ee8524 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2154,6 +2154,62 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) { delete results; } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) { + + NDArray x = NDArrayFactory::create('c', {1,2,3,1}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); + NDArray exp = NDArrayFactory::create('c', {1,2,3,1}, {-63.75, -63.75, -63.251953, -63.251953, 0.0, 0.0}); + NDArray min = NDArrayFactory::create(-63.65); + NDArray max = NDArrayFactory::create(0.1); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); + // result->printIndexedBuffer("Quantized2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) { + + NDArray x = NDArrayFactory::create('c', {2,4,5,3}); + NDArray exp = NDArrayFactory::create('c', {2,4,5,3}, + {1.0588236, 1.9607843, 3.019608, 4.0588236, 5.098039, 6.039216, 7.0588236, 8.039216, 9.058824, + 10.058824, 10.980392, 12.078432, 13.058824, 13.921569, 15.09804, 16.058825, 17.058825, 18.117647, + 19.058825, 20., 21.137257, 22.058825, 22.941177, 23.882355, 25.058825, 26.078432, 26.901962, + 28.058825, 29.019608, 29.92157, 31.058825, 31.960785, 32.941177, 34.058823, 35.09804, 35.960785, + 37.058823, 38.039215, 38.980392, 40.058823, 40.980392, 42.000004, 43.058826, 43.92157, 45.01961, + 45., 47.058823, 48.03922, 45., 50., 51.058826, 45., 50., 54.078434, + 45., 50., 57.09804, 45., 50., 60.11765, 45., 50., 62.862747, + 45., 50., 65.882355, 45., 50., 68.90196, 45., 50., 70., + 45., 50., 70., 45., 50., 70., 45., 50., 70., + 45., 50., 70., 45., 50., 70., 45., 50., 70., + 45., 50., 70., 45., 50., 70., 45., 50., 70., + 45., 50., 70., 45., 50., 70., 45., 50., 70., + 45., 50., 70., 45., 50., 70., 45., 50., 70., + 45., 50., 70.}); + NDArray min = NDArrayFactory::create({20., 20., 20.}); + NDArray max = NDArrayFactory::create({65., 70., 90.}); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); + // result->printIndexedBuffer("Quantized2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {