Added input checks for op.
parent
3a89e51811
commit
d0cbd33b0e
|
@ -28,22 +28,19 @@ namespace nd4j {
|
||||||
CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) {
|
CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto min = INPUT_VARIABLE(1);
|
||||||
NDArray* min;
|
auto max = INPUT_VARIABLE(2);
|
||||||
NDArray* max;
|
|
||||||
|
|
||||||
REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs");
|
REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs");
|
||||||
|
auto depth = x->sizeAt(-1);
|
||||||
|
REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0,
|
||||||
|
"fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length");
|
||||||
|
REQUIRE_TRUE(depth == min->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Min length should be"
|
||||||
|
" %lld, but %lld occurs.", depth, min->lengthOf());
|
||||||
|
|
||||||
NDArray m;
|
REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be"
|
||||||
NDArray m2;
|
"%lld, but %lld occurs.", depth, max->lengthOf());
|
||||||
if(block.width() == 3) {
|
if(block.width() == 3) {
|
||||||
min = INPUT_VARIABLE(1);
|
|
||||||
max = INPUT_VARIABLE(2);
|
|
||||||
} else if(block.getTArguments()->size() == 2){
|
|
||||||
m = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext());
|
|
||||||
m2 = NDArrayFactory::create(x->dataType(), T_ARG(1), block.launchContext());
|
|
||||||
min = &m;
|
|
||||||
max = &m2;
|
|
||||||
}
|
}
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
int numBits = 8;
|
int numBits = 8;
|
||||||
|
@ -54,7 +51,9 @@ namespace nd4j {
|
||||||
if (block.getIArguments()->size() == 2) {
|
if (block.getIArguments()->size() == 2) {
|
||||||
numBits = INT_ARG(0);
|
numBits = INT_ARG(0);
|
||||||
narrowed = INT_ARG(1);
|
narrowed = INT_ARG(1);
|
||||||
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits for quatization should be in between 2 and 16, but %i was given.", numBits);
|
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits"
|
||||||
|
" for quatization should be in between 2 and 16, but %i "
|
||||||
|
"was given.", numBits);
|
||||||
}
|
}
|
||||||
helpers::fakeQuantWithMinMaxVarsPerChannel(x, min, max, numBits, narrowed, output);
|
helpers::fakeQuantWithMinMaxVarsPerChannel(x, min, max, numBits, narrowed, output);
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
|
|
Loading…
Reference in New Issue