Fixed crash with strided_slice_bp op and tests. (#29)
parent
7b14a9f603
commit
9124974e3b
|
@ -623,7 +623,7 @@ namespace nd4j {
|
||||||
|
|
||||||
//Zero output array, so unused elements have 0 gradient
|
//Zero output array, so unused elements have 0 gradient
|
||||||
output->nullify();
|
output->nullify();
|
||||||
|
std::sort(indices.begin(), indices.end());
|
||||||
if(indices.size() == 3 && (indices[1] - indices[0]) == 1) {
|
if(indices.size() == 3 && (indices[1] - indices[0]) == 1) {
|
||||||
output->p(indices[0], *epsNext);
|
output->p(indices[0], *epsNext);
|
||||||
}
|
}
|
||||||
|
|
|
@ -241,6 +241,52 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
|
||||||
|
int zero = 0;
|
||||||
|
auto matrix = NDArrayFactory::create<double>('c', {5, 4});
|
||||||
|
// auto b = NDArrayFactory::create<int>('c', {1}, {zero});
|
||||||
|
// auto e = NDArrayFactory::create<int>('c', {1}, {zero});
|
||||||
|
// auto s = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
|
auto grad = NDArrayFactory::create<double>('c', {5,4});
|
||||||
|
|
||||||
|
matrix.linspace(1);
|
||||||
|
grad.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::strided_slice_bp op;
|
||||||
|
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printShapeInfo("Output shape");
|
||||||
|
z->printIndexedBuffer("Output");
|
||||||
|
//ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
|
||||||
|
int zero = 0;
|
||||||
|
auto matrix = NDArrayFactory::create<double>('c', {1, 2});
|
||||||
|
// auto b = NDArrayFactory::create<int>('c', {1}, {zero});
|
||||||
|
// auto e = NDArrayFactory::create<int>('c', {1}, {zero});
|
||||||
|
// auto s = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
|
auto grad = NDArrayFactory::create<double>('c', {1}, {1.});
|
||||||
|
|
||||||
|
matrix.linspace(1);
|
||||||
|
//grad.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::strided_slice_bp op;
|
||||||
|
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printShapeInfo("Output shape");
|
||||||
|
z->printIndexedBuffer("Output");
|
||||||
|
//ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f});
|
auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});
|
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});
|
||||||
|
|
Loading…
Reference in New Issue