From 9124974e3bf4bcaa8b698ed9acc21a87ddd93d6b Mon Sep 17 00:00:00 2001 From: shugeo Date: Tue, 5 Nov 2019 12:49:15 +0200 Subject: [PATCH] Fixed crash with strided_slice_bp op and tests. (#29) --- .../generic/parity_ops/strided_slice.cpp | 2 +- .../layers_tests/DeclarableOpsTests6.cpp | 46 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp index 44fe5999a..e0bc57923 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp @@ -623,7 +623,7 @@ namespace nd4j { //Zero output array, so unused elements have 0 gradient output->nullify(); - + std::sort(indices.begin(), indices.end()); if(indices.size() == 3 && (indices[1] - indices[0]) == 1) { output->p(indices[0], *epsNext); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index e6c692f5b..b1d080b20 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -241,6 +241,52 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) { delete result; } +TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) { + int zero = 0; + auto matrix = NDArrayFactory::create('c', {5, 4}); +// auto b = NDArrayFactory::create('c', {1}, {zero}); +// auto e = NDArrayFactory::create('c', {1}, {zero}); +// auto s = NDArrayFactory::create('c', {1}, {1}); + + auto grad = NDArrayFactory::create('c', {5,4}); + + matrix.linspace(1); + grad.linspace(1); + + nd4j::ops::strided_slice_bp op; + auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + z->printShapeInfo("Output shape"); + z->printIndexedBuffer("Output"); + //ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} +TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) { + int zero = 0; + auto matrix = NDArrayFactory::create('c', {1, 2}); +// auto b = NDArrayFactory::create('c', {1}, {zero}); +// auto e = NDArrayFactory::create('c', {1}, {zero}); +// auto s = NDArrayFactory::create('c', {1}, {1}); + + auto grad = NDArrayFactory::create('c', {1}, {1.}); + + matrix.linspace(1); + //grad.linspace(1); + + nd4j::ops::strided_slice_bp op; + auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + z->printShapeInfo("Output shape"); + z->printIndexedBuffer("Output"); + //ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) { auto x = NDArrayFactory::create('c', {1, 1}, {2.0f}); auto exp = NDArrayFactory::create('c', {1, 1}, {4.0f});