Shugeo solve ls (#203)

* lstsq op. Initial commit.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Least squares linear problem solve op (lstsq). Cpu draft implementation.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed shape routine and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added test for lstsq op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Rectification for lstsq op implementation.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected test to avoid numerical inconsistensy.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added prints for check computing.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected tests to use evalueate facility instead.

Signed-off-by: shugeo <sgazeos@gmail.com>

* CPU implementation of MatrixSolveLs op and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added cuda implementation for helpers with lstsq op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored tests for lstsq op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added processing for empty inputs.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Merged tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored lstsq op for fast case.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed test.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored lstsq op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed some issues with solve.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed lstsq op to avoid erros.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added kernel for giagonal factor

Signed-off-by: shugeo <sgazeos@gmail.com>

* lstsq wrapper and triangular_solve fixed

* Added proper processing empty inputs and test.

Signed-off-by: shugeo <sgazeos@gmail.com>

* SequenceMask test

* Build fixed

* Added proper processing of empty inputs with solve op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Mapping added

* Added check of input shapes with solve op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added a couple of tests for lstsq op and minor changes with cuda helper for one.'

Signed-off-by: shugeo <sgazeos@gmail.com>

* Tests on

* Refactored test for lstsq op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed test

* Added another approach for lstsq op aka solve_ls.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Finished cpu part for solve_ls op helpers.

* Added helper for low triangular matrix inversion.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored alternate solve_ls cpu implementation.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Removed alternate approach for solve_ls op. Added multithreading with matrix inversion.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Assert fixed

* Refactored multithreading for inverse matricies.

Signed-off-by: shugeo <sgazeos@gmail.com>

Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
shugeo 2020-02-28 10:37:26 +02:00 committed by GitHub
parent 358c650b62
commit 330a69d4e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 896 additions and 64 deletions

View File

@ -310,7 +310,7 @@ namespace nd4j {
* This method returns new uninitialized array with the same shape & data type
* @return
*/
NDArray ulike();
NDArray ulike() const;
/**

View File

@ -4725,7 +4725,7 @@ NDArray NDArray::like() {
}
////////////////////////////////////////////////////////////////////////
NDArray NDArray::ulike() {
NDArray NDArray::ulike() const{
return NDArray(this, false, getContext());
}

View File

@ -32,7 +32,7 @@ namespace nd4j {
REQUIRE_TRUE(input->rankOf() >=2, 0, "cholesky: The rank of input array should not less than 2, but %i is given", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "cholesky: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
REQUIRE_TRUE(helpers::checkCholeskyInput(block.launchContext(), input), 0, "cholesky: The input tensor should be positive-defined and symmetric.");
return helpers::cholesky(block.launchContext(), input, output);
return helpers::cholesky(block.launchContext(), input, output, block.isInplace());
}
DECLARE_TYPES(cholesky) {
getOpDescriptor()

View File

@ -0,0 +1,133 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit, K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by GS <sgazeos@gmail.com> at 01/28/2020
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_lstsq)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/lstsq.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(lstsq, 2, 1, false, 0, 0) {
auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
bool fastFlag = true;
double l2_factor = 0.;
if (block.numB() > 0) {
fastFlag = B_ARG(0);
}
if (block.numT() > 0) {
l2_factor = T_ARG(0);
}
REQUIRE_TRUE(a->rankOf() >=2, 0, "lstsq: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf());
REQUIRE_TRUE(b->rankOf() >=2, 0, "lstsq: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf());
// REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, "lstsq: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2));
//REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not finished for factor difference from 0.");
if (a->isEmpty() || b->isEmpty() || z->isEmpty())
return Status::OK();
auto res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, l2_factor, fastFlag, z);
return res;
}
CUSTOM_OP_IMPL(solve_ls, 2, 1, false, 0, 0) {
auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
bool fastFlag = true;
double l2_factor = 0.;
if (block.numB() > 0) {
fastFlag = B_ARG(0);
}
if (block.numT() > 0) {
l2_factor = T_ARG(0);
}
REQUIRE_TRUE(a->rankOf() >=2, 0, "lstsq: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf());
REQUIRE_TRUE(b->rankOf() >=2, 0, "lstsq: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf());
// REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, "lstsq: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2));
//REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not finished for factor difference from 0.");
auto res = Status::OK();
if (a->isEmpty() || b->isEmpty() || z->isEmpty())
return res;
res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, l2_factor, fastFlag, z);
return res;
}
DECLARE_SYN(MatrixSolveLs, lstsq);
DECLARE_SHAPE_FN(lstsq) {
auto in0 = inputShape->at(0);
auto in1 = inputShape->at(1);
auto shapeOf = ShapeUtils::shapeAsVector(in1);
auto rank = shapeOf.size();
shapeOf[rank - 2] = shape::sizeAt(in0, -1);
if (shape::isEmpty(in0) || shape::isEmpty(in1)) {
shapeOf[rank - 1] = 0; // set output shape to empty
}
auto resShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in0), shape::order(in1), shapeOf);//ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace());
if (shapeOf[rank - 1] == 0) {
ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY);
}
return SHAPELIST(resShape);
}
DECLARE_TYPES(lstsq) {
getOpDescriptor()
->setAllowedInputTypes({ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS})
->setSameMode(false);
}
DECLARE_SHAPE_FN(solve_ls) {
auto in0 = inputShape->at(0);
auto in1 = inputShape->at(1);
auto shapeOf = ShapeUtils::shapeAsVector(in1);
auto rank = shapeOf.size();
shapeOf[rank - 2] = shape::sizeAt(in0, -1);
if (shape::isEmpty(in0) || shape::isEmpty(in1)) {
shapeOf[rank - 1] = 0; // set output shape to empty
}
auto resShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in0), shape::order(in1), shapeOf);//ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace());
if (shapeOf[rank - 1] == 0) {
ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY);
}
return SHAPELIST(resShape);
}
DECLARE_TYPES(solve_ls) {
getOpDescriptor()
->setAllowedInputTypes({ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS})
->setSameMode(false);
}
}
}
#endif

View File

@ -35,8 +35,9 @@ namespace nd4j {
REQUIRE_TRUE(input->rankOf() >=2, 0, "qr: The rank of input array should not be less than 2, but %i is given", input->rankOf());
REQUIRE_TRUE((fullMatricies && outputQ->sizeAt(-1) == input->sizeAt(-2)) || (!fullMatricies && outputQ->isSameShape(input)), 0, "qr: The last dimmensions should be equal to result Q, but %i and %i are given", outputQ->sizeAt(-1), input->sizeAt(-2));
REQUIRE_TRUE((fullMatricies && outputR->sizeAt(-1) == input->sizeAt(-1)) || (!fullMatricies && outputR->sizeAt(-1) == outputR->sizeAt(-2)), 0, "qr: The last dimmensions should be equal to result R, but %i and %i are given", outputR->sizeAt(-1), input->sizeAt(-1));
if (!input->isEmpty() && !outputQ->isEmpty() && !outputR->isEmpty())
helpers::qr(block.launchContext(), input, outputQ, outputR, fullMatricies);
helpers::qr(block.launchContext(), input, outputQ, outputR, fullMatricies);
return Status::OK();
}

View File

@ -35,12 +35,15 @@ namespace nd4j {
if (block.numB() > 0) {
useAdjoint = B_ARG(0);
}
REQUIRE_TRUE(shape::shapeEquals(a->rankOf() - 2, a->shapeInfo(), b->rankOf() - 2, b->shapeInfo()), 0, "solve: Input shapes should be alike.");
REQUIRE_TRUE(a->rankOf() >=2, 0, "solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf());
REQUIRE_TRUE(b->rankOf() >=2, 0, "solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf());
REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2));
if (a->isEmpty() || b->isEmpty() || z->isEmpty())
return Status::OK();
auto input = a;
if (useAdjoint) {
auto adjointA = a->ulike();

View File

@ -1044,6 +1044,48 @@ namespace nd4j {
DECLARE_CUSTOM_OP(logdet, 1, 1, false, 0, 0);
#endif
/**
* matrix_solve_ls op (lstsq) - solves one or more linear least-squares problems.
*
* input params:
* 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations
* 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
*
* float args:
* 0 - l2_regularizer (default 0. and only for 0 implemented)
*
* boolean args:
* 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies.
*
* return value:
* tensor with dimension (x * y * z * ::: * N * K) with solutions
*
*/
#if NOT_EXCLUDED(OP_lstsq)
DECLARE_CUSTOM_OP(lstsq, 2, 1, false, 0, 0);
#endif
/* solve_ls - analog of lstsq op with another solution approach
*
* input params:
* 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations
* 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
*
* float args:
* 0 - l2_regularizer (default 0. and only for 0 implemented)
*
* boolean args:
* 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies.
*
* return value:
* tensor with dimension (x * y * z * ::: * N * K) with solutions
*
* Note: if fast is false - then l2_regularizer arg is ignored and used lstsq method due QR decomposition
* */
#if NOT_EXCLUDED(OP_solve_ls)
DECLARE_CUSTOM_OP(solve_ls, 2, 1, false, 0, 0);
#endif
/**
* matrix_inverse op. - make inverse for all 2D square matricies found in the input tensor
*
@ -1073,7 +1115,7 @@ namespace nd4j {
*
*/
#if NOT_EXCLUDED(OP_triangular_solve)
DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0);
DECLARE_CUSTOM_OP(triangular_solve, 2, 1, false, 0, 0);
#endif
/**

View File

@ -0,0 +1,108 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit, K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#include <NDArray.h>
#include <execution/Threads.h>
#include <MmulHelper.h>
#include <ShapeUtils.h>
#include "../lup.h"
#include "../triangular_solve.h"
#include "../lstsq.h"
#include "../qr.h"
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static void fillRegularizer(NDArray& ioMatrix, double const value) {
auto lastDims = ioMatrix.allTensorsAlongDimension({-2, -1});
auto rows = ioMatrix.sizeAt(-2);
//auto cols = ioMatrix.sizeAt(-1);
for (auto x = 0; x < lastDims.size(); x++) {
for (auto r = 0; r < rows; r++) {
lastDims[x]->t<T>(r,r) = (T)value;
}
}
}
template <typename T>
int leastSquaresSolveFunctor_(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) {
NDArray::preparePrimaryUse({output}, {leftInput, rightInput});
if (fast) { // Cholesky decomposition approach
// Equation for solve A^T * Ax = A^T * b, so
// 1. Computing A2:
auto tAtShape = ShapeUtils::evalShapeForMatmul(leftInput->getShapeInfo(), leftInput->getShapeInfo(), true, false);
//tAtShape[tAtShape.size() - 2] = output->sizeAt(-2);
NDArray leftOutput('c', tAtShape, output->dataType(), context);
MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, false); // Computing A2 = A^T * A
// 2. Computing B' = A^T * b
auto rightOutput = output->ulike();
MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, false); // Computing B' = A^T * b
// 3. due l2Regularizer = 0, skip regularization ( indeed A' = A2 - l2Regularizer * I)
auto regularizer = leftOutput.ulike();
fillRegularizer<T>(regularizer, l2Regularizer);https://mangapark.net/
// regularizer *= l2Regularizer;
leftOutput += regularizer;
// 4. Cholesky decomposition -- output matrix is square and lower triangular
// auto leftOutputT = leftOutput.ulike();
auto err = helpers::cholesky(context, &leftOutput, &leftOutput, true); // inplace decomposition
if (err) return err;
// alternate moment: inverse lower triangular matrix to solve equation A'x = b' => L^Tx = L^-1 * b'
// solve one upper triangular system (to avoid float problems)
// 5. Solve two triangular systems:
auto rightB = rightOutput.ulike();
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, false, &rightB);
helpers::adjointMatrix(context, &leftOutput, true, &leftOutput); //.transposei();
helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, output);
// All done
}
else { // QR decomposition approach
// Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal matrix, and R - upper triangular
// 1. QR decomposition
auto qShape = leftInput->getShapeAsVector();
auto rShape = leftInput->getShapeAsVector();
qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2);
NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), context);// = leftInput->ulike();
NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), context); // = rightInput->ulike();
helpers::qr(context, leftInput, &Q, &R, true);
// 2. b` = Q^t * b:
auto rightOutput = rightInput->ulike();
MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false);
// 3. Solve triangular system
helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, output);
}
NDArray::registerPrimaryUse({output}, {leftInput, rightInput});
return Status::OK();
}
int leastSquaresSolveFunctor(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) {
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return leastSquaresSolveFunctor_, (context, leftInput, rightInput, l2Regularizer, fast, output), FLOAT_TYPES);
}
}
}
}

View File

@ -65,24 +65,30 @@ namespace helpers {
template <typename T>
static void invertLowerMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
int n = inputMatrix->rows();
invertedMatrix->assign(0.f);
// PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
for (int i = 0; i < n; i++)
invertedMatrix->p(i, i, 1.0f);
invertedMatrix->setIdentity();
if (inputMatrix->isIdentityMatrix()) return;
//PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
for (int i = 1; i < n; i++)
invertedMatrix->t<T>(i, i - 1) = -inputMatrix->t<T>(i, i - 1);
auto invertDiagonals = PRAGMA_THREADS_FOR {
for (int i = start; i < stop; i += increment)
invertedMatrix->t<T>(i, i) /= inputMatrix->t<T>(i, i);
};
//PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int i = 2; i < n; i++) {
for (int j = i - 2; j > -1; --j)
auto invertSubDiagonals = PRAGMA_THREADS_FOR {
for (int i = start; i < stop; i += increment)
invertedMatrix->t<T>(i, i - 1) -= (inputMatrix->t<T>(i, i - 1) * invertedMatrix->t<T>(i - 1, i - 1) / inputMatrix->t<T>(i, i));
};
samediff::Threads::parallel_for(invertDiagonals, 0, n, 1);
samediff::Threads::parallel_for(invertSubDiagonals, 1, n, 1);
// PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int i = 1; i < n; i++) {
for (int j = 0; j < i - 1 ; j++)
for (int k = 0; k < i; k++)
invertedMatrix->t<T>(i, j) -= (invertedMatrix->t<T>(k, j) * inputMatrix->t<T>(i, k));
invertedMatrix->t<T>(i, j) -= ((invertedMatrix->t<T>(k, j) * inputMatrix->t<T>(i, k) / inputMatrix->t<T>(i, i)));
}
}
BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
@ -100,18 +106,25 @@ namespace helpers {
return;
}
//PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
for (int i = 0; i < n; i++)
invertedMatrix->t<T>(i, i) /= inputMatrix->t<T>(i, i);
auto invertDiagonals = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment)
invertedMatrix->t<T>(i, i) /= inputMatrix->t<T>(i, i);
};
//PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
for (int i = 0; i < n - 1; i++)
invertedMatrix->t<T>(i, i + 1) -= (inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i));
auto invertUpDiagonals = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment)
invertedMatrix->t<T>(i, i + 1) -= (inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) /
inputMatrix->t<T>(i, i));
};
samediff::Threads::parallel_for(invertDiagonals, 0, n, 1);
samediff::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1);
// PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int i = n - 2; i > - 1; i--) {
for (int j = i + 2; j < n; j++)
for (int k = i; k < n; k++)
for (auto i = n - 2; i >= 0; i--) {
for (auto j = i + 2; j < n; j++)
for (auto k = i; k < n; k++)
invertedMatrix->t<T>(i, j) -= ((invertedMatrix->t<T>(k, j) * inputMatrix->t<T>(i, k) / inputMatrix->t<T>(i, i)));
}
}
@ -420,10 +433,81 @@ template <typename T>
return Status::OK();
}
template <typename T>
static int lowerInverse_(LaunchContext *context, NDArray* input, NDArray* output) {
auto n = input->sizeAt(-1);
auto n2 = n * n;
auto totalCount = output->lengthOf() / n2;
output->assign(0.f); // fill up output tensor with zeros
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
// auto batchLoop = PRAGMA_THREADS_FOR {
for (int e = 0; e < totalCount; e++) {
if (e)
matrix.assign(0.f);
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
matrix.p(row++, input->e<T>(k));
}
T det = T(1.f);
for (auto i = 0; i < n; i++) {
det *= matrix. template t<T>(i, i);
}
// FIXME: and how this is going to work on float16?
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
nd4j_printf("matrix_inverse: The matrix %i has no inverse due determinant is %lf. Quiting...\n", e, det);
matrix.printIndexedBuffer("Wrong matrix");
return ND4J_STATUS_VALIDATION;
}
lowerMatrix.nullify();
invertLowerMatrix(&matrix, &lowerMatrix);
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
output->t<T>(k) = lowerMatrix.template t<T>(row++);
}
}
return Status::OK();
}
template <typename T>
static int upperInverse_(LaunchContext *context, NDArray* input, NDArray* output) {
auto n = input->sizeAt(-1);
auto n2 = n * n;
output->nullify(); // fill up output tensor with zeros
// auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
// auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
// auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
auto inputPart = input->allTensorsAlongDimension({-2, -1});
auto outputPart = output->allTensorsAlongDimension({-2, -1});
auto totalCount = outputPart.size(); //lengthOf() / n2;
for (int e = 0; e < totalCount; e++) {
invertUpperMatrix(inputPart.at(e), outputPart.at(e));
}
return Status::OK();
}
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
}
int lowerInverseFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return lowerInverse_, (context, input, output), FLOAT_TYPES);
}
int upperInverseFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return upperInverse_, (context, input, output), FLOAT_TYPES);
}
template <typename T>
static bool checkCholeskyInput_(nd4j::LaunchContext * context, NDArray const* input) {
//std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType())); //, block.getWorkspace());

View File

@ -106,7 +106,7 @@ namespace helpers {
}
template <typename T>
void qr_(NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
void qr_(NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
Nd4jLong lastDim = input->rankOf() - 1;
Nd4jLong preLastDim = input->rankOf() - 2;
ResultSet listOutQ(outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}));
@ -123,7 +123,7 @@ namespace helpers {
}
void qr(nd4j::LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
void qr(nd4j::LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (input, outputQ, outputR, fullMatricies), FLOAT_TYPES);
}

View File

@ -0,0 +1,115 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit, K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#include <NDArray.h>
#include <MmulHelper.h>
#include <ShapeUtils.h>
#include <ConstantTadHelper.h>
#include "../triangular_solve.h"
#include "../lup.h"
#include "../qr.h"
#include "../lstsq.h"
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static __global__ void fillRegularizerKernel(T* ioMatrixData, Nd4jLong* ioMatrixShape, Nd4jLong* ioMatrixTads, Nd4jLong* ioMatrixOffsets, Nd4jLong batchSize, Nd4jLong rows, T const value) {
for (auto x = blockIdx.x; x < batchSize; x += gridDim.x) {
auto z = ioMatrixData + ioMatrixOffsets[x];
for (auto r = threadIdx.x; r < rows; r += blockDim.x) {
Nd4jLong pos[] = {r,r};
auto zIndex = shape::getOffset(ioMatrixTads, pos);
z[zIndex] = value;
}
}
}
template <typename T>
static void fillRegularizer(nd4j::LaunchContext* context, NDArray& ioMatrix, double const value) {
auto lastDimsTads = ConstantTadHelper::getInstance()->tadForDimensions(ioMatrix.shapeInfo(), {-2, -1});
auto stream = context->getCudaStream();
auto rows = ioMatrix.sizeAt(-2);
//auto cols = ioMatrix.sizeAt(-1);
fillRegularizerKernel<T><<<256, 256, 128, *stream>>>(ioMatrix.dataBuffer()->specialAsT<T>(), ioMatrix.specialShapeInfo(), lastDimsTads.specialShapeInfo(), lastDimsTads.specialOffsets(), lastDimsTads.numberOfTads(), rows, (T)value);
}
template <typename T>
int leastSquaresSolveFunctor_(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) {
if (fast) { // Cholesky decomposition approach
// Equation for solve A^T * Ax = A^T * b, so
// 1. Computing A2:
auto tAtShape = ShapeUtils::evalShapeForMatmul(leftInput->getShapeInfo(), leftInput->getShapeInfo(), true, false);
//tAtShape[tAtShape.size() - 2] = output->sizeAt(-2);
NDArray leftOutput(leftInput->ordering(), tAtShape, output->dataType(), context);
MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, false); // Computing A2 = A^T * A
// 2. Computing B' = A^T * b
auto rightOutput = output->ulike();
MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, false); // Computing B' = A^T * b
// 3. Regularization ( indeed A' = A2 - l2Regularizer * I)
if (l2Regularizer != 0.0) {
auto regularizer = leftOutput.ulike(); regularizer.nullify();
fillRegularizer<T>(context, regularizer, (T)l2Regularizer);
leftOutput += regularizer;
}
// 4. Cholesky decomposition -- output matrix is square and lower triangular
helpers::cholesky(context, &leftOutput, &leftOutput, true); // inplace decomposition
// 5. Solve two triangular systems:
auto rightB = rightOutput.ulike(); rightB.nullify();
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, false, &rightB);
helpers::adjointMatrix(context, &leftOutput, true, &leftOutput);
helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, output);
// All done
}
else { // QR decomposition approach
// Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal matrix, and R - upper triangular
// 1. QR decomposition
auto qShape = leftInput->getShapeAsVector();
auto rShape = leftInput->getShapeAsVector();
qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2);
NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), context);// = leftInput->ulike();
NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), context); // = rightInput->ulike();
helpers::qr(context, leftInput, &Q, &R, true);
// 2. b` = Q^t * b:
auto rightOutput = rightInput->ulike();
MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false);
// 3. Solve triangular system
helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, output);
}
return Status::OK();
}
int leastSquaresSolveFunctor(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) {
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return leastSquaresSolveFunctor_, (context, leftInput, rightInput, l2Regularizer, fast, output), FLOAT_TYPES);
}
}
}
}

View File

@ -151,7 +151,7 @@ namespace helpers {
}
template <typename T>
void qr_(LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
void qr_(LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
Nd4jLong lastDim = input->rankOf() - 1;
Nd4jLong preLastDim = input->rankOf() - 2;
@ -170,7 +170,7 @@ namespace helpers {
NDArray::registerSpecialUse({outputQ, outputR}, {input});
}
void qr(nd4j::LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
void qr(nd4j::LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (context, input, outputQ, outputR, fullMatricies), FLOAT_TYPES);
}

View File

@ -0,0 +1,33 @@
/*******************************************************************************
* Copyright (c) Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#ifndef __LST_SQ_SOLVE__H_HELPERS__
#define __LST_SQ_SOLVE__H_HELPERS__
#include <op_boilerplate.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
int leastSquaresSolveFunctor(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output);
}
}
}
#endif

View File

@ -32,6 +32,8 @@ namespace helpers {
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
int upperInverseFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* output);
int lowerInverseFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* output);
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input);
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace = false);

View File

@ -26,7 +26,7 @@ namespace nd4j {
namespace ops {
namespace helpers {
void qr(nd4j::LaunchContext * context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies);
void qr(nd4j::LaunchContext * context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies);
}

View File

@ -24,6 +24,7 @@
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
#include <helpers/MmulHelper.h>
using namespace nd4j;
@ -1918,8 +1919,8 @@ TEST_F(DeclarableOpsTests11, Solve_Test_4_6) {
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
z->printBuffer("4_6 Solve 3x3");
exp.printBuffer("4_6 Expec 3x3");
// z->printBuffer("4_6 Solve 3x3");
// exp.printBuffer("4_6 Expec 3x3");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
@ -1955,8 +1956,8 @@ TEST_F(DeclarableOpsTests11, Solve_Test_4_7) {
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
z->printBuffer("4_7 Solve 3x3");
exp.printBuffer("4_7 Expec 3x3");
// z->printBuffer("4_7 Solve 3x3");
// exp.printBuffer("4_7 Expec 3x3");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
@ -1989,12 +1990,127 @@ TEST_F(DeclarableOpsTests11, Solve_Test_5) {
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
z->printBuffer("4 Solve 4x4");
exp.printBuffer("4 Expec 4x4");
// z->printBuffer("4 Solve 4x4");
// exp.printBuffer("4 Expec 4x4");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, SolveLS_Test_1) {
auto a = NDArrayFactory::create<double>('c', {2,2, 2}, {
1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f
});
auto b = NDArrayFactory::create<double>('c', {2, 2, 1}, {
3.f, 7.f, 11.f, 15.f
});
auto exp = NDArrayFactory::create<double>('c', {2, 2, 1}, {
0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f
});
nd4j::ops::lstsq op;
auto res = op.evaluate({&a, &b}, {0.5}, {}, {true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("LS Solve 2x2");
// exp.printIndexedBuffer("LS Expec 2x2");
ASSERT_TRUE(exp.equalsTo(z, 1.e-4));
delete res;
}
TEST_F(DeclarableOpsTests11, SolveLS_Test_2) {
auto a = NDArrayFactory::create<float>('c', {2,2, 2}, {
1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f
});
auto b = NDArrayFactory::create<float>('c', {2, 2, 1}, {
3.f, 7.f, 11.f, 15.f
});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 1}, {
0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f
});
nd4j::ops::lstsq op;
auto res = op.evaluate({&a, &b}, {0.5}, {}, {true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("2LS Solve 2x2");
// exp.printIndexedBuffer("2LS Expec 2x2");
ASSERT_TRUE(exp.equalsTo(z, 1.e-4));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2) {
auto a = NDArrayFactory::create<float>('c', {2,2, 2}, {
10.f, 14.f,
14.f, 20.f,
74.f, 86.f,
86.f, 100.f
});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
3.1622777f, 0.f, 4.427189f, 0.6324552f,
8.602325f, 0.f, 9.997296f, 0.23252854f
});
nd4j::ops::cholesky op;
auto res = op.evaluate({&a});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
z->printIndexedBuffer("L matrix is");
exp.printIndexedBuffer("L expected is");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2_2) {
auto a = NDArrayFactory::create<float>('c', {2,2, 2}, {
10.5f, 14.f,
14.f, 20.5f,
74.5f, 86.f,
86.f, 100.5f
});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
3.2403703f, 0.f, 4.3204937f, 1.3540066f,
8.631338f, 0.f, 9.963693f, 1.1067207f
});
nd4j::ops::cholesky op;
auto res = op.evaluate({&a});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("L matrix is");
// exp.printIndexedBuffer("L expected is");
MmulHelper::matmul(z, z, &exp, false, true);
ASSERT_TRUE(exp.equalsTo(a));
delete res;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {

View File

@ -26,6 +26,7 @@
#include <GradCheck.h>
#include <ConstantTadHelper.h>
#include <helpers/PointersManager.h>
#include <helpers/MmulHelper.h>
using namespace nd4j;
@ -2936,7 +2937,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_4) {
nd4j::ops::triangular_solve op;
auto res = op.evaluate({&a, &b}, {}, {}, {false});
auto res = op.evaluate({&a, &b}, {false});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
@ -2966,7 +2967,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) {
nd4j::ops::triangular_solve op;
auto res = op.evaluate({&a, &b}, {}, {}, {false, true});
auto res = op.evaluate({&a, &b}, {false, true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
@ -2977,6 +2978,142 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) {
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, SolveLs_Test_1) {
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
3.f, 0.f, 0.f, 0.f,
2.f, 1.f, 0.f, 0.f,
1.f, 0.f, 1.f, 0.f,
1.f, 1.f, 1.f, 1.f
});
auto b = NDArrayFactory::create<float>('c', {4, 1}, {
4.f, 2.f, 4.f, 2.f
});
auto exp = NDArrayFactory::create<float>('c', {4, 1}, {
1.333333f, -0.6666667f, 2.6666667f, -1.3333333f });
nd4j::ops::lstsq op;
auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("MatrixSolveLS");
MmulHelper::matmul(&a, z, &exp, false, false);
ASSERT_TRUE(exp.equalsTo(b));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, SolveLs_Test_2) {
auto a = NDArrayFactory::create<double>('c', {3, 3}, {
1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 11.f, 8.f, 21.f
});
auto b = NDArrayFactory::create<double>('c', {3, 1}, { 1.f, 2.f, 3.f });
auto exp = NDArrayFactory::create<double>('c', {3, 1}, { -0.24999914f, 0.4999994f, 0.08333314f });
nd4j::ops::lstsq op;
auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
MmulHelper::matmul(&a, z, &exp, false, false);
// z->printIndexedBuffer("MatrixSolveLS2");
ASSERT_TRUE(exp.equalsTo(b));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, SolveLs_Test_3) {
auto a = NDArrayFactory::create<float>('c', {3, 4}, {
1.f,1.f,0.f,0.f,-1.f,1.f,0.f,0.f,1.f,1.f,-1.f,-1.f
});
auto b = NDArrayFactory::create<float>('c', {3, 1}, { 1.f, 2.f, 3.f });
auto exp = NDArrayFactory::create<float>('c', {3, 1}, { -0.5f, 1.5f, -2.f });
nd4j::ops::lstsq op;
auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("MatrixSolveLS3");
MmulHelper::matmul(&a, z, &exp, false, false);
ASSERT_TRUE(exp.equalsTo(b));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, SolveLs_Test_4) {
auto a = NDArrayFactory::create<float>('c', {3, 4}, {
1.f,1.f,0.f,0.f,-1.f,1.f,0.f,0.f,1.f,1.f,-1.f,-1.f
});
auto b = NDArrayFactory::create<float>('c', {3, 1}, { 1.f, 2.f, 3.f });
auto exp = NDArrayFactory::create<float>('c', {4, 1}, { -0.5f, 1.5f, -2.f, 0.f});
nd4j::ops::lstsq op;
auto res = op.evaluate({&a, &b}, {false});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("Output_12.4");
// z->printShapeInfo("Output_12.4 shape");
// MmulHelper::matmul(&a, z, &exp, false, false);
// z->printIndexedBuffer("MatrixSolveLS4");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, SolveLs_Test_5) {
auto a = NDArrayFactory::create<float>('c', {1, 0, 3, 4});
auto b = NDArrayFactory::create<float>('c', {1, 0, 3, 1});
nd4j::ops::lstsq op;
auto res = op.evaluate({&a, &b}, {false});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
ASSERT_TRUE(z->isEmpty());
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, Solve_Test_6) {
auto a = NDArrayFactory::create<float>('c', {1, 0, 3, 3});
auto b = NDArrayFactory::create<float>('c', {1, 0, 3, 1});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b}, {true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
ASSERT_TRUE(z->isEmpty());
delete res;
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) {
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
@ -3004,4 +3141,4 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) {
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
}

View File

@ -623,7 +623,8 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.Digamma.class,
org.nd4j.linalg.api.ops.custom.Lu.class,
org.nd4j.linalg.api.ops.custom.TriangularSolve.class,
org.nd4j.linalg.api.ops.custom.LinearSolve.class
org.nd4j.linalg.api.ops.custom.LinearSolve.class,
org.nd4j.linalg.api.ops.custom.Lstsq.class
);
static {

View File

@ -0,0 +1,40 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@NoArgsConstructor
public class Lstsq extends DynamicCustomOp {
public Lstsq(@NonNull INDArray matrix, @NonNull INDArray rhs, double l2_regularizer, boolean fast) {
addInputArgument(matrix, rhs);
addTArgument(l2_regularizer);
addBArgument(fast);
}
public Lstsq(@NonNull INDArray matrix, @NonNull INDArray rhs) {
this(matrix, rhs, 0.0, true);
}
@Override
public String opName() {
return "lstsq";
}
}

View File

@ -34,19 +34,20 @@ public class TriangularSolve extends DynamicCustomOp {
addBArgument(lower, adjoint);
}
@Override
public String opName() {
return "triangular_solve";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("adjoint")){
addBArgument(attributesForNode.get("adjoint").getB());
}
if(attributesForNode.containsKey("lower")){
addBArgument(attributesForNode.get("lower").getB());
}
if(attributesForNode.containsKey("adjoint")){
addBArgument(attributesForNode.get("adjoint").getB());
}
}
@Override
public String opName() {
return "triangular_solve";
}
@Override

View File

@ -67,10 +67,9 @@ public class SequenceMask extends DynamicCustomOp {
public SequenceMask(INDArray input, int maxLen, DataType dataType) {
addInputArgument(input);
addIArgument(maxLen);
//addIArgument(dataType.toInt());
addDArgument(dataType);
this.dataType = dataType;
}
addDArgument(dataType);
}
@Override

View File

@ -123,12 +123,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
//AB 2020/01/07 - Known issues
"bitcast/from_float64_to_int64",
"bitcast/from_rank2_float64_to_int64",
"bitcast/from_float64_to_uint64",
// 2020/02/14 - new ops which are not passing yet
"linear_solve/.*",
"triangular_solve/.*",
"lstsq/.*"
"bitcast/from_float64_to_uint64"
};
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have

View File

@ -1739,15 +1739,37 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(expected, ret[0]);
}
@Test
public void testLstsq() {
INDArray a = Nd4j.createFromArray(new float[]{
1.f, 2.f, 3.f,
4.f, 5.f, 6.f,
11.f, 8.f, 21.f
}).reshape(3,3);
INDArray b = Nd4j.createFromArray(new float[]{ 1.f, 2.f, 3.f }).reshape(3,1);
val op = new Lstsq(a,b);
INDArray[] ret = Nd4j.exec(op);
DynamicCustomOp matmul = DynamicCustomOp.builder("matmul")
.addInputs(a, ret[0])
.build();
INDArray[] matres = Nd4j.exec(matmul);
for (int i = 0; i < 3; ++i) {
assertEquals(b.getFloat(i, 0), matres[0].getFloat(i, 0), 1e-4);
}
}
@Test
public void testSequenceMask() {
INDArray arr = Nd4j.createFromArray(new int[]{1, 3, 2});
// Test with static max len
int maxlen = 2;
INDArray expected = Nd4j.createFromArray(new int[]{
1,0,0,
1,1,1,
1,1,0
1, 0, 0,
1, 1, 1,
1, 1, 0
}).reshape(3, 3);
INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.INT32));

View File

@ -76,13 +76,13 @@ public class RngValidationTests extends BaseNd4jTest {
@Builder.Default private double stdRelativeErrorTolerance = 0.01;
private Double meanMinAbsErrorTolerance; //Consider relative error between 0 and 0.001: relative error is 1.0, but absolute error is small
private Double stdMinAbsErrorTolerance;
@Builder.Default private Map<String,Object> args = new LinkedHashMap<>();
@Builder.Default private static Map<String,Object> args = new LinkedHashMap<>();
public static class TestCaseBuilder {
public TestCaseBuilder arg(String arg, Object value){
if(args == null) {
args(new LinkedHashMap<>());
args = new LinkedHashMap<>();
}
args.put(arg, value);
return this;

View File

@ -335,7 +335,7 @@
<jackson.databind.version>2.10.1</jackson.databind.version>
<shaded.snakeyaml.version>1.24</shaded.snakeyaml.version>
<geo.jackson.version>2.8.7</geo.jackson.version>
<lombok.version>1.18.2</lombok.version>
<lombok.version>1.18.12</lombok.version>
<cleartk.version>2.0.0</cleartk.version>
<lucene-solr.version>7.7.1</lucene-solr.version>
<json.version>20131018</json.version>