From 863ff76878e7969a98ee20f2ac2d0573d25089d2 Mon Sep 17 00:00:00 2001 From: shugeo Date: Wed, 2 Oct 2019 12:17:00 +0300 Subject: [PATCH 1/6] Added declaration for bincast op. --- libnd4j/include/ops/declarable/headers/datatypes.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/libnd4j/include/ops/declarable/headers/datatypes.h b/libnd4j/include/ops/declarable/headers/datatypes.h index d8ff39d48..7c96ae4c7 100644 --- a/libnd4j/include/ops/declarable/headers/datatypes.h +++ b/libnd4j/include/ops/declarable/headers/datatypes.h @@ -99,6 +99,14 @@ namespace nd4j { #if NOT_EXCLUDED(OP_cast) DECLARE_CUSTOM_OP(cast, 1, 1, false, 0, 1); #endif + /** + * This operation change type of input and modified shape of output to conform with given data type + * + * all as above op + * */ + #if NOT_EXCLUDED(OP_bincast) + DECLARE_CUSTOM_OP(bincast, 1, 1, false, 0, 1); + #endif } } From a27e61553a32ec279062ad9a4687f17f4fafc57f Mon Sep 17 00:00:00 2001 From: shugeo Date: Wed, 2 Oct 2019 15:04:28 +0300 Subject: [PATCH 2/6] Added tests and fixed op name. --- .../ops/declarable/headers/datatypes.h | 4 +-- .../layers_tests/DeclarableOpsTests15.cpp | 26 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/ops/declarable/headers/datatypes.h b/libnd4j/include/ops/declarable/headers/datatypes.h index 7c96ae4c7..b82ab4ad6 100644 --- a/libnd4j/include/ops/declarable/headers/datatypes.h +++ b/libnd4j/include/ops/declarable/headers/datatypes.h @@ -104,8 +104,8 @@ namespace nd4j { * * all as above op * */ - #if NOT_EXCLUDED(OP_bincast) - DECLARE_CUSTOM_OP(bincast, 1, 1, false, 0, 1); + #if NOT_EXCLUDED(OP_bitcast) + DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1); #endif } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 0bd05cec3..685953f2d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -228,6 +228,32 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { ASSERT_TRUE(e.equalsTo(out)); delete result; } +TEST_F(DeclarableOpsTests15, Test_BinCast_1) { + auto x = NDArrayFactory::create('c', {2, 2, 2}); + auto e = NDArrayFactory::create('c', {2, 2}, {2., 512., 8192., 131072.032 }); + x.linspace(1.); + nd4j::ops::bitcast op; + auto result = op.execute({&x}, {}, {nd4j::DataType::DOUBLE}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto out = result->at(0); +// out->printIndexedBuffer("Casted result"); + ASSERT_TRUE(e.equalsTo(out)); + delete result; +} + +TEST_F(DeclarableOpsTests15, Test_BinCast_2) { + auto x = NDArrayFactory::create('c', {2, 4}); + auto e = NDArrayFactory::create('c', {2, 4, 2}, {0, 1.875, 0, 2., 0, 2.125, 0, 2.25, + 0, 2.312, 0, 2.375, 0, 2.438, 0., 2.5}); + x.linspace(1.); + nd4j::ops::bitcast op; + auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto out = result->at(0); + out->printIndexedBuffer("Casted result"); + ASSERT_TRUE(e.equalsTo(out)); + delete result; +} TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) { auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); From 1c6173d21882bea8dfc5e7ab77104b1e3e8f7bca Mon Sep 17 00:00:00 2001 From: shugeo Date: Wed, 2 Oct 2019 15:04:59 +0300 Subject: [PATCH 3/6] Added implementation of bitcast op. --- .../declarable/generic/datatypes/bitcast.cpp | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp diff --git a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp new file mode 100644 index 000000000..6cbdfdae2 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp @@ -0,0 +1,97 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#include +#if NOT_EXCLUDED(OP_bitcast) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(bitcast, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if(input->isEmpty()){ + REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty."); + return Status::OK(); + } + +// if (!block.isInplace()) +// output->assign(input); + input->syncToHost(); + output->syncToHost(); + memcpy(output->buffer(), input->buffer(), input->lengthOf() * input->sizeOfT()); + output->syncToDevice(); + output->tickWriteDevice(); + //STORE_RESULT(output); + return Status::OK(); + } + DECLARE_SYN(BitCast, bitcast); + + DECLARE_SHAPE_FN(bitcast) { + auto inShape = inputShape->at(0); + auto inputRank = shape::rank(inShape); + auto it = INT_ARG(0); + DataType newType = DataTypeUtils::fromInt(it); + DataType oldType = ArrayOptions::dataType(inShape); + // correct output shape to conform with output data type + auto inputSize = DataTypeUtils::sizeOf(oldType); + auto outputSize = DataTypeUtils::sizeOf(newType); + + if (shape::length(inShape) == 0) + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType))); + + if (inputSize == outputSize) { + // only type should be changed + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType))); + } + else if (inputSize > outputSize) { + // range of output increased by 1 with inputSize / outputSize as last dimension + std::vector shapeOf(inputRank + 1); + int i; + for (i = 0; i < inputRank; ++i) { + shapeOf[i] = inShape[i + 1]; + } + shapeOf[i] = inputSize / outputSize; + auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf); + return SHAPELIST(outputShape); + } + REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %ull > %ull. So last dimension should be %ull, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1)); + std::vector shapeOf(inputRank - 1); + + for (auto i = 0; i < shapeOf.size(); ++i) { + shapeOf[i] = inShape[i + 1]; + } + + auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf); + return SHAPELIST(outputShape); + } + + DECLARE_TYPES(bitcast) { + getOpDescriptor() + ->setAllowedInputTypes(nd4j::DataType::ANY) + ->setAllowedOutputTypes(nd4j::DataType::ANY); + } + } +} + +#endif \ No newline at end of file From f3e42173ef731d67453041a8ea50ba8416928921 Mon Sep 17 00:00:00 2001 From: shugeo Date: Wed, 2 Oct 2019 16:51:09 +0300 Subject: [PATCH 4/6] Refactored buffer copying to avoid wrong usage of buffers. --- .../ops/declarable/generic/datatypes/bitcast.cpp | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp index 6cbdfdae2..4e54e541a 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp @@ -29,20 +29,15 @@ namespace nd4j { CUSTOM_OP_IMPL(bitcast, 1, 1, false, 0, 1) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - + // when empty - nothing to do if(input->isEmpty()){ REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty."); return Status::OK(); } + // buffers for both input and output should be equals + DataBuffer buf(input->buffer(), input->specialBuffer(), input->lengthOf() * input->sizeOfT(), input->dataType()); + *(output->dataBuffer()) = buf; -// if (!block.isInplace()) -// output->assign(input); - input->syncToHost(); - output->syncToHost(); - memcpy(output->buffer(), input->buffer(), input->lengthOf() * input->sizeOfT()); - output->syncToDevice(); - output->tickWriteDevice(); - //STORE_RESULT(output); return Status::OK(); } DECLARE_SYN(BitCast, bitcast); From 75ad3c8153a8392f9aa85604b5f977bf09be0a40 Mon Sep 17 00:00:00 2001 From: shugeo Date: Wed, 2 Oct 2019 19:05:26 +0300 Subject: [PATCH 5/6] Fixed test names. --- libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 685953f2d..fc7f29e3c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -228,7 +228,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { ASSERT_TRUE(e.equalsTo(out)); delete result; } -TEST_F(DeclarableOpsTests15, Test_BinCast_1) { +TEST_F(DeclarableOpsTests15, Test_BitCast_1) { auto x = NDArrayFactory::create('c', {2, 2, 2}); auto e = NDArrayFactory::create('c', {2, 2}, {2., 512., 8192., 131072.032 }); x.linspace(1.); @@ -241,7 +241,7 @@ TEST_F(DeclarableOpsTests15, Test_BinCast_1) { delete result; } -TEST_F(DeclarableOpsTests15, Test_BinCast_2) { +TEST_F(DeclarableOpsTests15, Test_BitCast_2) { auto x = NDArrayFactory::create('c', {2, 4}); auto e = NDArrayFactory::create('c', {2, 4, 2}, {0, 1.875, 0, 2., 0, 2.125, 0, 2.25, 0, 2.312, 0, 2.375, 0, 2.438, 0., 2.5}); From 130ee25682abdf432217b0873d35bf110b96d612 Mon Sep 17 00:00:00 2001 From: shugeo Date: Thu, 3 Oct 2019 10:57:48 +0300 Subject: [PATCH 6/6] Implemented compare_and_bitpack op. --- .../parity_ops/compare_and_bitpack.cpp | 60 +++++++++++++++++++ .../ops/declarable/headers/parity_ops.h | 14 +++++ .../layers_tests/DeclarableOpsTests9.cpp | 19 ++++++ 3 files changed, 93 insertions(+) create mode 100644 libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp new file mode 100644 index 000000000..b9fe7fef9 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +#include +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(compare_and_bitpack, 2, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + auto z0 = NDArrayFactory::create(x->ordering(), x->getShapeAsVector()); + BROADCAST_CHECK_EMPTY(x, y, (&z0)); + + auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); + bitcast res; + auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false); + if (tZ != &z0) { + delete tZ; + } + + return status; + } + + DECLARE_TYPES(compare_and_bitpack) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::UINT8); + } + + DECLARE_SHAPE_FN(compare_and_bitpack) { + auto inShape = inputShape->at(0); + DataType newType = DataType::UINT8; + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType))); + } + + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 652c2be8c..605ce95d3 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1731,6 +1731,20 @@ namespace nd4j { DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2); #endif + /** + * compare_and_bitpack - compare with greater and pack result with uint8 + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - threshold + * + * + * output: + * 0 - NDArray with the same shape as input and type uint8 + */ + #if NOT_EXCLUDED(OP_compare_and_bitpack) + DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0); + #endif } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 1d649cd82..84a1f2dc9 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -2387,6 +2387,25 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { delete result; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto threshold = NDArrayFactory::create(2.0); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + + nd4j::ops::compare_and_bitpack op; + + auto result = op.execute({&x, &threshold}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto output = result->at(0); +// output->printIndexedBuffer("Packed to uint8"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {