From 679e42199ae30f06561fa682946e8a3e254e22c1 Mon Sep 17 00:00:00 2001 From: shugeo Date: Thu, 7 Nov 2019 12:44:02 +0200 Subject: [PATCH] Shugeo strided slice bp fix2 (#33) * Fixed crash and restored brocken functionality for strided slice. * Added comments for strided_slice_bp main step. --- .../generic/parity_ops/strided_slice.cpp | 13 +++++---- .../layers_tests/DeclarableOpsTests6.cpp | 28 ++++++++++++++++++- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp index e0bc57923..4b622e821 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp @@ -620,14 +620,15 @@ namespace nd4j { // FIXME: remove this method once we get 1D vectors supported vectorize(input_shape); REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSliceBP: shape calculation failed"); - - //Zero output array, so unused elements have 0 gradient - output->nullify(); - std::sort(indices.begin(), indices.end()); - if(indices.size() == 3 && (indices[1] - indices[0]) == 1) { + //REQUIRE_TRUE(epsNext->isSameShape(final_shape), 0, "StridedSlice_bp: gradOut shape should be equals to output from strided_slice op."); + //Zero output array, so unused elements have 0 gradient + output->nullify(); + // + // the first case: only for scalar gradient step + if(epsNext->lengthOf() == 1 && (indices.size() == 3 && (indices[1] - indices[0]) == 1 || (indices[2] - indices[0] == 1))) { output->p(indices[0], *epsNext); } - else { + else { // else for other cases auto sub = (*output)(indices, true, true); sub.assign(epsNext); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index b1d080b20..34b66c61a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) { // auto e = NDArrayFactory::create('c', {1}, {zero}); // auto s = NDArrayFactory::create('c', {1}, {1}); - auto grad = NDArrayFactory::create('c', {5,4}); + auto grad = NDArrayFactory::create('c', {5}); matrix.linspace(1); grad.linspace(1); @@ -264,6 +264,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) { delete result; } + TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) { int zero = 0; auto matrix = NDArrayFactory::create('c', {1, 2}); @@ -287,6 +288,31 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) { delete result; } + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) { + int zero = 0; + auto matrix = NDArrayFactory::create('c', {4, 8192}); +// auto b = NDArrayFactory::create('c', {1}, {zero}); +// auto e = NDArrayFactory::create('c', {1}, {zero}); +// auto s = NDArrayFactory::create('c', {1}, {1}); + + auto grad = NDArrayFactory::create('c', {4, 256}); + + matrix.linspace(1); + grad.linspace(1); + + nd4j::ops::strided_slice_bp op; + auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + z->printShapeInfo("Output shape"); + z->printIndexedBuffer("Output"); + //ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) { auto x = NDArrayFactory::create('c', {1, 1}, {2.0f}); auto exp = NDArrayFactory::create('c', {1, 1}, {4.0f});