Merge remote-tracking branch 'origin/master'

master
raver119 2019-08-17 14:52:13 +03:00
commit bb80fe4f94
7 changed files with 210 additions and 47 deletions

View File

@ -23,18 +23,41 @@
#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/helpers/roll.h>
#include <ops/declarable/helpers/axis.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 1) {
CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 0) {
auto output = OUTPUT_VARIABLE(0);
auto input = INPUT_VARIABLE(0);
bool shiftIsLinear = true;
//std::vector<int> axes(input->rankOf());
int shift = INT_ARG(0);
int inputLen = input->lengthOf();
if (block.isInplace()) output = input;
bool shiftIsLinear = block.width() == 1;
std::vector<int> axes;
std::vector<int> shifts;
if (block.width() > 1) {
REQUIRE_TRUE(block.width() == 3, 0, "roll: 3 arguments required for roll - input, shifts and axes. But %i given.", block.width());
auto axesI = INPUT_VARIABLE(2);
auto shiftsI = INPUT_VARIABLE(1);
REQUIRE_TRUE(axesI->rankOf() == shiftsI->rankOf(), 0, "roll: shifts and axes should be the same rank, but %i and %i given.", (int)shiftsI->rankOf(), (int)axesI->rankOf());
REQUIRE_TRUE(axesI->lengthOf() == shiftsI->lengthOf(), 0, "roll: shifts and axes should be the same length, but %i and %i given.", (int)shiftsI->lengthOf(), (int)axesI->lengthOf());
helpers::adjustAxis(axesI->lengthOf(), axesI, axes );
shifts.resize(shiftsI->lengthOf());
for (Nd4jLong i = 0; i < shiftsI->lengthOf(); i++) {
auto shift = shiftsI->e<int>(i);
if (shift < 0) {
shift -= input->sizeAt(i) * (shift / inputLen - 1);
}
else {
shift %= input->sizeAt(i);
}
shifts[i] = shift;
}
}
else {
int shift = INT_ARG(0);
if (shift < 0) {
// convert shift to positive value between 1 and inputLen - 1
shift -= inputLen * (shift / inputLen - 1);
@ -42,21 +65,32 @@ namespace ops {
else
// cut shift to value between 1 and inputLen - 1
shift %= inputLen;
axes.resize(block.getIArguments()->size() - 1);
if (axes.size())
shifts.resize(axes.size());//emplace_back(shift);
else
shifts.push_back(shift);
for (auto& s: shifts)
s = shift;
for (unsigned e = 0; e < axes.size(); ++e) {
int axis = INT_ARG(e + 1);
REQUIRE_TRUE(axis < input->rankOf() && axis >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.",
input->rankOf(), input->rankOf() - 1, axis);
axes[e] = (axis < 0? (input->rankOf() + axis) : axis);
}
}
if (block.isInplace()) output = input;
shiftIsLinear = axes.size() == 0;
if (block.numI() > 1)
shiftIsLinear = false;
if (shiftIsLinear) {
helpers::rollFunctorLinear(block.launchContext(), input, output, shift, block.isInplace());
helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace());
}
else {
std::vector<int> axes(block.numI() - 1);
for (unsigned e = 0; e < axes.size(); ++e) {
int axe = INT_ARG(e + 1);
REQUIRE_TRUE(axe < input->rankOf() && axe >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.",
input->rankOf(), input->rankOf() - 1, axe);
axes[e] = (axe < 0? (input->rankOf() + axe) : axe);
}
helpers::rollFunctorFull(block.launchContext(), input, output, shift, axes, block.isInplace());
helpers::rollFunctorFull(block.launchContext(), input, output, shifts, axes, block.isInplace());
}
return Status::OK();
@ -64,7 +98,9 @@ namespace ops {
DECLARE_TYPES(roll) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedInputTypes(0,nd4j::DataType::ANY)
->setAllowedInputTypes(1,nd4j::DataType::INT32) // TODO: all ints in future
->setAllowedInputTypes(2,nd4j::DataType::INT32)
->setAllowedOutputTypes(nd4j::DataType::ANY)
->setSameMode(true);
}

View File

@ -43,7 +43,6 @@ namespace helpers {
axisVector[e] = a + rank;
}
}
}
}
}

View File

