diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp new file mode 100644 index 000000000..ba8eb9e7b --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp @@ -0,0 +1,73 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George Shulinok , created on 08.10.2019 +// + +#include +#if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) + +#include +#include +namespace nd4j { + namespace ops { + CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) { + + auto x = INPUT_VARIABLE(0); + + NDArray* min; + NDArray* max; + + REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs"); + + NDArray m; + NDArray m2; + if(block.width() == 3){ + min = INPUT_VARIABLE(1); + max = INPUT_VARIABLE(2); + } else if(block.getTArguments()->size() == 2){ + m = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); + m2 = NDArrayFactory::create(x->dataType(), T_ARG(1), block.launchContext()); + min = &m; + max = &m2; + } + auto output = OUTPUT_VARIABLE(0); + int numBits = 8; + if (block.getIArguments() && block.getIArguments()->size()) + numBits = INT_ARG(0); + bool narrowed = false; + //INT_ARG(1); + if (block.getIArguments()->size() == 2) { + numBits = INT_ARG(0); + narrowed = INT_ARG(1); + REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits for quatization should be in between 2 and 16, but %i was given.", numBits); + } + helpers::fakeQuantWithMinMaxVarsPerChannel(x, min, max, numBits, narrowed, output); + return ND4J_STATUS_OK; + } + + DECLARE_TYPES(fake_quant_with_min_max_vars_per_channel) { + getOpDescriptor() + -> setAllowedOutputTypes({ALL_FLOATS}) + -> setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); + } + + DECLARE_SYN(fake_quant_with_min_max_args_per_channel, fake_quant_with_min_max_vars_per_channel); + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 8f6849ef7..e0dc55937 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1747,6 +1747,9 @@ namespace nd4j { #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars) DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2); #endif + #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) + DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, -2); + #endif /** * compare_and_bitpack - compare with greater and pack result with uint8 diff --git a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp index df162474f..88c451ffb 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp @@ -93,6 +93,10 @@ namespace helpers { void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); } + void fakeQuantWithMinMaxVarsPerChannel(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); + } + BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu index 9bb331685..4e62aafa8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu @@ -91,6 +91,10 @@ namespace helpers { void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); } + void fakeQuantWithMinMaxVarsPerChannel(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); + } + BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/fake_quantization.h b/libnd4j/include/ops/declarable/helpers/fake_quantization.h index aa0941db4..7a43a15cb 100644 --- a/libnd4j/include/ops/declarable/helpers/fake_quantization.h +++ b/libnd4j/include/ops/declarable/helpers/fake_quantization.h @@ -27,6 +27,7 @@ namespace ops { namespace helpers { void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output); + void fakeQuantWithMinMaxVarsPerChannel(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output); } } }