diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp index ba8eb9e7b..4d719b38b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp @@ -28,22 +28,19 @@ namespace nd4j { CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) { auto x = INPUT_VARIABLE(0); - - NDArray* min; - NDArray* max; + auto min = INPUT_VARIABLE(1); + auto max = INPUT_VARIABLE(2); REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs"); + auto depth = x->sizeAt(-1); + REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0, + "fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length"); + REQUIRE_TRUE(depth == min->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Min length should be" + " %lld, but %lld occurs.", depth, min->lengthOf()); - NDArray m; - NDArray m2; - if(block.width() == 3){ - min = INPUT_VARIABLE(1); - max = INPUT_VARIABLE(2); - } else if(block.getTArguments()->size() == 2){ - m = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); - m2 = NDArrayFactory::create(x->dataType(), T_ARG(1), block.launchContext()); - min = &m; - max = &m2; + REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be" + "%lld, but %lld occurs.", depth, max->lengthOf()); + if(block.width() == 3) { } auto output = OUTPUT_VARIABLE(0); int numBits = 8; @@ -54,7 +51,9 @@ namespace nd4j { if (block.getIArguments()->size() == 2) { numBits = INT_ARG(0); narrowed = INT_ARG(1); - REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits for quatization should be in between 2 and 16, but %i was given.", numBits); + REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits" + " for quatization should be in between 2 and 16, but %i " + "was given.", numBits); } helpers::fakeQuantWithMinMaxVarsPerChannel(x, min, max, numBits, narrowed, output); return ND4J_STATUS_OK;