@ -85,18 +85,19 @@ namespace helpers {
}
}
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axes, bool inplace){
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
if (!inplace)
output->assign(input);
auto source = output; //input;
for (int axe: axes) {
for (auto i = 0; i < axes.size(); i++) {
int axe = axes[i];
if (axe == source->rankOf() - 1) {// last dimension
std::unique_ptr<ResultSet> listOfTensors(source->allTensorsAlongDimension({axe}));
std::unique_ptr<ResultSet> listOfOutTensors(output->allTensorsAlongDimension({axe}));
int fullLen = listOfTensors->size();
int theShift = shift;
int theShift = shifts[i];
if (theShift > 0) {
theShift %= fullLen;
}
@ -118,7 +119,7 @@ namespace helpers {
int fullLen = listOfTensors->size();
int sizeAt = input->sizeAt(axe);
int theShift = shift;
int theShift = shifts[i];
if (theShift > 0) {
theShift %= sizeAt;

View File

@ -27,6 +27,8 @@ namespace helpers {
void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector<int>& output) {
output.resize(axisVector->lengthOf());
axisVector->tickReadDevice();
axisVector->syncToHost();
for (int e = 0; e < axisVector->lengthOf(); e++) {
auto ca = axisVector->e<int>(e);
if (ca < 0)

View File

@ -228,22 +228,23 @@ namespace helpers {
}
template <typename T>
static void rollFunctorFull_(NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace){
static void rollFunctorFull_(NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
if (!inplace)
output->assign(input);
for (int axe: axis) {
for (size_t i = 0; i < axes.size(); i++) {
int axe = axes[i];
if (axe == input->rankOf() - 1) { // last dimension
std::unique_ptr<ResultSet> listOfTensors(output->allTensorsAlongDimension({axe}));
std::unique_ptr<ResultSet> listOfOutTensors(output->allTensorsAlongDimension({axe}));
int fullLen = listOfTensors->size();
int theShift = shift;
if (theShift > 0) {
theShift %= fullLen;
}
else {
theShift -= fullLen * (theShift / fullLen - 1);
}
int theShift = shifts[i];
// if (theShift > 0) {
// theShift %= fullLen;
// }
// else {
// theShift -= fullLen * (theShift / fullLen - 1);
// }
for (int k = 0; k < fullLen; k++) {
rollFunctorLinear(output->getContext(), listOfTensors->at(k), listOfOutTensors->at(k), theShift, true);
}
@ -258,12 +259,12 @@ namespace helpers {
int sizeAt = input->sizeAt(axe);
auto tadLength = shape::length(packZ.primaryShapeInfo());
int theShift = shift;
int theShift = shifts[i];
if (theShift > 0)
theShift %= sizeAt;
else
theShift -= sizeAt * (theShift / sizeAt - 1);
// if (theShift > 0)
// theShift %= sizeAt;
// else
// theShift -= sizeAt * (theShift / sizeAt - 1);
if (theShift) {
for (int dim = 0; dim < numTads / sizeAt; ++dim) {
@ -307,10 +308,10 @@ namespace helpers {
}
}
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace){
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
input->syncToDevice();
BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shift, axis, inplace), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shifts, axes, inplace), LIBND4J_TYPES);
output->tickWriteDevice();
}
@ -324,7 +325,7 @@ namespace helpers {
}
BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, (NDArray* input, NDArray* output, int shift, bool inplace), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace), LIBND4J_TYPES);
}
}
}

View File

@ -26,7 +26,7 @@ namespace ops {
namespace helpers {
void rollFunctorLinear(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace = false);
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axes, bool inplace = false);
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace = false);
}
}
}

View File

@ -3311,6 +3311,130 @@ auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
// delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_10) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
auto result = op.execute({&x}, {}, {3, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_11) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>({1,2});
auto axis = NDArrayFactory::create<int>({0, 1});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
NDArray* y = nullptr;
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_12) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>({1,1,1});
auto axis = NDArrayFactory::create<int>({0, 1, 2});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
NDArray* y = nullptr;
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_13) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>(3);
auto axis = NDArrayFactory::create<int>(2);
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
2,3,4,1,6,7,8,5,10,11,12,9,14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
NDArray* y = nullptr;
auto result = op.execute({&x}, {}, {3,2}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_14) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>({1,1,1});
auto axis = NDArrayFactory::create<int>({0, 1, 2});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, percentile_test1) {