Shugeo strided slice bp fix2 (#33)
* Fixed crash and restored brocken functionality for strided slice. * Added comments for strided_slice_bp main step.
This commit is contained in:
		
							parent
							
								
									73b5a508fc
								
							
						
					
					
						commit
						679e42199a
					
				@ -620,14 +620,15 @@ namespace nd4j {
 | 
			
		||||
            // FIXME: remove this method once we get 1D vectors supported
 | 
			
		||||
            vectorize(input_shape);
 | 
			
		||||
            REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSliceBP: shape calculation failed");
 | 
			
		||||
 | 
			
		||||
			//Zero output array, so unused elements have 0 gradient
 | 
			
		||||
			output->nullify();
 | 
			
		||||
            std::sort(indices.begin(), indices.end());
 | 
			
		||||
            if(indices.size() == 3 && (indices[1] - indices[0]) == 1) {
 | 
			
		||||
            //REQUIRE_TRUE(epsNext->isSameShape(final_shape), 0, "StridedSlice_bp: gradOut shape should be equals to output from strided_slice op.");
 | 
			
		||||
            //Zero output array, so unused elements have 0 gradient
 | 
			
		||||
            output->nullify();
 | 
			
		||||
            //
 | 
			
		||||
            // the first case: only for scalar gradient step
 | 
			
		||||
            if(epsNext->lengthOf() == 1 && (indices.size() == 3 && (indices[1] - indices[0]) == 1 || (indices[2] - indices[0] == 1))) {
 | 
			
		||||
                output->p(indices[0], *epsNext);
 | 
			
		||||
            }
 | 
			
		||||
            else {
 | 
			
		||||
            else { // else for other cases
 | 
			
		||||
                auto sub = (*output)(indices, true, true);
 | 
			
		||||
                sub.assign(epsNext);
 | 
			
		||||
            }           
 | 
			
		||||
 | 
			
		||||
@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
 | 
			
		||||
//    auto e = NDArrayFactory::create<int>('c', {1}, {zero});
 | 
			
		||||
//    auto s = NDArrayFactory::create<int>('c', {1}, {1});
 | 
			
		||||
 | 
			
		||||
    auto grad = NDArrayFactory::create<double>('c', {5,4});
 | 
			
		||||
    auto grad = NDArrayFactory::create<double>('c', {5});
 | 
			
		||||
 | 
			
		||||
    matrix.linspace(1);
 | 
			
		||||
    grad.linspace(1);
 | 
			
		||||
@ -264,6 +264,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
 | 
			
		||||
 | 
			
		||||
    delete result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
 | 
			
		||||
    int zero = 0;
 | 
			
		||||
    auto matrix = NDArrayFactory::create<double>('c', {1, 2});
 | 
			
		||||
@ -287,6 +288,31 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
 | 
			
		||||
 | 
			
		||||
    delete result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
 | 
			
		||||
    int zero = 0;
 | 
			
		||||
    auto matrix = NDArrayFactory::create<float>('c', {4, 8192});
 | 
			
		||||
//    auto b = NDArrayFactory::create<int>('c', {1}, {zero});
 | 
			
		||||
//    auto e = NDArrayFactory::create<int>('c', {1}, {zero});
 | 
			
		||||
//    auto s = NDArrayFactory::create<int>('c', {1}, {1});
 | 
			
		||||
 | 
			
		||||
    auto grad = NDArrayFactory::create<double>('c', {4, 256});
 | 
			
		||||
 | 
			
		||||
    matrix.linspace(1);
 | 
			
		||||
    grad.linspace(1);
 | 
			
		||||
 | 
			
		||||
    nd4j::ops::strided_slice_bp op;
 | 
			
		||||
    auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1});
 | 
			
		||||
    ASSERT_EQ(Status::OK(), result->status());
 | 
			
		||||
 | 
			
		||||
    auto z = result->at(0);
 | 
			
		||||
    z->printShapeInfo("Output shape");
 | 
			
		||||
    z->printIndexedBuffer("Output");
 | 
			
		||||
    //ASSERT_TRUE(exp.equalsTo(z));
 | 
			
		||||
 | 
			
		||||
    delete result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
 | 
			
		||||
    auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f});
 | 
			
		||||
    auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user