Shugeo roll fix3 (#127)
* Added tests for roll with scalar shift and axis. * Fixed problem with roll on 1D input with scalar axis and test. * Only cosmetic changes.master
parent
f5068f3980
commit
fc7c6d4e82
|
@ -84,7 +84,7 @@ namespace ops {
|
||||||
|
|
||||||
if (block.isInplace()) output = input;
|
if (block.isInplace()) output = input;
|
||||||
|
|
||||||
shiftIsLinear = axes.size() == 0;
|
shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1);
|
||||||
|
|
||||||
if (shiftIsLinear) {
|
if (shiftIsLinear) {
|
||||||
helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace());
|
helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace());
|
||||||
|
|
|
@ -3148,14 +3148,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_13) {
|
||||||
auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME"
|
auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME"
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
// output->printShapeInfo("Output shape");
|
|
||||||
// output->printBuffer("Output");
|
|
||||||
// exp.printBuffer("Expect");
|
|
||||||
// for (Nd4jLong e = 0; e < exp.lengthOf(); e++)
|
|
||||||
// if (exp.e<double>(e) != output->e<double>(e))
|
|
||||||
// printf("%lld ", e);
|
|
||||||
// printf("\n");
|
|
||||||
//result->at(1)->printBuffer("OUtput2");
|
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
@ -3240,10 +3233,6 @@ auto exp = NDArrayFactory::create<double>('c', {2, 2, 4, 2}, {
|
||||||
12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
|
12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
|
||||||
21.41, 21.42, 22.11, 22.12
|
21.41, 21.42, 22.11, 22.12
|
||||||
});
|
});
|
||||||
|
|
||||||
// 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42,
|
|
||||||
// 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32
|
|
||||||
// 21.41, 21.42, 22.11, 22.12
|
|
||||||
// ----------------------------------------------------------------
|
// ----------------------------------------------------------------
|
||||||
nd4j::ops::roll op;
|
nd4j::ops::roll op;
|
||||||
|
|
||||||
|
@ -3269,10 +3258,6 @@ auto exp = NDArrayFactory::create<double>('c', {2, 2, 4, 2}, {
|
||||||
12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
|
12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
|
||||||
21.41, 21.42, 22.11, 22.12
|
21.41, 21.42, 22.11, 22.12
|
||||||
});
|
});
|
||||||
|
|
||||||
// 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42,
|
|
||||||
// 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32
|
|
||||||
// 21.41, 21.42, 22.11, 22.12
|
|
||||||
// ----------------------------------------------------------------
|
// ----------------------------------------------------------------
|
||||||
nd4j::ops::roll op;
|
nd4j::ops::roll op;
|
||||||
NDArray* y = nullptr;
|
NDArray* y = nullptr;
|
||||||
|
@ -3518,6 +3503,27 @@ TEST_F(DeclarableOpsTests7, TestRoll_14) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, TestRoll_15) {
|
||||||
|
auto x = NDArrayFactory::create<float>({0.7788f, 0.8012f, 0.7244f, 0.2309f });
|
||||||
|
auto shift = NDArrayFactory::create<int>(2);
|
||||||
|
auto axis = NDArrayFactory::create<int>(0);
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>({0.7244f, 0.2309f, 0.7788f, 0.8012f });
|
||||||
|
// ----------------------------------------------------------------
|
||||||
|
nd4j::ops::roll op;
|
||||||
|
|
||||||
|
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||||
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
|
auto out = result->at(0);
|
||||||
|
// out->printIndexedBuffer("Output 15");
|
||||||
|
// exp.printIndexedBuffer("Expect 15");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(out));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests7, percentile_test1) {
|
TEST_F(DeclarableOpsTests7, percentile_test1) {
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue