From 908e4c4912fbb552c77a676d763a0016cf86ddeb Mon Sep 17 00:00:00 2001 From: shugeo Date: Fri, 4 Oct 2019 10:29:15 +0300 Subject: [PATCH] Added implementation for divide_no_nan op and tests. --- libnd4j/include/loops/legacy_ops.h | 10 ++-- libnd4j/include/ops/BroadcastOpsTuple.h | 1 + .../generic/broadcastable/divide_no_nan.cpp | 57 +++++++++++++++++++ .../ops/declarable/headers/broadcastable.h | 12 ++++ .../include/ops/impl/BroadcastOpsTuple.cpp | 4 ++ libnd4j/include/ops/ops.h | 28 +++++++++ .../layers_tests/DeclarableOpsTests1.cpp | 35 ++++++++++++ 7 files changed, 143 insertions(+), 4 deletions(-) create mode 100644 libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index c298dde3a..b803bdb8d 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -77,7 +77,8 @@ (27, LogicalOr) ,\ (28, LogicalXor) ,\ (29, LogicalNot) ,\ - (30, LogicalAnd) + (30, LogicalAnd), \ + (31, DivideNoNan) // these ops return same data type as input #define TRANSFORM_SAME_OPS \ @@ -243,8 +244,8 @@ (42, LstmClip), \ (43, TruncateMod) ,\ (44, SquaredReverseSubtract) ,\ - (45, ReversePow) - + (45, ReversePow), \ + (46, DivideNoNan) @@ -378,7 +379,8 @@ (34, AMaxPairwise), \ (35, AMinPairwise) ,\ (36, TruncateMod), \ - (37, ReplaceNans) + (37, ReplaceNans), \ + (38, DivideNoNan) diff --git a/libnd4j/include/ops/BroadcastOpsTuple.h b/libnd4j/include/ops/BroadcastOpsTuple.h index 0583a8e4a..c665a0abc 100644 --- a/libnd4j/include/ops/BroadcastOpsTuple.h +++ b/libnd4j/include/ops/BroadcastOpsTuple.h @@ -46,6 +46,7 @@ namespace nd4j { static BroadcastOpsTuple Add(); static BroadcastOpsTuple Assign(); static BroadcastOpsTuple Divide(); + static BroadcastOpsTuple DivideNoNan(); static BroadcastOpsTuple Multiply(); static BroadcastOpsTuple Subtract(); }; diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp new file mode 100644 index 000000000..3cf808b8e --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#include +#if NOT_EXCLUDED(OP_divide_no_nan) + +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(divide_no_nan, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + REQUIRE_TRUE(!y->isB(), 0, "DIVIDE_NO_NAN OP: you can't divide by bool array!"); + auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::DivideNoNan(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); + } + DECLARE_SYN(Div, divide); + + DECLARE_TYPES(divide_no_nan) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/broadcastable.h b/libnd4j/include/ops/declarable/headers/broadcastable.h index 679a60254..b3b2463cd 100644 --- a/libnd4j/include/ops/declarable/headers/broadcastable.h +++ b/libnd4j/include/ops/declarable/headers/broadcastable.h @@ -156,6 +156,18 @@ namespace nd4j { DECLARE_CUSTOM_OP(divide_bp, 3, 2, false, 0, 0); #endif + /** + * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: + * 1) if shapes are equal that's pairwise operation, result will have the same shape. + * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. + * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. + * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. + * + * This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0 + */ + #if NOT_EXCLUDED(OP_divide_no_nan) + DECLARE_BROADCASTABLE_OP(divide_no_nan, 0, 0); + #endif /** * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: * 1) if shapes are equal that's pairwise operation, result will have the same shape. diff --git a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp index f42228cfb..ca408e8dc 100644 --- a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp @@ -37,6 +37,10 @@ namespace nd4j { return custom(nd4j::scalar::Divide, nd4j::pairwise::Divide, nd4j::broadcast::Divide); } + BroadcastOpsTuple BroadcastOpsTuple::DivideNoNan() { + return custom(nd4j::scalar::DivideNoNan, nd4j::pairwise::DivideNoNan, nd4j::broadcast::DivideNoNan); + } + BroadcastOpsTuple BroadcastOpsTuple::Multiply() { return custom(nd4j::scalar::Multiply, nd4j::pairwise::Multiply, nd4j::broadcast::Multiply); } diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index a80e274ca..a738f0bdc 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -360,6 +360,34 @@ namespace simdOps { } }; + template + class DivideNoNan { + public: + op_def static Z op(X d1, Y d2) { + if (d2 == (Y)0) return (Z)0; + return static_cast(d1 / d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + if (d2 == (Y)0) return (Z)0; + return static_cast(d1 / d2); + } + + op_def static Z op(X d1) { + return static_cast(d1); + } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + if (params[0] == (Y)0) return (Z)0; + return static_cast(d1 / params[0]); + } + + op_def static X startingValue() { + return static_cast(1); + } + }; + template class SafeDivide { public: diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index d0c597cc5..b6f5f125d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -1194,6 +1194,41 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) { delete res; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) { + + auto x = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto y = NDArrayFactory::create('c', {1, 6}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + x.assign(6); + y.assign(2); + exp.assign(3); + + nd4j::ops::divide_no_nan div; + auto res = div.execute({&x, &y}, {}, {}); + + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + ASSERT_TRUE(res->at(0)->equalsTo(exp)); + + delete res; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) { + + auto x = NDArrayFactory::create({6,6,6,6,6}); + auto y = NDArrayFactory::create({3,3,0,3,3}); + auto exp = NDArrayFactory::create({2, 2, 0, 2, 2}); + + nd4j::ops::divide_no_nan div; + auto res = div.execute({&x, &y}, {}, {}); + + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + ASSERT_TRUE(res->at(0)->equalsTo(exp)); + + delete res; +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) {