Added input checks for op.
parent
3a89e51811
commit
d0cbd33b0e
|
@ -28,22 +28,19 @@ namespace nd4j {
|
|||
CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
|
||||
NDArray* min;
|
||||
NDArray* max;
|
||||
auto min = INPUT_VARIABLE(1);
|
||||
auto max = INPUT_VARIABLE(2);
|
||||
|
||||
REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs");
|
||||
auto depth = x->sizeAt(-1);
|
||||
REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0,
|
||||
"fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length");
|
||||
REQUIRE_TRUE(depth == min->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Min length should be"
|
||||
" %lld, but %lld occurs.", depth, min->lengthOf());
|
||||
|
||||
NDArray m;
|
||||
NDArray m2;
|
||||
if(block.width() == 3){
|
||||
min = INPUT_VARIABLE(1);
|
||||
max = INPUT_VARIABLE(2);
|
||||
} else if(block.getTArguments()->size() == 2){
|
||||
m = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext());
|
||||
m2 = NDArrayFactory::create(x->dataType(), T_ARG(1), block.launchContext());
|
||||
min = &m;
|
||||
max = &m2;
|
||||
REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be"
|
||||
"%lld, but %lld occurs.", depth, max->lengthOf());
|
||||
if(block.width() == 3) {
|
||||
}
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
int numBits = 8;
|
||||
|
@ -54,7 +51,9 @@ namespace nd4j {
|
|||
if (block.getIArguments()->size() == 2) {
|
||||
numBits = INT_ARG(0);
|
||||
narrowed = INT_ARG(1);
|
||||
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits for quatization should be in between 2 and 16, but %i was given.", numBits);
|
||||
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits"
|
||||
" for quatization should be in between 2 and 16, but %i "
|
||||
"was given.", numBits);
|
||||
}
|
||||
helpers::fakeQuantWithMinMaxVarsPerChannel(x, min, max, numBits, narrowed, output);
|
||||
return ND4J_STATUS_OK;
|
||||
|
|
Loading…
Reference in New Issue