Fixed crash with strided_slice_bp op and tests. (#29)
This commit is contained in:
		
							parent
							
								
									7b14a9f603
								
							
						
					
					
						commit
						9124974e3b
					
				| @ -623,7 +623,7 @@ namespace nd4j { | ||||
| 
 | ||||
| 			//Zero output array, so unused elements have 0 gradient
 | ||||
| 			output->nullify(); | ||||
| 
 | ||||
|             std::sort(indices.begin(), indices.end()); | ||||
|             if(indices.size() == 3 && (indices[1] - indices[0]) == 1) { | ||||
|                 output->p(indices[0], *epsNext); | ||||
|             } | ||||
|  | ||||
| @ -241,6 +241,52 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) { | ||||
|     delete result; | ||||
| } | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) { | ||||
|     int zero = 0; | ||||
|     auto matrix = NDArrayFactory::create<double>('c', {5, 4}); | ||||
| //    auto b = NDArrayFactory::create<int>('c', {1}, {zero});
 | ||||
| //    auto e = NDArrayFactory::create<int>('c', {1}, {zero});
 | ||||
| //    auto s = NDArrayFactory::create<int>('c', {1}, {1});
 | ||||
| 
 | ||||
|     auto grad = NDArrayFactory::create<double>('c', {5,4}); | ||||
| 
 | ||||
|     matrix.linspace(1); | ||||
|     grad.linspace(1); | ||||
| 
 | ||||
|     nd4j::ops::strided_slice_bp op; | ||||
|     auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); | ||||
|     ASSERT_EQ(Status::OK(), result->status()); | ||||
| 
 | ||||
|     auto z = result->at(0); | ||||
|     z->printShapeInfo("Output shape"); | ||||
|     z->printIndexedBuffer("Output"); | ||||
|     //ASSERT_TRUE(exp.equalsTo(z));
 | ||||
| 
 | ||||
|     delete result; | ||||
| } | ||||
| TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) { | ||||
|     int zero = 0; | ||||
|     auto matrix = NDArrayFactory::create<double>('c', {1, 2}); | ||||
| //    auto b = NDArrayFactory::create<int>('c', {1}, {zero});
 | ||||
| //    auto e = NDArrayFactory::create<int>('c', {1}, {zero});
 | ||||
| //    auto s = NDArrayFactory::create<int>('c', {1}, {1});
 | ||||
| 
 | ||||
|     auto grad = NDArrayFactory::create<double>('c', {1}, {1.}); | ||||
| 
 | ||||
|     matrix.linspace(1); | ||||
|     //grad.linspace(1);
 | ||||
| 
 | ||||
|     nd4j::ops::strided_slice_bp op; | ||||
|     auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); | ||||
|     ASSERT_EQ(Status::OK(), result->status()); | ||||
| 
 | ||||
|     auto z = result->at(0); | ||||
|     z->printShapeInfo("Output shape"); | ||||
|     z->printIndexedBuffer("Output"); | ||||
|     //ASSERT_TRUE(exp.equalsTo(z));
 | ||||
| 
 | ||||
|     delete result; | ||||
| } | ||||
| TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) { | ||||
|     auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f}); | ||||
|     auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f}); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user