diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp new file mode 100644 index 000000000..b9fe7fef9 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +#include +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(compare_and_bitpack, 2, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + auto z0 = NDArrayFactory::create(x->ordering(), x->getShapeAsVector()); + BROADCAST_CHECK_EMPTY(x, y, (&z0)); + + auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); + bitcast res; + auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false); + if (tZ != &z0) { + delete tZ; + } + + return status; + } + + DECLARE_TYPES(compare_and_bitpack) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::UINT8); + } + + DECLARE_SHAPE_FN(compare_and_bitpack) { + auto inShape = inputShape->at(0); + DataType newType = DataType::UINT8; + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType))); + } + + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 652c2be8c..605ce95d3 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1731,6 +1731,20 @@ namespace nd4j { DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2); #endif + /** + * compare_and_bitpack - compare with greater and pack result with uint8 + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - threshold + * + * + * output: + * 0 - NDArray with the same shape as input and type uint8 + */ + #if NOT_EXCLUDED(OP_compare_and_bitpack) + DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0); + #endif } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 1d649cd82..84a1f2dc9 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -2387,6 +2387,25 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { delete result; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto threshold = NDArrayFactory::create(2.0); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + + nd4j::ops::compare_and_bitpack op; + + auto result = op.execute({&x, &threshold}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto output = result->at(0); +// output->printIndexedBuffer("Packed to uint8"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {