Shugeo solve linear (#191)
* linear equations systems solve op. Initial commit. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed compiling issues. Signed-off-by: shugeo <sgazeos@gmail.com> * Linear equations systems solve. The next stage commit. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for linear equations systems solve operation. Signed-off-by: shugeo <sgazeos@gmail.com> * Added additional test and fixed lower matrix retrievance. * Implementation for solve of the systems of linear equations." Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored permutation generation. Signed-off-by: shugeo <sgazeos@gmail.com> * Added restore for permutations batched with cuda helper for solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Finished cuda implementation for solve op helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored cpu helpers for solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fix gtest output on Windows * Fixed issue with permutation matrix for cuda implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed issue with permutation matrix for cpu implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Eliminated waste comments. Signed-off-by: shugeo <sgazeos@gmail.com> * LinearSolve added * Mapping added * Javadoc added * Refactored implementation of triangular_solve helpers and tests for solve matrix equations generally. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a test for solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Solve test added * Fix for TF import Co-authored-by: Serhii Shepel <9946053+sshepel@users.noreply.github.com> Co-authored-by: raver119 <raver119@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
57d5eb473b
commit
41ff907bc6
|
@ -0,0 +1,75 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit, K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by GS <sgazeos@gmail.com> at 01/22/2020
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_solve)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/solve.h>
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(solve, 2, 1, false, 0, 0) {
|
||||
auto a = INPUT_VARIABLE(0);
|
||||
auto b = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
bool useAdjoint = false;
|
||||
|
||||
if (block.numB() > 0) {
|
||||
useAdjoint = B_ARG(0);
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(a->rankOf() >=2, 0, "solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf());
|
||||
REQUIRE_TRUE(b->rankOf() >=2, 0, "solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf());
|
||||
|
||||
REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
|
||||
REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2));
|
||||
auto input = a;
|
||||
if (useAdjoint) {
|
||||
auto adjointA = a->ulike();
|
||||
helpers::adjointMatrix(block.launchContext(), a, &adjointA);
|
||||
input = new NDArray(adjointA); //.detach();
|
||||
};
|
||||
|
||||
auto res = helpers::solveFunctor(block.launchContext(), input, b, useAdjoint, z);
|
||||
if (input != a)
|
||||
delete input;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(solve) {
|
||||
auto in0 = inputShape->at(1);
|
||||
auto in1 = inputShape->at(1);
|
||||
auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace());
|
||||
|
||||
return SHAPELIST(CONSTANT(luShape));
|
||||
}
|
||||
|
||||
DECLARE_TYPES(solve) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS})
|
||||
->setSameMode(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -1076,6 +1076,24 @@ namespace nd4j {
|
|||
DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* solve op. - solve systems of linear equations - general method.
|
||||
*
|
||||
* input params:
|
||||
* 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations
|
||||
* 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
|
||||
*
|
||||
* boolean args:
|
||||
* 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used
|
||||
*
|
||||
* return value:
|
||||
* tensor with dimension (x * y * z * ::: * M * K) with solutions
|
||||
*
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_solve)
|
||||
DECLARE_CUSTOM_OP(solve, 2, 1, true, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* lu op. - make LUP decomposition of given batch of 2D square matricies
|
||||
*
|
||||
|
|
|
@ -237,25 +237,65 @@ namespace helpers {
|
|||
samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void doolitleLU(LaunchContext* context, NDArray* compound, Nd4jLong rowNum) {
|
||||
auto input = compound->dup();
|
||||
compound->nullify();
|
||||
|
||||
// Decomposing matrix into Upper and Lower
|
||||
// triangular matrix
|
||||
for (auto i = 0; i < rowNum; i++) {
|
||||
|
||||
// Upper Triangular
|
||||
for (auto k = i; k < rowNum; k++) {
|
||||
|
||||
// Summation of L(i, j) * U(j, k)
|
||||
int sum = 0;
|
||||
for (int j = 0; j < i; j++)
|
||||
sum += compound->t<T>(i,j) * compound->t<T>(j,k);
|
||||
|
||||
// Evaluating U(i, k)
|
||||
compound->t<T>(i, k) = input.t<T>(i, k) - sum;
|
||||
}
|
||||
|
||||
// Lower Triangular
|
||||
for (int k = i + 1; k < rowNum; k++) {
|
||||
// Summation of L(k, j) * U(j, i)
|
||||
int sum = 0;
|
||||
for (int j = 0; j < i; j++)
|
||||
sum += compound->t<T>(k,j) * compound->t<T>(j, i);
|
||||
|
||||
// Evaluating L(k, i)
|
||||
compound->t<T>(k, i) = (input.t<T>(k, i) - sum) / compound->t<T>(i,i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) {
|
||||
|
||||
//const int rowNum = compound->rows();
|
||||
// const int columnNum = output->columns();
|
||||
permutation->linspace(0);
|
||||
auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>();
|
||||
auto compoundBuf = compound->bufferAsT<T>();
|
||||
auto compoundShape = compound->shapeInfo();
|
||||
auto permutationShape = permutation->shapeInfo();
|
||||
for (auto i = 0; i < rowNum - 1; i++) {
|
||||
auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape);
|
||||
if (pivotIndex < 0) {
|
||||
throw std::runtime_error("helpers::luNN_: input matrix is singular.");
|
||||
}
|
||||
math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]);
|
||||
swapRows(compoundBuf, compoundShape, i, pivotIndex);
|
||||
if (permutation) { // LUP algorithm
|
||||
permutation->linspace(0);
|
||||
auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>();
|
||||
auto compoundBuf = compound->bufferAsT<T>();
|
||||
auto compoundShape = compound->shapeInfo();
|
||||
auto permutationShape = permutation->shapeInfo();
|
||||
for (auto i = 0; i < rowNum - 1; i++) {
|
||||
auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape);
|
||||
if (pivotIndex < 0) {
|
||||
throw std::runtime_error("helpers::luNN_: input matrix is singular.");
|
||||
}
|
||||
math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)],
|
||||
permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]);
|
||||
swapRows(compoundBuf, compoundShape, i, pivotIndex);
|
||||
|
||||
processColumns(i, rowNum, compoundBuf, compoundShape);
|
||||
processColumns(i, rowNum, compoundBuf, compoundShape);
|
||||
}
|
||||
}
|
||||
else { // Doolitle algorithm with LU decomposition
|
||||
doolitleLU<T>(context, compound, rowNum);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -265,17 +305,20 @@ namespace helpers {
|
|||
|
||||
output->assign(input); // fill up output tensor with zeros
|
||||
ResultSet outputs = output->allTensorsAlongDimension({-2, -1});
|
||||
ResultSet permutations = permutationVectors->allTensorsAlongDimension({-1});
|
||||
ResultSet permutations;
|
||||
if (permutationVectors)
|
||||
permutations = permutationVectors->allTensorsAlongDimension({-1});
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
luNN_<T, I>(context, outputs.at(i), permutations.at(i), n);
|
||||
luNN_<T, I>(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n);
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(loop, 0, outputs.size(), 1);
|
||||
}
|
||||
|
||||
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation?permutation->dataType():DataType::INT32, lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES);
|
||||
}
|
||||
|
||||
// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES);
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit, K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
#include <op_boilerplate.h>
|
||||
#include <NDArray.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <execution/Threads.h>
|
||||
#include <helpers/MmulHelper.h>
|
||||
|
||||
#include "../triangular_solve.h"
|
||||
#include "../lup.h"
|
||||
#include "../solve.h"
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T>
|
||||
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
||||
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
||||
output->assign(input);
|
||||
auto batchLoop = PRAGMA_THREADS_FOR {
|
||||
for (auto batch = start; batch < stop; batch += increment) {
|
||||
for (auto r = 0; r < input->rows(); r++) {
|
||||
for (auto c = 0; c < r; c++) {
|
||||
math::nd4j_swap(outputPart[batch]->t<T>(r, c) , outputPart[batch]->t<T>(c, r));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T>
|
||||
static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) {
|
||||
|
||||
// stage 1: LU decomposition batched
|
||||
auto leftOutput = leftInput->ulike();
|
||||
auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back();
|
||||
auto permutations = NDArrayFactory::create<int>('c', permuShape, context);
|
||||
helpers::lu(context, leftInput, &leftOutput, &permutations);
|
||||
auto P = leftInput->ulike(); //permutations batched matrix
|
||||
P.nullify(); // to fill up matricies with zeros
|
||||
auto PPart = P.allTensorsAlongDimension({-2,-1});
|
||||
auto permutationsPart = permutations.allTensorsAlongDimension({-1});
|
||||
|
||||
for (auto batch = 0; batch < permutationsPart.size(); ++batch) {
|
||||
for (auto row = 0; row < PPart[batch]->rows(); ++row) {
|
||||
PPart[batch]->t<T>(row, permutationsPart[batch]->t<int>(row)) = T(1.f);
|
||||
}
|
||||
}
|
||||
|
||||
auto leftLower = leftOutput.dup();
|
||||
auto rightOutput = rightInput->ulike();
|
||||
auto rightPermuted = rightOutput.ulike();
|
||||
MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0);
|
||||
ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1});
|
||||
for (auto i = 0; i < leftLowerPart.size(); i++) {
|
||||
for (auto r = 0; r < leftLowerPart[i]->rows(); r++)
|
||||
leftLowerPart[i]->t<T>(r,r) = (T)1.f;
|
||||
}
|
||||
// stage 2: triangularSolveFunctor for Lower with given b
|
||||
helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput);
|
||||
// stage 3: triangularSolveFunctor for Upper with output of previous stage
|
||||
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||
int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES);
|
||||
}
|
||||
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||
void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||
}
|
||||
}
|
||||
}
|
|
@ -41,13 +41,16 @@ namespace helpers {
|
|||
template <typename T>
|
||||
static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
||||
auto rows = leftInput->rows();
|
||||
auto cols = rightInput->columns();
|
||||
//output->t<T>(0,0) = rightInput->t<T>(0,0) / leftInput->t<T>(0,0);
|
||||
for (auto r = 0; r < rows; r++) {
|
||||
auto sum = rightInput->t<T>(r, 0);
|
||||
for (auto c = 0; c < r; c++) {
|
||||
sum -= leftInput->t<T>(r,c) * output->t<T>(c, 0);
|
||||
for (auto j = 0; j < cols; j++) {
|
||||
auto sum = rightInput->t<T>(r, j);
|
||||
for (auto c = 0; c < r; c++) {
|
||||
sum -= leftInput->t<T>(r, c) * output->t<T>(c, j);
|
||||
}
|
||||
output->t<T>(r, j) = sum / leftInput->t<T>(r, r);
|
||||
}
|
||||
output->t<T>(r, 0) = sum / leftInput->t<T>(r, r);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -68,13 +71,15 @@ namespace helpers {
|
|||
template <typename T>
|
||||
static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
||||
auto rows = leftInput->rows();
|
||||
|
||||
auto cols = rightInput->columns();
|
||||
for (auto r = rows; r > 0; r--) {
|
||||
auto sum = rightInput->t<T>(r - 1, 0);
|
||||
for (auto c = r; c < rows; c++) {
|
||||
sum -= leftInput->t<T>(r - 1, c) * output->t<T>(c, 0);
|
||||
for (auto j = 0; j < cols; j++) {
|
||||
auto sum = rightInput->t<T>(r - 1, j);
|
||||
for (auto c = r; c < rows; c++) {
|
||||
sum -= leftInput->t<T>(r - 1, c) * output->t<T>(c, j);
|
||||
}
|
||||
output->t<T>(r - 1, j) = sum / leftInput->t<T>(r - 1, r - 1);
|
||||
}
|
||||
output->t<T>(r - 1, 0) = sum / leftInput->t<T>(r - 1, r - 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit, K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#include <NDArray.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <MmulHelper.h>
|
||||
|
||||
#include <execution/Threads.h>
|
||||
#include <ConstantTadHelper.h>
|
||||
#include "../triangular_solve.h"
|
||||
#include "../lup.h"
|
||||
#include "../solve.h"
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
static __global__ void oneOnDiagonalKernel(T* ioBuf, Nd4jLong* ioShape, Nd4jLong* tadShape, Nd4jLong* tadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) {
|
||||
for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) {
|
||||
auto matrixPart = ioBuf + tadOffsets[i];
|
||||
for (auto j = threadIdx.x; j < rowNum; j += blockDim.x) {
|
||||
Nd4jLong pos[] = {j, j};
|
||||
auto offset = shape::getOffset(tadShape, pos);
|
||||
|
||||
matrixPart[offset] = T(1.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void restorePermutationsKernel(T* PBuf, Nd4jLong* PShapeInfo, int const* permutationsBuf,
|
||||
Nd4jLong* PTadShapeInfo, Nd4jLong* PTadSOffsets, Nd4jLong* permutationsTadShapeInfo, Nd4jLong* permutationsTadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) {
|
||||
for (auto batch = blockIdx.x; batch < batchNum; batch += gridDim.x) {
|
||||
auto permutations = permutationsBuf + permutationsTadOffsets[batch];
|
||||
auto P = PBuf + PTadSOffsets[batch];
|
||||
|
||||
for (auto row = threadIdx.x; row < rowNum; row += blockDim.x) {
|
||||
//auto posX[] = {row};
|
||||
Nd4jLong posZ[] = {row, permutations[row]};
|
||||
auto zOffset = shape::getOffset(PTadShapeInfo, posZ);
|
||||
P[zOffset] = T(1.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput,
|
||||
bool adjoint, NDArray* output) {
|
||||
NDArray::prepareSpecialUse({output}, {leftInput, rightInput});
|
||||
// stage 1: LU decomposition batched
|
||||
auto leftOutput = leftInput->ulike();
|
||||
auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back();
|
||||
auto permutations = NDArrayFactory::create<int>('c', permuShape, context);
|
||||
helpers::lu(context, leftInput, &leftOutput, &permutations);
|
||||
auto leftLower = leftOutput.dup();
|
||||
auto rightOutput = rightInput->ulike();
|
||||
auto leftLowerTad = ConstantTadHelper::getInstance()->tadForDimensions(leftLower.getShapeInfo(), {-2, -1});
|
||||
auto stream = context->getCudaStream();
|
||||
oneOnDiagonalKernel<T><<<128, 256, 256, *stream>>>(leftLower.dataBuffer()->specialAsT<T>(), leftLower.specialShapeInfo(), leftLowerTad.specialShapeInfo(), leftLowerTad.specialOffsets(), leftLowerTad.numberOfTads(), leftLower.sizeAt(-1));
|
||||
auto P = leftOutput.ulike(); P.nullify();
|
||||
auto PTad = ConstantTadHelper::getInstance()->tadForDimensions(P.getShapeInfo(), {-2, -1});
|
||||
auto permutationsTad = ConstantTadHelper::getInstance()->tadForDimensions(permutations.getShapeInfo(), {-1});
|
||||
restorePermutationsKernel<T><<<128, 256, 256, *stream>>>(P.dataBuffer()->specialAsT<T>(), P.specialShapeInfo(), permutations.dataBuffer()->specialAsT<int>(),
|
||||
PTad.specialShapeInfo(), PTad.specialOffsets(), permutationsTad.specialShapeInfo(), permutationsTad.specialOffsets(), permutationsTad.numberOfTads(), permutations.sizeAt(-1));
|
||||
P.tickWriteDevice();
|
||||
auto rightPart = rightInput->ulike();
|
||||
MmulHelper::matmul(&P, rightInput, &rightPart, 0, 0);
|
||||
|
||||
// stage 2: triangularSolveFunctor for Lower with given b
|
||||
helpers::triangularSolveFunctor(context, &leftLower, &rightPart, true, false, &rightOutput);
|
||||
// stage 3: triangularSolveFunctor for Upper with output of previous stage
|
||||
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output);
|
||||
NDArray::registerSpecialUse({output}, {leftInput, rightInput});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void adjointKernel(T* output, Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, Nd4jLong* outputTads,
|
||||
Nd4jLong* outputOffsets) {
|
||||
|
||||
for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) {
|
||||
auto outputPart = output + outputOffsets[b];
|
||||
for (auto r = threadIdx.x; r < rows; r += blockDim.x) {
|
||||
for (auto c = threadIdx.y; c < r; c += blockDim.y) {
|
||||
Nd4jLong zPos[] = {r, c};
|
||||
Nd4jLong xPos[] = {c, r};
|
||||
auto zIndex = shape::getOffset(outputTads, zPos);
|
||||
auto xIndex = shape::getOffset(outputTads, xPos);
|
||||
math::nd4j_swap(outputPart[zIndex], outputPart[xIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {-2, -1});
|
||||
auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {-2, -1});
|
||||
auto stream = context->getCudaStream();
|
||||
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer());
|
||||
auto rows = input->sizeAt(-2);
|
||||
auto columns = input->sizeAt(-1);
|
||||
output->assign(input);
|
||||
adjointKernel<T><<<128, 256, 256, *stream>>>(outputBuf, outputTads.numberOfTads(), rows, columns, outputTads.specialShapeInfo(), outputTads.specialOffsets());
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
}
|
||||
|
||||
void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -44,24 +44,26 @@ namespace nd4j {
|
|||
static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
|
||||
T const* rightInput, Nd4jLong const* rightInputShape,
|
||||
bool const adjoint, T* output, Nd4jLong* outputShape,
|
||||
Nd4jLong rows) {
|
||||
Nd4jLong rows, Nd4jLong cols) {
|
||||
|
||||
for (auto r = 0; r < rows; r++) {
|
||||
Nd4jLong posY[] = {r, 0};
|
||||
Nd4jLong posX[] = {r, r};
|
||||
auto xIndex = shape::getOffset(leftInputShape, posX, 0);
|
||||
auto yIndex = shape::getOffset(rightInputShape, posY, 0);
|
||||
auto zIndex = shape::getOffset(outputShape, posY, 0);
|
||||
for (auto j = 0; j < cols; j++) {
|
||||
Nd4jLong posY[] = {r, j};
|
||||
Nd4jLong posX[] = {r, r};
|
||||
auto xIndex = shape::getOffset(leftInputShape, posX, 0);
|
||||
auto yIndex = shape::getOffset(rightInputShape, posY, 0);
|
||||
auto zIndex = shape::getOffset(outputShape, posY, 0);
|
||||
|
||||
auto sum = rightInput[yIndex];
|
||||
for (auto c = 0; c < r; c++) {
|
||||
Nd4jLong posZ[] = {c, 0};
|
||||
Nd4jLong pos[] = {r, c};
|
||||
auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
|
||||
auto zcIndex = shape::getOffset(outputShape, posZ, 0);
|
||||
sum -= leftInput[xcIndex] * output[zcIndex];
|
||||
auto sum = rightInput[yIndex];
|
||||
for (auto c = 0; c < r; c++) {
|
||||
Nd4jLong posZ[] = {c, j};
|
||||
Nd4jLong pos[] = {r, c};
|
||||
auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
|
||||
auto zcIndex = shape::getOffset(outputShape, posZ, 0);
|
||||
sum -= leftInput[xcIndex] * output[zcIndex];
|
||||
}
|
||||
output[zIndex] = sum / leftInput[xIndex];
|
||||
}
|
||||
output[zIndex] = sum / leftInput[xIndex];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -82,23 +84,25 @@ namespace nd4j {
|
|||
template <typename T>
|
||||
static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
|
||||
T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output,
|
||||
Nd4jLong* outputShape, Nd4jLong rows) {
|
||||
Nd4jLong* outputShape, Nd4jLong rows, Nd4jLong cols) {
|
||||
|
||||
for (auto r = rows; r > 0; r--) {
|
||||
Nd4jLong posY[] = {r - 1, 0};
|
||||
Nd4jLong posX[] = {r - 1, r - 1};
|
||||
auto xIndex = shape::getOffset(leftInputShape, posX, 0);
|
||||
auto yIndex = shape::getOffset(rightInputShape, posY, 0);
|
||||
auto zIndex = shape::getOffset(outputShape, posY, 0);
|
||||
auto sum = rightInput[yIndex];
|
||||
for (auto c = r; c < rows; c++) {
|
||||
Nd4jLong posZ[] = {c, 0};
|
||||
Nd4jLong pos[] = {r - 1, c};
|
||||
auto zcIndex = shape::getOffset(outputShape, posZ, 0);
|
||||
auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
|
||||
sum -= leftInput[xcIndex] * output[zcIndex];
|
||||
for (auto j = 0; j < cols; j++) {
|
||||
Nd4jLong posY[] = {r - 1, j};
|
||||
Nd4jLong posX[] = {r - 1, r - 1};
|
||||
auto xIndex = shape::getOffset(leftInputShape, posX, 0);
|
||||
auto yIndex = shape::getOffset(rightInputShape, posY, 0);
|
||||
auto zIndex = shape::getOffset(outputShape, posY, 0);
|
||||
auto sum = rightInput[yIndex];
|
||||
for (auto c = r; c < rows; c++) {
|
||||
Nd4jLong posZ[] = {c, j};
|
||||
Nd4jLong pos[] = {r - 1, c};
|
||||
auto zcIndex = shape::getOffset(outputShape, posZ, 0);
|
||||
auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
|
||||
sum -= leftInput[xcIndex] * output[zcIndex];
|
||||
}
|
||||
output[zIndex] = sum / leftInput[xIndex];
|
||||
}
|
||||
output[zIndex] = sum / leftInput[xIndex];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -109,8 +113,11 @@ namespace nd4j {
|
|||
Nd4jLong* tadRightOffset, Nd4jLong* tadOutputShape, Nd4jLong* tadOutputOffset, Nd4jLong batchNum) {
|
||||
|
||||
__shared__ Nd4jLong rows;
|
||||
__shared__ Nd4jLong cols;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
rows = shape::sizeAt(leftPartShape, -2);
|
||||
cols = shape::sizeAt(rightPartShape, -1);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
@ -123,9 +130,9 @@ namespace nd4j {
|
|||
auto pRightPart = rightInput + tadRightOffset[i];
|
||||
auto pOutputPart = output + tadOutputOffset[i];
|
||||
if (lower) {
|
||||
lowerTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows);
|
||||
lowerTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols);
|
||||
} else {
|
||||
upperTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows);
|
||||
upperTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
#ifndef __SOLVE__H_HELPERS__
|
||||
#define __SOLVE__H_HELPERS__
|
||||
#include <op_boilerplate.h>
|
||||
#include <NDArray.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
int solveFunctor(nd4j::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output);
|
||||
void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
|
@ -1541,6 +1541,166 @@ TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
|
|||
delete []arr;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, Solve_Test_1) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||
2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<float>('c', {3, 1}, {
|
||||
2.f, 4.f, 3.f
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {
|
||||
7.625f, 3.25f, 5.f
|
||||
});
|
||||
|
||||
nd4j::ops::solve op;
|
||||
|
||||
auto res = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
// z->printIndexedBuffer("Solve of 3x3");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, Solve_Test_2) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
|
||||
1.f, 1.f, 1.f, 1.f,
|
||||
0.f, 1.f, 1.f, 0.f,
|
||||
0.f, 0.f, 2.f, 1.f,
|
||||
0.f, 0.f, 0.f, 3.f,
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<float>('c', {4, 1}, {
|
||||
2.f, 4.f, 2.f, 4.f
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {4, 1}, {
|
||||
-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f
|
||||
});
|
||||
|
||||
nd4j::ops::solve op;
|
||||
|
||||
auto res = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
// z->printIndexedBuffer("Solve 4x4");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, Solve_Test_3) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {2, 4, 4}, {
|
||||
1.f, 1.f, 1.f, 1.f,
|
||||
0.f, 1.f, 1.f, 0.f,
|
||||
0.f, 0.f, 2.f, 1.f,
|
||||
0.f, 0.f, 0.f, 3.f,
|
||||
|
||||
3.f, 0.f, 0.f, 0.f,
|
||||
2.f, 1.f, 0.f, 0.f,
|
||||
1.f, 0.f, 1.f, 0.f,
|
||||
1.f, 1.f, 1.f, 1.f
|
||||
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<float>('c', {2, 4, 1}, {
|
||||
2.f, 4.f, 2.f, 4.f,
|
||||
4.f, 2.f, 4.f, 2.f
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 4, 1}, {
|
||||
-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f,
|
||||
1.333333f, -0.6666667f, 2.6666667f, -1.3333333f
|
||||
});
|
||||
|
||||
nd4j::ops::solve op;
|
||||
|
||||
auto res = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
// z->printIndexedBuffer("Solve 4x4");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, Solve_Test_4) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||
0.7788f, 0.8012f, 0.7244f, 0.2309f,
|
||||
0.7271f, 0.1804f, 0.5056f, 0.8925f
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||
0.7717f, 0.9281f, 0.9846f, 0.4838f,
|
||||
0.6433f, 0.6041f, 0.6501f, 0.7612f
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||
// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f,
|
||||
// 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f
|
||||
1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f,
|
||||
0.81915987f, 0.72049433f, 0.2643504f, 0.44472617f
|
||||
});
|
||||
|
||||
nd4j::ops::solve op;
|
||||
|
||||
auto res = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
// z->printBuffer("4 Solve 4x4");
|
||||
// exp.printBuffer("4 Expec 4x4");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, Solve_Test_5) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||
0.7788f, 0.8012f, 0.7244f,
|
||||
0.2309f, 0.7271f, 0.1804f,
|
||||
0.5056f, 0.8925f, 0.5461f
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||
0.7717f, 0.9281f, 0.9846f,
|
||||
0.4838f, 0.6433f, 0.6041f,
|
||||
0.6501f, 0.7612f, 0.7605f
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||
1.5504692f, 1.8953944f, 2.2765768f,
|
||||
0.03399149f, 0.2883001f, 0.5377323f,
|
||||
-0.8774802f, -1.2155888f, -1.8049058f
|
||||
});
|
||||
|
||||
nd4j::ops::solve op;
|
||||
|
||||
auto res = op.evaluate({&a, &b}, {true});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
z->printBuffer("4 Solve 4x4");
|
||||
exp.printBuffer("4 Expec 4x4");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {
|
||||
|
||||
|
|
|
@ -3008,3 +3008,33 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) {
|
|||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) {
|
||||
|
||||
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
|
||||
5.f, 1.f, -3.f, 3.f,
|
||||
0.f, 1.f, 1.f, -1.f,
|
||||
0.f, 0.f, 2.f, -9.f,
|
||||
0.f, 0.f, 0.f, 4.f
|
||||
});
|
||||
|
||||
auto b = NDArrayFactory::create<float>('c', {4, 2}, {
|
||||
5.f, 1.f, 2.f, 1.f, 0.f, 1.f, -3.f, 1.f
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {4, 2}, {
|
||||
1.f,0.2f, 1.f,0.8f, 1.f,0.4f, 1.f,1.2f
|
||||
});
|
||||
|
||||
nd4j::ops::triangular_solve op;
|
||||
|
||||
auto res = op.evaluate({&a, &b}, {}, {}, {false, true});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
|
||||
z->printIndexedBuffer("TriangularSolve with adjoint");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
delete res;
|
||||
}
|
|
@ -39,7 +39,7 @@ do
|
|||
done
|
||||
|
||||
CHIP="${CHIP:-cpu}"
|
||||
export GTEST_OUTPUT="xml:../target/surefire-reports/TEST-${CHIP}-results.xml"
|
||||
export GTEST_OUTPUT="xml:surefire-reports/TEST-${CHIP}-results.xml"
|
||||
|
||||
# On Mac, make sure it can find libraries for GCC
|
||||
export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/
|
||||
|
@ -48,9 +48,11 @@ export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/
|
|||
if [ -n "$BUILD_PATH" ]; then
|
||||
if which cygpath; then
|
||||
BUILD_PATH=$(cygpath -p $BUILD_PATH)
|
||||
export GTEST_OUTPUT="xml:'..\target\surefire-reports\TEST-${CHIP}-results.xml'"
|
||||
fi
|
||||
export PATH="$PATH:$BUILD_PATH"
|
||||
fi
|
||||
|
||||
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests
|
||||
|
||||
# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion)
|
||||
[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/
|
||||
|
|
|
@ -623,7 +623,8 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.custom.Igammac.class,
|
||||
org.nd4j.linalg.api.ops.custom.Digamma.class,
|
||||
org.nd4j.linalg.api.ops.custom.Lu.class,
|
||||
org.nd4j.linalg.api.ops.custom.TriangularSolve.class
|
||||
org.nd4j.linalg.api.ops.custom.TriangularSolve.class,
|
||||
org.nd4j.linalg.api.ops.custom.LinearSolve.class
|
||||
);
|
||||
|
||||
static {
|
||||
|
|
|
@ -16,26 +16,48 @@
|
|||
|
||||
package org.nd4j.linalg.api.buffer;
|
||||
|
||||
/**
|
||||
* Enum lists supported data types.
|
||||
*
|
||||
*/
|
||||
public enum DataType {
|
||||
|
||||
DOUBLE,
|
||||
FLOAT,
|
||||
|
||||
/**
|
||||
* @deprecated Replaced by {@link DataType#FLOAT16}, use that instead
|
||||
*/
|
||||
@Deprecated
|
||||
HALF,
|
||||
|
||||
/**
|
||||
* @deprecated Replaced by {@link DataType#INT64}, use that instead
|
||||
*/
|
||||
@Deprecated
|
||||
LONG,
|
||||
|
||||
/**
|
||||
* @deprecated Replaced by {@link DataType#INT32}, use that instead
|
||||
*/
|
||||
@Deprecated
|
||||
INT,
|
||||
|
||||
/**
|
||||
* @deprecated Replaced by {@link DataType#INT16}, use that instead
|
||||
*/
|
||||
@Deprecated
|
||||
SHORT,
|
||||
|
||||
/**
|
||||
* @deprecated Replaced by {@link DataType#UINT8}, use that instead
|
||||
*/
|
||||
@Deprecated
|
||||
UBYTE,
|
||||
|
||||
/**
|
||||
* @deprecated Replaced by {@link DataType#INT8}, use that instead
|
||||
*/
|
||||
@Deprecated
|
||||
BYTE,
|
||||
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class LinearSolve extends DynamicCustomOp {
|
||||
|
||||
public LinearSolve(INDArray a, INDArray b, boolean adjoint) {
|
||||
addInputArgument(a, b);
|
||||
addBArgument(adjoint);
|
||||
}
|
||||
|
||||
public LinearSolve(INDArray a, INDArray b) {
|
||||
this(a,b,false);
|
||||
}
|
||||
|
||||
public LinearSolve(SameDiff sameDiff, SDVariable a, SDVariable b, SDVariable adjoint) {
|
||||
super(sameDiff, new SDVariable[] {a, b, adjoint});
|
||||
}
|
||||
|
||||
public LinearSolve(SameDiff sameDiff, SDVariable a, SDVariable b, boolean adjoint) {
|
||||
super(sameDiff, new SDVariable[] {a, b});
|
||||
addBArgument(adjoint);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "solve";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "MatrixSolve";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
boolean adjoint = attributesForNode.containsKey("adjoint") ? attributesForNode.get("adjoint").getB() : false;
|
||||
addBArgument(adjoint);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
int n = args().length;
|
||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), dataTypes);
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -1691,4 +1691,50 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
|
||||
assertEquals(e, x);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLinearSolve() {
|
||||
INDArray a = Nd4j.createFromArray(new float[]{
|
||||
2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f
|
||||
}).reshape(3, 3);
|
||||
|
||||
INDArray b = Nd4j.createFromArray(new float[]{
|
||||
2.f, 4.f, 3.f
|
||||
}).reshape(3, 1);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||
7.625f, 3.25f, 5.f
|
||||
}).reshape(3, 1);
|
||||
|
||||
val op = new LinearSolve(a, b);
|
||||
INDArray[] ret = Nd4j.exec(op);
|
||||
|
||||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLinearSolveAdjust() {
|
||||
INDArray a = Nd4j.createFromArray(new float[]{
|
||||
0.7788f, 0.8012f, 0.7244f,
|
||||
0.2309f, 0.7271f, 0.1804f,
|
||||
0.5056f, 0.8925f, 0.5461f
|
||||
}).reshape(3, 3);
|
||||
|
||||
INDArray b = Nd4j.createFromArray(new float[]{
|
||||
0.7717f, 0.9281f, 0.9846f,
|
||||
0.4838f, 0.6433f, 0.6041f,
|
||||
0.6501f, 0.7612f, 0.7605f
|
||||
}).reshape(3, 3);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||
1.5504692f, 1.8953944f, 2.2765768f,
|
||||
0.03399149f, 0.2883001f , 0.5377323f,
|
||||
-0.8774802f, -1.2155888f, -1.8049058f
|
||||
}).reshape(3, 3);
|
||||
|
||||
val op = new LinearSolve(a, b, true);
|
||||
INDArray[] ret = Nd4j.exec(op);
|
||||
|
||||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue