Merge remote-tracking branch 'origin/master'

master
raver119 2019-08-17 14:52:13 +03:00
commit bb80fe4f94
7 changed files with 210 additions and 47 deletions

View File

@ -23,40 +23,74 @@
#include <ops/declarable/headers/parity_ops.h> #include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/helpers/roll.h> #include <ops/declarable/helpers/roll.h>
#include <ops/declarable/helpers/axis.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 1) { CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 0) {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
bool shiftIsLinear = true;
//std::vector<int> axes(input->rankOf());
int shift = INT_ARG(0);
int inputLen = input->lengthOf(); int inputLen = input->lengthOf();
if (block.isInplace()) output = input;
if (shift < 0) {
// convert shift to positive value between 1 and inputLen - 1
shift -= inputLen * (shift / inputLen - 1);
}
else
// cut shift to value between 1 and inputLen - 1
shift %= inputLen;
if (block.numI() > 1) bool shiftIsLinear = block.width() == 1;
shiftIsLinear = false; std::vector<int> axes;
if (shiftIsLinear) { std::vector<int> shifts;
helpers::rollFunctorLinear(block.launchContext(), input, output, shift, block.isInplace()); if (block.width() > 1) {
REQUIRE_TRUE(block.width() == 3, 0, "roll: 3 arguments required for roll - input, shifts and axes. But %i given.", block.width());
auto axesI = INPUT_VARIABLE(2);
auto shiftsI = INPUT_VARIABLE(1);
REQUIRE_TRUE(axesI->rankOf() == shiftsI->rankOf(), 0, "roll: shifts and axes should be the same rank, but %i and %i given.", (int)shiftsI->rankOf(), (int)axesI->rankOf());
REQUIRE_TRUE(axesI->lengthOf() == shiftsI->lengthOf(), 0, "roll: shifts and axes should be the same length, but %i and %i given.", (int)shiftsI->lengthOf(), (int)axesI->lengthOf());
helpers::adjustAxis(axesI->lengthOf(), axesI, axes );
shifts.resize(shiftsI->lengthOf());
for (Nd4jLong i = 0; i < shiftsI->lengthOf(); i++) {
auto shift = shiftsI->e<int>(i);
if (shift < 0) {
shift -= input->sizeAt(i) * (shift / inputLen - 1);
}
else {
shift %= input->sizeAt(i);
}
shifts[i] = shift;
}
} }
else { else {
std::vector<int> axes(block.numI() - 1); int shift = INT_ARG(0);
for (unsigned e = 0; e < axes.size(); ++e) { if (shift < 0) {
int axe = INT_ARG(e + 1); // convert shift to positive value between 1 and inputLen - 1
REQUIRE_TRUE(axe < input->rankOf() && axe >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.", shift -= inputLen * (shift / inputLen - 1);
input->rankOf(), input->rankOf() - 1, axe);
axes[e] = (axe < 0? (input->rankOf() + axe) : axe);
} }
helpers::rollFunctorFull(block.launchContext(), input, output, shift, axes, block.isInplace()); else
// cut shift to value between 1 and inputLen - 1
shift %= inputLen;
axes.resize(block.getIArguments()->size() - 1);
if (axes.size())
shifts.resize(axes.size());//emplace_back(shift);
else
shifts.push_back(shift);
for (auto& s: shifts)
s = shift;
for (unsigned e = 0; e < axes.size(); ++e) {
int axis = INT_ARG(e + 1);
REQUIRE_TRUE(axis < input->rankOf() && axis >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.",
input->rankOf(), input->rankOf() - 1, axis);
axes[e] = (axis < 0? (input->rankOf() + axis) : axis);
}
}
if (block.isInplace()) output = input;
shiftIsLinear = axes.size() == 0;
if (shiftIsLinear) {
helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace());
}
else {
helpers::rollFunctorFull(block.launchContext(), input, output, shifts, axes, block.isInplace());
} }
return Status::OK(); return Status::OK();
@ -64,7 +98,9 @@ namespace ops {
DECLARE_TYPES(roll) { DECLARE_TYPES(roll) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(0,nd4j::DataType::ANY)
->setAllowedInputTypes(1,nd4j::DataType::INT32) // TODO: all ints in future
->setAllowedInputTypes(2,nd4j::DataType::INT32)
->setAllowedOutputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::ANY)
->setSameMode(true); ->setSameMode(true);
} }

View File

@ -43,7 +43,6 @@ namespace helpers {
axisVector[e] = a + rank; axisVector[e] = a + rank;
} }
} }
} }
} }
} }

View File

