parent
c32acb2ec7
commit
fbf7c9d38b
|
@ -28,10 +28,14 @@ namespace nd4j {
|
|||
CUSTOM_OP_IMPL(lu, 1, 2, false, 0, 0) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
auto p = OUTPUT_VARIABLE(1);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() >=2, 0, "matrix_inverse: The rank of input array should not less than 2, but %i is given", input->rankOf());
|
||||
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "matrix_inverse: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
|
||||
auto p = OUTPUT_VARIABLE(1);
|
||||
if (block.getIArguments()->size()) {
|
||||
DataType dtype = (DataType)INT_ARG(0);
|
||||
REQUIRE_TRUE(dtype == nd4j::DataType::INT32 || dtype == nd4j::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); }
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() >=2, 0, "lu: The rank of input array should not less than 2, but %i is given", input->rankOf());
|
||||
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "lu: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
|
||||
|
||||
helpers::lu(block.launchContext(), input, z, p);
|
||||
return Status::OK();
|
||||
|
@ -41,7 +45,12 @@ namespace nd4j {
|
|||
auto in = inputShape->at(0);
|
||||
auto shapeVector = ShapeUtils::shapeAsVector(in);
|
||||
auto luShape = ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace());
|
||||
auto luP = ShapeBuilders::createShapeInfo(nd4j::DataType::INT32, shape::order(in), shapeVector.size() - 1,
|
||||
auto dtype = nd4j::DataType::INT32;
|
||||
if (block.getIArguments()->size()) {
|
||||
dtype = (DataType)INT_ARG(0);
|
||||
REQUIRE_TRUE(dtype == nd4j::DataType::INT32 || dtype == nd4j::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str());
|
||||
}
|
||||
auto luP = ShapeBuilders::createShapeInfo(dtype, shape::order(in), shapeVector.size() - 1,
|
||||
shapeVector.data(), block.workspace());
|
||||
return SHAPELIST(CONSTANT(luShape), CONSTANT(luP));
|
||||
}
|
||||
|
|
|
@ -598,14 +598,13 @@ namespace helpers {
|
|||
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
|
||||
auto n = input->sizeAt(-1);
|
||||
auto stream = context->getCudaStream();
|
||||
auto iota = NDArrayFactory::create<int>('c', {n});
|
||||
NDArray iota('c', {n}, permutationVectors->dataType());// = NDArrayFactory::create(); // <int>('c', {n});
|
||||
iota.linspace(0); iota.syncToDevice();
|
||||
|
||||
output->assign(input); // fill up output tensor with zeros
|
||||
output->tickWriteDevice();
|
||||
// output->tickWriteDevice();
|
||||
permutationVectors->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), iota, *permutationVectors, true, nullptr);
|
||||
permutationVectors->tickWriteDevice();
|
||||
|
||||
// permutationVectors->tickWriteDevice();
|
||||
auto tads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1});
|
||||
auto permutaionTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-1});
|
||||
auto batchNum = tads.numberOfTads();
|
||||
|
|
|
@ -2677,3 +2677,57 @@ TEST_F(DeclarableOpsTests12, LU_Test_3_3) {
|
|||
ASSERT_TRUE(expP.equalsTo(p));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, LU_Test_4_1) {
|
||||
|
||||
auto in = NDArrayFactory::create<float>('c', {2, 2,2}, {0.7788f, 0.8012f,
|
||||
0.7244f, 0.2309f,
|
||||
0.7271f, 0.1804f,
|
||||
0.5056f, 0.8925f});
|
||||
auto expLU = NDArrayFactory::create<float>('c', {2, 2,2}, {
|
||||
0.7788f, 0.8012f, 0.930149f, -0.514335f,
|
||||
0.7271f, 0.1804f, 0.695365f, 0.767056f
|
||||
});
|
||||
|
||||
auto expP = NDArrayFactory::create<int>('c', {2,2}, {0, 1, 0, 1});
|
||||
nd4j::ops::lu op;
|
||||
|
||||
auto res = op.execute({&in}, {}, {});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
auto p = res->at(1);
|
||||
// z->printIndexedBuffer("Triangulars4_1");
|
||||
// p->printIndexedBuffer("Permutaions4_1");
|
||||
|
||||
ASSERT_TRUE(expLU.equalsTo(z));
|
||||
ASSERT_TRUE(expP.equalsTo(p));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, LU_Test_4_2) {
|
||||
|
||||
auto in = NDArrayFactory::create<float>('c', {2, 2,2}, {0.7788f, 0.8012f,
|
||||
0.7244f, 0.2309f,
|
||||
0.7271f, 0.1804f,
|
||||
0.5056f, 0.8925f});
|
||||
auto expLU = NDArrayFactory::create<float>('c', {2, 2,2}, {
|
||||
0.7788f, 0.8012f, 0.930149f, -0.514335f,
|
||||
0.7271f, 0.1804f, 0.695365f, 0.767056f
|
||||
});
|
||||
|
||||
auto expP = NDArrayFactory::create<Nd4jLong>('c', {2,2}, {0, 1, 0, 1});
|
||||
nd4j::ops::lu op;
|
||||
|
||||
auto res = op.execute({&in}, {}, {nd4j::DataType::INT64});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
auto p = res->at(1);
|
||||
// z->printIndexedBuffer("Triangulars4_2");
|
||||
// p->printIndexedBuffer("Permutaions4_2");
|
||||
|
||||
ASSERT_TRUE(expLU.equalsTo(z));
|
||||
ASSERT_TRUE(expP.equalsTo(p));
|
||||
delete res;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue