Shugeo sequence mask fix2 (#216)

* Fixed sequence_mask op and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Cuda fix for sequence_mask op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed sequence_mask op for both platforms and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed solve and triangular_solve for more than 2D for adjoint cases.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added adjoint solve test again.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added a set of tests for triangual_solve and generic solve ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added a pair tests for triangular_solve

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added tests for triangular_solve op.

Signed-off-by: shugeo <sgazeos@gmail.com>
master
shugeo 2020-02-06 20:06:50 +02:00 committed by GitHub
parent 569a46f87d
commit 5ae40f6e38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 326 additions and 25 deletions

View File

@ -58,30 +58,31 @@ namespace nd4j {
int outRank = shape::rank(in) + 1;
auto input = INPUT_VARIABLE(0);
auto dtype = DataType::BOOL;
Nd4jLong maxInd = input->argMax();
Nd4jLong max = input->e<Nd4jLong>(maxInd);
auto argMaxInd = input->argMax();
Nd4jLong max = input->e<Nd4jLong>(argMaxInd);
Nd4jLong maxInd = max;
if (block.getIArguments()->size() > 0) {
if (block.width() < 2) {
maxInd = INT_ARG(0);
if (maxInd < max)
maxInd = static_cast<Nd4jLong>(max);
if (block.getIArguments()->size() > 1)
dtype = (DataType)INT_ARG(1);
}
else {
dtype = (DataType)INT_ARG(0);
}
}
if (block.numD() > 0)
dtype = D_ARG(0);
if (block.width() > 1) {
auto maxlen = INPUT_VARIABLE(1);
Nd4jLong tmaxlen = maxlen->e<Nd4jLong>(0);
if (tmaxlen > max)
maxInd = static_cast<Nd4jLong>(tmaxlen);
if (block.numI() > 0) {
dtype = (DataType) INT_ARG(0);
}
}
else {
if (block.numI() > 0) {
maxInd = INT_ARG(0);
}
if (maxInd < max)
maxInd = max;
if (block.numI() > 1)
dtype = (DataType)INT_ARG(1); // to work with legacy code
}
else
maxInd = static_cast<Nd4jLong>(max);
int lastDimension = maxInd;
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);

View File

@ -38,10 +38,10 @@ namespace helpers {
}
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
}
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
}
}
}

View File

