Shugeo roll fix3 (#127)

* Added tests for roll with scalar shift and axis.

* Fixed problem with roll on 1D input with scalar axis and test.

* Only cosmetic changes.
master
shugeo 2019-12-19 12:10:06 +02:00 committed by raver119
parent f5068f3980
commit fc7c6d4e82
2 changed files with 23 additions and 17 deletions

View File

@ -84,7 +84,7 @@ namespace ops {
if (block.isInplace()) output = input; if (block.isInplace()) output = input;
shiftIsLinear = axes.size() == 0; shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1);
if (shiftIsLinear) { if (shiftIsLinear) {
helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace()); helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace());

View File

@ -3148,14 +3148,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_13) {
auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME"
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
auto output = result->at(0); auto output = result->at(0);
// output->printShapeInfo("Output shape");
// output->printBuffer("Output");
// exp.printBuffer("Expect");
// for (Nd4jLong e = 0; e < exp.lengthOf(); e++)
// if (exp.e<double>(e) != output->e<double>(e))
// printf("%lld ", e);
// printf("\n");
//result->at(1)->printBuffer("OUtput2");
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
@ -3240,10 +3233,6 @@ auto exp = NDArrayFactory::create<double>('c', {2, 2, 4, 2}, {
12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
21.41, 21.42, 22.11, 22.12 21.41, 21.42, 22.11, 22.12
}); });
// 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42,
// 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32
// 21.41, 21.42, 22.11, 22.12
// ---------------------------------------------------------------- // ----------------------------------------------------------------
nd4j::ops::roll op; nd4j::ops::roll op;
@ -3269,10 +3258,6 @@ auto exp = NDArrayFactory::create<double>('c', {2, 2, 4, 2}, {
12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
21.41, 21.42, 22.11, 22.12 21.41, 21.42, 22.11, 22.12
}); });
// 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42,
// 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32
// 21.41, 21.42, 22.11, 22.12
// ---------------------------------------------------------------- // ----------------------------------------------------------------
nd4j::ops::roll op; nd4j::ops::roll op;
NDArray* y = nullptr; NDArray* y = nullptr;
@ -3518,6 +3503,27 @@ TEST_F(DeclarableOpsTests7, TestRoll_14) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_15) {
auto x = NDArrayFactory::create<float>({0.7788f, 0.8012f, 0.7244f, 0.2309f });
auto shift = NDArrayFactory::create<int>(2);
auto axis = NDArrayFactory::create<int>(0);
auto exp = NDArrayFactory::create<float>({0.7244f, 0.2309f, 0.7788f, 0.8012f });
// ----------------------------------------------------------------
nd4j::ops::roll op;
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output 15");
// exp.printIndexedBuffer("Expect 15");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, percentile_test1) { TEST_F(DeclarableOpsTests7, percentile_test1) {