@ -85,18 +85,19 @@ namespace helpers {
} }
} }
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axes, bool inplace){ void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
if (!inplace) if (!inplace)
output->assign(input); output->assign(input);
auto source = output; //input; auto source = output; //input;
for (int axe: axes) { for (auto i = 0; i < axes.size(); i++) {
int axe = axes[i];
if (axe == source->rankOf() - 1) {// last dimension if (axe == source->rankOf() - 1) {// last dimension
std::unique_ptr<ResultSet> listOfTensors(source->allTensorsAlongDimension({axe})); std::unique_ptr<ResultSet> listOfTensors(source->allTensorsAlongDimension({axe}));
std::unique_ptr<ResultSet> listOfOutTensors(output->allTensorsAlongDimension({axe})); std::unique_ptr<ResultSet> listOfOutTensors(output->allTensorsAlongDimension({axe}));
int fullLen = listOfTensors->size(); int fullLen = listOfTensors->size();
int theShift = shift; int theShift = shifts[i];
if (theShift > 0) { if (theShift > 0) {
theShift %= fullLen; theShift %= fullLen;
} }
@ -118,7 +119,7 @@ namespace helpers {
int fullLen = listOfTensors->size(); int fullLen = listOfTensors->size();
int sizeAt = input->sizeAt(axe); int sizeAt = input->sizeAt(axe);
int theShift = shift; int theShift = shifts[i];
if (theShift > 0) { if (theShift > 0) {
theShift %= sizeAt; theShift %= sizeAt;

View File

@ -27,6 +27,8 @@ namespace helpers {
void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector<int>& output) { void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector<int>& output) {
output.resize(axisVector->lengthOf()); output.resize(axisVector->lengthOf());
axisVector->tickReadDevice();
axisVector->syncToHost();
for (int e = 0; e < axisVector->lengthOf(); e++) { for (int e = 0; e < axisVector->lengthOf(); e++) {
auto ca = axisVector->e<int>(e); auto ca = axisVector->e<int>(e);
if (ca < 0) if (ca < 0)

View File

@ -228,22 +228,23 @@ namespace helpers {
} }
template <typename T> template <typename T>
static void rollFunctorFull_(NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace){ static void rollFunctorFull_(NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
if (!inplace) if (!inplace)
output->assign(input); output->assign(input);
for (int axe: axis) { for (size_t i = 0; i < axes.size(); i++) {
int axe = axes[i];
if (axe == input->rankOf() - 1) { // last dimension if (axe == input->rankOf() - 1) { // last dimension
std::unique_ptr<ResultSet> listOfTensors(output->allTensorsAlongDimension({axe})); std::unique_ptr<ResultSet> listOfTensors(output->allTensorsAlongDimension({axe}));
std::unique_ptr<ResultSet> listOfOutTensors(output->allTensorsAlongDimension({axe})); std::unique_ptr<ResultSet> listOfOutTensors(output->allTensorsAlongDimension({axe}));
int fullLen = listOfTensors->size(); int fullLen = listOfTensors->size();
int theShift = shift; int theShift = shifts[i];
if (theShift > 0) { // if (theShift > 0) {
theShift %= fullLen; // theShift %= fullLen;
} // }
else { // else {
theShift -= fullLen * (theShift / fullLen - 1); // theShift -= fullLen * (theShift / fullLen - 1);
} // }
for (int k = 0; k < fullLen; k++) { for (int k = 0; k < fullLen; k++) {
rollFunctorLinear(output->getContext(), listOfTensors->at(k), listOfOutTensors->at(k), theShift, true); rollFunctorLinear(output->getContext(), listOfTensors->at(k), listOfOutTensors->at(k), theShift, true);
} }
@ -258,12 +259,12 @@ namespace helpers {
int sizeAt = input->sizeAt(axe); int sizeAt = input->sizeAt(axe);
auto tadLength = shape::length(packZ.primaryShapeInfo()); auto tadLength = shape::length(packZ.primaryShapeInfo());
int theShift = shift; int theShift = shifts[i];
if (theShift > 0) // if (theShift > 0)
theShift %= sizeAt; // theShift %= sizeAt;
else // else
theShift -= sizeAt * (theShift / sizeAt - 1); // theShift -= sizeAt * (theShift / sizeAt - 1);
if (theShift) { if (theShift) {
for (int dim = 0; dim < numTads / sizeAt; ++dim) { for (int dim = 0; dim < numTads / sizeAt; ++dim) {
@ -307,10 +308,10 @@ namespace helpers {
} }
} }
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace){ void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
input->syncToDevice(); input->syncToDevice();
BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shift, axis, inplace), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shifts, axes, inplace), LIBND4J_TYPES);
output->tickWriteDevice(); output->tickWriteDevice();
} }
@ -324,7 +325,7 @@ namespace helpers {
} }
BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, (NDArray* input, NDArray* output, int shift, bool inplace), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, (NDArray* input, NDArray* output, int shift, bool inplace), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace), LIBND4J_TYPES);
} }
} }
} }

View File

@ -26,7 +26,7 @@ namespace ops {
namespace helpers { namespace helpers {
void rollFunctorLinear(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace = false); void rollFunctorLinear(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace = false);
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axes, bool inplace = false); void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace = false);
} }
} }
} }

View File

@ -3311,6 +3311,130 @@ auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
// delete result; // delete result;
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_10) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
auto result = op.execute({&x}, {}, {3, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_11) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>({1,2});
auto axis = NDArrayFactory::create<int>({0, 1});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
NDArray* y = nullptr;
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_12) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>({1,1,1});
auto axis = NDArrayFactory::create<int>({0, 1, 2});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
NDArray* y = nullptr;
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_13) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>(3);
auto axis = NDArrayFactory::create<int>(2);
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
2,3,4,1,6,7,8,5,10,11,12,9,14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
NDArray* y = nullptr;
auto result = op.execute({&x}, {}, {3,2}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_14) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>({1,1,1});
auto axis = NDArrayFactory::create<int>({0, 1, 2});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, percentile_test1) { TEST_F(DeclarableOpsTests7, percentile_test1) {