Merge pull request #5 from KonduitAI/shugeo_divnonan_full
Added implementation for divide_no_nan op and tests.master
commit
b8f2a83a5a
|
@ -77,7 +77,8 @@
|
||||||
(27, LogicalOr) ,\
|
(27, LogicalOr) ,\
|
||||||
(28, LogicalXor) ,\
|
(28, LogicalXor) ,\
|
||||||
(29, LogicalNot) ,\
|
(29, LogicalNot) ,\
|
||||||
(30, LogicalAnd)
|
(30, LogicalAnd), \
|
||||||
|
(31, DivideNoNan)
|
||||||
|
|
||||||
// these ops return same data type as input
|
// these ops return same data type as input
|
||||||
#define TRANSFORM_SAME_OPS \
|
#define TRANSFORM_SAME_OPS \
|
||||||
|
@ -243,8 +244,8 @@
|
||||||
(42, LstmClip), \
|
(42, LstmClip), \
|
||||||
(43, TruncateMod) ,\
|
(43, TruncateMod) ,\
|
||||||
(44, SquaredReverseSubtract) ,\
|
(44, SquaredReverseSubtract) ,\
|
||||||
(45, ReversePow)
|
(45, ReversePow), \
|
||||||
|
(46, DivideNoNan)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -378,7 +379,8 @@
|
||||||
(34, AMaxPairwise), \
|
(34, AMaxPairwise), \
|
||||||
(35, AMinPairwise) ,\
|
(35, AMinPairwise) ,\
|
||||||
(36, TruncateMod), \
|
(36, TruncateMod), \
|
||||||
(37, ReplaceNans)
|
(37, ReplaceNans), \
|
||||||
|
(38, DivideNoNan)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ namespace nd4j {
|
||||||
static BroadcastOpsTuple Add();
|
static BroadcastOpsTuple Add();
|
||||||
static BroadcastOpsTuple Assign();
|
static BroadcastOpsTuple Assign();
|
||||||
static BroadcastOpsTuple Divide();
|
static BroadcastOpsTuple Divide();
|
||||||
|
static BroadcastOpsTuple DivideNoNan();
|
||||||
static BroadcastOpsTuple Multiply();
|
static BroadcastOpsTuple Multiply();
|
||||||
static BroadcastOpsTuple Subtract();
|
static BroadcastOpsTuple Subtract();
|
||||||
};
|
};
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author George A. Shulinok <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_divide_no_nan)
|
||||||
|
|
||||||
|
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
BROADCASTABLE_OP_IMPL(divide_no_nan, 0, 0) {
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(!y->isB(), 0, "DIVIDE_NO_NAN OP: you can't divide by bool array!");
|
||||||
|
auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::DivideNoNan(), x, y, z);
|
||||||
|
if (tZ == nullptr)
|
||||||
|
return ND4J_STATUS_KERNEL_FAILURE;
|
||||||
|
else if (tZ != z) {
|
||||||
|
OVERWRITE_RESULT(tZ);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
DECLARE_SYN(Div, divide);
|
||||||
|
|
||||||
|
DECLARE_TYPES(divide_no_nan) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, DataType::ANY)
|
||||||
|
->setAllowedInputTypes(1, DataType::ANY)
|
||||||
|
->setAllowedOutputTypes(0, DataType::INHERIT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -156,6 +156,18 @@ namespace nd4j {
|
||||||
DECLARE_CUSTOM_OP(divide_bp, 3, 2, false, 0, 0);
|
DECLARE_CUSTOM_OP(divide_bp, 3, 2, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
|
||||||
|
* 1) if shapes are equal that's pairwise operation, result will have the same shape.
|
||||||
|
* 2) if shape X is scalar and shape Y is array - result will have shape equal to Y.
|
||||||
|
* 3) if shape X is array and shape Y is scalar - result will have shape equal to X.
|
||||||
|
* 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result.
|
||||||
|
*
|
||||||
|
* This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_divide_no_nan)
|
||||||
|
DECLARE_BROADCASTABLE_OP(divide_no_nan, 0, 0);
|
||||||
|
#endif
|
||||||
/**
|
/**
|
||||||
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
|
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
|
||||||
* 1) if shapes are equal that's pairwise operation, result will have the same shape.
|
* 1) if shapes are equal that's pairwise operation, result will have the same shape.
|
||||||
|
|
|
@ -37,6 +37,10 @@ namespace nd4j {
|
||||||
return custom(nd4j::scalar::Divide, nd4j::pairwise::Divide, nd4j::broadcast::Divide);
|
return custom(nd4j::scalar::Divide, nd4j::pairwise::Divide, nd4j::broadcast::Divide);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BroadcastOpsTuple BroadcastOpsTuple::DivideNoNan() {
|
||||||
|
return custom(nd4j::scalar::DivideNoNan, nd4j::pairwise::DivideNoNan, nd4j::broadcast::DivideNoNan);
|
||||||
|
}
|
||||||
|
|
||||||
BroadcastOpsTuple BroadcastOpsTuple::Multiply() {
|
BroadcastOpsTuple BroadcastOpsTuple::Multiply() {
|
||||||
return custom(nd4j::scalar::Multiply, nd4j::pairwise::Multiply, nd4j::broadcast::Multiply);
|
return custom(nd4j::scalar::Multiply, nd4j::pairwise::Multiply, nd4j::broadcast::Multiply);
|
||||||
}
|
}
|
||||||
|
|
|
@ -360,6 +360,34 @@ namespace simdOps {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename X, typename Y, typename Z>
|
||||||
|
class DivideNoNan {
|
||||||
|
public:
|
||||||
|
op_def static Z op(X d1, Y d2) {
|
||||||
|
if (d2 == (Y)0) return (Z)0;
|
||||||
|
return static_cast<Z>(d1 / d2);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z op(X d1, Y d2, Z *params) {
|
||||||
|
if (d2 == (Y)0) return (Z)0;
|
||||||
|
return static_cast<Z>(d1 / d2);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z op(X d1) {
|
||||||
|
return static_cast<Z>(d1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// op for MetaOps
|
||||||
|
op_def static Z op(X d1, Y *params) {
|
||||||
|
if (params[0] == (Y)0) return (Z)0;
|
||||||
|
return static_cast<Z>(d1 / params[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static X startingValue() {
|
||||||
|
return static_cast<X>(1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
template <typename X, typename Y, typename Z>
|
||||||
class SafeDivide {
|
class SafeDivide {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -1194,6 +1194,41 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) {
|
||||||
delete res;
|
delete res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3, 4, 5, 1});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 6});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5, 6});
|
||||||
|
x.assign(6);
|
||||||
|
y.assign(2);
|
||||||
|
exp.assign(3);
|
||||||
|
|
||||||
|
nd4j::ops::divide_no_nan div;
|
||||||
|
auto res = div.execute({&x, &y}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
ASSERT_TRUE(res->at(0)->equalsTo(exp));
|
||||||
|
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>({6,6,6,6,6});
|
||||||
|
auto y = NDArrayFactory::create<float>({3,3,0,3,3});
|
||||||
|
auto exp = NDArrayFactory::create<float>({2, 2, 0, 2, 2});
|
||||||
|
|
||||||
|
nd4j::ops::divide_no_nan div;
|
||||||
|
auto res = div.execute({&x, &y}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
ASSERT_TRUE(res->at(0)->equalsTo(exp));
|
||||||
|
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) {
|
TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) {
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue