Added tests for fake_quant_with_min_max_vars_per_channel op.
parent
cb56b0b06a
commit
3a89e51811
|
@ -2154,6 +2154,62 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<double>('c', {1,2,3,1}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1});
|
||||||
|
NDArray exp = NDArrayFactory::create<double>('c', {1,2,3,1}, {-63.75, -63.75, -63.251953, -63.251953, 0.0, 0.0});
|
||||||
|
NDArray min = NDArrayFactory::create<double>(-63.65);
|
||||||
|
NDArray max = NDArrayFactory::create<double>(0.1);
|
||||||
|
|
||||||
|
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||||
|
auto results = op.execute({&x, &min, &max}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
// result->printIndexedBuffer("Quantized2");
|
||||||
|
ASSERT_TRUE(exp.isSameShapeStrict(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<double>('c', {2,4,5,3});
|
||||||
|
NDArray exp = NDArrayFactory::create<double>('c', {2,4,5,3},
|
||||||
|
{1.0588236, 1.9607843, 3.019608, 4.0588236, 5.098039, 6.039216, 7.0588236, 8.039216, 9.058824,
|
||||||
|
10.058824, 10.980392, 12.078432, 13.058824, 13.921569, 15.09804, 16.058825, 17.058825, 18.117647,
|
||||||
|
19.058825, 20., 21.137257, 22.058825, 22.941177, 23.882355, 25.058825, 26.078432, 26.901962,
|
||||||
|
28.058825, 29.019608, 29.92157, 31.058825, 31.960785, 32.941177, 34.058823, 35.09804, 35.960785,
|
||||||
|
37.058823, 38.039215, 38.980392, 40.058823, 40.980392, 42.000004, 43.058826, 43.92157, 45.01961,
|
||||||
|
45., 47.058823, 48.03922, 45., 50., 51.058826, 45., 50., 54.078434,
|
||||||
|
45., 50., 57.09804, 45., 50., 60.11765, 45., 50., 62.862747,
|
||||||
|
45., 50., 65.882355, 45., 50., 68.90196, 45., 50., 70.,
|
||||||
|
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||||
|
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||||
|
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||||
|
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||||
|
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||||
|
45., 50., 70.});
|
||||||
|
NDArray min = NDArrayFactory::create<double>({20., 20., 20.});
|
||||||
|
NDArray max = NDArrayFactory::create<double>({65., 70., 90.});
|
||||||
|
|
||||||
|
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||||
|
auto results = op.execute({&x, &min, &max}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
// result->printIndexedBuffer("Quantized2");
|
||||||
|
ASSERT_TRUE(exp.isSameShapeStrict(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {
|
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue