From 41ff907bc628f86db2bde6034ec809921280cfe1 Mon Sep 17 00:00:00 2001 From: shugeo Date: Tue, 4 Feb 2020 07:59:11 +0200 Subject: [PATCH] Shugeo solve linear (#191) * linear equations systems solve op. Initial commit. Signed-off-by: shugeo * Fixed compiling issues. Signed-off-by: shugeo * Linear equations systems solve. The next stage commit. Signed-off-by: shugeo * Added test for linear equations systems solve operation. Signed-off-by: shugeo * Added additional test and fixed lower matrix retrievance. * Implementation for solve of the systems of linear equations." Signed-off-by: shugeo * Refactored permutation generation. Signed-off-by: shugeo * Added restore for permutations batched with cuda helper for solve op. Signed-off-by: shugeo * Finished cuda implementation for solve op helpers. Signed-off-by: shugeo * Refactored cpu helpers for solve op. Signed-off-by: shugeo * Fix gtest output on Windows * Fixed issue with permutation matrix for cuda implementation. Signed-off-by: shugeo * Fixed issue with permutation matrix for cpu implementation. Signed-off-by: shugeo * Eliminated waste comments. Signed-off-by: shugeo * LinearSolve added * Mapping added * Javadoc added * Refactored implementation of triangular_solve helpers and tests for solve matrix equations generally. Signed-off-by: shugeo * Added a test for solve op. Signed-off-by: shugeo * Solve test added * Fix for TF import Co-authored-by: Serhii Shepel <9946053+sshepel@users.noreply.github.com> Co-authored-by: raver119 Co-authored-by: Alexander Stoyakin --- .../declarable/generic/parity_ops/solve.cpp | 75 ++++++++ .../ops/declarable/headers/parity_ops.h | 18 ++ .../ops/declarable/helpers/cpu/lup.cpp | 75 ++++++-- .../ops/declarable/helpers/cpu/solve.cpp | 100 +++++++++++ .../helpers/cpu/triangular_solve.cpp | 23 ++- .../ops/declarable/helpers/cuda/solve.cu | 140 +++++++++++++++ .../helpers/cuda/triangular_solve.cu | 67 ++++---- .../include/ops/declarable/helpers/solve.h | 34 ++++ .../layers_tests/DeclarableOpsTests11.cpp | 160 ++++++++++++++++++ .../layers_tests/DeclarableOpsTests12.cpp | 30 ++++ libnd4j/tests_cpu/run_tests.sh | 6 +- .../converters/ImportClassMapping.java | 3 +- .../org/nd4j/linalg/api/buffer/DataType.java | 22 +++ .../linalg/api/ops/custom/LinearSolve.java | 77 +++++++++ .../nd4j/linalg/custom/CustomOpsTests.java | 46 +++++ 15 files changed, 818 insertions(+), 58 deletions(-) create mode 100644 libnd4j/include/ops/declarable/generic/parity_ops/solve.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/solve.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/solve.cu create mode 100644 libnd4j/include/ops/declarable/helpers/solve.h create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/solve.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/solve.cpp new file mode 100644 index 000000000..5790ae960 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/solve.cpp @@ -0,0 +1,75 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by GS at 01/22/2020 +// + +#include +#if NOT_EXCLUDED(OP_solve) + +#include +#include +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(solve, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + bool useAdjoint = false; + + if (block.numB() > 0) { + useAdjoint = B_ARG(0); + } + + REQUIRE_TRUE(a->rankOf() >=2, 0, "solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); + REQUIRE_TRUE(b->rankOf() >=2, 0, "solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); + + REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); + auto input = a; + if (useAdjoint) { + auto adjointA = a->ulike(); + helpers::adjointMatrix(block.launchContext(), a, &adjointA); + input = new NDArray(adjointA); //.detach(); + }; + + auto res = helpers::solveFunctor(block.launchContext(), input, b, useAdjoint, z); + if (input != a) + delete input; + + return Status::OK(); + } + + DECLARE_SHAPE_FN(solve) { + auto in0 = inputShape->at(1); + auto in1 = inputShape->at(1); + auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); + + return SHAPELIST(CONSTANT(luShape)); + } + + DECLARE_TYPES(solve) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index c5d0ff207..1b5bda29a 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1076,6 +1076,24 @@ namespace nd4j { DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0); #endif + /** + * solve op. - solve systems of linear equations - general method. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations + * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations + * + * boolean args: + * 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ + #if NOT_EXCLUDED(OP_solve) + DECLARE_CUSTOM_OP(solve, 2, 1, true, 0, 0); + #endif + /** * lu op. - make LUP decomposition of given batch of 2D square matricies * diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 9c7cb1bfe..2856e73b9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -237,25 +237,65 @@ namespace helpers { samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1); } + template + static void doolitleLU(LaunchContext* context, NDArray* compound, Nd4jLong rowNum) { + auto input = compound->dup(); + compound->nullify(); + + // Decomposing matrix into Upper and Lower + // triangular matrix + for (auto i = 0; i < rowNum; i++) { + + // Upper Triangular + for (auto k = i; k < rowNum; k++) { + + // Summation of L(i, j) * U(j, k) + int sum = 0; + for (int j = 0; j < i; j++) + sum += compound->t(i,j) * compound->t(j,k); + + // Evaluating U(i, k) + compound->t(i, k) = input.t(i, k) - sum; + } + + // Lower Triangular + for (int k = i + 1; k < rowNum; k++) { + // Summation of L(k, j) * U(j, i) + int sum = 0; + for (int j = 0; j < i; j++) + sum += compound->t(k,j) * compound->t(j, i); + + // Evaluating L(k, i) + compound->t(k, i) = (input.t(k, i) - sum) / compound->t(i,i); + } + } + } + template static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) { //const int rowNum = compound->rows(); // const int columnNum = output->columns(); - permutation->linspace(0); - auto permutationBuf = permutation->bufferAsT(); //dataBuffer()->primaryAsT(); - auto compoundBuf = compound->bufferAsT(); - auto compoundShape = compound->shapeInfo(); - auto permutationShape = permutation->shapeInfo(); - for (auto i = 0; i < rowNum - 1; i++) { - auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape); - if (pivotIndex < 0) { - throw std::runtime_error("helpers::luNN_: input matrix is singular."); - } - math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); - swapRows(compoundBuf, compoundShape, i, pivotIndex); + if (permutation) { // LUP algorithm + permutation->linspace(0); + auto permutationBuf = permutation->bufferAsT(); //dataBuffer()->primaryAsT(); + auto compoundBuf = compound->bufferAsT(); + auto compoundShape = compound->shapeInfo(); + auto permutationShape = permutation->shapeInfo(); + for (auto i = 0; i < rowNum - 1; i++) { + auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape); + if (pivotIndex < 0) { + throw std::runtime_error("helpers::luNN_: input matrix is singular."); + } + math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], + permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); + swapRows(compoundBuf, compoundShape, i, pivotIndex); - processColumns(i, rowNum, compoundBuf, compoundShape); + processColumns(i, rowNum, compoundBuf, compoundShape); + } + } + else { // Doolitle algorithm with LU decomposition + doolitleLU(context, compound, rowNum); } } @@ -265,17 +305,20 @@ namespace helpers { output->assign(input); // fill up output tensor with zeros ResultSet outputs = output->allTensorsAlongDimension({-2, -1}); - ResultSet permutations = permutationVectors->allTensorsAlongDimension({-1}); + ResultSet permutations; + if (permutationVectors) + permutations = permutationVectors->allTensorsAlongDimension({-1}); + auto loop = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { - luNN_(context, outputs.at(i), permutations.at(i), n); + luNN_(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n); } }; samediff::Threads::parallel_for(loop, 0, outputs.size(), 1); } void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) { - BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation?permutation->dataType():DataType::INT32, lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES); } // BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp new file mode 100644 index 000000000..8583d9cba --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp @@ -0,0 +1,100 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// +#include +#include +#include +#include +#include + +#include "../triangular_solve.h" +#include "../lup.h" +#include "../solve.h" + +namespace nd4j { +namespace ops { +namespace helpers { + +// --------------------------------------------------------------------------------------------------------------------------------------- // + template + static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) { + auto inputPart = input->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + output->assign(input); + auto batchLoop = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch += increment) { + for (auto r = 0; r < input->rows(); r++) { + for (auto c = 0; c < r; c++) { + math::nd4j_swap(outputPart[batch]->t(r, c) , outputPart[batch]->t(c, r)); + } + } + } + }; + samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); + } + +// --------------------------------------------------------------------------------------------------------------------------------------- // + template + static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) { + + // stage 1: LU decomposition batched + auto leftOutput = leftInput->ulike(); + auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back(); + auto permutations = NDArrayFactory::create('c', permuShape, context); + helpers::lu(context, leftInput, &leftOutput, &permutations); + auto P = leftInput->ulike(); //permutations batched matrix + P.nullify(); // to fill up matricies with zeros + auto PPart = P.allTensorsAlongDimension({-2,-1}); + auto permutationsPart = permutations.allTensorsAlongDimension({-1}); + + for (auto batch = 0; batch < permutationsPart.size(); ++batch) { + for (auto row = 0; row < PPart[batch]->rows(); ++row) { + PPart[batch]->t(row, permutationsPart[batch]->t(row)) = T(1.f); + } + } + + auto leftLower = leftOutput.dup(); + auto rightOutput = rightInput->ulike(); + auto rightPermuted = rightOutput.ulike(); + MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0); + ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1}); + for (auto i = 0; i < leftLowerPart.size(); i++) { + for (auto r = 0; r < leftLowerPart[i]->rows(); r++) + leftLowerPart[i]->t(r,r) = (T)1.f; + } + // stage 2: triangularSolveFunctor for Lower with given b + helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput); + // stage 3: triangularSolveFunctor for Upper with output of previous stage + helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output); + + return Status::OK(); + } + +// --------------------------------------------------------------------------------------------------------------------------------------- // + int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES); + } +// --------------------------------------------------------------------------------------------------------------------------------------- // + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES); + } +// --------------------------------------------------------------------------------------------------------------------------------------- // +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp index ab409a0c6..e904d219c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp @@ -41,13 +41,16 @@ namespace helpers { template static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { auto rows = leftInput->rows(); + auto cols = rightInput->columns(); //output->t(0,0) = rightInput->t(0,0) / leftInput->t(0,0); for (auto r = 0; r < rows; r++) { - auto sum = rightInput->t(r, 0); - for (auto c = 0; c < r; c++) { - sum -= leftInput->t(r,c) * output->t(c, 0); + for (auto j = 0; j < cols; j++) { + auto sum = rightInput->t(r, j); + for (auto c = 0; c < r; c++) { + sum -= leftInput->t(r, c) * output->t(c, j); + } + output->t(r, j) = sum / leftInput->t(r, r); } - output->t(r, 0) = sum / leftInput->t(r, r); } } @@ -68,13 +71,15 @@ namespace helpers { template static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { auto rows = leftInput->rows(); - + auto cols = rightInput->columns(); for (auto r = rows; r > 0; r--) { - auto sum = rightInput->t(r - 1, 0); - for (auto c = r; c < rows; c++) { - sum -= leftInput->t(r - 1, c) * output->t(c, 0); + for (auto j = 0; j < cols; j++) { + auto sum = rightInput->t(r - 1, j); + for (auto c = r; c < rows; c++) { + sum -= leftInput->t(r - 1, c) * output->t(c, j); + } + output->t(r - 1, j) = sum / leftInput->t(r - 1, r - 1); } - output->t(r - 1, 0) = sum / leftInput->t(r - 1, r - 1); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/solve.cu b/libnd4j/include/ops/declarable/helpers/cuda/solve.cu new file mode 100644 index 000000000..6437b80bd --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/solve.cu @@ -0,0 +1,140 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// + +#include +#include +#include +#include + +#include +#include +#include "../triangular_solve.h" +#include "../lup.h" +#include "../solve.h" + +namespace nd4j { + namespace ops { + namespace helpers { + + template + static __global__ void oneOnDiagonalKernel(T* ioBuf, Nd4jLong* ioShape, Nd4jLong* tadShape, Nd4jLong* tadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) { + for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { + auto matrixPart = ioBuf + tadOffsets[i]; + for (auto j = threadIdx.x; j < rowNum; j += blockDim.x) { + Nd4jLong pos[] = {j, j}; + auto offset = shape::getOffset(tadShape, pos); + + matrixPart[offset] = T(1.f); + } + } + } + + template + static __global__ void restorePermutationsKernel(T* PBuf, Nd4jLong* PShapeInfo, int const* permutationsBuf, + Nd4jLong* PTadShapeInfo, Nd4jLong* PTadSOffsets, Nd4jLong* permutationsTadShapeInfo, Nd4jLong* permutationsTadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) { + for (auto batch = blockIdx.x; batch < batchNum; batch += gridDim.x) { + auto permutations = permutationsBuf + permutationsTadOffsets[batch]; + auto P = PBuf + PTadSOffsets[batch]; + + for (auto row = threadIdx.x; row < rowNum; row += blockDim.x) { + //auto posX[] = {row}; + Nd4jLong posZ[] = {row, permutations[row]}; + auto zOffset = shape::getOffset(PTadShapeInfo, posZ); + P[zOffset] = T(1.f); + } + } + } + + template + static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, + bool adjoint, NDArray* output) { + NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); + // stage 1: LU decomposition batched + auto leftOutput = leftInput->ulike(); + auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back(); + auto permutations = NDArrayFactory::create('c', permuShape, context); + helpers::lu(context, leftInput, &leftOutput, &permutations); + auto leftLower = leftOutput.dup(); + auto rightOutput = rightInput->ulike(); + auto leftLowerTad = ConstantTadHelper::getInstance()->tadForDimensions(leftLower.getShapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + oneOnDiagonalKernel<<<128, 256, 256, *stream>>>(leftLower.dataBuffer()->specialAsT(), leftLower.specialShapeInfo(), leftLowerTad.specialShapeInfo(), leftLowerTad.specialOffsets(), leftLowerTad.numberOfTads(), leftLower.sizeAt(-1)); + auto P = leftOutput.ulike(); P.nullify(); + auto PTad = ConstantTadHelper::getInstance()->tadForDimensions(P.getShapeInfo(), {-2, -1}); + auto permutationsTad = ConstantTadHelper::getInstance()->tadForDimensions(permutations.getShapeInfo(), {-1}); + restorePermutationsKernel<<<128, 256, 256, *stream>>>(P.dataBuffer()->specialAsT(), P.specialShapeInfo(), permutations.dataBuffer()->specialAsT(), + PTad.specialShapeInfo(), PTad.specialOffsets(), permutationsTad.specialShapeInfo(), permutationsTad.specialOffsets(), permutationsTad.numberOfTads(), permutations.sizeAt(-1)); + P.tickWriteDevice(); + auto rightPart = rightInput->ulike(); + MmulHelper::matmul(&P, rightInput, &rightPart, 0, 0); + + // stage 2: triangularSolveFunctor for Lower with given b + helpers::triangularSolveFunctor(context, &leftLower, &rightPart, true, false, &rightOutput); + // stage 3: triangularSolveFunctor for Upper with output of previous stage + helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output); + NDArray::registerSpecialUse({output}, {leftInput, rightInput}); + + return Status::OK(); + } + + int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES); + } + + template + static __global__ void adjointKernel(T* output, Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, Nd4jLong* outputTads, + Nd4jLong* outputOffsets) { + + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = threadIdx.y; c < r; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(outputTads, xPos); + math::nd4j_swap(outputPart[zIndex], outputPart[xIndex]); + } + } + } + + } + + template + static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input}); + auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + auto outputBuf = reinterpret_cast(output->specialBuffer()); + auto rows = input->sizeAt(-2); + auto columns = input->sizeAt(-1); + output->assign(input); + adjointKernel<<<128, 256, 256, *stream>>>(outputBuf, outputTads.numberOfTads(), rows, columns, outputTads.specialShapeInfo(), outputTads.specialOffsets()); + NDArray::registerSpecialUse({output}, {input}); + } + + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES); + } + + } + } +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu index 8846be45c..6f5fe6b8c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu @@ -44,24 +44,26 @@ namespace nd4j { static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output, Nd4jLong* outputShape, - Nd4jLong rows) { + Nd4jLong rows, Nd4jLong cols) { for (auto r = 0; r < rows; r++) { - Nd4jLong posY[] = {r, 0}; - Nd4jLong posX[] = {r, r}; - auto xIndex = shape::getOffset(leftInputShape, posX, 0); - auto yIndex = shape::getOffset(rightInputShape, posY, 0); - auto zIndex = shape::getOffset(outputShape, posY, 0); + for (auto j = 0; j < cols; j++) { + Nd4jLong posY[] = {r, j}; + Nd4jLong posX[] = {r, r}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); - auto sum = rightInput[yIndex]; - for (auto c = 0; c < r; c++) { - Nd4jLong posZ[] = {c, 0}; - Nd4jLong pos[] = {r, c}; - auto xcIndex = shape::getOffset(leftInputShape, pos, 0); - auto zcIndex = shape::getOffset(outputShape, posZ, 0); - sum -= leftInput[xcIndex] * output[zcIndex]; + auto sum = rightInput[yIndex]; + for (auto c = 0; c < r; c++) { + Nd4jLong posZ[] = {c, j}; + Nd4jLong pos[] = {r, c}; + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = sum / leftInput[xIndex]; } - output[zIndex] = sum / leftInput[xIndex]; } } @@ -82,23 +84,25 @@ namespace nd4j { template static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output, - Nd4jLong* outputShape, Nd4jLong rows) { + Nd4jLong* outputShape, Nd4jLong rows, Nd4jLong cols) { for (auto r = rows; r > 0; r--) { - Nd4jLong posY[] = {r - 1, 0}; - Nd4jLong posX[] = {r - 1, r - 1}; - auto xIndex = shape::getOffset(leftInputShape, posX, 0); - auto yIndex = shape::getOffset(rightInputShape, posY, 0); - auto zIndex = shape::getOffset(outputShape, posY, 0); - auto sum = rightInput[yIndex]; - for (auto c = r; c < rows; c++) { - Nd4jLong posZ[] = {c, 0}; - Nd4jLong pos[] = {r - 1, c}; - auto zcIndex = shape::getOffset(outputShape, posZ, 0); - auto xcIndex = shape::getOffset(leftInputShape, pos, 0); - sum -= leftInput[xcIndex] * output[zcIndex]; + for (auto j = 0; j < cols; j++) { + Nd4jLong posY[] = {r - 1, j}; + Nd4jLong posX[] = {r - 1, r - 1}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); + auto sum = rightInput[yIndex]; + for (auto c = r; c < rows; c++) { + Nd4jLong posZ[] = {c, j}; + Nd4jLong pos[] = {r - 1, c}; + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = sum / leftInput[xIndex]; } - output[zIndex] = sum / leftInput[xIndex]; } } @@ -109,8 +113,11 @@ namespace nd4j { Nd4jLong* tadRightOffset, Nd4jLong* tadOutputShape, Nd4jLong* tadOutputOffset, Nd4jLong batchNum) { __shared__ Nd4jLong rows; + __shared__ Nd4jLong cols; + if (threadIdx.x == 0) { rows = shape::sizeAt(leftPartShape, -2); + cols = shape::sizeAt(rightPartShape, -1); } __syncthreads(); @@ -123,9 +130,9 @@ namespace nd4j { auto pRightPart = rightInput + tadRightOffset[i]; auto pOutputPart = output + tadOutputOffset[i]; if (lower) { - lowerTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows); + lowerTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols); } else { - upperTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows); + upperTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols); } } } diff --git a/libnd4j/include/ops/declarable/helpers/solve.h b/libnd4j/include/ops/declarable/helpers/solve.h new file mode 100644 index 000000000..d097fa217 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/solve.h @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// +#ifndef __SOLVE__H_HELPERS__ +#define __SOLVE__H_HELPERS__ +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + int solveFunctor(nd4j::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output); + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output); +} +} +} +#endif diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 71ebdc7e6..de4bdc31b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -1541,6 +1541,166 @@ TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { delete []arr; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_1) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f + }); + + auto b = NDArrayFactory::create('c', {3, 1}, { + 2.f, 4.f, 3.f + }); + + auto exp = NDArrayFactory::create('c', {3, 1}, { + 7.625f, 3.25f, 5.f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("Solve of 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_2) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 2.f, 4.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("Solve 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_3) { + + auto a = NDArrayFactory::create('c', {2, 4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + + }); + + auto b = NDArrayFactory::create('c', {2, 4, 1}, { + 2.f, 4.f, 2.f, 4.f, + 4.f, 2.f, 4.f, 2.f + }); + + auto exp = NDArrayFactory::create('c', {2, 4, 1}, { + -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f, + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("Solve 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4) { + + auto a = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f + }); + + auto b = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7717f, 0.9281f, 0.9846f, 0.4838f, + 0.6433f, 0.6041f, 0.6501f, 0.7612f + }); + + auto exp = NDArrayFactory::create('c', {2, 2, 2}, { +// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f, +// 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f + 1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f, + 0.81915987f, 0.72049433f, 0.2643504f, 0.44472617f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4 Solve 4x4"); +// exp.printBuffer("4 Expec 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_5) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 1.5504692f, 1.8953944f, 2.2765768f, + 0.03399149f, 0.2883001f, 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + + z->printBuffer("4 Solve 4x4"); + exp.printBuffer("4 Expec 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 142a3dbd4..6025216f9 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -3008,3 +3008,33 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) { ASSERT_TRUE(exp.equalsTo(z)); delete res; } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 5.f, 1.f, -3.f, 3.f, + 0.f, 1.f, 1.f, -1.f, + 0.f, 0.f, 2.f, -9.f, + 0.f, 0.f, 0.f, 4.f + }); + + auto b = NDArrayFactory::create('c', {4, 2}, { + 5.f, 1.f, 2.f, 1.f, 0.f, 1.f, -3.f, 1.f + }); + + auto exp = NDArrayFactory::create('c', {4, 2}, { + 1.f,0.2f, 1.f,0.8f, 1.f,0.4f, 1.f,1.2f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {}, {}, {false, true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + + z->printIndexedBuffer("TriangularSolve with adjoint"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/run_tests.sh b/libnd4j/tests_cpu/run_tests.sh index 9b1271df6..8f412dee5 100755 --- a/libnd4j/tests_cpu/run_tests.sh +++ b/libnd4j/tests_cpu/run_tests.sh @@ -39,7 +39,7 @@ do done CHIP="${CHIP:-cpu}" -export GTEST_OUTPUT="xml:../target/surefire-reports/TEST-${CHIP}-results.xml" +export GTEST_OUTPUT="xml:surefire-reports/TEST-${CHIP}-results.xml" # On Mac, make sure it can find libraries for GCC export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/ @@ -48,9 +48,11 @@ export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/ if [ -n "$BUILD_PATH" ]; then if which cygpath; then BUILD_PATH=$(cygpath -p $BUILD_PATH) - export GTEST_OUTPUT="xml:'..\target\surefire-reports\TEST-${CHIP}-results.xml'" fi export PATH="$PATH:$BUILD_PATH" fi ../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests + +# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion) +[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index f804d8c95..3ed96fe9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -623,7 +623,8 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.Igammac.class, org.nd4j.linalg.api.ops.custom.Digamma.class, org.nd4j.linalg.api.ops.custom.Lu.class, - org.nd4j.linalg.api.ops.custom.TriangularSolve.class + org.nd4j.linalg.api.ops.custom.TriangularSolve.class, + org.nd4j.linalg.api.ops.custom.LinearSolve.class ); static { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java index 7555bce21..94cfdca43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java @@ -16,26 +16,48 @@ package org.nd4j.linalg.api.buffer; +/** + * Enum lists supported data types. + * + */ public enum DataType { DOUBLE, FLOAT, + /** + * @deprecated Replaced by {@link DataType#FLOAT16}, use that instead + */ @Deprecated HALF, + /** + * @deprecated Replaced by {@link DataType#INT64}, use that instead + */ @Deprecated LONG, + /** + * @deprecated Replaced by {@link DataType#INT32}, use that instead + */ @Deprecated INT, + /** + * @deprecated Replaced by {@link DataType#INT16}, use that instead + */ @Deprecated SHORT, + /** + * @deprecated Replaced by {@link DataType#UINT8}, use that instead + */ @Deprecated UBYTE, + /** + * @deprecated Replaced by {@link DataType#INT8}, use that instead + */ @Deprecated BYTE, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java new file mode 100644 index 000000000..7c835006e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java @@ -0,0 +1,77 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +@NoArgsConstructor +public class LinearSolve extends DynamicCustomOp { + + public LinearSolve(INDArray a, INDArray b, boolean adjoint) { + addInputArgument(a, b); + addBArgument(adjoint); + } + + public LinearSolve(INDArray a, INDArray b) { + this(a,b,false); + } + + public LinearSolve(SameDiff sameDiff, SDVariable a, SDVariable b, SDVariable adjoint) { + super(sameDiff, new SDVariable[] {a, b, adjoint}); + } + + public LinearSolve(SameDiff sameDiff, SDVariable a, SDVariable b, boolean adjoint) { + super(sameDiff, new SDVariable[] {a, b}); + addBArgument(adjoint); + } + + @Override + public String opName() { + return "solve"; + } + + @Override + public String tensorflowName() { + return "MatrixSolve"; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + boolean adjoint = attributesForNode.containsKey("adjoint") ? attributesForNode.get("adjoint").getB() : false; + addBArgument(adjoint); + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + int n = args().length; + Preconditions.checkState(dataTypes != null && dataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), dataTypes); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 49ff345e7..b8d795460 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -1691,4 +1691,50 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(e, x); } + + @Test + public void testLinearSolve() { + INDArray a = Nd4j.createFromArray(new float[]{ + 2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f + }).reshape(3, 3); + + INDArray b = Nd4j.createFromArray(new float[]{ + 2.f, 4.f, 3.f + }).reshape(3, 1); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 7.625f, 3.25f, 5.f + }).reshape(3, 1); + + val op = new LinearSolve(a, b); + INDArray[] ret = Nd4j.exec(op); + + assertEquals(expected, ret[0]); + } + + @Test + public void testLinearSolveAdjust() { + INDArray a = Nd4j.createFromArray(new float[]{ + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }).reshape(3, 3); + + INDArray b = Nd4j.createFromArray(new float[]{ + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }).reshape(3, 3); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 1.5504692f, 1.8953944f, 2.2765768f, + 0.03399149f, 0.2883001f , 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f + }).reshape(3, 3); + + val op = new LinearSolve(a, b, true); + INDArray[] ret = Nd4j.exec(op); + + assertEquals(expected, ret[0]); + } }