Merge pull request #3 from KonduitAI/shugeo_div_no_nan

Implement divide_no_nan op.
master
raver119 2019-10-03 18:27:12 +03:00 committed by GitHub
commit bbfa41cbc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 79 additions and 4 deletions

View File

@ -77,7 +77,8 @@
(27, LogicalOr) ,\
(28, LogicalXor) ,\
(29, LogicalNot) ,\
(30, LogicalAnd)
(30, LogicalAnd), \
(31, DivideNoNan)
// these ops return same data type as input
#define TRANSFORM_SAME_OPS \
@ -243,8 +244,8 @@
(42, LstmClip), \
(43, TruncateMod) ,\
(44, SquaredReverseSubtract) ,\
(45, ReversePow)
(45, ReversePow), \
(46, DivideNoNan)
@ -378,7 +379,8 @@
(34, AMaxPairwise), \
(35, AMinPairwise) ,\
(36, TruncateMod), \
(37, ReplaceNans)
(37, ReplaceNans), \
(38, DivideNoNan)

View File

@ -46,6 +46,7 @@ namespace nd4j {
static BroadcastOpsTuple Add();
static BroadcastOpsTuple Assign();
static BroadcastOpsTuple Divide();
static BroadcastOpsTuple DivideNoNan();
static BroadcastOpsTuple Multiply();
static BroadcastOpsTuple Subtract();
};

View File

@ -156,6 +156,9 @@ namespace nd4j {
DECLARE_CUSTOM_OP(divide_bp, 3, 2, false, 0, 0);
#endif
#if NOT_EXCLUDED(OP_divide_no_nan)
DECLARE_BROADCASTABLE_OP(divide_no_nan, 0, 0);
#endif
/**
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
* 1) if shapes are equal that's pairwise operation, result will have the same shape.

View File

@ -37,6 +37,10 @@ namespace nd4j {
return custom(nd4j::scalar::Divide, nd4j::pairwise::Divide, nd4j::broadcast::Divide);
}
BroadcastOpsTuple BroadcastOpsTuple::DivideNoNan() {
return custom(nd4j::scalar::DivideNoNan, nd4j::pairwise::DivideNoNan, nd4j::broadcast::DivideNoNan);
}
BroadcastOpsTuple BroadcastOpsTuple::Multiply() {
return custom(nd4j::scalar::Multiply, nd4j::pairwise::Multiply, nd4j::broadcast::Multiply);
}

View File

@ -360,6 +360,34 @@ namespace simdOps {
}
};
template <typename X, typename Y, typename Z>
class DivideNoNan {
public:
op_def static Z op(X d1, Y d2) {
if (d2 == (Y)0) return (Z)0;
return static_cast<Z>(d1 / d2);
}
op_def static Z op(X d1, Y d2, Z *params) {
if (d2 == (Y)0) return (Z)0;
return static_cast<Z>(d1 / d2);
}
op_def static Z op(X d1) {
return static_cast<Z>(d1);
}
// op for MetaOps
op_def static Z op(X d1, Y *params) {
if (params[0] == (Y)0) return (Z)0;
return static_cast<Z>(d1 / params[0]);
}
op_def static X startingValue() {
return static_cast<X>(1);
}
};
template <typename X, typename Y, typename Z>
class SafeDivide {
public:

View File

@ -1194,6 +1194,43 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) {
delete res;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) {
auto x = NDArrayFactory::create<float>('c', {3, 4, 5, 1});
auto y = NDArrayFactory::create<float>('c', {1, 6});
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5, 6});
x.assign(6);
y.assign(2);
exp.assign(3);
nd4j::ops::divide_no_nan div;
auto res = div.execute({&x, &y}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(exp));
delete res;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) {
auto x = NDArrayFactory::create<float>({6,6,6,6,6});
auto y = NDArrayFactory::create<float>({3,3,0,3,3});
auto exp = NDArrayFactory::create<float>({2, 2, 0, 2, 2});
nd4j::ops::divide_no_nan div;
auto res = div.execute({&x, &y}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(exp));
delete res;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) {