Implement divide_no_nan op.
This commit is contained in:
		
							parent
							
								
									d2e98564d4
								
							
						
					
					
						commit
						6eaca179d6
					
				| @ -77,7 +77,8 @@ | ||||
|        (27, LogicalOr) ,\ | ||||
|        (28, LogicalXor) ,\ | ||||
|        (29, LogicalNot) ,\ | ||||
|        (30, LogicalAnd) | ||||
|        (30, LogicalAnd), \ | ||||
|        (31, DivideNoNan) | ||||
| 
 | ||||
| // these ops return same data type as input
 | ||||
| #define TRANSFORM_SAME_OPS \ | ||||
| @ -243,8 +244,8 @@ | ||||
|         (42, LstmClip), \ | ||||
|         (43, TruncateMod) ,\ | ||||
|         (44, SquaredReverseSubtract) ,\ | ||||
|         (45, ReversePow) | ||||
| 
 | ||||
|         (45, ReversePow), \ | ||||
|         (46, DivideNoNan) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @ -378,7 +379,8 @@ | ||||
|         (34, AMaxPairwise), \ | ||||
|         (35, AMinPairwise) ,\ | ||||
|         (36, TruncateMod), \ | ||||
|         (37, ReplaceNans) | ||||
|         (37, ReplaceNans), \ | ||||
|         (38, DivideNoNan) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -46,6 +46,7 @@ namespace nd4j { | ||||
|         static BroadcastOpsTuple Add(); | ||||
|         static BroadcastOpsTuple Assign(); | ||||
|         static BroadcastOpsTuple Divide(); | ||||
|         static BroadcastOpsTuple DivideNoNan(); | ||||
|         static BroadcastOpsTuple Multiply(); | ||||
|         static BroadcastOpsTuple Subtract(); | ||||
|     }; | ||||
|  | ||||
| @ -156,6 +156,9 @@ namespace nd4j { | ||||
|         DECLARE_CUSTOM_OP(divide_bp, 3, 2, false, 0, 0); | ||||
|         #endif | ||||
| 
 | ||||
|         #if NOT_EXCLUDED(OP_divide_no_nan) | ||||
|         DECLARE_BROADCASTABLE_OP(divide_no_nan, 0, 0); | ||||
|         #endif | ||||
|         /**
 | ||||
|          * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: | ||||
|          * 1) if shapes are equal that's pairwise operation, result will have the same shape. | ||||
|  | ||||
| @ -37,6 +37,10 @@ namespace nd4j { | ||||
|         return custom(nd4j::scalar::Divide, nd4j::pairwise::Divide, nd4j::broadcast::Divide); | ||||
|     } | ||||
| 
 | ||||
|     BroadcastOpsTuple BroadcastOpsTuple::DivideNoNan() { | ||||
|         return custom(nd4j::scalar::DivideNoNan, nd4j::pairwise::DivideNoNan, nd4j::broadcast::DivideNoNan); | ||||
|     } | ||||
| 
 | ||||
|     BroadcastOpsTuple BroadcastOpsTuple::Multiply() { | ||||
|         return custom(nd4j::scalar::Multiply, nd4j::pairwise::Multiply, nd4j::broadcast::Multiply); | ||||
|     } | ||||
|  | ||||
| @ -360,6 +360,34 @@ namespace simdOps { | ||||
| 		} | ||||
| 	}; | ||||
| 
 | ||||
|     template <typename X, typename Y, typename Z> | ||||
|     class DivideNoNan { | ||||
|     public: | ||||
|         op_def static Z op(X d1, Y d2) { | ||||
|             if (d2 == (Y)0) return (Z)0; | ||||
|             return static_cast<Z>(d1 / d2); | ||||
|         } | ||||
| 
 | ||||
|         op_def static Z op(X d1, Y d2, Z *params) { | ||||
|             if (d2 == (Y)0) return (Z)0; | ||||
|             return static_cast<Z>(d1 / d2); | ||||
|         } | ||||
| 
 | ||||
|         op_def static Z op(X d1) { | ||||
|             return static_cast<Z>(d1); | ||||
|         } | ||||
| 
 | ||||
|         // op for MetaOps
 | ||||
|         op_def static Z op(X d1, Y *params) { | ||||
|             if (params[0] == (Y)0) return (Z)0; | ||||
|             return static_cast<Z>(d1 / params[0]); | ||||
|         } | ||||
| 
 | ||||
|         op_def static X startingValue() { | ||||
|             return static_cast<X>(1); | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
| 	template <typename X, typename Y, typename Z> | ||||
| 	class SafeDivide { | ||||
| 	public: | ||||
|  | ||||
| @ -1194,6 +1194,43 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) { | ||||
|     delete res; | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) { | ||||
| 
 | ||||
|     auto  x = NDArrayFactory::create<float>('c', {3, 4, 5, 1}); | ||||
|     auto  y = NDArrayFactory::create<float>('c', {1, 6}); | ||||
|     auto  exp = NDArrayFactory::create<float>('c', {3, 4, 5, 6}); | ||||
|     x.assign(6); | ||||
|     y.assign(2); | ||||
|     exp.assign(3); | ||||
| 
 | ||||
|     nd4j::ops::divide_no_nan div; | ||||
| 
 | ||||
|     auto res = div.execute({&x, &y}, {}, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(res->status(), ND4J_STATUS_OK); | ||||
|     ASSERT_TRUE(res->at(0)->equalsTo(exp)); | ||||
| 
 | ||||
|     delete res; | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) { | ||||
| 
 | ||||
|     auto  x = NDArrayFactory::create<float>({6,6,6,6,6}); | ||||
|     auto  y = NDArrayFactory::create<float>({3,3,0,3,3}); | ||||
|     auto  exp = NDArrayFactory::create<float>({2, 2, 0, 2, 2}); | ||||
| 
 | ||||
|     nd4j::ops::divide_no_nan div; | ||||
| 
 | ||||
|     auto res = div.execute({&x, &y}, {}, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(res->status(), ND4J_STATUS_OK); | ||||
|     ASSERT_TRUE(res->at(0)->equalsTo(exp)); | ||||
| 
 | ||||
|     delete res; | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user