Fixed lu for cuda platform and tests. (#158)

Signed-off-by: shugeo <sgazeos@gmail.com>
master
shugeo 2020-01-02 22:25:41 +02:00 committed by raver119
parent c32acb2ec7
commit fbf7c9d38b
3 changed files with 70 additions and 8 deletions

View File

@ -28,10 +28,14 @@ namespace nd4j {
CUSTOM_OP_IMPL(lu, 1, 2, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
auto p = OUTPUT_VARIABLE(1);
REQUIRE_TRUE(input->rankOf() >=2, 0, "matrix_inverse: The rank of input array should not less than 2, but %i is given", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "matrix_inverse: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
auto p = OUTPUT_VARIABLE(1);
if (block.getIArguments()->size()) {
DataType dtype = (DataType)INT_ARG(0);
REQUIRE_TRUE(dtype == nd4j::DataType::INT32 || dtype == nd4j::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); }
REQUIRE_TRUE(input->rankOf() >=2, 0, "lu: The rank of input array should not less than 2, but %i is given", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "lu: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
helpers::lu(block.launchContext(), input, z, p);
return Status::OK();
@ -41,7 +45,12 @@ namespace nd4j {
auto in = inputShape->at(0);
auto shapeVector = ShapeUtils::shapeAsVector(in);
auto luShape = ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace());
auto luP = ShapeBuilders::createShapeInfo(nd4j::DataType::INT32, shape::order(in), shapeVector.size() - 1,
auto dtype = nd4j::DataType::INT32;
if (block.getIArguments()->size()) {
dtype = (DataType)INT_ARG(0);
REQUIRE_TRUE(dtype == nd4j::DataType::INT32 || dtype == nd4j::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str());
}
auto luP = ShapeBuilders::createShapeInfo(dtype, shape::order(in), shapeVector.size() - 1,
shapeVector.data(), block.workspace());
return SHAPELIST(CONSTANT(luShape), CONSTANT(luP));
}

View File

@ -598,14 +598,13 @@ namespace helpers {
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
auto n = input->sizeAt(-1);
auto stream = context->getCudaStream();
auto iota = NDArrayFactory::create<int>('c', {n});
NDArray iota('c', {n}, permutationVectors->dataType());// = NDArrayFactory::create(); // <int>('c', {n});
iota.linspace(0); iota.syncToDevice();
output->assign(input); // fill up output tensor with zeros
output->tickWriteDevice();
// output->tickWriteDevice();
permutationVectors->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), iota, *permutationVectors, true, nullptr);
permutationVectors->tickWriteDevice();
// permutationVectors->tickWriteDevice();
auto tads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1});
auto permutaionTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-1});
auto batchNum = tads.numberOfTads();

View File

@ -2677,3 +2677,57 @@ TEST_F(DeclarableOpsTests12, LU_Test_3_3) {
ASSERT_TRUE(expP.equalsTo(p));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, LU_Test_4_1) {
auto in = NDArrayFactory::create<float>('c', {2, 2,2}, {0.7788f, 0.8012f,
0.7244f, 0.2309f,
0.7271f, 0.1804f,
0.5056f, 0.8925f});
auto expLU = NDArrayFactory::create<float>('c', {2, 2,2}, {
0.7788f, 0.8012f, 0.930149f, -0.514335f,
0.7271f, 0.1804f, 0.695365f, 0.767056f
});
auto expP = NDArrayFactory::create<int>('c', {2,2}, {0, 1, 0, 1});
nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
auto p = res->at(1);
// z->printIndexedBuffer("Triangulars4_1");
// p->printIndexedBuffer("Permutaions4_1");
ASSERT_TRUE(expLU.equalsTo(z));
ASSERT_TRUE(expP.equalsTo(p));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, LU_Test_4_2) {
auto in = NDArrayFactory::create<float>('c', {2, 2,2}, {0.7788f, 0.8012f,
0.7244f, 0.2309f,
0.7271f, 0.1804f,
0.5056f, 0.8925f});
auto expLU = NDArrayFactory::create<float>('c', {2, 2,2}, {
0.7788f, 0.8012f, 0.930149f, -0.514335f,
0.7271f, 0.1804f, 0.695365f, 0.767056f
});
auto expP = NDArrayFactory::create<Nd4jLong>('c', {2,2}, {0, 1, 0, 1});
nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {nd4j::DataType::INT64});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
auto p = res->at(1);
// z->printIndexedBuffer("Triangulars4_2");
// p->printIndexedBuffer("Permutaions4_2");
ASSERT_TRUE(expLU.equalsTo(z));
ASSERT_TRUE(expP.equalsTo(p));
delete res;
}