Fixed fake_quant_with_min_max_vars op and tests.

master
shugeo 2019-10-10 13:23:11 +03:00
parent 3c0c59ab88
commit c13e945a96
2 changed files with 91 additions and 43 deletions

View File

@ -26,51 +26,44 @@ namespace ops {
namespace helpers {
template <typename T>
static void Nudge(T min, T max, T quant_min, T quant_max, T* scale, T* nudged_min, T* nudged_max) {
*scale = (max - min) / (quant_max - quant_min);
auto zero_point_from_min = quant_min - min / *scale;
uint16_t const nudged_zero_point = [zero_point_from_min, quant_min, quant_max] {
if (zero_point_from_min < quant_min) {
static void Nudge(T min, T max, int quant_min, int quant_max, T* scale, T* nudged_min, T* nudged_max) {
T quant_max_float = static_cast<T>(quant_max);
T quant_min_float = static_cast<T>(quant_min);
*scale = (max - min) / (quant_max_float - quant_min_float);
auto zero_point_from_min = quant_min_float - min / *scale;
uint16_t const nudged_zero_point = [zero_point_from_min, quant_min, quant_max, quant_max_float, quant_min_float] {
if (zero_point_from_min < quant_min_float) {
return static_cast<uint16_t>(quant_min);
}
if (zero_point_from_min > quant_max) {
if (zero_point_from_min > quant_max_float) {
return static_cast<uint16_t>(quant_max);
}
return nd4j::math::nd4j_round<T,uint16_t>(zero_point_from_min);
}();
*nudged_min = (quant_min - nudged_zero_point) * (*scale);
*nudged_max = (quant_max - nudged_zero_point) * (*scale);
*nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
*nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
}
template <typename T>
void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
int lowIntBound = narrowed ? 1 : 0;
int upperIntBound = 1 << numBits - 1;
int upperIntBound = (1 << numBits) - 1;
auto channels = input->sizeAt(-1);
const float quant_min_float = static_cast<float>(lowIntBound);
const float quant_max_float = static_cast<float>(upperIntBound);
// auto scaleTensor(*input); // = NDArrayFactory::create(input->ordering(), input->getShapeAsVector(), input->getWorkspace());
auto clamped(*input); // = NDArrayFactory::create(input->ordering(), input->getShapeAsVector(), input->getWorkspace());
for (auto i = 0; i < min->lengthOf(); i++) {
PRAGMA_OMP_PARALLEL_FOR
for (auto i = 0; i < channels; i++) {
T scale, nudged_min, nudged_max;
Nudge<T>(min->t<T>(i), max->t<T>(i), quant_min_float, quant_max_float, &scale, &nudged_min, &nudged_max);
auto wiseMinMax = LAMBDA_T(x, nudged_min, nudged_max) {
if (x < nudged_min) {
return nudged_min;
}
else if (x > nudged_max)
return nudged_max;
return x;
};
// scaleTensor.assign(scale);
input->applyLambda<T>(wiseMinMax, &clamped);
clamped -= nudged_min;
// auto nudgedScale = scale;
clamped /= scale;
clamped += T(0.5f);
clamped.applyTransform(transform::Floor, output, nullptr);
(*output) *= scale;
(*output) += nudged_min;
Nudge<T>(min->t<T>(i), max->t<T>(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max);
for (auto e = 0; e < input->lengthOf(); e += channels) {
T val = input->t<T>(e + i);
if ( val <= nudged_min)
val = nudged_min;
else if (val >= nudged_max)
val = nudged_max;
output->t<T>(e + i) = math::nd4j_floor<T,T>((val - nudged_min)/scale + T(0.5)) * scale + nudged_min;
}
}
}
@ -91,7 +84,7 @@ namespace helpers {
template <typename T>
void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
int lowIntBound = narrowed ? 1 : 0;
int upperIntBound = 1 << numBits - 1;
int upperIntBound = (1 << numBits) - 1;
const float quant_min_float = static_cast<float>(lowIntBound);
const float quant_max_float = static_cast<float>(upperIntBound);

View File

@ -2117,7 +2117,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) {
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.251953f, 0.0f, 0.0f}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,3}, {-63.75, -63.75, -63.75, -63.5, 0., 0.}, nd4j::DataType::FLOAT32);
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
@ -2127,7 +2127,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printIndexedBuffer("Quantized");
result->printBuffer("Quantized");
exp.printBuffer("Expected");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
@ -2137,7 +2138,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) {
NDArray x = NDArrayFactory::create<double>('c', {2,3}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1});
NDArray exp = NDArrayFactory::create<double>('c', {2,3}, {-63.75, -63.75, -63.251953, -63.251953, 0.0, 0.0});
NDArray exp = NDArrayFactory::create<double>('c', {2,3}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. });
NDArray min = NDArrayFactory::create<double>(-63.65);
NDArray max = NDArrayFactory::create<double>(0.1);
@ -2158,7 +2159,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) {
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) {
NDArray x = NDArrayFactory::create<double>('c', {1,2,3,1}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1});
NDArray exp = NDArrayFactory::create<double>('c', {1,2,3,1}, {-63.75, -63.75, -63.251953, -63.251953, 0.0, 0.0});
NDArray exp = NDArrayFactory::create<double>('c', {1,2,3,1}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. });
NDArray min = NDArrayFactory::create<double>('c', {1},{-63.65});
NDArray max = NDArrayFactory::create<double>('c', {1}, {0.1});
@ -2179,8 +2180,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) {
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {
NDArray x = NDArrayFactory::create<float>('c', {2,4,5,3});
NDArray exp = NDArrayFactory::create<float>('c', {2,4,5,3},
{1.0588236, 1.9607843, 3.019608, 4.0588236, 5.098039, 6.039216, 7.0588236, 8.039216, 9.058824,
NDArray exp = NDArrayFactory::create<float>('c', {2,4,5,3},{
1.0588236, 1.9607843, 3.019608, 4.0588236, 5.098039, 6.039216, 7.0588236, 8.039216, 9.058824,
10.058824, 10.980392, 12.078432, 13.058824, 13.921569, 15.09804, 16.058825, 17.058825, 18.117647,
19.058825, 20., 21.137257, 22.058825, 22.941177, 23.882355, 25.058825, 26.078432, 26.901962,
28.058825, 29.019608, 29.92157, 31.058825, 31.960785, 32.941177, 34.058823, 35.09804, 35.960785,
@ -2203,10 +2204,64 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
result->printBuffer("Quantized per channels 4");
exp.printBuffer("Quantized per channest E");
auto diff = *result - exp;
diff.printIndexedBuffer("Difference");
// result->printBuffer("Quantized per channels 4");
// exp.printBuffer("Quantized per channest E");
// auto diff = *result - exp;
// diff.printIndexedBuffer("Difference");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
NDArray x = NDArrayFactory::create<float>('c', {2, 3, 5, 4});
NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 5, 4},{
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-16. , -15.058824 , -13.960785 , -13.0196085 ,
-11.92157 , -10.980392 , -10.039217 , -8.941177 ,
-8.000001 , -7.0588236 , -5.960785 , -5.0196085 ,
-3.9215698 , -2.9803925 , -2.039217 , -0.94117737,
0. , 0.94117737, 2.039215 , 2.9803925 ,
4.07843 , 5.0196075 , 5.960783 , 7.0588226 ,
8. , 8.941177 , 10.039215 , 10.980392 ,
12.07843 , 13.019608 , 13.960783 , 15.058823 ,
16. , 16.941177 , 18.039217 , 18.980392 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823
});
NDArray min = NDArrayFactory::create<float>({-20., -19., -18., -17});
NDArray max = NDArrayFactory::create<float>({20., 21., 22., 23});
x.linspace(-60.);
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printBuffer("Quantized per channels 5");
// exp.printBuffer("Quantized per channest E");
// auto diff = *result - exp;
// diff.printIndexedBuffer("Difference");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));