Shugeo sequence mask fix2 (#216)
* Fixed sequence_mask op and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda fix for sequence_mask op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed sequence_mask op for both platforms and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed solve and triangular_solve for more than 2D for adjoint cases. Signed-off-by: shugeo <sgazeos@gmail.com> * Added adjoint solve test again. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for triangual_solve and generic solve ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a pair tests for triangular_solve Signed-off-by: shugeo <sgazeos@gmail.com> * Added tests for triangular_solve op. Signed-off-by: shugeo <sgazeos@gmail.com>master
parent
569a46f87d
commit
5ae40f6e38
|
@ -58,30 +58,31 @@ namespace nd4j {
|
||||||
int outRank = shape::rank(in) + 1;
|
int outRank = shape::rank(in) + 1;
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto dtype = DataType::BOOL;
|
auto dtype = DataType::BOOL;
|
||||||
Nd4jLong maxInd = input->argMax();
|
auto argMaxInd = input->argMax();
|
||||||
Nd4jLong max = input->e<Nd4jLong>(maxInd);
|
Nd4jLong max = input->e<Nd4jLong>(argMaxInd);
|
||||||
|
Nd4jLong maxInd = max;
|
||||||
|
|
||||||
if (block.getIArguments()->size() > 0) {
|
if (block.numD() > 0)
|
||||||
if (block.width() < 2) {
|
dtype = D_ARG(0);
|
||||||
maxInd = INT_ARG(0);
|
|
||||||
if (maxInd < max)
|
|
||||||
maxInd = static_cast<Nd4jLong>(max);
|
|
||||||
if (block.getIArguments()->size() > 1)
|
|
||||||
dtype = (DataType)INT_ARG(1);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
dtype = (DataType)INT_ARG(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (block.width() > 1) {
|
if (block.width() > 1) {
|
||||||
auto maxlen = INPUT_VARIABLE(1);
|
auto maxlen = INPUT_VARIABLE(1);
|
||||||
Nd4jLong tmaxlen = maxlen->e<Nd4jLong>(0);
|
Nd4jLong tmaxlen = maxlen->e<Nd4jLong>(0);
|
||||||
if (tmaxlen > max)
|
if (tmaxlen > max)
|
||||||
maxInd = static_cast<Nd4jLong>(tmaxlen);
|
maxInd = static_cast<Nd4jLong>(tmaxlen);
|
||||||
|
if (block.numI() > 0) {
|
||||||
|
dtype = (DataType) INT_ARG(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (block.numI() > 0) {
|
||||||
|
maxInd = INT_ARG(0);
|
||||||
|
}
|
||||||
|
if (maxInd < max)
|
||||||
|
maxInd = max;
|
||||||
|
if (block.numI() > 1)
|
||||||
|
dtype = (DataType)INT_ARG(1); // to work with legacy code
|
||||||
}
|
}
|
||||||
else
|
|
||||||
maxInd = static_cast<Nd4jLong>(max);
|
|
||||||
|
|
||||||
int lastDimension = maxInd;
|
int lastDimension = maxInd;
|
||||||
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);
|
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);
|
||||||
|
|
|
@ -38,10 +38,10 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -36,10 +36,12 @@ namespace helpers {
|
||||||
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||||
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
||||||
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
||||||
|
auto rows = input->sizeAt(-2);
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
|
|
||||||
auto batchLoop = PRAGMA_THREADS_FOR {
|
auto batchLoop = PRAGMA_THREADS_FOR {
|
||||||
for (auto batch = start; batch < stop; batch += increment) {
|
for (auto batch = start; batch < stop; batch += increment) {
|
||||||
for (auto r = 0; r < input->rows(); r++) {
|
for (auto r = 0; r < rows; r++) {
|
||||||
for (auto c = 0; c < r; c++) {
|
for (auto c = 0; c < r; c++) {
|
||||||
math::nd4j_swap(outputPart[batch]->t<T>(r, c) , outputPart[batch]->t<T>(c, r));
|
math::nd4j_swap(outputPart[batch]->t<T>(r, c) , outputPart[batch]->t<T>(c, r));
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,17 +108,20 @@ namespace helpers {
|
||||||
static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) {
|
static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) {
|
||||||
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
||||||
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
||||||
|
auto cols = input->sizeAt(-1);
|
||||||
|
auto rows = input->sizeAt(-2);
|
||||||
|
|
||||||
auto batchLoop = PRAGMA_THREADS_FOR {
|
auto batchLoop = PRAGMA_THREADS_FOR {
|
||||||
for (auto batch = start; batch < stop; batch += increment) {
|
for (auto batch = start; batch < stop; batch += increment) {
|
||||||
if (!lower) {
|
if (!lower) {
|
||||||
for (auto r = 0; r < input->rows(); r++) {
|
for (auto r = 0; r < rows; r++) {
|
||||||
for (auto c = 0; c <= r; c++) {
|
for (auto c = 0; c <= r; c++) {
|
||||||
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (auto r = 0; r < input->rows(); r++) {
|
for (auto r = 0; r < rows; r++) {
|
||||||
for (auto c = r; c < input->columns(); c++) {
|
for (auto c = r; c < cols; c++) {
|
||||||
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,10 +55,10 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1667,6 +1667,241 @@ TEST_F(DeclarableOpsTests11, Solve_Test_4) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
delete res;
|
delete res;
|
||||||
}
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_1) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f, 0.2309f,
|
||||||
|
0.7271f, 0.1804f, 0.5056f, 0.8925f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||||
|
1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f,
|
||||||
|
0.4400987f, 0.2766527f, 0.6394467f, 0.79696566f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {true});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printBuffer("4 Solve 4x4");
|
||||||
|
// exp.printBuffer("4 Expec 4x4");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_2) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.99088347f, 1.1917052f, 1.2642528f,
|
||||||
|
0.35071516f, 0.50630623f, 0.42935497f,
|
||||||
|
-0.30013534f, -0.53690606f, -0.47959247f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {true, false});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printBuffer("4_2 Triangular_Solve 3x3");
|
||||||
|
// exp.printBuffer("4_2 Triangular_Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_3) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.45400196f, 0.53174824f, 0.62064564f,
|
||||||
|
-0.79585856f, -0.82621557f, -0.87855506f,
|
||||||
|
1.1904413f, 1.3938838f, 1.3926021f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {true, true});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printBuffer("4_3 Triangular_Solve 3x3");
|
||||||
|
// exp.printBuffer("4_3 Triangular_Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_4) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.8959121f, 1.6109066f, 1.7501404f,
|
||||||
|
0.49000582f, 0.66842675f, 0.5577021f,
|
||||||
|
-0.4398522f, -1.1899745f, -1.1392052f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {false});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printBuffer("4_4 Solve 3x3");
|
||||||
|
// exp.printBuffer("4_4 Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_5) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
1.5504692f, 1.8953944f, 2.2765768f,
|
||||||
|
0.03399149f, 0.2883001f, 0.5377323f,
|
||||||
|
-0.8774802f, -1.2155888f, -1.8049058f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {true, true});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printBuffer("4_5 Solve 3x3");
|
||||||
|
// exp.printBuffer("4_5 Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_6) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.99088347f, 1.1917052f, 1.2642528f,
|
||||||
|
-0.426483f, -0.42840624f, -0.5622601f,
|
||||||
|
0.01692283f, -0.04538865f, -0.09868701f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {false, true});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
z->printBuffer("4_6 Solve 3x3");
|
||||||
|
exp.printBuffer("4_6 Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_7) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
// 0.7788f, 0.2309f, 0.5056f,
|
||||||
|
// 0.8012f, 0.7271f, 0.8925f,
|
||||||
|
// 0.7244f, 0.1804f, 0.5461f
|
||||||
|
|
||||||
|
0.7788f, 0.2309f, 0.5056f,
|
||||||
|
0.8012f, 0.7271f, 0.8925f,
|
||||||
|
0.7244f, 0.1804f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.99088347f, 1.1917052f, 1.2642528f,
|
||||||
|
-0.426483f, -0.42840624f, -0.5622601f,
|
||||||
|
0.01692283f, -0.04538865f, -0.09868701f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {true, false});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
z->printBuffer("4_7 Solve 3x3");
|
||||||
|
exp.printBuffer("4_7 Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests11, Solve_Test_5) {
|
TEST_F(DeclarableOpsTests11, Solve_Test_5) {
|
||||||
|
|
|
@ -802,6 +802,66 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Test_SequenceMask_3) {
|
||||||
|
auto input = NDArrayFactory::create<int>('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8});
|
||||||
|
auto exp = NDArrayFactory::create<int>('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
|
||||||
|
nd4j::ops::sequence_mask op;
|
||||||
|
auto result = op.evaluate({&input}, {nd4j::DataType::INT32});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printBuffer("Output");
|
||||||
|
// z->printShapeInfo("Shape");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Test_SequenceMask_4) {
|
||||||
|
auto input = NDArrayFactory::create<int>({1, 3, 2});
|
||||||
|
auto maxLen = NDArrayFactory::create<int>(5);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3,5}, {
|
||||||
|
1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::sequence_mask op;
|
||||||
|
auto result = op.evaluate({&input, &maxLen}, {nd4j::DataType::FLOAT32});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printBuffer("Output");
|
||||||
|
// z->printShapeInfo("Shape");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Test_SequenceMask_5) {
|
||||||
|
auto input = NDArrayFactory::create<int>({1, 3, 2});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3,5}, {
|
||||||
|
1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::sequence_mask op;
|
||||||
|
auto result = op.evaluate({&input}, {5, (int)nd4j::DataType::FLOAT32});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printBuffer("Output");
|
||||||
|
// z->printShapeInfo("Shape");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests7, TestSegmentMax_1) {
|
TEST_F(DeclarableOpsTests7, TestSegmentMax_1) {
|
||||||
auto x = NDArrayFactory::create<double>({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.});
|
auto x = NDArrayFactory::create<double>({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.});
|
||||||
|
|
Loading…
Reference in New Issue