Added tests and fixed op name.
This commit is contained in:
		
							parent
							
								
									863ff76878
								
							
						
					
					
						commit
						a27e61553a
					
				| @ -104,8 +104,8 @@ namespace nd4j { | ||||
|          * | ||||
|          * all as above op | ||||
|          * */ | ||||
|         #if NOT_EXCLUDED(OP_bincast) | ||||
|                 DECLARE_CUSTOM_OP(bincast, 1, 1, false, 0, 1); | ||||
|         #if NOT_EXCLUDED(OP_bitcast) | ||||
|                 DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1); | ||||
|         #endif | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -228,6 +228,32 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { | ||||
|     ASSERT_TRUE(e.equalsTo(out)); | ||||
|     delete result; | ||||
| } | ||||
| TEST_F(DeclarableOpsTests15, Test_BinCast_1) { | ||||
|     auto x = NDArrayFactory::create<float>('c', {2, 2, 2}); | ||||
|     auto e = NDArrayFactory::create<double>('c', {2, 2}, {2., 512., 8192., 131072.032 }); | ||||
|     x.linspace(1.); | ||||
|     nd4j::ops::bitcast op; | ||||
|     auto result = op.execute({&x}, {}, {nd4j::DataType::DOUBLE}, {}); | ||||
|     ASSERT_EQ(Status::OK(), result->status()); | ||||
|     auto out = result->at(0); | ||||
| //    out->printIndexedBuffer("Casted result");
 | ||||
|     ASSERT_TRUE(e.equalsTo(out)); | ||||
|     delete result; | ||||
| } | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests15, Test_BinCast_2) { | ||||
|     auto x = NDArrayFactory::create<float>('c', {2, 4}); | ||||
|     auto e = NDArrayFactory::create<float16>('c', {2, 4, 2}, {0, 1.875, 0, 2.,    0, 2.125, 0,  2.25, | ||||
|                                                                               0, 2.312, 0, 2.375, 0, 2.438, 0., 2.5}); | ||||
|     x.linspace(1.); | ||||
|     nd4j::ops::bitcast op; | ||||
|     auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {}); | ||||
|     ASSERT_EQ(Status::OK(), result->status()); | ||||
|     auto out = result->at(0); | ||||
|     out->printIndexedBuffer("Casted result"); | ||||
|     ASSERT_TRUE(e.equalsTo(out)); | ||||
|     delete result; | ||||
| } | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) { | ||||
|     auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64}); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user