diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/triangular_solve.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/triangular_solve.cpp new file mode 100644 index 000000000..181f47d3d --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/triangular_solve.cpp @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by GS at 01/14/2020 +// + +#include +#if NOT_EXCLUDED(OP_triangual_solve) + +#include +#include +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(triangular_solve, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + bool isLower = true; + bool useAdjoint = false; + + if (block.numB() > 0) { + if (block.numB() > 1) { + isLower = B_ARG(0); + useAdjoint = B_ARG(1); + } + else { + isLower = B_ARG(0); + } + } + + REQUIRE_TRUE(a->rankOf() >=2, 0, "triangular_solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); + REQUIRE_TRUE(b->rankOf() >=2, 0, "triangular_solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); + + REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "triangular_solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "triangular_solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); + auto input = a; + if (useAdjoint) { + auto adjointA = a->ulike(); + helpers::adjointMatrix(block.launchContext(), a, isLower, &adjointA); + input = new NDArray(adjointA); //.detach(); + isLower = !isLower; + }; + + auto res = helpers::triangularSolveFunctor(block.launchContext(), input, b, isLower, useAdjoint, z); + if (input != a) + delete input; + + return Status::OK(); + } + + DECLARE_SHAPE_FN(triangular_solve) { + auto in0 = inputShape->at(1); + auto in1 = inputShape->at(1); + auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); + + return SHAPELIST(CONSTANT(luShape)); + } + + DECLARE_TYPES(triangular_solve) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 4a1f85130..c218b8516 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1041,6 +1041,25 @@ namespace nd4j { DECLARE_OP(matrix_inverse, 1, 1, true); #endif + /** + * triangular_solve op. - reverse Gaussian method for solve systems of linear equations. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations + * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations + * + * boolean args: + * 0 - lower - default is true (optional) - left part is lower triangular matrix + * 1 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ + #if NOT_EXCLUDED(OP_triangular_solve) + DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0); + #endif + /** * lu op. - make LUP decomposition of given batch of 2D square matricies * diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp new file mode 100644 index 000000000..ab409a0c6 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp @@ -0,0 +1,135 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// +#include +#include +#include +#include "../triangular_solve.h" + +namespace nd4j { +namespace ops { +namespace helpers { + /* + * lower triangular process for system of linear equations + * x_1 = b_1/a_1,1 + * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 + * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 + * ... + * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M + * + * output == x + * a == leftInput + * b == rightInput + * + * */ + template + static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { + auto rows = leftInput->rows(); + //output->t(0,0) = rightInput->t(0,0) / leftInput->t(0,0); + for (auto r = 0; r < rows; r++) { + auto sum = rightInput->t(r, 0); + for (auto c = 0; c < r; c++) { + sum -= leftInput->t(r,c) * output->t(c, 0); + } + output->t(r, 0) = sum / leftInput->t(r, r); + } + } + + /* + * upper triangular process for system of linear equations + * x_M = b_M/a_M,M + * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 + * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 + * ... + * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 + * + * output == x + * a == leftInput + * b == rightInput + * + * */ + + template + static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { + auto rows = leftInput->rows(); + + for (auto r = rows; r > 0; r--) { + auto sum = rightInput->t(r - 1, 0); + for (auto c = r; c < rows; c++) { + sum -= leftInput->t(r - 1, c) * output->t(c, 0); + } + output->t(r - 1, 0) = sum / leftInput->t(r - 1, r - 1); + } + } + + template + static int triangularSolveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { + auto leftPart = leftInput->allTensorsAlongDimension({-2, -1}); + auto rightPart = rightInput->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + + auto batchLoop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + if (lower) { + lowerTriangularSolve(context, leftPart[i], rightPart[i], adjoint, outputPart[i]); + } else { + upperTriangularSolve(context, leftPart[i], rightPart[i], adjoint, outputPart[i]); + } + } + }; + + samediff::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1); + + return Status::OK(); + + } + template + static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { + auto inputPart = input->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto batchLoop = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch += increment) { + if (!lower) { + for (auto r = 0; r < input->rows(); r++) { + for (auto c = 0; c <= r; c++) { + outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r); + } + } + } else { + for (auto r = 0; r < input->rows(); r++) { + for (auto c = r; c < input->columns(); c++) { + outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r); + } + } + } + } + }; + samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); + } + + int triangularSolveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return triangularSolveFunctor_, (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE); + } + + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, (context, input, lower, output), FLOAT_NATIVE); + } +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu new file mode 100644 index 000000000..8846be45c --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu @@ -0,0 +1,227 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// + +#include +#include +#include +#include +#include "../triangular_solve.h" + +namespace nd4j { + namespace ops { + namespace helpers { + /* + * lower triangular process for system of linear equations + * x_1 = b_1/a_1,1 + * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 + * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 + * ... + * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M + * + * output == x + * a == leftInput + * b == rightInput + * + * */ + template + static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, + T const* rightInput, Nd4jLong const* rightInputShape, + bool const adjoint, T* output, Nd4jLong* outputShape, + Nd4jLong rows) { + + for (auto r = 0; r < rows; r++) { + Nd4jLong posY[] = {r, 0}; + Nd4jLong posX[] = {r, r}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); + + auto sum = rightInput[yIndex]; + for (auto c = 0; c < r; c++) { + Nd4jLong posZ[] = {c, 0}; + Nd4jLong pos[] = {r, c}; + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = sum / leftInput[xIndex]; + } + } + + /* + * upper triangular process for system of linear equations + * x_M = b_M/a_M,M + * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 + * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 + * ... + * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 + * + * output == x + * a == leftInput + * b == rightInput + * + * */ + + template + static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, + T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output, + Nd4jLong* outputShape, Nd4jLong rows) { + + for (auto r = rows; r > 0; r--) { + Nd4jLong posY[] = {r - 1, 0}; + Nd4jLong posX[] = {r - 1, r - 1}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); + auto sum = rightInput[yIndex]; + for (auto c = r; c < rows; c++) { + Nd4jLong posZ[] = {c, 0}; + Nd4jLong pos[] = {r - 1, c}; + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = sum / leftInput[xIndex]; + } + } + + template + static __global__ void triangularSolveKernel(T const* leftInput, Nd4jLong const* leftPartShape, + T const* rightInput, Nd4jLong const* rightPartShape, bool const lower, bool const adjoint, T* output, + Nd4jLong* outputShape, Nd4jLong* tadLeftShape, Nd4jLong* tadLeftOffset, Nd4jLong* tadRightShape, + Nd4jLong* tadRightOffset, Nd4jLong* tadOutputShape, Nd4jLong* tadOutputOffset, Nd4jLong batchNum) { + + __shared__ Nd4jLong rows; + if (threadIdx.x == 0) { + rows = shape::sizeAt(leftPartShape, -2); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto stop = batchNum; + auto increment = blockDim.x * gridDim.x; + + for (auto i = start; i < stop; i += increment) { + auto pLeftPart = leftInput + tadLeftOffset[i]; + auto pRightPart = rightInput + tadRightOffset[i]; + auto pOutputPart = output + tadOutputOffset[i]; + if (lower) { + lowerTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows); + } else { + upperTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows); + } + } + } + + template + static int triangularSolveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, + bool lower, bool adjoint, NDArray* output) { + NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); + auto leftTads = ConstantTadHelper::getInstance()->tadForDimensions(leftInput->getShapeInfo(), {-2, -1}); + auto rightTads = ConstantTadHelper::getInstance()->tadForDimensions(rightInput->getShapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1}); + + auto stream = context->getCudaStream(); + T const* leftBuf = reinterpret_cast(leftInput->getSpecialBuffer()); + T const* rightBuf = reinterpret_cast(rightInput->getSpecialBuffer()); + T* outputBuf = reinterpret_cast(output->specialBuffer()); + triangularSolveKernel<<<128, 128, 256, *stream>>>(leftBuf, leftInput->getSpecialShapeInfo(), + rightBuf, rightInput->getSpecialShapeInfo(), lower, adjoint, outputBuf, output->specialShapeInfo(), + leftTads.specialShapeInfo(), leftTads.specialOffsets(), rightTads.specialShapeInfo(), + rightTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets(), + leftTads.numberOfTads()); + + NDArray::registerSpecialUse({output}, {leftInput, rightInput}); + + return Status::OK(); + + } + + int triangularSolveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return triangularSolveFunctor_, (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE); + } + + template + static __global__ void upperAdjointKernel(T const* input, T* output, + Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, + Nd4jLong* inputTads, Nd4jLong* inputOffsets, Nd4jLong* outputTads, Nd4jLong* outputOffsets) { + + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto inputPart = input + inputOffsets[b]; + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = threadIdx.y; c <= r; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(inputTads, xPos); + outputPart[zIndex] = inputPart[xIndex]; + } + } + } + + } + + template + static __global__ void lowerAdjointKernel(T const* input, T* output, + Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, + Nd4jLong* inputTads, Nd4jLong* inputOffsets, Nd4jLong* outputTads, Nd4jLong* outputOffsets) { + + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto inputPart = input + inputOffsets[b]; + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = r + threadIdx.y; c < columns; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(inputTads, xPos); + outputPart[zIndex] = inputPart[xIndex]; + } + } + } + } + + template + static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, + NDArray* output) { + + auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + auto inputBuf = reinterpret_cast(input->getSpecialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()); + auto rows = input->sizeAt(-2); + auto columns = input->sizeAt(-1); + + if (lower) { + lowerAdjointKernel<<<128, 256, 256, *stream>>>(inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, inputTads.specialShapeInfo(), inputTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets()); + } else { + upperAdjointKernel<<<128, 256, 256, *stream>>>(inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, inputTads.specialShapeInfo(), inputTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets()); + } + } + + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, (context, input, lower, output), FLOAT_NATIVE); + } + + } + } +} diff --git a/libnd4j/include/ops/declarable/helpers/triangular_solve.h b/libnd4j/include/ops/declarable/helpers/triangular_solve.h new file mode 100644 index 000000000..a40a3e144 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/triangular_solve.h @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// +#ifndef __TRIANGULAR_SOLVE__H_HELPERS__ +#define __TRIANGULAR_SOLVE__H_HELPERS__ +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + int triangularSolveFunctor(nd4j::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output); + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output); +} +} +} +#endif diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 15aa5751c..23d19d013 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -2734,3 +2734,157 @@ TEST_F(DeclarableOpsTests12, LU_Test_4_2) { ASSERT_TRUE(expP.equalsTo(p)); delete res; } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_1) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 4.f, 2.f, 4.f, 2.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f }); + + nd4j::ops::triangular_solve op; + + auto res = op.execute({&a, &b}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_2) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 2.f, 4.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 1.f, 1.3333333f }); + + nd4j::ops::triangular_solve op; + + auto res = op.execute({&a, &b}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_3) { + + auto a = NDArrayFactory::create('c', {2, 4, 4}, { + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f, + + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + }); + + auto b = NDArrayFactory::create('c', {2, 4, 1}, { + 4.f, 2.f, 4.f, 2.f, + 4.f, 2.f, 4.f, 2.f + }); + + auto exp = NDArrayFactory::create('c', {2, 4, 1}, { + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f, + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.execute({&a, &b}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_4) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 2.f, 4.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.execute({&a, &b}, {}, {}, {false}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 5.f, 1., -3.f, 3.f, + 0.f, 1.f, 1.f, -1.f, + 0.f, 0.f, 2.f, -9.f, + 0.f, 0.f, 0.f, 4.f + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 5.f, 2.f, 0.f, -3.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + 1.f, 1.f, 1.f, 1.f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.execute({&a, &b}, {}, {}, {false, true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + + z->printIndexedBuffer("TriangularSolve with adjoint"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 7a4b00d3f..f804d8c95 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -622,7 +622,8 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.Igamma.class, org.nd4j.linalg.api.ops.custom.Igammac.class, org.nd4j.linalg.api.ops.custom.Digamma.class, - org.nd4j.linalg.api.ops.custom.Lu.class + org.nd4j.linalg.api.ops.custom.Lu.class, + org.nd4j.linalg.api.ops.custom.TriangularSolve.class ); static { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriangularSolve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriangularSolve.java new file mode 100644 index 000000000..7423d3a91 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriangularSolve.java @@ -0,0 +1,43 @@ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class TriangularSolve extends DynamicCustomOp { + + public TriangularSolve(INDArray matrix, INDArray rhs, boolean lower, boolean adjoint) { + addInputArgument(matrix, rhs); + addBArgument(lower, adjoint); + } + + public TriangularSolve(SameDiff sameDiff, SDVariable matrix, SDVariable rhs, + SDVariable lower, SDVariable adjoint) { + super(sameDiff, new SDVariable[] {matrix, rhs, lower, adjoint}); + } + + @Override + public String opName() { + return "triangular_solve"; + } + + @Override + public String tensorflowName() { + return "MatrixTriangularSolve"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + int n = args().length; + Preconditions.checkState(dataTypes != null && dataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), dataTypes); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 944966654..cfa3a9785 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -1653,4 +1653,27 @@ public class CustomOpsTests extends BaseNd4jTest { INDArray[] ret = Nd4j.exec(op); assertArrayEquals(image.shape(), ret[0].shape()); } + + @Test + public void testTriangularSolve() { + INDArray a = Nd4j.createFromArray(new float[]{ + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + }).reshape(4, 4); + + INDArray b = Nd4j.createFromArray(new float[]{ + 4.f, 2.f, 4.f, 2.f + }).reshape(4, 1); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f + }).reshape(4, 1); + + val op = new TriangularSolve(a, b, true, false); + INDArray[] ret = Nd4j.exec(op); + + assertEquals(expected, ret[0]); + } }