diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp index e7694b409..477b298a3 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp @@ -58,30 +58,31 @@ namespace nd4j { int outRank = shape::rank(in) + 1; auto input = INPUT_VARIABLE(0); auto dtype = DataType::BOOL; - Nd4jLong maxInd = input->argMax(); - Nd4jLong max = input->e(maxInd); + auto argMaxInd = input->argMax(); + Nd4jLong max = input->e(argMaxInd); + Nd4jLong maxInd = max; - if (block.getIArguments()->size() > 0) { - if (block.width() < 2) { - maxInd = INT_ARG(0); - if (maxInd < max) - maxInd = static_cast(max); - if (block.getIArguments()->size() > 1) - dtype = (DataType)INT_ARG(1); - } - else { - dtype = (DataType)INT_ARG(0); - } - } + if (block.numD() > 0) + dtype = D_ARG(0); if (block.width() > 1) { auto maxlen = INPUT_VARIABLE(1); Nd4jLong tmaxlen = maxlen->e(0); if (tmaxlen > max) maxInd = static_cast(tmaxlen); + if (block.numI() > 0) { + dtype = (DataType) INT_ARG(0); + } + } + else { + if (block.numI() > 0) { + maxInd = INT_ARG(0); + } + if (maxInd < max) + maxInd = max; + if (block.numI() > 1) + dtype = (DataType)INT_ARG(1); // to work with legacy code } - else - maxInd = static_cast(max); int lastDimension = maxInd; ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp index bf3463afe..c175fd96d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp @@ -38,10 +38,10 @@ namespace helpers { } void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); } - BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp index 8583d9cba..48f7f0d9a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp @@ -36,10 +36,12 @@ namespace helpers { static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) { auto inputPart = input->allTensorsAlongDimension({-2, -1}); auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto rows = input->sizeAt(-2); output->assign(input); + auto batchLoop = PRAGMA_THREADS_FOR { for (auto batch = start; batch < stop; batch += increment) { - for (auto r = 0; r < input->rows(); r++) { + for (auto r = 0; r < rows; r++) { for (auto c = 0; c < r; c++) { math::nd4j_swap(outputPart[batch]->t(r, c) , outputPart[batch]->t(c, r)); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp index e904d219c..ceb228439 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp @@ -108,17 +108,20 @@ namespace helpers { static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { auto inputPart = input->allTensorsAlongDimension({-2, -1}); auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto cols = input->sizeAt(-1); + auto rows = input->sizeAt(-2); + auto batchLoop = PRAGMA_THREADS_FOR { for (auto batch = start; batch < stop; batch += increment) { if (!lower) { - for (auto r = 0; r < input->rows(); r++) { + for (auto r = 0; r < rows; r++) { for (auto c = 0; c <= r; c++) { outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r); } } } else { - for (auto r = 0; r < input->rows(); r++) { - for (auto c = r; c < input->columns(); c++) { + for (auto r = 0; r < rows; r++) { + for (auto c = r; c < cols; c++) { outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu b/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu index c07db1b95..6b33a384e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu @@ -55,10 +55,10 @@ namespace helpers { } void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); } - BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); } } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index de4bdc31b..465703768 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -1667,6 +1667,241 @@ TEST_F(DeclarableOpsTests11, Solve_Test_4) { ASSERT_TRUE(exp.equalsTo(z)); delete res; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_1) { + + auto a = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f + }); + + auto b = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f + }); + + auto exp = NDArrayFactory::create('c', {2, 2, 2}, { + 1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f, + 0.4400987f, 0.2766527f, 0.6394467f, 0.79696566f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4 Solve 4x4"); +// exp.printBuffer("4 Expec 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_2) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.99088347f, 1.1917052f, 1.2642528f, + 0.35071516f, 0.50630623f, 0.42935497f, + -0.30013534f, -0.53690606f, -0.47959247f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {true, false}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4_2 Triangular_Solve 3x3"); +// exp.printBuffer("4_2 Triangular_Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_3) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.45400196f, 0.53174824f, 0.62064564f, + -0.79585856f, -0.82621557f, -0.87855506f, + 1.1904413f, 1.3938838f, 1.3926021f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {true, true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4_3 Triangular_Solve 3x3"); +// exp.printBuffer("4_3 Triangular_Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_4) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.8959121f, 1.6109066f, 1.7501404f, + 0.49000582f, 0.66842675f, 0.5577021f, + -0.4398522f, -1.1899745f, -1.1392052f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4_4 Solve 3x3"); +// exp.printBuffer("4_4 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_5) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 1.5504692f, 1.8953944f, 2.2765768f, + 0.03399149f, 0.2883001f, 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}, {true, true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4_5 Solve 3x3"); +// exp.printBuffer("4_5 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_6) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.99088347f, 1.1917052f, 1.2642528f, + -0.426483f, -0.42840624f, -0.5622601f, + 0.01692283f, -0.04538865f, -0.09868701f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {false, true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + + z->printBuffer("4_6 Solve 3x3"); + exp.printBuffer("4_6 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_7) { + + auto a = NDArrayFactory::create('c', {3, 3}, { +// 0.7788f, 0.2309f, 0.5056f, +// 0.8012f, 0.7271f, 0.8925f, +// 0.7244f, 0.1804f, 0.5461f + + 0.7788f, 0.2309f, 0.5056f, + 0.8012f, 0.7271f, 0.8925f, + 0.7244f, 0.1804f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.99088347f, 1.1917052f, 1.2642528f, + -0.426483f, -0.42840624f, -0.5622601f, + 0.01692283f, -0.04538865f, -0.09868701f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {true, false}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + + z->printBuffer("4_7 Solve 3x3"); + exp.printBuffer("4_7 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_5) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 39761ecb3..0a6f8e5e8 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -758,7 +758,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) { TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { auto input = NDArrayFactory::create('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - auto exp = NDArrayFactory::create('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + auto exp = NDArrayFactory::create('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, @@ -802,6 +802,66 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { delete result; } +TEST_F(DeclarableOpsTests7, Test_SequenceMask_3) { + auto input = NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + nd4j::ops::sequence_mask op; + auto result = op.evaluate({&input}, {nd4j::DataType::INT32}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printBuffer("Output"); +// z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +TEST_F(DeclarableOpsTests7, Test_SequenceMask_4) { + auto input = NDArrayFactory::create({1, 3, 2}); + auto maxLen = NDArrayFactory::create(5); + auto exp = NDArrayFactory::create('c', {3,5}, { + 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f + }); + + nd4j::ops::sequence_mask op; + auto result = op.evaluate({&input, &maxLen}, {nd4j::DataType::FLOAT32}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printBuffer("Output"); +// z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +TEST_F(DeclarableOpsTests7, Test_SequenceMask_5) { + auto input = NDArrayFactory::create({1, 3, 2}); + auto exp = NDArrayFactory::create('c', {3,5}, { + 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f + }); + + nd4j::ops::sequence_mask op; + auto result = op.evaluate({&input}, {5, (int)nd4j::DataType::FLOAT32}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printBuffer("Output"); +// z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.});