Shugeo solve ls (#203)
* lstsq op. Initial commit. Signed-off-by: shugeo <sgazeos@gmail.com> * Least squares linear problem solve op (lstsq). Cpu draft implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed shape routine and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for lstsq op. Signed-off-by: shugeo <sgazeos@gmail.com> * Rectification for lstsq op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected test to avoid numerical inconsistensy. Signed-off-by: shugeo <sgazeos@gmail.com> * Added prints for check computing. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected tests to use evalueate facility instead. Signed-off-by: shugeo <sgazeos@gmail.com> * CPU implementation of MatrixSolveLs op and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Added cuda implementation for helpers with lstsq op. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tests for lstsq op. Signed-off-by: shugeo <sgazeos@gmail.com> * Added processing for empty inputs. Signed-off-by: shugeo <sgazeos@gmail.com> * Merged tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored lstsq op for fast case. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed test. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored lstsq op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed some issues with solve. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed lstsq op to avoid erros. Signed-off-by: shugeo <sgazeos@gmail.com> * Added kernel for giagonal factor Signed-off-by: shugeo <sgazeos@gmail.com> * lstsq wrapper and triangular_solve fixed * Added proper processing empty inputs and test. Signed-off-by: shugeo <sgazeos@gmail.com> * SequenceMask test * Build fixed * Added proper processing of empty inputs with solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Mapping added * Added check of input shapes with solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a couple of tests for lstsq op and minor changes with cuda helper for one.' Signed-off-by: shugeo <sgazeos@gmail.com> * Tests on * Refactored test for lstsq op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed test * Added another approach for lstsq op aka solve_ls. Signed-off-by: shugeo <sgazeos@gmail.com> * Finished cpu part for solve_ls op helpers. * Added helper for low triangular matrix inversion. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored alternate solve_ls cpu implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Removed alternate approach for solve_ls op. Added multithreading with matrix inversion. Signed-off-by: shugeo <sgazeos@gmail.com> * Assert fixed * Refactored multithreading for inverse matricies. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
358c650b62
commit
330a69d4e2
|
@ -310,7 +310,7 @@ namespace nd4j {
|
|||
* This method returns new uninitialized array with the same shape & data type
|
||||
* @return
|
||||
*/
|
||||
NDArray ulike();
|
||||
NDArray ulike() const;
|
||||
|
||||
|
||||
/**
|
||||
|
|
|
@ -4725,7 +4725,7 @@ NDArray NDArray::like() {
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::ulike() {
|
||||
NDArray NDArray::ulike() const{
|
||||
|
||||
return NDArray(this, false, getContext());
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(input->rankOf() >=2, 0, "cholesky: The rank of input array should not less than 2, but %i is given", input->rankOf());
|
||||
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "cholesky: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
|
||||
REQUIRE_TRUE(helpers::checkCholeskyInput(block.launchContext(), input), 0, "cholesky: The input tensor should be positive-defined and symmetric.");
|
||||
return helpers::cholesky(block.launchContext(), input, output);
|
||||
return helpers::cholesky(block.launchContext(), input, output, block.isInplace());
|
||||
}
|
||||
DECLARE_TYPES(cholesky) {
|
||||
getOpDescriptor()
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit, K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by GS <sgazeos@gmail.com> at 01/28/2020
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_lstsq)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/lstsq.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(lstsq, 2, 1, false, 0, 0) {
|
||||
auto a = INPUT_VARIABLE(0);
|
||||
auto b = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
bool fastFlag = true;
|
||||
double l2_factor = 0.;
|
||||
if (block.numB() > 0) {
|
||||
fastFlag = B_ARG(0);
|
||||
}
|
||||
if (block.numT() > 0) {
|
||||
l2_factor = T_ARG(0);
|
||||
}
|
||||
REQUIRE_TRUE(a->rankOf() >=2, 0, "lstsq: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf());
|
||||
REQUIRE_TRUE(b->rankOf() >=2, 0, "lstsq: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf());
|
||||
|
||||
// REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
|
||||
REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, "lstsq: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2));
|
||||
//REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not finished for factor difference from 0.");
|
||||
if (a->isEmpty() || b->isEmpty() || z->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
auto res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, l2_factor, fastFlag, z);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
CUSTOM_OP_IMPL(solve_ls, 2, 1, false, 0, 0) {
|
||||
auto a = INPUT_VARIABLE(0);
|
||||
auto b = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
bool fastFlag = true;
|
||||
double l2_factor = 0.;
|
||||
if (block.numB() > 0) {
|
||||
fastFlag = B_ARG(0);
|
||||
}
|
||||
if (block.numT() > 0) {
|
||||
l2_factor = T_ARG(0);
|
||||
}
|
||||
REQUIRE_TRUE(a->rankOf() >=2, 0, "lstsq: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf());
|
||||
REQUIRE_TRUE(b->rankOf() >=2, 0, "lstsq: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf());
|
||||
|
||||
// REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
|
||||
REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, "lstsq: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2));
|
||||
//REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not finished for factor difference from 0.");
|
||||
auto res = Status::OK();
|
||||
if (a->isEmpty() || b->isEmpty() || z->isEmpty())
|
||||
return res;
|
||||
|
||||
res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, l2_factor, fastFlag, z);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
DECLARE_SYN(MatrixSolveLs, lstsq);
|
||||
|
||||
DECLARE_SHAPE_FN(lstsq) {
|
||||
auto in0 = inputShape->at(0);
|
||||
auto in1 = inputShape->at(1);
|
||||
auto shapeOf = ShapeUtils::shapeAsVector(in1);
|
||||
auto rank = shapeOf.size();
|
||||
shapeOf[rank - 2] = shape::sizeAt(in0, -1);
|
||||
|
||||
if (shape::isEmpty(in0) || shape::isEmpty(in1)) {
|
||||
shapeOf[rank - 1] = 0; // set output shape to empty
|
||||
}
|
||||
auto resShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in0), shape::order(in1), shapeOf);//ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace());
|
||||
if (shapeOf[rank - 1] == 0) {
|
||||
ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY);
|
||||
}
|
||||
return SHAPELIST(resShape);
|
||||
}
|
||||
|
||||
DECLARE_TYPES(lstsq) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS})
|
||||
->setSameMode(false);
|
||||
}
|
||||
DECLARE_SHAPE_FN(solve_ls) {
|
||||
auto in0 = inputShape->at(0);
|
||||
auto in1 = inputShape->at(1);
|
||||
auto shapeOf = ShapeUtils::shapeAsVector(in1);
|
||||
auto rank = shapeOf.size();
|
||||
shapeOf[rank - 2] = shape::sizeAt(in0, -1);
|
||||
|
||||
if (shape::isEmpty(in0) || shape::isEmpty(in1)) {
|
||||
shapeOf[rank - 1] = 0; // set output shape to empty
|
||||
}
|
||||
auto resShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in0), shape::order(in1), shapeOf);//ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace());
|
||||
if (shapeOf[rank - 1] == 0) {
|
||||
ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY);
|
||||
}
|
||||
return SHAPELIST(resShape);
|
||||
}
|
||||
|
||||
DECLARE_TYPES(solve_ls) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS})
|
||||
->setSameMode(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -35,8 +35,9 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(input->rankOf() >=2, 0, "qr: The rank of input array should not be less than 2, but %i is given", input->rankOf());
|
||||
REQUIRE_TRUE((fullMatricies && outputQ->sizeAt(-1) == input->sizeAt(-2)) || (!fullMatricies && outputQ->isSameShape(input)), 0, "qr: The last dimmensions should be equal to result Q, but %i and %i are given", outputQ->sizeAt(-1), input->sizeAt(-2));
|
||||
REQUIRE_TRUE((fullMatricies && outputR->sizeAt(-1) == input->sizeAt(-1)) || (!fullMatricies && outputR->sizeAt(-1) == outputR->sizeAt(-2)), 0, "qr: The last dimmensions should be equal to result R, but %i and %i are given", outputR->sizeAt(-1), input->sizeAt(-1));
|
||||
if (!input->isEmpty() && !outputQ->isEmpty() && !outputR->isEmpty())
|
||||
helpers::qr(block.launchContext(), input, outputQ, outputR, fullMatricies);
|
||||
|
||||
helpers::qr(block.launchContext(), input, outputQ, outputR, fullMatricies);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -35,12 +35,15 @@ namespace nd4j {
|
|||
if (block.numB() > 0) {
|
||||
useAdjoint = B_ARG(0);
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(shape::shapeEquals(a->rankOf() - 2, a->shapeInfo(), b->rankOf() - 2, b->shapeInfo()), 0, "solve: Input shapes should be alike.");
|
||||
REQUIRE_TRUE(a->rankOf() >=2, 0, "solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf());
|
||||
REQUIRE_TRUE(b->rankOf() >=2, 0, "solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf());
|
||||
|
||||
REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
|
||||
REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2));
|
||||
if (a->isEmpty() || b->isEmpty() || z->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
auto input = a;
|
||||
if (useAdjoint) {
|
||||
auto adjointA = a->ulike();
|
||||
|
|
|
@ -1044,6 +1044,48 @@ namespace nd4j {
|
|||
DECLARE_CUSTOM_OP(logdet, 1, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* matrix_solve_ls op (lstsq) - solves one or more linear least-squares problems.
|
||||
*
|
||||
* input params:
|
||||
* 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations
|
||||
* 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
|
||||
*
|
||||
* float args:
|
||||
* 0 - l2_regularizer (default 0. and only for 0 implemented)
|
||||
*
|
||||
* boolean args:
|
||||
* 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies.
|
||||
*
|
||||
* return value:
|
||||
* tensor with dimension (x * y * z * ::: * N * K) with solutions
|
||||
*
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_lstsq)
|
||||
DECLARE_CUSTOM_OP(lstsq, 2, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/* solve_ls - analog of lstsq op with another solution approach
|
||||
*
|
||||
* input params:
|
||||
* 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations
|
||||
* 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
|
||||
*
|
||||
* float args:
|
||||
* 0 - l2_regularizer (default 0. and only for 0 implemented)
|
||||
*
|
||||
* boolean args:
|
||||
* 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies.
|
||||
*
|
||||
* return value:
|
||||
* tensor with dimension (x * y * z * ::: * N * K) with solutions
|
||||
*
|
||||
* Note: if fast is false - then l2_regularizer arg is ignored and used lstsq method due QR decomposition
|
||||
* */
|
||||
#if NOT_EXCLUDED(OP_solve_ls)
|
||||
DECLARE_CUSTOM_OP(solve_ls, 2, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* matrix_inverse op. - make inverse for all 2D square matricies found in the input tensor
|
||||
*
|
||||
|
@ -1073,7 +1115,7 @@ namespace nd4j {
|
|||
*
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_triangular_solve)
|
||||
DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0);
|
||||
DECLARE_CUSTOM_OP(triangular_solve, 2, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit, K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
#include <op_boilerplate.h>
|
||||
#include <NDArray.h>
|
||||
#include <execution/Threads.h>
|
||||
#include <MmulHelper.h>
|
||||
#include <ShapeUtils.h>
|
||||
|
||||
#include "../lup.h"
|
||||
#include "../triangular_solve.h"
|
||||
#include "../lstsq.h"
|
||||
#include "../qr.h"
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
static void fillRegularizer(NDArray& ioMatrix, double const value) {
|
||||
auto lastDims = ioMatrix.allTensorsAlongDimension({-2, -1});
|
||||
auto rows = ioMatrix.sizeAt(-2);
|
||||
//auto cols = ioMatrix.sizeAt(-1);
|
||||
|
||||
for (auto x = 0; x < lastDims.size(); x++) {
|
||||
for (auto r = 0; r < rows; r++) {
|
||||
lastDims[x]->t<T>(r,r) = (T)value;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int leastSquaresSolveFunctor_(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) {
|
||||
NDArray::preparePrimaryUse({output}, {leftInput, rightInput});
|
||||
if (fast) { // Cholesky decomposition approach
|
||||
// Equation for solve A^T * Ax = A^T * b, so
|
||||
// 1. Computing A2:
|
||||
auto tAtShape = ShapeUtils::evalShapeForMatmul(leftInput->getShapeInfo(), leftInput->getShapeInfo(), true, false);
|
||||
//tAtShape[tAtShape.size() - 2] = output->sizeAt(-2);
|
||||
NDArray leftOutput('c', tAtShape, output->dataType(), context);
|
||||
MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, false); // Computing A2 = A^T * A
|
||||
// 2. Computing B' = A^T * b
|
||||
auto rightOutput = output->ulike();
|
||||
|
||||
MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, false); // Computing B' = A^T * b
|
||||
// 3. due l2Regularizer = 0, skip regularization ( indeed A' = A2 - l2Regularizer * I)
|
||||
auto regularizer = leftOutput.ulike();
|
||||
fillRegularizer<T>(regularizer, l2Regularizer);https://mangapark.net/
|
||||
// regularizer *= l2Regularizer;
|
||||
leftOutput += regularizer;
|
||||
// 4. Cholesky decomposition -- output matrix is square and lower triangular
|
||||
// auto leftOutputT = leftOutput.ulike();
|
||||
auto err = helpers::cholesky(context, &leftOutput, &leftOutput, true); // inplace decomposition
|
||||
if (err) return err;
|
||||
// alternate moment: inverse lower triangular matrix to solve equation A'x = b' => L^Tx = L^-1 * b'
|
||||
// solve one upper triangular system (to avoid float problems)
|
||||
|
||||
// 5. Solve two triangular systems:
|
||||
auto rightB = rightOutput.ulike();
|
||||
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, false, &rightB);
|
||||
helpers::adjointMatrix(context, &leftOutput, true, &leftOutput); //.transposei();
|
||||
helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, output);
|
||||
// All done
|
||||
}
|
||||
else { // QR decomposition approach
|
||||
// Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal matrix, and R - upper triangular
|
||||
// 1. QR decomposition
|
||||
auto qShape = leftInput->getShapeAsVector();
|
||||
auto rShape = leftInput->getShapeAsVector();
|
||||
qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2);
|
||||
|
||||
NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), context);// = leftInput->ulike();
|
||||
NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), context); // = rightInput->ulike();
|
||||
helpers::qr(context, leftInput, &Q, &R, true);
|
||||
// 2. b` = Q^t * b:
|
||||
auto rightOutput = rightInput->ulike();
|
||||
MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false);
|
||||
// 3. Solve triangular system
|
||||
helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, output);
|
||||
}
|
||||
NDArray::registerPrimaryUse({output}, {leftInput, rightInput});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int leastSquaresSolveFunctor(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return leastSquaresSolveFunctor_, (context, leftInput, rightInput, l2Regularizer, fast, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -65,24 +65,30 @@ namespace helpers {
|
|||
template <typename T>
|
||||
static void invertLowerMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||
int n = inputMatrix->rows();
|
||||
invertedMatrix->assign(0.f);
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
|
||||
for (int i = 0; i < n; i++)
|
||||
invertedMatrix->p(i, i, 1.0f);
|
||||
invertedMatrix->setIdentity();
|
||||
|
||||
if (inputMatrix->isIdentityMatrix()) return;
|
||||
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
|
||||
for (int i = 1; i < n; i++)
|
||||
invertedMatrix->t<T>(i, i - 1) = -inputMatrix->t<T>(i, i - 1);
|
||||
auto invertDiagonals = PRAGMA_THREADS_FOR {
|
||||
for (int i = start; i < stop; i += increment)
|
||||
invertedMatrix->t<T>(i, i) /= inputMatrix->t<T>(i, i);
|
||||
};
|
||||
|
||||
//PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int i = 2; i < n; i++) {
|
||||
for (int j = i - 2; j > -1; --j)
|
||||
auto invertSubDiagonals = PRAGMA_THREADS_FOR {
|
||||
for (int i = start; i < stop; i += increment)
|
||||
invertedMatrix->t<T>(i, i - 1) -= (inputMatrix->t<T>(i, i - 1) * invertedMatrix->t<T>(i - 1, i - 1) / inputMatrix->t<T>(i, i));
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(invertDiagonals, 0, n, 1);
|
||||
samediff::Threads::parallel_for(invertSubDiagonals, 1, n, 1);
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int i = 1; i < n; i++) {
|
||||
for (int j = 0; j < i - 1 ; j++)
|
||||
for (int k = 0; k < i; k++)
|
||||
invertedMatrix->t<T>(i, j) -= (invertedMatrix->t<T>(k, j) * inputMatrix->t<T>(i, k));
|
||||
invertedMatrix->t<T>(i, j) -= ((invertedMatrix->t<T>(k, j) * inputMatrix->t<T>(i, k) / inputMatrix->t<T>(i, i)));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
|
||||
|
@ -100,18 +106,25 @@ namespace helpers {
|
|||
return;
|
||||
}
|
||||
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
|
||||
for (int i = 0; i < n; i++)
|
||||
invertedMatrix->t<T>(i, i) /= inputMatrix->t<T>(i, i);
|
||||
auto invertDiagonals = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment)
|
||||
invertedMatrix->t<T>(i, i) /= inputMatrix->t<T>(i, i);
|
||||
};
|
||||
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
|
||||
for (int i = 0; i < n - 1; i++)
|
||||
invertedMatrix->t<T>(i, i + 1) -= (inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i));
|
||||
auto invertUpDiagonals = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment)
|
||||
invertedMatrix->t<T>(i, i + 1) -= (inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) /
|
||||
inputMatrix->t<T>(i, i));
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(invertDiagonals, 0, n, 1);
|
||||
samediff::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1);
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int i = n - 2; i > - 1; i--) {
|
||||
for (int j = i + 2; j < n; j++)
|
||||
for (int k = i; k < n; k++)
|
||||
for (auto i = n - 2; i >= 0; i--) {
|
||||
for (auto j = i + 2; j < n; j++)
|
||||
for (auto k = i; k < n; k++)
|
||||
invertedMatrix->t<T>(i, j) -= ((invertedMatrix->t<T>(k, j) * inputMatrix->t<T>(i, k) / inputMatrix->t<T>(i, i)));
|
||||
}
|
||||
}
|
||||
|
@ -420,10 +433,81 @@ template <typename T>
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int lowerInverse_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||
|
||||
auto n = input->sizeAt(-1);
|
||||
auto n2 = n * n;
|
||||
auto totalCount = output->lengthOf() / n2;
|
||||
|
||||
output->assign(0.f); // fill up output tensor with zeros
|
||||
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||
|
||||
// auto batchLoop = PRAGMA_THREADS_FOR {
|
||||
for (int e = 0; e < totalCount; e++) {
|
||||
if (e)
|
||||
matrix.assign(0.f);
|
||||
|
||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||
matrix.p(row++, input->e<T>(k));
|
||||
}
|
||||
T det = T(1.f);
|
||||
for (auto i = 0; i < n; i++) {
|
||||
det *= matrix. template t<T>(i, i);
|
||||
}
|
||||
|
||||
// FIXME: and how this is going to work on float16?
|
||||
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
|
||||
nd4j_printf("matrix_inverse: The matrix %i has no inverse due determinant is %lf. Quiting...\n", e, det);
|
||||
matrix.printIndexedBuffer("Wrong matrix");
|
||||
return ND4J_STATUS_VALIDATION;
|
||||
}
|
||||
lowerMatrix.nullify();
|
||||
invertLowerMatrix(&matrix, &lowerMatrix);
|
||||
|
||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||
output->t<T>(k) = lowerMatrix.template t<T>(row++);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int upperInverse_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||
|
||||
auto n = input->sizeAt(-1);
|
||||
auto n2 = n * n;
|
||||
|
||||
output->nullify(); // fill up output tensor with zeros
|
||||
// auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||
// auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||
// auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
||||
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
||||
auto totalCount = outputPart.size(); //lengthOf() / n2;
|
||||
for (int e = 0; e < totalCount; e++) {
|
||||
invertUpperMatrix(inputPart.at(e), outputPart.at(e));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
int lowerInverseFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return lowerInverse_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
int upperInverseFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return upperInverse_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static bool checkCholeskyInput_(nd4j::LaunchContext * context, NDArray const* input) {
|
||||
//std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType())); //, block.getWorkspace());
|
||||
|
|
|
@ -106,7 +106,7 @@ namespace helpers {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void qr_(NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
|
||||
void qr_(NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
|
||||
Nd4jLong lastDim = input->rankOf() - 1;
|
||||
Nd4jLong preLastDim = input->rankOf() - 2;
|
||||
ResultSet listOutQ(outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}));
|
||||
|
@ -123,7 +123,7 @@ namespace helpers {
|
|||
|
||||
}
|
||||
|
||||
void qr(nd4j::LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
|
||||
void qr(nd4j::LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (input, outputQ, outputR, fullMatricies), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit, K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
#include <op_boilerplate.h>
|
||||
#include <NDArray.h>
|
||||
#include <MmulHelper.h>
|
||||
#include <ShapeUtils.h>
|
||||
#include <ConstantTadHelper.h>
|
||||
|
||||
#include "../triangular_solve.h"
|
||||
#include "../lup.h"
|
||||
#include "../qr.h"
|
||||
#include "../lstsq.h"
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
static __global__ void fillRegularizerKernel(T* ioMatrixData, Nd4jLong* ioMatrixShape, Nd4jLong* ioMatrixTads, Nd4jLong* ioMatrixOffsets, Nd4jLong batchSize, Nd4jLong rows, T const value) {
|
||||
|
||||
for (auto x = blockIdx.x; x < batchSize; x += gridDim.x) {
|
||||
auto z = ioMatrixData + ioMatrixOffsets[x];
|
||||
for (auto r = threadIdx.x; r < rows; r += blockDim.x) {
|
||||
Nd4jLong pos[] = {r,r};
|
||||
auto zIndex = shape::getOffset(ioMatrixTads, pos);
|
||||
z[zIndex] = value;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void fillRegularizer(nd4j::LaunchContext* context, NDArray& ioMatrix, double const value) {
|
||||
auto lastDimsTads = ConstantTadHelper::getInstance()->tadForDimensions(ioMatrix.shapeInfo(), {-2, -1});
|
||||
auto stream = context->getCudaStream();
|
||||
auto rows = ioMatrix.sizeAt(-2);
|
||||
//auto cols = ioMatrix.sizeAt(-1);
|
||||
fillRegularizerKernel<T><<<256, 256, 128, *stream>>>(ioMatrix.dataBuffer()->specialAsT<T>(), ioMatrix.specialShapeInfo(), lastDimsTads.specialShapeInfo(), lastDimsTads.specialOffsets(), lastDimsTads.numberOfTads(), rows, (T)value);
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int leastSquaresSolveFunctor_(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) {
|
||||
if (fast) { // Cholesky decomposition approach
|
||||
// Equation for solve A^T * Ax = A^T * b, so
|
||||
// 1. Computing A2:
|
||||
auto tAtShape = ShapeUtils::evalShapeForMatmul(leftInput->getShapeInfo(), leftInput->getShapeInfo(), true, false);
|
||||
//tAtShape[tAtShape.size() - 2] = output->sizeAt(-2);
|
||||
NDArray leftOutput(leftInput->ordering(), tAtShape, output->dataType(), context);
|
||||
MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, false); // Computing A2 = A^T * A
|
||||
// 2. Computing B' = A^T * b
|
||||
auto rightOutput = output->ulike();
|
||||
|
||||
MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, false); // Computing B' = A^T * b
|
||||
// 3. Regularization ( indeed A' = A2 - l2Regularizer * I)
|
||||
if (l2Regularizer != 0.0) {
|
||||
auto regularizer = leftOutput.ulike(); regularizer.nullify();
|
||||
fillRegularizer<T>(context, regularizer, (T)l2Regularizer);
|
||||
leftOutput += regularizer;
|
||||
}
|
||||
|
||||
// 4. Cholesky decomposition -- output matrix is square and lower triangular
|
||||
helpers::cholesky(context, &leftOutput, &leftOutput, true); // inplace decomposition
|
||||
// 5. Solve two triangular systems:
|
||||
auto rightB = rightOutput.ulike(); rightB.nullify();
|
||||
|
||||
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, false, &rightB);
|
||||
|
||||
helpers::adjointMatrix(context, &leftOutput, true, &leftOutput);
|
||||
helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, output);
|
||||
// All done
|
||||
}
|
||||
else { // QR decomposition approach
|
||||
// Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal matrix, and R - upper triangular
|
||||
// 1. QR decomposition
|
||||
auto qShape = leftInput->getShapeAsVector();
|
||||
auto rShape = leftInput->getShapeAsVector();
|
||||
qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2);
|
||||
|
||||
NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), context);// = leftInput->ulike();
|
||||
NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), context); // = rightInput->ulike();
|
||||
helpers::qr(context, leftInput, &Q, &R, true);
|
||||
// 2. b` = Q^t * b:
|
||||
auto rightOutput = rightInput->ulike();
|
||||
MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false);
|
||||
// 3. Solve triangular system
|
||||
helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, output);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int leastSquaresSolveFunctor(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return leastSquaresSolveFunctor_, (context, leftInput, rightInput, l2Regularizer, fast, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -151,7 +151,7 @@ namespace helpers {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void qr_(LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
|
||||
void qr_(LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
|
||||
Nd4jLong lastDim = input->rankOf() - 1;
|
||||
Nd4jLong preLastDim = input->rankOf() - 2;
|
||||
|
||||
|
@ -170,7 +170,7 @@ namespace helpers {
|
|||
NDArray::registerSpecialUse({outputQ, outputR}, {input});
|
||||
}
|
||||
|
||||
void qr(nd4j::LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
|
||||
void qr(nd4j::LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (context, input, outputQ, outputR, fullMatricies), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
#ifndef __LST_SQ_SOLVE__H_HELPERS__
|
||||
#define __LST_SQ_SOLVE__H_HELPERS__
|
||||
#include <op_boilerplate.h>
|
||||
#include <NDArray.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
int leastSquaresSolveFunctor(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
|
@ -32,6 +32,8 @@ namespace helpers {
|
|||
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
|
||||
|
||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
|
||||
int upperInverseFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* output);
|
||||
int lowerInverseFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* output);
|
||||
|
||||
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input);
|
||||
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace = false);
|
||||
|
|
|
@ -26,7 +26,7 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
void qr(nd4j::LaunchContext * context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies);
|
||||
void qr(nd4j::LaunchContext * context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies);
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <NDArray.h>
|
||||
#include <ops/ops.h>
|
||||
#include <GradCheck.h>
|
||||
#include <helpers/MmulHelper.h>
|
||||
|
||||
using namespace nd4j;
|
||||
|
||||
|
@ -1918,8 +1919,8 @@ TEST_F(DeclarableOpsTests11, Solve_Test_4_6) {
|
|||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
z->printBuffer("4_6 Solve 3x3");
|
||||
exp.printBuffer("4_6 Expec 3x3");
|
||||
// z->printBuffer("4_6 Solve 3x3");
|
||||
// exp.printBuffer("4_6 Expec 3x3");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
|
@ -1955,8 +1956,8 @@ TEST_F(DeclarableOpsTests11, Solve_Test_4_7) {
|
|||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
z->printBuffer("4_7 Solve 3x3");
|
||||
exp.printBuffer("4_7 Expec 3x3");
|
||||
// z->printBuffer("4_7 Solve 3x3");
|
||||
// exp.printBuffer("4_7 Expec 3x3");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
|
@ -1989,12 +1990,127 @@ TEST_F(DeclarableOpsTests11, Solve_Test_5) {
|
|||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
z->printBuffer("4 Solve 4x4");
|
||||
exp.printBuffer("4 Expec 4x4");
|
||||
// z->printBuffer("4 Solve 4x4");
|
||||
// exp.printBuffer("4 Expec 4x4");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, SolveLS_Test_1) {
|
||||
|
||||
auto a = NDArrayFactory::create<double>('c', {2,2, 2}, {
|
||||
1.f, 2.f, 3.f, 4.f,
|
||||
5.f, 6.f, 7.f, 8.f
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<double>('c', {2, 2, 1}, {
|
||||
3.f, 7.f, 11.f, 15.f
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 2, 1}, {
|
||||
0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f
|
||||
});
|
||||
|
||||
nd4j::ops::lstsq op;
|
||||
|
||||
auto res = op.evaluate({&a, &b}, {0.5}, {}, {true});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
// z->printIndexedBuffer("LS Solve 2x2");
|
||||
// exp.printIndexedBuffer("LS Expec 2x2");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z, 1.e-4));
|
||||
delete res;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests11, SolveLS_Test_2) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {2,2, 2}, {
|
||||
1.f, 2.f, 3.f, 4.f,
|
||||
5.f, 6.f, 7.f, 8.f
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<float>('c', {2, 2, 1}, {
|
||||
3.f, 7.f, 11.f, 15.f
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 1}, {
|
||||
0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f
|
||||
});
|
||||
|
||||
nd4j::ops::lstsq op;
|
||||
|
||||
auto res = op.evaluate({&a, &b}, {0.5}, {}, {true});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
// z->printIndexedBuffer("2LS Solve 2x2");
|
||||
// exp.printIndexedBuffer("2LS Expec 2x2");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z, 1.e-4));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {2,2, 2}, {
|
||||
10.f, 14.f,
|
||||
14.f, 20.f,
|
||||
|
||||
74.f, 86.f,
|
||||
86.f, 100.f
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||
3.1622777f, 0.f, 4.427189f, 0.6324552f,
|
||||
8.602325f, 0.f, 9.997296f, 0.23252854f
|
||||
});
|
||||
|
||||
nd4j::ops::cholesky op;
|
||||
|
||||
auto res = op.evaluate({&a});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
z->printIndexedBuffer("L matrix is");
|
||||
exp.printIndexedBuffer("L expected is");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2_2) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {2,2, 2}, {
|
||||
10.5f, 14.f,
|
||||
14.f, 20.5f,
|
||||
|
||||
74.5f, 86.f,
|
||||
86.f, 100.5f
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||
3.2403703f, 0.f, 4.3204937f, 1.3540066f,
|
||||
8.631338f, 0.f, 9.963693f, 1.1067207f
|
||||
});
|
||||
|
||||
nd4j::ops::cholesky op;
|
||||
|
||||
auto res = op.evaluate({&a});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
// z->printIndexedBuffer("L matrix is");
|
||||
// exp.printIndexedBuffer("L expected is");
|
||||
MmulHelper::matmul(z, z, &exp, false, true);
|
||||
ASSERT_TRUE(exp.equalsTo(a));
|
||||
delete res;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include <GradCheck.h>
|
||||
#include <ConstantTadHelper.h>
|
||||
#include <helpers/PointersManager.h>
|
||||
#include <helpers/MmulHelper.h>
|
||||
|
||||
using namespace nd4j;
|
||||
|
||||
|
@ -2936,7 +2937,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_4) {
|
|||
|
||||
nd4j::ops::triangular_solve op;
|
||||
|
||||
auto res = op.evaluate({&a, &b}, {}, {}, {false});
|
||||
auto res = op.evaluate({&a, &b}, {false});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
|
@ -2966,7 +2967,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) {
|
|||
|
||||
nd4j::ops::triangular_solve op;
|
||||
|
||||
auto res = op.evaluate({&a, &b}, {}, {}, {false, true});
|
||||
auto res = op.evaluate({&a, &b}, {false, true});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
|
@ -2977,6 +2978,142 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) {
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, SolveLs_Test_1) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
|
||||
3.f, 0.f, 0.f, 0.f,
|
||||
2.f, 1.f, 0.f, 0.f,
|
||||
1.f, 0.f, 1.f, 0.f,
|
||||
1.f, 1.f, 1.f, 1.f
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<float>('c', {4, 1}, {
|
||||
4.f, 2.f, 4.f, 2.f
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {4, 1}, {
|
||||
1.333333f, -0.6666667f, 2.6666667f, -1.3333333f });
|
||||
|
||||
nd4j::ops::lstsq op;
|
||||
|
||||
auto res = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
// z->printIndexedBuffer("MatrixSolveLS");
|
||||
MmulHelper::matmul(&a, z, &exp, false, false);
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(b));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, SolveLs_Test_2) {
|
||||
|
||||
auto a = NDArrayFactory::create<double>('c', {3, 3}, {
|
||||
1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 11.f, 8.f, 21.f
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<double>('c', {3, 1}, { 1.f, 2.f, 3.f });
|
||||
|
||||
auto exp = NDArrayFactory::create<double>('c', {3, 1}, { -0.24999914f, 0.4999994f, 0.08333314f });
|
||||
|
||||
nd4j::ops::lstsq op;
|
||||
|
||||
auto res = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
MmulHelper::matmul(&a, z, &exp, false, false);
|
||||
|
||||
// z->printIndexedBuffer("MatrixSolveLS2");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(b));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, SolveLs_Test_3) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {3, 4}, {
|
||||
1.f,1.f,0.f,0.f,-1.f,1.f,0.f,0.f,1.f,1.f,-1.f,-1.f
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<float>('c', {3, 1}, { 1.f, 2.f, 3.f });
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {3, 1}, { -0.5f, 1.5f, -2.f });
|
||||
|
||||
nd4j::ops::lstsq op;
|
||||
|
||||
auto res = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
// z->printIndexedBuffer("MatrixSolveLS3");
|
||||
MmulHelper::matmul(&a, z, &exp, false, false);
|
||||
ASSERT_TRUE(exp.equalsTo(b));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, SolveLs_Test_4) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {3, 4}, {
|
||||
1.f,1.f,0.f,0.f,-1.f,1.f,0.f,0.f,1.f,1.f,-1.f,-1.f
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<float>('c', {3, 1}, { 1.f, 2.f, 3.f });
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {4, 1}, { -0.5f, 1.5f, -2.f, 0.f});
|
||||
|
||||
nd4j::ops::lstsq op;
|
||||
|
||||
auto res = op.evaluate({&a, &b}, {false});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
// z->printIndexedBuffer("Output_12.4");
|
||||
// z->printShapeInfo("Output_12.4 shape");
|
||||
// MmulHelper::matmul(&a, z, &exp, false, false);
|
||||
|
||||
// z->printIndexedBuffer("MatrixSolveLS4");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, SolveLs_Test_5) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {1, 0, 3, 4});
|
||||
auto b = NDArrayFactory::create<float>('c', {1, 0, 3, 1});
|
||||
|
||||
nd4j::ops::lstsq op;
|
||||
|
||||
auto res = op.evaluate({&a, &b}, {false});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
ASSERT_TRUE(z->isEmpty());
|
||||
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, Solve_Test_6) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {1, 0, 3, 3});
|
||||
auto b = NDArrayFactory::create<float>('c', {1, 0, 3, 1});
|
||||
|
||||
nd4j::ops::solve op;
|
||||
|
||||
auto res = op.evaluate({&a, &b}, {true});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
ASSERT_TRUE(z->isEmpty());
|
||||
|
||||
delete res;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
|
||||
|
@ -3004,4 +3141,4 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -623,7 +623,8 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.custom.Digamma.class,
|
||||
org.nd4j.linalg.api.ops.custom.Lu.class,
|
||||
org.nd4j.linalg.api.ops.custom.TriangularSolve.class,
|
||||
org.nd4j.linalg.api.ops.custom.LinearSolve.class
|
||||
org.nd4j.linalg.api.ops.custom.LinearSolve.class,
|
||||
org.nd4j.linalg.api.ops.custom.Lstsq.class
|
||||
);
|
||||
|
||||
static {
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class Lstsq extends DynamicCustomOp {
|
||||
|
||||
public Lstsq(@NonNull INDArray matrix, @NonNull INDArray rhs, double l2_regularizer, boolean fast) {
|
||||
addInputArgument(matrix, rhs);
|
||||
addTArgument(l2_regularizer);
|
||||
addBArgument(fast);
|
||||
}
|
||||
|
||||
public Lstsq(@NonNull INDArray matrix, @NonNull INDArray rhs) {
|
||||
this(matrix, rhs, 0.0, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "lstsq";
|
||||
}
|
||||
}
|
|
@ -34,19 +34,20 @@ public class TriangularSolve extends DynamicCustomOp {
|
|||
addBArgument(lower, adjoint);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "triangular_solve";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
if(attributesForNode.containsKey("adjoint")){
|
||||
addBArgument(attributesForNode.get("adjoint").getB());
|
||||
}
|
||||
if(attributesForNode.containsKey("lower")){
|
||||
addBArgument(attributesForNode.get("lower").getB());
|
||||
}
|
||||
|
||||
if(attributesForNode.containsKey("adjoint")){
|
||||
addBArgument(attributesForNode.get("adjoint").getB());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "triangular_solve";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -67,10 +67,9 @@ public class SequenceMask extends DynamicCustomOp {
|
|||
public SequenceMask(INDArray input, int maxLen, DataType dataType) {
|
||||
addInputArgument(input);
|
||||
addIArgument(maxLen);
|
||||
//addIArgument(dataType.toInt());
|
||||
addDArgument(dataType);
|
||||
this.dataType = dataType;
|
||||
}
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
|
|
|
@ -123,12 +123,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
//AB 2020/01/07 - Known issues
|
||||
"bitcast/from_float64_to_int64",
|
||||
"bitcast/from_rank2_float64_to_int64",
|
||||
"bitcast/from_float64_to_uint64",
|
||||
|
||||
// 2020/02/14 - new ops which are not passing yet
|
||||
"linear_solve/.*",
|
||||
"triangular_solve/.*",
|
||||
"lstsq/.*"
|
||||
"bitcast/from_float64_to_uint64"
|
||||
};
|
||||
|
||||
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
||||
|
|
|
@ -1739,15 +1739,37 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLstsq() {
|
||||
INDArray a = Nd4j.createFromArray(new float[]{
|
||||
1.f, 2.f, 3.f,
|
||||
4.f, 5.f, 6.f,
|
||||
11.f, 8.f, 21.f
|
||||
}).reshape(3,3);
|
||||
|
||||
INDArray b = Nd4j.createFromArray(new float[]{ 1.f, 2.f, 3.f }).reshape(3,1);
|
||||
|
||||
val op = new Lstsq(a,b);
|
||||
INDArray[] ret = Nd4j.exec(op);
|
||||
|
||||
DynamicCustomOp matmul = DynamicCustomOp.builder("matmul")
|
||||
.addInputs(a, ret[0])
|
||||
.build();
|
||||
INDArray[] matres = Nd4j.exec(matmul);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
assertEquals(b.getFloat(i, 0), matres[0].getFloat(i, 0), 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSequenceMask() {
|
||||
INDArray arr = Nd4j.createFromArray(new int[]{1, 3, 2});
|
||||
// Test with static max len
|
||||
int maxlen = 2;
|
||||
INDArray expected = Nd4j.createFromArray(new int[]{
|
||||
1,0,0,
|
||||
1,1,1,
|
||||
1,1,0
|
||||
1, 0, 0,
|
||||
1, 1, 1,
|
||||
1, 1, 0
|
||||
}).reshape(3, 3);
|
||||
|
||||
INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.INT32));
|
||||
|
|
|
@ -76,13 +76,13 @@ public class RngValidationTests extends BaseNd4jTest {
|
|||
@Builder.Default private double stdRelativeErrorTolerance = 0.01;
|
||||
private Double meanMinAbsErrorTolerance; //Consider relative error between 0 and 0.001: relative error is 1.0, but absolute error is small
|
||||
private Double stdMinAbsErrorTolerance;
|
||||
@Builder.Default private Map<String,Object> args = new LinkedHashMap<>();
|
||||
@Builder.Default private static Map<String,Object> args = new LinkedHashMap<>();
|
||||
|
||||
public static class TestCaseBuilder {
|
||||
|
||||
public TestCaseBuilder arg(String arg, Object value){
|
||||
if(args == null) {
|
||||
args(new LinkedHashMap<>());
|
||||
args = new LinkedHashMap<>();
|
||||
}
|
||||
args.put(arg, value);
|
||||
return this;
|
||||
|
|
2
pom.xml
2
pom.xml
|
@ -335,7 +335,7 @@
|
|||
<jackson.databind.version>2.10.1</jackson.databind.version>
|
||||
<shaded.snakeyaml.version>1.24</shaded.snakeyaml.version>
|
||||
<geo.jackson.version>2.8.7</geo.jackson.version>
|
||||
<lombok.version>1.18.2</lombok.version>
|
||||
<lombok.version>1.18.12</lombok.version>
|
||||
<cleartk.version>2.0.0</cleartk.version>
|
||||
<lucene-solr.version>7.7.1</lucene-solr.version>
|
||||
<json.version>20131018</json.version>
|
||||
|
|
Loading…
Reference in New Issue