Added input checks for op.

master
shugeo 2019-10-09 15:52:13 +03:00
parent 3a89e51811
commit d0cbd33b0e
1 changed files with 13 additions and 14 deletions

View File

@ -28,22 +28,19 @@ namespace nd4j {
CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) {
auto x = INPUT_VARIABLE(0);
NDArray* min;
NDArray* max;
auto min = INPUT_VARIABLE(1);
auto max = INPUT_VARIABLE(2);
REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs");
auto depth = x->sizeAt(-1);
REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0,
"fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length");
REQUIRE_TRUE(depth == min->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Min length should be"
" %lld, but %lld occurs.", depth, min->lengthOf());
NDArray m;
NDArray m2;
REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be"
"%lld, but %lld occurs.", depth, max->lengthOf());
if(block.width() == 3) {
min = INPUT_VARIABLE(1);
max = INPUT_VARIABLE(2);
} else if(block.getTArguments()->size() == 2){
m = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext());
m2 = NDArrayFactory::create(x->dataType(), T_ARG(1), block.launchContext());
min = &m;
max = &m2;
}
auto output = OUTPUT_VARIABLE(0);
int numBits = 8;
@ -54,7 +51,9 @@ namespace nd4j {
if (block.getIArguments()->size() == 2) {
numBits = INT_ARG(0);
narrowed = INT_ARG(1);
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits for quatization should be in between 2 and 16, but %i was given.", numBits);
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits"
" for quatization should be in between 2 and 16, but %i "
"was given.", numBits);
}
helpers::fakeQuantWithMinMaxVarsPerChannel(x, min, max, numBits, narrowed, output);
return ND4J_STATUS_OK;