diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp index 7c4a52f9c..61f592f1d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp @@ -23,40 +23,74 @@ #include #include +#include namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 1) { + CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 0) { auto output = OUTPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0); - bool shiftIsLinear = true; - //std::vector axes(input->rankOf()); - int shift = INT_ARG(0); int inputLen = input->lengthOf(); - if (block.isInplace()) output = input; - if (shift < 0) { - // convert shift to positive value between 1 and inputLen - 1 - shift -= inputLen * (shift / inputLen - 1); - } - else - // cut shift to value between 1 and inputLen - 1 - shift %= inputLen; - if (block.numI() > 1) - shiftIsLinear = false; - if (shiftIsLinear) { - helpers::rollFunctorLinear(block.launchContext(), input, output, shift, block.isInplace()); + bool shiftIsLinear = block.width() == 1; + std::vector axes; + std::vector shifts; + if (block.width() > 1) { + REQUIRE_TRUE(block.width() == 3, 0, "roll: 3 arguments required for roll - input, shifts and axes. But %i given.", block.width()); + auto axesI = INPUT_VARIABLE(2); + auto shiftsI = INPUT_VARIABLE(1); + REQUIRE_TRUE(axesI->rankOf() == shiftsI->rankOf(), 0, "roll: shifts and axes should be the same rank, but %i and %i given.", (int)shiftsI->rankOf(), (int)axesI->rankOf()); + REQUIRE_TRUE(axesI->lengthOf() == shiftsI->lengthOf(), 0, "roll: shifts and axes should be the same length, but %i and %i given.", (int)shiftsI->lengthOf(), (int)axesI->lengthOf()); + helpers::adjustAxis(axesI->lengthOf(), axesI, axes ); + shifts.resize(shiftsI->lengthOf()); + for (Nd4jLong i = 0; i < shiftsI->lengthOf(); i++) { + auto shift = shiftsI->e(i); + if (shift < 0) { + shift -= input->sizeAt(i) * (shift / inputLen - 1); + } + else { + shift %= input->sizeAt(i); + } + shifts[i] = shift; + } + } else { - std::vector axes(block.numI() - 1); - for (unsigned e = 0; e < axes.size(); ++e) { - int axe = INT_ARG(e + 1); - REQUIRE_TRUE(axe < input->rankOf() && axe >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.", - input->rankOf(), input->rankOf() - 1, axe); - axes[e] = (axe < 0? (input->rankOf() + axe) : axe); + int shift = INT_ARG(0); + if (shift < 0) { + // convert shift to positive value between 1 and inputLen - 1 + shift -= inputLen * (shift / inputLen - 1); } - helpers::rollFunctorFull(block.launchContext(), input, output, shift, axes, block.isInplace()); + else + // cut shift to value between 1 and inputLen - 1 + shift %= inputLen; + axes.resize(block.getIArguments()->size() - 1); + if (axes.size()) + shifts.resize(axes.size());//emplace_back(shift); + else + shifts.push_back(shift); + + for (auto& s: shifts) + s = shift; + + for (unsigned e = 0; e < axes.size(); ++e) { + int axis = INT_ARG(e + 1); + REQUIRE_TRUE(axis < input->rankOf() && axis >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.", + input->rankOf(), input->rankOf() - 1, axis); + axes[e] = (axis < 0? (input->rankOf() + axis) : axis); + } + } + + if (block.isInplace()) output = input; + + shiftIsLinear = axes.size() == 0; + + if (shiftIsLinear) { + helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace()); + } + else { + helpers::rollFunctorFull(block.launchContext(), input, output, shifts, axes, block.isInplace()); } return Status::OK(); @@ -64,7 +98,9 @@ namespace ops { DECLARE_TYPES(roll) { getOpDescriptor() - ->setAllowedInputTypes(nd4j::DataType::ANY) + ->setAllowedInputTypes(0,nd4j::DataType::ANY) + ->setAllowedInputTypes(1,nd4j::DataType::INT32) // TODO: all ints in future + ->setAllowedInputTypes(2,nd4j::DataType::INT32) ->setAllowedOutputTypes(nd4j::DataType::ANY) ->setSameMode(true); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp b/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp index 6f0ed5f27..eb56acb9c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp @@ -43,7 +43,6 @@ namespace helpers { axisVector[e] = a + rank; } } - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp index e363fd8fa..da3cb3259 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp @@ -85,18 +85,19 @@ namespace helpers { } } - void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector const& axes, bool inplace){ + void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ if (!inplace) output->assign(input); auto source = output; //input; - for (int axe: axes) { + for (auto i = 0; i < axes.size(); i++) { + int axe = axes[i]; if (axe == source->rankOf() - 1) {// last dimension std::unique_ptr listOfTensors(source->allTensorsAlongDimension({axe})); std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension({axe})); int fullLen = listOfTensors->size(); - int theShift = shift; + int theShift = shifts[i]; if (theShift > 0) { theShift %= fullLen; } @@ -118,7 +119,7 @@ namespace helpers { int fullLen = listOfTensors->size(); int sizeAt = input->sizeAt(axe); - int theShift = shift; + int theShift = shifts[i]; if (theShift > 0) { theShift %= sizeAt; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/axis.cu b/libnd4j/include/ops/declarable/helpers/cuda/axis.cu index 6f0ed5f27..a3b2bcd32 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/axis.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/axis.cu @@ -27,6 +27,8 @@ namespace helpers { void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output) { output.resize(axisVector->lengthOf()); + axisVector->tickReadDevice(); + axisVector->syncToHost(); for (int e = 0; e < axisVector->lengthOf(); e++) { auto ca = axisVector->e(e); if (ca < 0) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu index 6bdd87650..216c6b7a0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu @@ -228,22 +228,23 @@ namespace helpers { } template - static void rollFunctorFull_(NDArray* input, NDArray* output, int shift, std::vector const& axis, bool inplace){ + static void rollFunctorFull_(NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ if (!inplace) output->assign(input); - for (int axe: axis) { + for (size_t i = 0; i < axes.size(); i++) { + int axe = axes[i]; if (axe == input->rankOf() - 1) { // last dimension std::unique_ptr listOfTensors(output->allTensorsAlongDimension({axe})); std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension({axe})); int fullLen = listOfTensors->size(); - int theShift = shift; - if (theShift > 0) { - theShift %= fullLen; - } - else { - theShift -= fullLen * (theShift / fullLen - 1); - } + int theShift = shifts[i]; +// if (theShift > 0) { +// theShift %= fullLen; +// } +// else { +// theShift -= fullLen * (theShift / fullLen - 1); +// } for (int k = 0; k < fullLen; k++) { rollFunctorLinear(output->getContext(), listOfTensors->at(k), listOfOutTensors->at(k), theShift, true); } @@ -258,12 +259,12 @@ namespace helpers { int sizeAt = input->sizeAt(axe); auto tadLength = shape::length(packZ.primaryShapeInfo()); - int theShift = shift; + int theShift = shifts[i]; - if (theShift > 0) - theShift %= sizeAt; - else - theShift -= sizeAt * (theShift / sizeAt - 1); +// if (theShift > 0) +// theShift %= sizeAt; +// else +// theShift -= sizeAt * (theShift / sizeAt - 1); if (theShift) { for (int dim = 0; dim < numTads / sizeAt; ++dim) { @@ -307,10 +308,10 @@ namespace helpers { } } - void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector const& axis, bool inplace){ + void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ input->syncToDevice(); - BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shift, axis, inplace), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shifts, axes, inplace), LIBND4J_TYPES); output->tickWriteDevice(); } @@ -324,7 +325,7 @@ namespace helpers { } BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, (NDArray* input, NDArray* output, int shift, bool inplace), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, int shift, std::vector const& axis, bool inplace), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace), LIBND4J_TYPES); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/roll.h b/libnd4j/include/ops/declarable/helpers/roll.h index ff6c67a57..b20367c0d 100644 --- a/libnd4j/include/ops/declarable/helpers/roll.h +++ b/libnd4j/include/ops/declarable/helpers/roll.h @@ -26,7 +26,7 @@ namespace ops { namespace helpers { void rollFunctorLinear(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace = false); - void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector const& axes, bool inplace = false); + void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace = false); } } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 996dd4f23..2e1dab1a3 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -3311,6 +3311,130 @@ auto exp = NDArrayFactory::create('c', {2, 3, 3}, { // delete result; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_10) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + auto result = op.execute({&x}, {}, {3, 1}, {}, false, nd4j::DataType::DOUBLE); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); + +// out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_11) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create({1,2}); + auto axis = NDArrayFactory::create({0, 1}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4 + }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + NDArray* y = nullptr; + auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); + +// out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_12) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create({1,1,1}); + auto axis = NDArrayFactory::create({0, 1, 2}); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 + }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + NDArray* y = nullptr; + auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); + out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_13) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create(3); + auto axis = NDArrayFactory::create(2); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 2,3,4,1,6,7,8,5,10,11,12,9,14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21 + }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + NDArray* y = nullptr; + auto result = op.execute({&x}, {}, {3,2}, {}, false, nd4j::DataType::DOUBLE); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); + +// out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_14) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create({1,1,1}); + auto axis = NDArrayFactory::create({0, 1, 2}); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 + }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + + auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); +// out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test1) {