@ -36,10 +36,12 @@ namespace helpers {
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
auto inputPart = input->allTensorsAlongDimension({-2, -1});
auto outputPart = output->allTensorsAlongDimension({-2, -1});
auto rows = input->sizeAt(-2);
output->assign(input);
auto batchLoop = PRAGMA_THREADS_FOR {
for (auto batch = start; batch < stop; batch += increment) {
for (auto r = 0; r < input->rows(); r++) {
for (auto r = 0; r < rows; r++) {
for (auto c = 0; c < r; c++) {
math::nd4j_swap(outputPart[batch]->t<T>(r, c) , outputPart[batch]->t<T>(c, r));
}

View File

@ -108,17 +108,20 @@ namespace helpers {
static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) {
auto inputPart = input->allTensorsAlongDimension({-2, -1});
auto outputPart = output->allTensorsAlongDimension({-2, -1});
auto cols = input->sizeAt(-1);
auto rows = input->sizeAt(-2);
auto batchLoop = PRAGMA_THREADS_FOR {
for (auto batch = start; batch < stop; batch += increment) {
if (!lower) {
for (auto r = 0; r < input->rows(); r++) {
for (auto r = 0; r < rows; r++) {
for (auto c = 0; c <= r; c++) {
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
}
}
} else {
for (auto r = 0; r < input->rows(); r++) {
for (auto c = r; c < input->columns(); c++) {
for (auto r = 0; r < rows; r++) {
for (auto c = r; c < cols; c++) {
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
}
}

View File

@ -55,10 +55,10 @@ namespace helpers {
}
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
}
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
}
}
}

View File

@ -1667,6 +1667,241 @@ TEST_F(DeclarableOpsTests11, Solve_Test_4) {
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_4_1) {
auto a = NDArrayFactory::create<float>('c', {2, 2, 2}, {
0.7788f, 0.8012f, 0.7244f, 0.2309f,
0.7271f, 0.1804f, 0.5056f, 0.8925f
});
auto b = NDArrayFactory::create<float>('c', {2, 2, 2}, {
0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f
});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f,
0.4400987f, 0.2766527f, 0.6394467f, 0.79696566f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b}, {true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printBuffer("4 Solve 4x4");
// exp.printBuffer("4 Expec 4x4");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_4_2) {
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f,
0.5056f, 0.8925f, 0.5461f
});
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
0.7717f, 0.9281f, 0.9846f,
0.4838f, 0.6433f, 0.6041f,
0.6501f, 0.7612f, 0.7605f
});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
0.99088347f, 1.1917052f, 1.2642528f,
0.35071516f, 0.50630623f, 0.42935497f,
-0.30013534f, -0.53690606f, -0.47959247f
});
nd4j::ops::triangular_solve op;
auto res = op.evaluate({&a, &b}, {true, false});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printBuffer("4_2 Triangular_Solve 3x3");
// exp.printBuffer("4_2 Triangular_Expec 3x3");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_4_3) {
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f,
0.5056f, 0.8925f, 0.5461f
});
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
0.7717f, 0.9281f, 0.9846f,
0.4838f, 0.6433f, 0.6041f,
0.6501f, 0.7612f, 0.7605f
});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
0.45400196f, 0.53174824f, 0.62064564f,
-0.79585856f, -0.82621557f, -0.87855506f,
1.1904413f, 1.3938838f, 1.3926021f
});
nd4j::ops::triangular_solve op;
auto res = op.evaluate({&a, &b}, {true, true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printBuffer("4_3 Triangular_Solve 3x3");
// exp.printBuffer("4_3 Triangular_Expec 3x3");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_4_4) {
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f,
0.5056f, 0.8925f, 0.5461f
});
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
0.7717f, 0.9281f, 0.9846f,
0.4838f, 0.6433f, 0.6041f,
0.6501f, 0.7612f, 0.7605f
});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
0.8959121f, 1.6109066f, 1.7501404f,
0.49000582f, 0.66842675f, 0.5577021f,
-0.4398522f, -1.1899745f, -1.1392052f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b}, {false});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printBuffer("4_4 Solve 3x3");
// exp.printBuffer("4_4 Expec 3x3");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_4_5) {
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f,
0.5056f, 0.8925f, 0.5461f
});
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
0.7717f, 0.9281f, 0.9846f,
0.4838f, 0.6433f, 0.6041f,
0.6501f, 0.7612f, 0.7605f
});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
1.5504692f, 1.8953944f, 2.2765768f,
0.03399149f, 0.2883001f, 0.5377323f,
-0.8774802f, -1.2155888f, -1.8049058f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b}, {true, true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printBuffer("4_5 Solve 3x3");
// exp.printBuffer("4_5 Expec 3x3");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_4_6) {
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f,
0.5056f, 0.8925f, 0.5461f
});
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
0.7717f, 0.9281f, 0.9846f,
0.4838f, 0.6433f, 0.6041f,
0.6501f, 0.7612f, 0.7605f
});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
0.99088347f, 1.1917052f, 1.2642528f,
-0.426483f, -0.42840624f, -0.5622601f,
0.01692283f, -0.04538865f, -0.09868701f
});
nd4j::ops::triangular_solve op;
auto res = op.evaluate({&a, &b}, {false, true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
z->printBuffer("4_6 Solve 3x3");
exp.printBuffer("4_6 Expec 3x3");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_4_7) {
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
// 0.7788f, 0.2309f, 0.5056f,
// 0.8012f, 0.7271f, 0.8925f,
// 0.7244f, 0.1804f, 0.5461f
0.7788f, 0.2309f, 0.5056f,
0.8012f, 0.7271f, 0.8925f,
0.7244f, 0.1804f, 0.5461f
});
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
0.7717f, 0.9281f, 0.9846f,
0.4838f, 0.6433f, 0.6041f,
0.6501f, 0.7612f, 0.7605f
});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
0.99088347f, 1.1917052f, 1.2642528f,
-0.426483f, -0.42840624f, -0.5622601f,
0.01692283f, -0.04538865f, -0.09868701f
});
nd4j::ops::triangular_solve op;
auto res = op.evaluate({&a, &b}, {true, false});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
z->printBuffer("4_7 Solve 3x3");
exp.printBuffer("4_7 Expec 3x3");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_5) {

View File

@ -802,6 +802,66 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) {
delete result;
}
TEST_F(DeclarableOpsTests7, Test_SequenceMask_3) {
auto input = NDArrayFactory::create<int>('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8});
auto exp = NDArrayFactory::create<int>('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
nd4j::ops::sequence_mask op;
auto result = op.evaluate({&input}, {nd4j::DataType::INT32});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printBuffer("Output");
// z->printShapeInfo("Shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests7, Test_SequenceMask_4) {
auto input = NDArrayFactory::create<int>({1, 3, 2});
auto maxLen = NDArrayFactory::create<int>(5);
auto exp = NDArrayFactory::create<float>('c', {3,5}, {
1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f
});
nd4j::ops::sequence_mask op;
auto result = op.evaluate({&input, &maxLen}, {nd4j::DataType::FLOAT32});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printBuffer("Output");
// z->printShapeInfo("Shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests7, Test_SequenceMask_5) {
auto input = NDArrayFactory::create<int>({1, 3, 2});
auto exp = NDArrayFactory::create<float>('c', {3,5}, {
1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f
});
nd4j::ops::sequence_mask op;
auto result = op.evaluate({&input}, {5, (int)nd4j::DataType::FLOAT32});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printBuffer("Output");
// z->printShapeInfo("Shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestSegmentMax_1) {
auto x = NDArrayFactory::create<double>({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.});