Shugeo solve linear (#191)
* linear equations systems solve op. Initial commit. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed compiling issues. Signed-off-by: shugeo <sgazeos@gmail.com> * Linear equations systems solve. The next stage commit. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for linear equations systems solve operation. Signed-off-by: shugeo <sgazeos@gmail.com> * Added additional test and fixed lower matrix retrievance. * Implementation for solve of the systems of linear equations." Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored permutation generation. Signed-off-by: shugeo <sgazeos@gmail.com> * Added restore for permutations batched with cuda helper for solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Finished cuda implementation for solve op helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored cpu helpers for solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fix gtest output on Windows * Fixed issue with permutation matrix for cuda implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed issue with permutation matrix for cpu implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Eliminated waste comments. Signed-off-by: shugeo <sgazeos@gmail.com> * LinearSolve added * Mapping added * Javadoc added * Refactored implementation of triangular_solve helpers and tests for solve matrix equations generally. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a test for solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Solve test added * Fix for TF import Co-authored-by: Serhii Shepel <9946053+sshepel@users.noreply.github.com> Co-authored-by: raver119 <raver119@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
57d5eb473b
commit
41ff907bc6
|
@ -0,0 +1,75 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit, K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by GS <sgazeos@gmail.com> at 01/22/2020
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_solve)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/solve.h>
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(solve, 2, 1, false, 0, 0) {
|
||||||
|
auto a = INPUT_VARIABLE(0);
|
||||||
|
auto b = INPUT_VARIABLE(1);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
bool useAdjoint = false;
|
||||||
|
|
||||||
|
if (block.numB() > 0) {
|
||||||
|
useAdjoint = B_ARG(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(a->rankOf() >=2, 0, "solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf());
|
||||||
|
REQUIRE_TRUE(b->rankOf() >=2, 0, "solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf());
|
||||||
|
|
||||||
|
REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
|
||||||
|
REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2));
|
||||||
|
auto input = a;
|
||||||
|
if (useAdjoint) {
|
||||||
|
auto adjointA = a->ulike();
|
||||||
|
helpers::adjointMatrix(block.launchContext(), a, &adjointA);
|
||||||
|
input = new NDArray(adjointA); //.detach();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto res = helpers::solveFunctor(block.launchContext(), input, b, useAdjoint, z);
|
||||||
|
if (input != a)
|
||||||
|
delete input;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(solve) {
|
||||||
|
auto in0 = inputShape->at(1);
|
||||||
|
auto in1 = inputShape->at(1);
|
||||||
|
auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace());
|
||||||
|
|
||||||
|
return SHAPELIST(CONSTANT(luShape));
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(solve) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS})
|
||||||
|
->setSameMode(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -1076,6 +1076,24 @@ namespace nd4j {
|
||||||
DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0);
|
DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* solve op. - solve systems of linear equations - general method.
|
||||||
|
*
|
||||||
|
* input params:
|
||||||
|
* 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations
|
||||||
|
* 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
|
||||||
|
*
|
||||||
|
* boolean args:
|
||||||
|
* 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used
|
||||||
|
*
|
||||||
|
* return value:
|
||||||
|
* tensor with dimension (x * y * z * ::: * M * K) with solutions
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_solve)
|
||||||
|
DECLARE_CUSTOM_OP(solve, 2, 1, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* lu op. - make LUP decomposition of given batch of 2D square matricies
|
* lu op. - make LUP decomposition of given batch of 2D square matricies
|
||||||
*
|
*
|
||||||
|
|
|
@ -237,25 +237,65 @@ namespace helpers {
|
||||||
samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
|
samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void doolitleLU(LaunchContext* context, NDArray* compound, Nd4jLong rowNum) {
|
||||||
|
auto input = compound->dup();
|
||||||
|
compound->nullify();
|
||||||
|
|
||||||
|
// Decomposing matrix into Upper and Lower
|
||||||
|
// triangular matrix
|
||||||
|
for (auto i = 0; i < rowNum; i++) {
|
||||||
|
|
||||||
|
// Upper Triangular
|
||||||
|
for (auto k = i; k < rowNum; k++) {
|
||||||
|
|
||||||
|
// Summation of L(i, j) * U(j, k)
|
||||||
|
int sum = 0;
|
||||||
|
for (int j = 0; j < i; j++)
|
||||||
|
sum += compound->t<T>(i,j) * compound->t<T>(j,k);
|
||||||
|
|
||||||
|
// Evaluating U(i, k)
|
||||||
|
compound->t<T>(i, k) = input.t<T>(i, k) - sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lower Triangular
|
||||||
|
for (int k = i + 1; k < rowNum; k++) {
|
||||||
|
// Summation of L(k, j) * U(j, i)
|
||||||
|
int sum = 0;
|
||||||
|
for (int j = 0; j < i; j++)
|
||||||
|
sum += compound->t<T>(k,j) * compound->t<T>(j, i);
|
||||||
|
|
||||||
|
// Evaluating L(k, i)
|
||||||
|
compound->t<T>(k, i) = (input.t<T>(k, i) - sum) / compound->t<T>(i,i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) {
|
static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) {
|
||||||
|
|
||||||
//const int rowNum = compound->rows();
|
//const int rowNum = compound->rows();
|
||||||
// const int columnNum = output->columns();
|
// const int columnNum = output->columns();
|
||||||
permutation->linspace(0);
|
if (permutation) { // LUP algorithm
|
||||||
auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>();
|
permutation->linspace(0);
|
||||||
auto compoundBuf = compound->bufferAsT<T>();
|
auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>();
|
||||||
auto compoundShape = compound->shapeInfo();
|
auto compoundBuf = compound->bufferAsT<T>();
|
||||||
auto permutationShape = permutation->shapeInfo();
|
auto compoundShape = compound->shapeInfo();
|
||||||
for (auto i = 0; i < rowNum - 1; i++) {
|
auto permutationShape = permutation->shapeInfo();
|
||||||
auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape);
|
for (auto i = 0; i < rowNum - 1; i++) {
|
||||||
if (pivotIndex < 0) {
|
auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape);
|
||||||
throw std::runtime_error("helpers::luNN_: input matrix is singular.");
|
if (pivotIndex < 0) {
|
||||||
}
|
throw std::runtime_error("helpers::luNN_: input matrix is singular.");
|
||||||
math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]);
|
}
|
||||||
swapRows(compoundBuf, compoundShape, i, pivotIndex);
|
math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)],
|
||||||
|
permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]);
|
||||||
|
swapRows(compoundBuf, compoundShape, i, pivotIndex);
|
||||||
|
|
||||||
processColumns(i, rowNum, compoundBuf, compoundShape);
|
processColumns(i, rowNum, compoundBuf, compoundShape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else { // Doolitle algorithm with LU decomposition
|
||||||
|
doolitleLU<T>(context, compound, rowNum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -265,17 +305,20 @@ namespace helpers {
|
||||||
|
|
||||||
output->assign(input); // fill up output tensor with zeros
|
output->assign(input); // fill up output tensor with zeros
|
||||||
ResultSet outputs = output->allTensorsAlongDimension({-2, -1});
|
ResultSet outputs = output->allTensorsAlongDimension({-2, -1});
|
||||||
ResultSet permutations = permutationVectors->allTensorsAlongDimension({-1});
|
ResultSet permutations;
|
||||||
|
if (permutationVectors)
|
||||||
|
permutations = permutationVectors->allTensorsAlongDimension({-1});
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
luNN_<T, I>(context, outputs.at(i), permutations.at(i), n);
|
luNN_<T, I>(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(loop, 0, outputs.size(), 1);
|
samediff::Threads::parallel_for(loop, 0, outputs.size(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {
|
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation?permutation->dataType():DataType::INT32, lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES);
|
// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit, K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <helpers/MmulHelper.h>
|
||||||
|
|
||||||
|
#include "../triangular_solve.h"
|
||||||
|
#include "../lup.h"
|
||||||
|
#include "../solve.h"
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T>
|
||||||
|
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||||
|
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
||||||
|
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
||||||
|
output->assign(input);
|
||||||
|
auto batchLoop = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto batch = start; batch < stop; batch += increment) {
|
||||||
|
for (auto r = 0; r < input->rows(); r++) {
|
||||||
|
for (auto c = 0; c < r; c++) {
|
||||||
|
math::nd4j_swap(outputPart[batch]->t<T>(r, c) , outputPart[batch]->t<T>(c, r));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T>
|
||||||
|
static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) {
|
||||||
|
|
||||||
|
// stage 1: LU decomposition batched
|
||||||
|
auto leftOutput = leftInput->ulike();
|
||||||
|
auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back();
|
||||||
|
auto permutations = NDArrayFactory::create<int>('c', permuShape, context);
|
||||||
|
helpers::lu(context, leftInput, &leftOutput, &permutations);
|
||||||
|
auto P = leftInput->ulike(); //permutations batched matrix
|
||||||
|
P.nullify(); // to fill up matricies with zeros
|
||||||
|
auto PPart = P.allTensorsAlongDimension({-2,-1});
|
||||||
|
auto permutationsPart = permutations.allTensorsAlongDimension({-1});
|
||||||
|
|
||||||
|
for (auto batch = 0; batch < permutationsPart.size(); ++batch) {
|
||||||
|
for (auto row = 0; row < PPart[batch]->rows(); ++row) {
|
||||||
|
PPart[batch]->t<T>(row, permutationsPart[batch]->t<int>(row)) = T(1.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto leftLower = leftOutput.dup();
|
||||||
|
auto rightOutput = rightInput->ulike();
|
||||||
|
auto rightPermuted = rightOutput.ulike();
|
||||||
|
MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0);
|
||||||
|
ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1});
|
||||||
|
for (auto i = 0; i < leftLowerPart.size(); i++) {
|
||||||
|
for (auto r = 0; r < leftLowerPart[i]->rows(); r++)
|
||||||
|
leftLowerPart[i]->t<T>(r,r) = (T)1.f;
|
||||||
|
}
|
||||||
|
// stage 2: triangularSolveFunctor for Lower with given b
|
||||||
|
helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput);
|
||||||
|
// stage 3: triangularSolveFunctor for Upper with output of previous stage
|
||||||
|
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||||
|
int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -41,13 +41,16 @@ namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
||||||
auto rows = leftInput->rows();
|
auto rows = leftInput->rows();
|
||||||
|
auto cols = rightInput->columns();
|
||||||
//output->t<T>(0,0) = rightInput->t<T>(0,0) / leftInput->t<T>(0,0);
|
//output->t<T>(0,0) = rightInput->t<T>(0,0) / leftInput->t<T>(0,0);
|
||||||
for (auto r = 0; r < rows; r++) {
|
for (auto r = 0; r < rows; r++) {
|
||||||
auto sum = rightInput->t<T>(r, 0);
|
for (auto j = 0; j < cols; j++) {
|
||||||
for (auto c = 0; c < r; c++) {
|
auto sum = rightInput->t<T>(r, j);
|
||||||
sum -= leftInput->t<T>(r,c) * output->t<T>(c, 0);
|
for (auto c = 0; c < r; c++) {
|
||||||
|
sum -= leftInput->t<T>(r, c) * output->t<T>(c, j);
|
||||||
|
}
|
||||||
|
output->t<T>(r, j) = sum / leftInput->t<T>(r, r);
|
||||||
}
|
}
|
||||||
output->t<T>(r, 0) = sum / leftInput->t<T>(r, r);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,13 +71,15 @@ namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
||||||
auto rows = leftInput->rows();
|
auto rows = leftInput->rows();
|
||||||
|
auto cols = rightInput->columns();
|
||||||
for (auto r = rows; r > 0; r--) {
|
for (auto r = rows; r > 0; r--) {
|
||||||
auto sum = rightInput->t<T>(r - 1, 0);
|
for (auto j = 0; j < cols; j++) {
|
||||||
for (auto c = r; c < rows; c++) {
|
auto sum = rightInput->t<T>(r - 1, j);
|
||||||
sum -= leftInput->t<T>(r - 1, c) * output->t<T>(c, 0);
|
for (auto c = r; c < rows; c++) {
|
||||||
|
sum -= leftInput->t<T>(r - 1, c) * output->t<T>(c, j);
|
||||||
|
}
|
||||||
|
output->t<T>(r - 1, j) = sum / leftInput->t<T>(r - 1, r - 1);
|
||||||
}
|
}
|
||||||
output->t<T>(r - 1, 0) = sum / leftInput->t<T>(r - 1, r - 1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,140 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit, K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <MmulHelper.h>
|
||||||
|
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
#include "../triangular_solve.h"
|
||||||
|
#include "../lup.h"
|
||||||
|
#include "../solve.h"
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void oneOnDiagonalKernel(T* ioBuf, Nd4jLong* ioShape, Nd4jLong* tadShape, Nd4jLong* tadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) {
|
||||||
|
for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) {
|
||||||
|
auto matrixPart = ioBuf + tadOffsets[i];
|
||||||
|
for (auto j = threadIdx.x; j < rowNum; j += blockDim.x) {
|
||||||
|
Nd4jLong pos[] = {j, j};
|
||||||
|
auto offset = shape::getOffset(tadShape, pos);
|
||||||
|
|
||||||
|
matrixPart[offset] = T(1.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void restorePermutationsKernel(T* PBuf, Nd4jLong* PShapeInfo, int const* permutationsBuf,
|
||||||
|
Nd4jLong* PTadShapeInfo, Nd4jLong* PTadSOffsets, Nd4jLong* permutationsTadShapeInfo, Nd4jLong* permutationsTadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) {
|
||||||
|
for (auto batch = blockIdx.x; batch < batchNum; batch += gridDim.x) {
|
||||||
|
auto permutations = permutationsBuf + permutationsTadOffsets[batch];
|
||||||
|
auto P = PBuf + PTadSOffsets[batch];
|
||||||
|
|
||||||
|
for (auto row = threadIdx.x; row < rowNum; row += blockDim.x) {
|
||||||
|
//auto posX[] = {row};
|
||||||
|
Nd4jLong posZ[] = {row, permutations[row]};
|
||||||
|
auto zOffset = shape::getOffset(PTadShapeInfo, posZ);
|
||||||
|
P[zOffset] = T(1.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput,
|
||||||
|
bool adjoint, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {leftInput, rightInput});
|
||||||
|
// stage 1: LU decomposition batched
|
||||||
|
auto leftOutput = leftInput->ulike();
|
||||||
|
auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back();
|
||||||
|
auto permutations = NDArrayFactory::create<int>('c', permuShape, context);
|
||||||
|
helpers::lu(context, leftInput, &leftOutput, &permutations);
|
||||||
|
auto leftLower = leftOutput.dup();
|
||||||
|
auto rightOutput = rightInput->ulike();
|
||||||
|
auto leftLowerTad = ConstantTadHelper::getInstance()->tadForDimensions(leftLower.getShapeInfo(), {-2, -1});
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
oneOnDiagonalKernel<T><<<128, 256, 256, *stream>>>(leftLower.dataBuffer()->specialAsT<T>(), leftLower.specialShapeInfo(), leftLowerTad.specialShapeInfo(), leftLowerTad.specialOffsets(), leftLowerTad.numberOfTads(), leftLower.sizeAt(-1));
|
||||||
|
auto P = leftOutput.ulike(); P.nullify();
|
||||||
|
auto PTad = ConstantTadHelper::getInstance()->tadForDimensions(P.getShapeInfo(), {-2, -1});
|
||||||
|
auto permutationsTad = ConstantTadHelper::getInstance()->tadForDimensions(permutations.getShapeInfo(), {-1});
|
||||||
|
restorePermutationsKernel<T><<<128, 256, 256, *stream>>>(P.dataBuffer()->specialAsT<T>(), P.specialShapeInfo(), permutations.dataBuffer()->specialAsT<int>(),
|
||||||
|
PTad.specialShapeInfo(), PTad.specialOffsets(), permutationsTad.specialShapeInfo(), permutationsTad.specialOffsets(), permutationsTad.numberOfTads(), permutations.sizeAt(-1));
|
||||||
|
P.tickWriteDevice();
|
||||||
|
auto rightPart = rightInput->ulike();
|
||||||
|
MmulHelper::matmul(&P, rightInput, &rightPart, 0, 0);
|
||||||
|
|
||||||
|
// stage 2: triangularSolveFunctor for Lower with given b
|
||||||
|
helpers::triangularSolveFunctor(context, &leftLower, &rightPart, true, false, &rightOutput);
|
||||||
|
// stage 3: triangularSolveFunctor for Upper with output of previous stage
|
||||||
|
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output);
|
||||||
|
NDArray::registerSpecialUse({output}, {leftInput, rightInput});
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void adjointKernel(T* output, Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, Nd4jLong* outputTads,
|
||||||
|
Nd4jLong* outputOffsets) {
|
||||||
|
|
||||||
|
for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) {
|
||||||
|
auto outputPart = output + outputOffsets[b];
|
||||||
|
for (auto r = threadIdx.x; r < rows; r += blockDim.x) {
|
||||||
|
for (auto c = threadIdx.y; c < r; c += blockDim.y) {
|
||||||
|
Nd4jLong zPos[] = {r, c};
|
||||||
|
Nd4jLong xPos[] = {c, r};
|
||||||
|
auto zIndex = shape::getOffset(outputTads, zPos);
|
||||||
|
auto xIndex = shape::getOffset(outputTads, xPos);
|
||||||
|
math::nd4j_swap(outputPart[zIndex], outputPart[xIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {-2, -1});
|
||||||
|
auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {-2, -1});
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer());
|
||||||
|
auto rows = input->sizeAt(-2);
|
||||||
|
auto columns = input->sizeAt(-1);
|
||||||
|
output->assign(input);
|
||||||
|
adjointKernel<T><<<128, 256, 256, *stream>>>(outputBuf, outputTads.numberOfTads(), rows, columns, outputTads.specialShapeInfo(), outputTads.specialOffsets());
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
}
|
||||||
|
|
||||||
|
void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -44,24 +44,26 @@ namespace nd4j {
|
||||||
static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
|
static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
|
||||||
T const* rightInput, Nd4jLong const* rightInputShape,
|
T const* rightInput, Nd4jLong const* rightInputShape,
|
||||||
bool const adjoint, T* output, Nd4jLong* outputShape,
|
bool const adjoint, T* output, Nd4jLong* outputShape,
|
||||||
Nd4jLong rows) {
|
Nd4jLong rows, Nd4jLong cols) {
|
||||||
|
|
||||||
for (auto r = 0; r < rows; r++) {
|
for (auto r = 0; r < rows; r++) {
|
||||||
Nd4jLong posY[] = {r, 0};
|
for (auto j = 0; j < cols; j++) {
|
||||||
Nd4jLong posX[] = {r, r};
|
Nd4jLong posY[] = {r, j};
|
||||||
auto xIndex = shape::getOffset(leftInputShape, posX, 0);
|
Nd4jLong posX[] = {r, r};
|
||||||
auto yIndex = shape::getOffset(rightInputShape, posY, 0);
|
auto xIndex = shape::getOffset(leftInputShape, posX, 0);
|
||||||
auto zIndex = shape::getOffset(outputShape, posY, 0);
|
auto yIndex = shape::getOffset(rightInputShape, posY, 0);
|
||||||
|
auto zIndex = shape::getOffset(outputShape, posY, 0);
|
||||||
|
|
||||||
auto sum = rightInput[yIndex];
|
auto sum = rightInput[yIndex];
|
||||||
for (auto c = 0; c < r; c++) {
|
for (auto c = 0; c < r; c++) {
|
||||||
Nd4jLong posZ[] = {c, 0};
|
Nd4jLong posZ[] = {c, j};
|
||||||
Nd4jLong pos[] = {r, c};
|
Nd4jLong pos[] = {r, c};
|
||||||
auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
|
auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
|
||||||
auto zcIndex = shape::getOffset(outputShape, posZ, 0);
|
auto zcIndex = shape::getOffset(outputShape, posZ, 0);
|
||||||
sum -= leftInput[xcIndex] * output[zcIndex];
|
sum -= leftInput[xcIndex] * output[zcIndex];
|
||||||
|
}
|
||||||
|
output[zIndex] = sum / leftInput[xIndex];
|
||||||
}
|
}
|
||||||
output[zIndex] = sum / leftInput[xIndex];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,23 +84,25 @@ namespace nd4j {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
|
static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
|
||||||
T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output,
|
T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output,
|
||||||
Nd4jLong* outputShape, Nd4jLong rows) {
|
Nd4jLong* outputShape, Nd4jLong rows, Nd4jLong cols) {
|
||||||
|
|
||||||
for (auto r = rows; r > 0; r--) {
|
for (auto r = rows; r > 0; r--) {
|
||||||
Nd4jLong posY[] = {r - 1, 0};
|
for (auto j = 0; j < cols; j++) {
|
||||||
Nd4jLong posX[] = {r - 1, r - 1};
|
Nd4jLong posY[] = {r - 1, j};
|
||||||
auto xIndex = shape::getOffset(leftInputShape, posX, 0);
|
Nd4jLong posX[] = {r - 1, r - 1};
|
||||||
auto yIndex = shape::getOffset(rightInputShape, posY, 0);
|
auto xIndex = shape::getOffset(leftInputShape, posX, 0);
|
||||||
auto zIndex = shape::getOffset(outputShape, posY, 0);
|
auto yIndex = shape::getOffset(rightInputShape, posY, 0);
|
||||||
auto sum = rightInput[yIndex];
|
auto zIndex = shape::getOffset(outputShape, posY, 0);
|
||||||
for (auto c = r; c < rows; c++) {
|
auto sum = rightInput[yIndex];
|
||||||
Nd4jLong posZ[] = {c, 0};
|
for (auto c = r; c < rows; c++) {
|
||||||
Nd4jLong pos[] = {r - 1, c};
|
Nd4jLong posZ[] = {c, j};
|
||||||
auto zcIndex = shape::getOffset(outputShape, posZ, 0);
|
Nd4jLong pos[] = {r - 1, c};
|
||||||
auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
|
auto zcIndex = shape::getOffset(outputShape, posZ, 0);
|
||||||
sum -= leftInput[xcIndex] * output[zcIndex];
|
auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
|
||||||
|
sum -= leftInput[xcIndex] * output[zcIndex];
|
||||||
|
}
|
||||||
|
output[zIndex] = sum / leftInput[xIndex];
|
||||||
}
|
}
|
||||||
output[zIndex] = sum / leftInput[xIndex];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,8 +113,11 @@ namespace nd4j {
|
||||||
Nd4jLong* tadRightOffset, Nd4jLong* tadOutputShape, Nd4jLong* tadOutputOffset, Nd4jLong batchNum) {
|
Nd4jLong* tadRightOffset, Nd4jLong* tadOutputShape, Nd4jLong* tadOutputOffset, Nd4jLong batchNum) {
|
||||||
|
|
||||||
__shared__ Nd4jLong rows;
|
__shared__ Nd4jLong rows;
|
||||||
|
__shared__ Nd4jLong cols;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
rows = shape::sizeAt(leftPartShape, -2);
|
rows = shape::sizeAt(leftPartShape, -2);
|
||||||
|
cols = shape::sizeAt(rightPartShape, -1);
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
@ -123,9 +130,9 @@ namespace nd4j {
|
||||||
auto pRightPart = rightInput + tadRightOffset[i];
|
auto pRightPart = rightInput + tadRightOffset[i];
|
||||||
auto pOutputPart = output + tadOutputOffset[i];
|
auto pOutputPart = output + tadOutputOffset[i];
|
||||||
if (lower) {
|
if (lower) {
|
||||||
lowerTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows);
|
lowerTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols);
|
||||||
} else {
|
} else {
|
||||||
upperTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows);
|
upperTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
#ifndef __SOLVE__H_HELPERS__
|
||||||
|
#define __SOLVE__H_HELPERS__
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
int solveFunctor(nd4j::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output);
|
||||||
|
void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -1541,6 +1541,166 @@ TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
|
||||||
delete []arr;
|
delete []arr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_1) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 1}, {
|
||||||
|
2.f, 4.f, 3.f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {
|
||||||
|
7.625f, 3.25f, 5.f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printIndexedBuffer("Solve of 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_2) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
|
||||||
|
1.f, 1.f, 1.f, 1.f,
|
||||||
|
0.f, 1.f, 1.f, 0.f,
|
||||||
|
0.f, 0.f, 2.f, 1.f,
|
||||||
|
0.f, 0.f, 0.f, 3.f,
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {4, 1}, {
|
||||||
|
2.f, 4.f, 2.f, 4.f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {4, 1}, {
|
||||||
|
-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printIndexedBuffer("Solve 4x4");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_3) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {2, 4, 4}, {
|
||||||
|
1.f, 1.f, 1.f, 1.f,
|
||||||
|
0.f, 1.f, 1.f, 0.f,
|
||||||
|
0.f, 0.f, 2.f, 1.f,
|
||||||
|
0.f, 0.f, 0.f, 3.f,
|
||||||
|
|
||||||
|
3.f, 0.f, 0.f, 0.f,
|
||||||
|
2.f, 1.f, 0.f, 0.f,
|
||||||
|
1.f, 0.f, 1.f, 0.f,
|
||||||
|
1.f, 1.f, 1.f, 1.f
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {2, 4, 1}, {
|
||||||
|
2.f, 4.f, 2.f, 4.f,
|
||||||
|
4.f, 2.f, 4.f, 2.f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 4, 1}, {
|
||||||
|
-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f,
|
||||||
|
1.333333f, -0.6666667f, 2.6666667f, -1.3333333f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printIndexedBuffer("Solve 4x4");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f, 0.2309f,
|
||||||
|
0.7271f, 0.1804f, 0.5056f, 0.8925f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f, 0.4838f,
|
||||||
|
0.6433f, 0.6041f, 0.6501f, 0.7612f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||||
|
// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f,
|
||||||
|
// 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f
|
||||||
|
1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f,
|
||||||
|
0.81915987f, 0.72049433f, 0.2643504f, 0.44472617f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printBuffer("4 Solve 4x4");
|
||||||
|
// exp.printBuffer("4 Expec 4x4");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_5) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
1.5504692f, 1.8953944f, 2.2765768f,
|
||||||
|
0.03399149f, 0.2883001f, 0.5377323f,
|
||||||
|
-0.8774802f, -1.2155888f, -1.8049058f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {true});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
z->printBuffer("4 Solve 4x4");
|
||||||
|
exp.printBuffer("4 Expec 4x4");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {
|
||||||
|
|
||||||
|
|
|
@ -3008,3 +3008,33 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
delete res;
|
delete res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
|
||||||
|
5.f, 1.f, -3.f, 3.f,
|
||||||
|
0.f, 1.f, 1.f, -1.f,
|
||||||
|
0.f, 0.f, 2.f, -9.f,
|
||||||
|
0.f, 0.f, 0.f, 4.f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {4, 2}, {
|
||||||
|
5.f, 1.f, 2.f, 1.f, 0.f, 1.f, -3.f, 1.f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {4, 2}, {
|
||||||
|
1.f,0.2f, 1.f,0.8f, 1.f,0.4f, 1.f,1.2f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {}, {}, {false, true});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
z->printIndexedBuffer("TriangularSolve with adjoint");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
|
@ -39,7 +39,7 @@ do
|
||||||
done
|
done
|
||||||
|
|
||||||
CHIP="${CHIP:-cpu}"
|
CHIP="${CHIP:-cpu}"
|
||||||
export GTEST_OUTPUT="xml:../target/surefire-reports/TEST-${CHIP}-results.xml"
|
export GTEST_OUTPUT="xml:surefire-reports/TEST-${CHIP}-results.xml"
|
||||||
|
|
||||||
# On Mac, make sure it can find libraries for GCC
|
# On Mac, make sure it can find libraries for GCC
|
||||||
export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/
|
export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/
|
||||||
|
@ -48,9 +48,11 @@ export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/
|
||||||
if [ -n "$BUILD_PATH" ]; then
|
if [ -n "$BUILD_PATH" ]; then
|
||||||
if which cygpath; then
|
if which cygpath; then
|
||||||
BUILD_PATH=$(cygpath -p $BUILD_PATH)
|
BUILD_PATH=$(cygpath -p $BUILD_PATH)
|
||||||
export GTEST_OUTPUT="xml:'..\target\surefire-reports\TEST-${CHIP}-results.xml'"
|
|
||||||
fi
|
fi
|
||||||
export PATH="$PATH:$BUILD_PATH"
|
export PATH="$PATH:$BUILD_PATH"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests
|
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests
|
||||||
|
|
||||||
|
# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion)
|
||||||
|
[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/
|
||||||
|
|
|
@ -623,7 +623,8 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.custom.Igammac.class,
|
org.nd4j.linalg.api.ops.custom.Igammac.class,
|
||||||
org.nd4j.linalg.api.ops.custom.Digamma.class,
|
org.nd4j.linalg.api.ops.custom.Digamma.class,
|
||||||
org.nd4j.linalg.api.ops.custom.Lu.class,
|
org.nd4j.linalg.api.ops.custom.Lu.class,
|
||||||
org.nd4j.linalg.api.ops.custom.TriangularSolve.class
|
org.nd4j.linalg.api.ops.custom.TriangularSolve.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.LinearSolve.class
|
||||||
);
|
);
|
||||||
|
|
||||||
static {
|
static {
|
||||||
|
|
|
@ -16,26 +16,48 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.buffer;
|
package org.nd4j.linalg.api.buffer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Enum lists supported data types.
|
||||||
|
*
|
||||||
|
*/
|
||||||
public enum DataType {
|
public enum DataType {
|
||||||
|
|
||||||
DOUBLE,
|
DOUBLE,
|
||||||
FLOAT,
|
FLOAT,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Replaced by {@link DataType#FLOAT16}, use that instead
|
||||||
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
HALF,
|
HALF,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Replaced by {@link DataType#INT64}, use that instead
|
||||||
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
LONG,
|
LONG,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Replaced by {@link DataType#INT32}, use that instead
|
||||||
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
INT,
|
INT,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Replaced by {@link DataType#INT16}, use that instead
|
||||||
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
SHORT,
|
SHORT,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Replaced by {@link DataType#UINT8}, use that instead
|
||||||
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
UBYTE,
|
UBYTE,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Replaced by {@link DataType#INT8}, use that instead
|
||||||
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
BYTE,
|
BYTE,
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class LinearSolve extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public LinearSolve(INDArray a, INDArray b, boolean adjoint) {
|
||||||
|
addInputArgument(a, b);
|
||||||
|
addBArgument(adjoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
public LinearSolve(INDArray a, INDArray b) {
|
||||||
|
this(a,b,false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public LinearSolve(SameDiff sameDiff, SDVariable a, SDVariable b, SDVariable adjoint) {
|
||||||
|
super(sameDiff, new SDVariable[] {a, b, adjoint});
|
||||||
|
}
|
||||||
|
|
||||||
|
public LinearSolve(SameDiff sameDiff, SDVariable a, SDVariable b, boolean adjoint) {
|
||||||
|
super(sameDiff, new SDVariable[] {a, b});
|
||||||
|
addBArgument(adjoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "solve";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "MatrixSolve";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
|
boolean adjoint = attributesForNode.containsKey("adjoint") ? attributesForNode.get("adjoint").getB() : false;
|
||||||
|
addBArgument(adjoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), dataTypes);
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -1691,4 +1691,50 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
|
|
||||||
assertEquals(e, x);
|
assertEquals(e, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLinearSolve() {
|
||||||
|
INDArray a = Nd4j.createFromArray(new float[]{
|
||||||
|
2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f
|
||||||
|
}).reshape(3, 3);
|
||||||
|
|
||||||
|
INDArray b = Nd4j.createFromArray(new float[]{
|
||||||
|
2.f, 4.f, 3.f
|
||||||
|
}).reshape(3, 1);
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||||
|
7.625f, 3.25f, 5.f
|
||||||
|
}).reshape(3, 1);
|
||||||
|
|
||||||
|
val op = new LinearSolve(a, b);
|
||||||
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
|
|
||||||
|
assertEquals(expected, ret[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLinearSolveAdjust() {
|
||||||
|
INDArray a = Nd4j.createFromArray(new float[]{
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
}).reshape(3, 3);
|
||||||
|
|
||||||
|
INDArray b = Nd4j.createFromArray(new float[]{
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
}).reshape(3, 3);
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||||
|
1.5504692f, 1.8953944f, 2.2765768f,
|
||||||
|
0.03399149f, 0.2883001f , 0.5377323f,
|
||||||
|
-0.8774802f, -1.2155888f, -1.8049058f
|
||||||
|
}).reshape(3, 3);
|
||||||
|
|
||||||
|
val op = new LinearSolve(a, b, true);
|
||||||
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
|
|
||||||
|
assertEquals(expected, ret[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue