[WIP] Roll rewritten (#128)
* Process correct input vector. * Added tests for roll. * Refactored roll to conform with TF. Eliminated memory leaks with Roll op tests.master
parent
62495dd77b
commit
e22880fd76
|
@ -23,18 +23,41 @@
|
|||
|
||||
#include <ops/declarable/headers/parity_ops.h>
|
||||
#include <ops/declarable/helpers/roll.h>
|
||||
#include <ops/declarable/helpers/axis.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 1) {
|
||||
CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 0) {
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
bool shiftIsLinear = true;
|
||||
//std::vector<int> axes(input->rankOf());
|
||||
int shift = INT_ARG(0);
|
||||
int inputLen = input->lengthOf();
|
||||
if (block.isInplace()) output = input;
|
||||
|
||||
bool shiftIsLinear = block.width() == 1;
|
||||
std::vector<int> axes;
|
||||
std::vector<int> shifts;
|
||||
if (block.width() > 1) {
|
||||
REQUIRE_TRUE(block.width() == 3, 0, "roll: 3 arguments required for roll - input, shifts and axes. But %i given.", block.width());
|
||||
auto axesI = INPUT_VARIABLE(2);
|
||||
auto shiftsI = INPUT_VARIABLE(1);
|
||||
REQUIRE_TRUE(axesI->rankOf() == shiftsI->rankOf(), 0, "roll: shifts and axes should be the same rank, but %i and %i given.", (int)shiftsI->rankOf(), (int)axesI->rankOf());
|
||||
REQUIRE_TRUE(axesI->lengthOf() == shiftsI->lengthOf(), 0, "roll: shifts and axes should be the same length, but %i and %i given.", (int)shiftsI->lengthOf(), (int)axesI->lengthOf());
|
||||
helpers::adjustAxis(axesI->lengthOf(), axesI, axes );
|
||||
shifts.resize(shiftsI->lengthOf());
|
||||
for (Nd4jLong i = 0; i < shiftsI->lengthOf(); i++) {
|
||||
auto shift = shiftsI->e<int>(i);
|
||||
if (shift < 0) {
|
||||
shift -= input->sizeAt(i) * (shift / inputLen - 1);
|
||||
}
|
||||
else {
|
||||
shift %= input->sizeAt(i);
|
||||
}
|
||||
shifts[i] = shift;
|
||||
}
|
||||
|
||||
}
|
||||
else {
|
||||
int shift = INT_ARG(0);
|
||||
if (shift < 0) {
|
||||
// convert shift to positive value between 1 and inputLen - 1
|
||||
shift -= inputLen * (shift / inputLen - 1);
|
||||
|
@ -42,21 +65,32 @@ namespace ops {
|
|||
else
|
||||
// cut shift to value between 1 and inputLen - 1
|
||||
shift %= inputLen;
|
||||
axes.resize(block.getIArguments()->size() - 1);
|
||||
if (axes.size())
|
||||
shifts.resize(axes.size());//emplace_back(shift);
|
||||
else
|
||||
shifts.push_back(shift);
|
||||
|
||||
for (auto& s: shifts)
|
||||
s = shift;
|
||||
|
||||
for (unsigned e = 0; e < axes.size(); ++e) {
|
||||
int axis = INT_ARG(e + 1);
|
||||
REQUIRE_TRUE(axis < input->rankOf() && axis >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.",
|
||||
input->rankOf(), input->rankOf() - 1, axis);
|
||||
axes[e] = (axis < 0? (input->rankOf() + axis) : axis);
|
||||
}
|
||||
}
|
||||
|
||||
if (block.isInplace()) output = input;
|
||||
|
||||
shiftIsLinear = axes.size() == 0;
|
||||
|
||||
if (block.numI() > 1)
|
||||
shiftIsLinear = false;
|
||||
if (shiftIsLinear) {
|
||||
helpers::rollFunctorLinear(block.launchContext(), input, output, shift, block.isInplace());
|
||||
helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace());
|
||||
}
|
||||
else {
|
||||
std::vector<int> axes(block.numI() - 1);
|
||||
for (unsigned e = 0; e < axes.size(); ++e) {
|
||||
int axe = INT_ARG(e + 1);
|
||||
REQUIRE_TRUE(axe < input->rankOf() && axe >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.",
|
||||
input->rankOf(), input->rankOf() - 1, axe);
|
||||
axes[e] = (axe < 0? (input->rankOf() + axe) : axe);
|
||||
}
|
||||
helpers::rollFunctorFull(block.launchContext(), input, output, shift, axes, block.isInplace());
|
||||
helpers::rollFunctorFull(block.launchContext(), input, output, shifts, axes, block.isInplace());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -64,7 +98,9 @@ namespace ops {
|
|||
|
||||
DECLARE_TYPES(roll) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedInputTypes(0,nd4j::DataType::ANY)
|
||||
->setAllowedInputTypes(1,nd4j::DataType::INT32) // TODO: all ints in future
|
||||
->setAllowedInputTypes(2,nd4j::DataType::INT32)
|
||||
->setAllowedOutputTypes(nd4j::DataType::ANY)
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
|
|
@ -43,7 +43,6 @@ namespace helpers {
|
|||
axisVector[e] = a + rank;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,18 +85,19 @@ namespace helpers {
|
|||
}
|
||||
}
|
||||
|
||||
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axes, bool inplace){
|
||||
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
|
||||
|
||||
if (!inplace)
|
||||
output->assign(input);
|
||||
|
||||
auto source = output; //input;
|
||||
for (int axe: axes) {
|
||||
for (auto i = 0; i < axes.size(); i++) {
|
||||
int axe = axes[i];
|
||||
if (axe == source->rankOf() - 1) {// last dimension
|
||||
std::unique_ptr<ResultSet> listOfTensors(source->allTensorsAlongDimension({axe}));
|
||||
std::unique_ptr<ResultSet> listOfOutTensors(output->allTensorsAlongDimension({axe}));
|
||||
int fullLen = listOfTensors->size();
|
||||
int theShift = shift;
|
||||
int theShift = shifts[i];
|
||||
if (theShift > 0) {
|
||||
theShift %= fullLen;
|
||||
}
|
||||
|
@ -118,7 +119,7 @@ namespace helpers {
|
|||
int fullLen = listOfTensors->size();
|
||||
int sizeAt = input->sizeAt(axe);
|
||||
|
||||
int theShift = shift;
|
||||
int theShift = shifts[i];
|
||||
|
||||
if (theShift > 0) {
|
||||
theShift %= sizeAt;
|
||||
|
|
|
@ -27,6 +27,8 @@ namespace helpers {
|
|||
|
||||
void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector<int>& output) {
|
||||
output.resize(axisVector->lengthOf());
|
||||
axisVector->tickReadDevice();
|
||||
axisVector->syncToHost();
|
||||
for (int e = 0; e < axisVector->lengthOf(); e++) {
|
||||
auto ca = axisVector->e<int>(e);
|
||||
if (ca < 0)
|
||||
|
|
|
@ -228,22 +228,23 @@ namespace helpers {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
static void rollFunctorFull_(NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace){
|
||||
static void rollFunctorFull_(NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
|
||||
if (!inplace)
|
||||
output->assign(input);
|
||||
|
||||
for (int axe: axis) {
|
||||
for (size_t i = 0; i < axes.size(); i++) {
|
||||
int axe = axes[i];
|
||||
if (axe == input->rankOf() - 1) { // last dimension
|
||||
std::unique_ptr<ResultSet> listOfTensors(output->allTensorsAlongDimension({axe}));
|
||||
std::unique_ptr<ResultSet> listOfOutTensors(output->allTensorsAlongDimension({axe}));
|
||||
int fullLen = listOfTensors->size();
|
||||
int theShift = shift;
|
||||
if (theShift > 0) {
|
||||
theShift %= fullLen;
|
||||
}
|
||||
else {
|
||||
theShift -= fullLen * (theShift / fullLen - 1);
|
||||
}
|
||||
int theShift = shifts[i];
|
||||
// if (theShift > 0) {
|
||||
// theShift %= fullLen;
|
||||
// }
|
||||
// else {
|
||||
// theShift -= fullLen * (theShift / fullLen - 1);
|
||||
// }
|
||||
for (int k = 0; k < fullLen; k++) {
|
||||
rollFunctorLinear(output->getContext(), listOfTensors->at(k), listOfOutTensors->at(k), theShift, true);
|
||||
}
|
||||
|
@ -258,12 +259,12 @@ namespace helpers {
|
|||
int sizeAt = input->sizeAt(axe);
|
||||
auto tadLength = shape::length(packZ.primaryShapeInfo());
|
||||
|
||||
int theShift = shift;
|
||||
int theShift = shifts[i];
|
||||
|
||||
if (theShift > 0)
|
||||
theShift %= sizeAt;
|
||||
else
|
||||
theShift -= sizeAt * (theShift / sizeAt - 1);
|
||||
// if (theShift > 0)
|
||||
// theShift %= sizeAt;
|
||||
// else
|
||||
// theShift -= sizeAt * (theShift / sizeAt - 1);
|
||||
|
||||
if (theShift) {
|
||||
for (int dim = 0; dim < numTads / sizeAt; ++dim) {
|
||||
|
@ -307,10 +308,10 @@ namespace helpers {
|
|||
}
|
||||
}
|
||||
|
||||
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace){
|
||||
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
|
||||
input->syncToDevice();
|
||||
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shift, axis, inplace), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shifts, axes, inplace), LIBND4J_TYPES);
|
||||
|
||||
output->tickWriteDevice();
|
||||
}
|
||||
|
@ -324,7 +325,7 @@ namespace helpers {
|
|||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, (NDArray* input, NDArray* output, int shift, bool inplace), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace), LIBND4J_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -26,7 +26,7 @@ namespace ops {
|
|||
namespace helpers {
|
||||
void rollFunctorLinear(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace = false);
|
||||
|
||||
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axes, bool inplace = false);
|
||||
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace = false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3311,6 +3311,130 @@ auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
|
|||
// delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestRoll_10) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
|
||||
});
|
||||
// ----------------------------------------------------------------
|
||||
nd4j::ops::roll op;
|
||||
auto result = op.execute({&x}, {}, {3, 1}, {}, false, nd4j::DataType::DOUBLE);
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
auto out = result->at(0);
|
||||
|
||||
// out->printIndexedBuffer("Output");
|
||||
//exp.printIndexedBuffer("Expect");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(out));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestRoll_11) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
|
||||
});
|
||||
auto shift = NDArrayFactory::create<int>({1,2});
|
||||
auto axis = NDArrayFactory::create<int>({0, 1});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
|
||||
17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4
|
||||
});
|
||||
// ----------------------------------------------------------------
|
||||
nd4j::ops::roll op;
|
||||
NDArray* y = nullptr;
|
||||
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
auto out = result->at(0);
|
||||
|
||||
// out->printIndexedBuffer("Output");
|
||||
//exp.printIndexedBuffer("Expect");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(out));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestRoll_12) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
|
||||
});
|
||||
auto shift = NDArrayFactory::create<int>({1,1,1});
|
||||
auto axis = NDArrayFactory::create<int>({0, 1, 2});
|
||||
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
|
||||
24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7
|
||||
});
|
||||
// ----------------------------------------------------------------
|
||||
nd4j::ops::roll op;
|
||||
NDArray* y = nullptr;
|
||||
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
auto out = result->at(0);
|
||||
out->printIndexedBuffer("Output");
|
||||
//exp.printIndexedBuffer("Expect");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(out));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestRoll_13) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
|
||||
});
|
||||
auto shift = NDArrayFactory::create<int>(3);
|
||||
auto axis = NDArrayFactory::create<int>(2);
|
||||
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
|
||||
2,3,4,1,6,7,8,5,10,11,12,9,14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21
|
||||
});
|
||||
// ----------------------------------------------------------------
|
||||
nd4j::ops::roll op;
|
||||
NDArray* y = nullptr;
|
||||
auto result = op.execute({&x}, {}, {3,2}, {}, false, nd4j::DataType::DOUBLE);
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
auto out = result->at(0);
|
||||
|
||||
// out->printIndexedBuffer("Output");
|
||||
//exp.printIndexedBuffer("Expect");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(out));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestRoll_14) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
|
||||
});
|
||||
auto shift = NDArrayFactory::create<int>({1,1,1});
|
||||
auto axis = NDArrayFactory::create<int>({0, 1, 2});
|
||||
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
|
||||
24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7
|
||||
});
|
||||
// ----------------------------------------------------------------
|
||||
nd4j::ops::roll op;
|
||||
|
||||
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
auto out = result->at(0);
|
||||
// out->printIndexedBuffer("Output");
|
||||
//exp.printIndexedBuffer("Expect");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(out));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, percentile_test1) {
|
||||
|
||||
|
|
Loading…
Reference in New Issue