Shugeo solve linear (#191)

* linear equations systems solve op. Initial commit.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed compiling issues.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Linear equations systems solve. The next stage commit.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added test for linear equations systems solve operation.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added additional test and fixed lower matrix retrievance.

* Implementation for solve of the systems of linear equations."

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored permutation generation.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added restore for permutations batched with cuda helper for solve op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Finished cuda implementation for solve op helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored cpu helpers for solve op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fix gtest output on Windows

* Fixed issue with permutation matrix for cuda implementation.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed issue with permutation matrix for cpu implementation.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Eliminated waste comments.

Signed-off-by: shugeo <sgazeos@gmail.com>

* LinearSolve added

* Mapping added

* Javadoc added

* Refactored implementation of triangular_solve helpers and tests for solve matrix equations generally.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added a test for solve op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Solve test added

* Fix for TF import

Co-authored-by: Serhii Shepel <9946053+sshepel@users.noreply.github.com>
Co-authored-by: raver119 <raver119@gmail.com>
Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
shugeo 2020-02-04 07:59:11 +02:00 committed by GitHub
parent 57d5eb473b
commit 41ff907bc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 818 additions and 58 deletions

View File

@ -0,0 +1,75 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit, K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by GS <sgazeos@gmail.com> at 01/22/2020
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_solve)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/solve.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(solve, 2, 1, false, 0, 0) {
auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
bool useAdjoint = false;
if (block.numB() > 0) {
useAdjoint = B_ARG(0);
}
REQUIRE_TRUE(a->rankOf() >=2, 0, "solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf());
REQUIRE_TRUE(b->rankOf() >=2, 0, "solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf());
REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2));
auto input = a;
if (useAdjoint) {
auto adjointA = a->ulike();
helpers::adjointMatrix(block.launchContext(), a, &adjointA);
input = new NDArray(adjointA); //.detach();
};
auto res = helpers::solveFunctor(block.launchContext(), input, b, useAdjoint, z);
if (input != a)
delete input;
return Status::OK();
}
DECLARE_SHAPE_FN(solve) {
auto in0 = inputShape->at(1);
auto in1 = inputShape->at(1);
auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace());
return SHAPELIST(CONSTANT(luShape));
}
DECLARE_TYPES(solve) {
getOpDescriptor()
->setAllowedInputTypes({ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS})
->setSameMode(false);
}
}
}
#endif

View File

@ -1076,6 +1076,24 @@ namespace nd4j {
DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0); DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0);
#endif #endif
/**
* solve op. - solve systems of linear equations - general method.
*
* input params:
* 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations
* 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
*
* boolean args:
* 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used
*
* return value:
* tensor with dimension (x * y * z * ::: * M * K) with solutions
*
*/
#if NOT_EXCLUDED(OP_solve)
DECLARE_CUSTOM_OP(solve, 2, 1, true, 0, 0);
#endif
/** /**
* lu op. - make LUP decomposition of given batch of 2D square matricies * lu op. - make LUP decomposition of given batch of 2D square matricies
* *

View File

@ -237,25 +237,65 @@ namespace helpers {
samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1); samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
} }
template <typename T>
static void doolitleLU(LaunchContext* context, NDArray* compound, Nd4jLong rowNum) {
auto input = compound->dup();
compound->nullify();
// Decomposing matrix into Upper and Lower
// triangular matrix
for (auto i = 0; i < rowNum; i++) {
// Upper Triangular
for (auto k = i; k < rowNum; k++) {
// Summation of L(i, j) * U(j, k)
int sum = 0;
for (int j = 0; j < i; j++)
sum += compound->t<T>(i,j) * compound->t<T>(j,k);
// Evaluating U(i, k)
compound->t<T>(i, k) = input.t<T>(i, k) - sum;
}
// Lower Triangular
for (int k = i + 1; k < rowNum; k++) {
// Summation of L(k, j) * U(j, i)
int sum = 0;
for (int j = 0; j < i; j++)
sum += compound->t<T>(k,j) * compound->t<T>(j, i);
// Evaluating L(k, i)
compound->t<T>(k, i) = (input.t<T>(k, i) - sum) / compound->t<T>(i,i);
}
}
}
template <typename T, typename I> template <typename T, typename I>
static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) { static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) {
//const int rowNum = compound->rows(); //const int rowNum = compound->rows();
// const int columnNum = output->columns(); // const int columnNum = output->columns();
permutation->linspace(0); if (permutation) { // LUP algorithm
auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>(); permutation->linspace(0);
auto compoundBuf = compound->bufferAsT<T>(); auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>();
auto compoundShape = compound->shapeInfo(); auto compoundBuf = compound->bufferAsT<T>();
auto permutationShape = permutation->shapeInfo(); auto compoundShape = compound->shapeInfo();
for (auto i = 0; i < rowNum - 1; i++) { auto permutationShape = permutation->shapeInfo();
auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape); for (auto i = 0; i < rowNum - 1; i++) {
if (pivotIndex < 0) { auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape);
throw std::runtime_error("helpers::luNN_: input matrix is singular."); if (pivotIndex < 0) {
} throw std::runtime_error("helpers::luNN_: input matrix is singular.");
math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); }
swapRows(compoundBuf, compoundShape, i, pivotIndex); math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)],
permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]);
swapRows(compoundBuf, compoundShape, i, pivotIndex);
processColumns(i, rowNum, compoundBuf, compoundShape); processColumns(i, rowNum, compoundBuf, compoundShape);
}
}
else { // Doolitle algorithm with LU decomposition
doolitleLU<T>(context, compound, rowNum);
} }
} }
@ -265,17 +305,20 @@ namespace helpers {
output->assign(input); // fill up output tensor with zeros output->assign(input); // fill up output tensor with zeros
ResultSet outputs = output->allTensorsAlongDimension({-2, -1}); ResultSet outputs = output->allTensorsAlongDimension({-2, -1});
ResultSet permutations = permutationVectors->allTensorsAlongDimension({-1}); ResultSet permutations;
if (permutationVectors)
permutations = permutationVectors->allTensorsAlongDimension({-1});
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i += increment) {
luNN_<T, I>(context, outputs.at(i), permutations.at(i), n); luNN_<T, I>(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n);
} }
}; };
samediff::Threads::parallel_for(loop, 0, outputs.size(), 1); samediff::Threads::parallel_for(loop, 0, outputs.size(), 1);
} }
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) { void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(input->dataType(), permutation?permutation->dataType():DataType::INT32, lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES);
} }
// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); // BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES);

View File

@ -0,0 +1,100 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit, K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#include <NDArray.h>
#include <NDArrayFactory.h>
#include <execution/Threads.h>
#include <helpers/MmulHelper.h>
#include "../triangular_solve.h"
#include "../lup.h"
#include "../solve.h"
namespace nd4j {
namespace ops {
namespace helpers {
// --------------------------------------------------------------------------------------------------------------------------------------- //
template <typename T>
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
auto inputPart = input->allTensorsAlongDimension({-2, -1});
auto outputPart = output->allTensorsAlongDimension({-2, -1});
output->assign(input);
auto batchLoop = PRAGMA_THREADS_FOR {
for (auto batch = start; batch < stop; batch += increment) {
for (auto r = 0; r < input->rows(); r++) {
for (auto c = 0; c < r; c++) {
math::nd4j_swap(outputPart[batch]->t<T>(r, c) , outputPart[batch]->t<T>(c, r));
}
}
}
};
samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
}
// --------------------------------------------------------------------------------------------------------------------------------------- //
template <typename T>
static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) {
// stage 1: LU decomposition batched
auto leftOutput = leftInput->ulike();
auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back();
auto permutations = NDArrayFactory::create<int>('c', permuShape, context);
helpers::lu(context, leftInput, &leftOutput, &permutations);
auto P = leftInput->ulike(); //permutations batched matrix
P.nullify(); // to fill up matricies with zeros
auto PPart = P.allTensorsAlongDimension({-2,-1});
auto permutationsPart = permutations.allTensorsAlongDimension({-1});
for (auto batch = 0; batch < permutationsPart.size(); ++batch) {
for (auto row = 0; row < PPart[batch]->rows(); ++row) {
PPart[batch]->t<T>(row, permutationsPart[batch]->t<int>(row)) = T(1.f);
}
}
auto leftLower = leftOutput.dup();
auto rightOutput = rightInput->ulike();
auto rightPermuted = rightOutput.ulike();
MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0);
ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1});
for (auto i = 0; i < leftLowerPart.size(); i++) {
for (auto r = 0; r < leftLowerPart[i]->rows(); r++)
leftLowerPart[i]->t<T>(r,r) = (T)1.f;
}
// stage 2: triangularSolveFunctor for Lower with given b
helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput);
// stage 3: triangularSolveFunctor for Upper with output of previous stage
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output);
return Status::OK();
}
// --------------------------------------------------------------------------------------------------------------------------------------- //
int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) {
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES);
}
// --------------------------------------------------------------------------------------------------------------------------------------- //
void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES);
}
// --------------------------------------------------------------------------------------------------------------------------------------- //
}
}
}

View File

@ -41,13 +41,16 @@ namespace helpers {
template <typename T> template <typename T>
static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
auto rows = leftInput->rows(); auto rows = leftInput->rows();
auto cols = rightInput->columns();
//output->t<T>(0,0) = rightInput->t<T>(0,0) / leftInput->t<T>(0,0); //output->t<T>(0,0) = rightInput->t<T>(0,0) / leftInput->t<T>(0,0);
for (auto r = 0; r < rows; r++) { for (auto r = 0; r < rows; r++) {
auto sum = rightInput->t<T>(r, 0); for (auto j = 0; j < cols; j++) {
for (auto c = 0; c < r; c++) { auto sum = rightInput->t<T>(r, j);
sum -= leftInput->t<T>(r,c) * output->t<T>(c, 0); for (auto c = 0; c < r; c++) {
sum -= leftInput->t<T>(r, c) * output->t<T>(c, j);
}
output->t<T>(r, j) = sum / leftInput->t<T>(r, r);
} }
output->t<T>(r, 0) = sum / leftInput->t<T>(r, r);
} }
} }
@ -68,13 +71,15 @@ namespace helpers {
template <typename T> template <typename T>
static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
auto rows = leftInput->rows(); auto rows = leftInput->rows();
auto cols = rightInput->columns();
for (auto r = rows; r > 0; r--) { for (auto r = rows; r > 0; r--) {
auto sum = rightInput->t<T>(r - 1, 0); for (auto j = 0; j < cols; j++) {
for (auto c = r; c < rows; c++) { auto sum = rightInput->t<T>(r - 1, j);
sum -= leftInput->t<T>(r - 1, c) * output->t<T>(c, 0); for (auto c = r; c < rows; c++) {
sum -= leftInput->t<T>(r - 1, c) * output->t<T>(c, j);
}
output->t<T>(r - 1, j) = sum / leftInput->t<T>(r - 1, r - 1);
} }
output->t<T>(r - 1, 0) = sum / leftInput->t<T>(r - 1, r - 1);
} }
} }

View File

@ -0,0 +1,140 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit, K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#include <NDArray.h>
#include <NDArrayFactory.h>
#include <MmulHelper.h>
#include <execution/Threads.h>
#include <ConstantTadHelper.h>
#include "../triangular_solve.h"
#include "../lup.h"
#include "../solve.h"
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static __global__ void oneOnDiagonalKernel(T* ioBuf, Nd4jLong* ioShape, Nd4jLong* tadShape, Nd4jLong* tadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) {
for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) {
auto matrixPart = ioBuf + tadOffsets[i];
for (auto j = threadIdx.x; j < rowNum; j += blockDim.x) {
Nd4jLong pos[] = {j, j};
auto offset = shape::getOffset(tadShape, pos);
matrixPart[offset] = T(1.f);
}
}
}
template <typename T>
static __global__ void restorePermutationsKernel(T* PBuf, Nd4jLong* PShapeInfo, int const* permutationsBuf,
Nd4jLong* PTadShapeInfo, Nd4jLong* PTadSOffsets, Nd4jLong* permutationsTadShapeInfo, Nd4jLong* permutationsTadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) {
for (auto batch = blockIdx.x; batch < batchNum; batch += gridDim.x) {
auto permutations = permutationsBuf + permutationsTadOffsets[batch];
auto P = PBuf + PTadSOffsets[batch];
for (auto row = threadIdx.x; row < rowNum; row += blockDim.x) {
//auto posX[] = {row};
Nd4jLong posZ[] = {row, permutations[row]};
auto zOffset = shape::getOffset(PTadShapeInfo, posZ);
P[zOffset] = T(1.f);
}
}
}
template <typename T>
static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput,
bool adjoint, NDArray* output) {
NDArray::prepareSpecialUse({output}, {leftInput, rightInput});
// stage 1: LU decomposition batched
auto leftOutput = leftInput->ulike();
auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back();
auto permutations = NDArrayFactory::create<int>('c', permuShape, context);
helpers::lu(context, leftInput, &leftOutput, &permutations);
auto leftLower = leftOutput.dup();
auto rightOutput = rightInput->ulike();
auto leftLowerTad = ConstantTadHelper::getInstance()->tadForDimensions(leftLower.getShapeInfo(), {-2, -1});
auto stream = context->getCudaStream();
oneOnDiagonalKernel<T><<<128, 256, 256, *stream>>>(leftLower.dataBuffer()->specialAsT<T>(), leftLower.specialShapeInfo(), leftLowerTad.specialShapeInfo(), leftLowerTad.specialOffsets(), leftLowerTad.numberOfTads(), leftLower.sizeAt(-1));
auto P = leftOutput.ulike(); P.nullify();
auto PTad = ConstantTadHelper::getInstance()->tadForDimensions(P.getShapeInfo(), {-2, -1});
auto permutationsTad = ConstantTadHelper::getInstance()->tadForDimensions(permutations.getShapeInfo(), {-1});
restorePermutationsKernel<T><<<128, 256, 256, *stream>>>(P.dataBuffer()->specialAsT<T>(), P.specialShapeInfo(), permutations.dataBuffer()->specialAsT<int>(),
PTad.specialShapeInfo(), PTad.specialOffsets(), permutationsTad.specialShapeInfo(), permutationsTad.specialOffsets(), permutationsTad.numberOfTads(), permutations.sizeAt(-1));
P.tickWriteDevice();
auto rightPart = rightInput->ulike();
MmulHelper::matmul(&P, rightInput, &rightPart, 0, 0);
// stage 2: triangularSolveFunctor for Lower with given b
helpers::triangularSolveFunctor(context, &leftLower, &rightPart, true, false, &rightOutput);
// stage 3: triangularSolveFunctor for Upper with output of previous stage
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output);
NDArray::registerSpecialUse({output}, {leftInput, rightInput});
return Status::OK();
}
int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES);
}
template <typename T>
static __global__ void adjointKernel(T* output, Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, Nd4jLong* outputTads,
Nd4jLong* outputOffsets) {
for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) {
auto outputPart = output + outputOffsets[b];
for (auto r = threadIdx.x; r < rows; r += blockDim.x) {
for (auto c = threadIdx.y; c < r; c += blockDim.y) {
Nd4jLong zPos[] = {r, c};
Nd4jLong xPos[] = {c, r};
auto zIndex = shape::getOffset(outputTads, zPos);
auto xIndex = shape::getOffset(outputTads, xPos);
math::nd4j_swap(outputPart[zIndex], outputPart[xIndex]);
}
}
}
}
template <typename T>
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input});
auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {-2, -1});
auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {-2, -1});
auto stream = context->getCudaStream();
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer());
auto rows = input->sizeAt(-2);
auto columns = input->sizeAt(-1);
output->assign(input);
adjointKernel<T><<<128, 256, 256, *stream>>>(outputBuf, outputTads.numberOfTads(), rows, columns, outputTads.specialShapeInfo(), outputTads.specialOffsets());
NDArray::registerSpecialUse({output}, {input});
}
void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES);
}
}
}
}

View File

@ -44,24 +44,26 @@ namespace nd4j {
static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
T const* rightInput, Nd4jLong const* rightInputShape, T const* rightInput, Nd4jLong const* rightInputShape,
bool const adjoint, T* output, Nd4jLong* outputShape, bool const adjoint, T* output, Nd4jLong* outputShape,
Nd4jLong rows) { Nd4jLong rows, Nd4jLong cols) {
for (auto r = 0; r < rows; r++) { for (auto r = 0; r < rows; r++) {
Nd4jLong posY[] = {r, 0}; for (auto j = 0; j < cols; j++) {
Nd4jLong posX[] = {r, r}; Nd4jLong posY[] = {r, j};
auto xIndex = shape::getOffset(leftInputShape, posX, 0); Nd4jLong posX[] = {r, r};
auto yIndex = shape::getOffset(rightInputShape, posY, 0); auto xIndex = shape::getOffset(leftInputShape, posX, 0);
auto zIndex = shape::getOffset(outputShape, posY, 0); auto yIndex = shape::getOffset(rightInputShape, posY, 0);
auto zIndex = shape::getOffset(outputShape, posY, 0);
auto sum = rightInput[yIndex]; auto sum = rightInput[yIndex];
for (auto c = 0; c < r; c++) { for (auto c = 0; c < r; c++) {
Nd4jLong posZ[] = {c, 0}; Nd4jLong posZ[] = {c, j};
Nd4jLong pos[] = {r, c}; Nd4jLong pos[] = {r, c};
auto xcIndex = shape::getOffset(leftInputShape, pos, 0); auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
auto zcIndex = shape::getOffset(outputShape, posZ, 0); auto zcIndex = shape::getOffset(outputShape, posZ, 0);
sum -= leftInput[xcIndex] * output[zcIndex]; sum -= leftInput[xcIndex] * output[zcIndex];
}
output[zIndex] = sum / leftInput[xIndex];
} }
output[zIndex] = sum / leftInput[xIndex];
} }
} }
@ -82,23 +84,25 @@ namespace nd4j {
template <typename T> template <typename T>
static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output, T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output,
Nd4jLong* outputShape, Nd4jLong rows) { Nd4jLong* outputShape, Nd4jLong rows, Nd4jLong cols) {
for (auto r = rows; r > 0; r--) { for (auto r = rows; r > 0; r--) {
Nd4jLong posY[] = {r - 1, 0}; for (auto j = 0; j < cols; j++) {
Nd4jLong posX[] = {r - 1, r - 1}; Nd4jLong posY[] = {r - 1, j};
auto xIndex = shape::getOffset(leftInputShape, posX, 0); Nd4jLong posX[] = {r - 1, r - 1};
auto yIndex = shape::getOffset(rightInputShape, posY, 0); auto xIndex = shape::getOffset(leftInputShape, posX, 0);
auto zIndex = shape::getOffset(outputShape, posY, 0); auto yIndex = shape::getOffset(rightInputShape, posY, 0);
auto sum = rightInput[yIndex]; auto zIndex = shape::getOffset(outputShape, posY, 0);
for (auto c = r; c < rows; c++) { auto sum = rightInput[yIndex];
Nd4jLong posZ[] = {c, 0}; for (auto c = r; c < rows; c++) {
Nd4jLong pos[] = {r - 1, c}; Nd4jLong posZ[] = {c, j};
auto zcIndex = shape::getOffset(outputShape, posZ, 0); Nd4jLong pos[] = {r - 1, c};
auto xcIndex = shape::getOffset(leftInputShape, pos, 0); auto zcIndex = shape::getOffset(outputShape, posZ, 0);
sum -= leftInput[xcIndex] * output[zcIndex]; auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
sum -= leftInput[xcIndex] * output[zcIndex];
}
output[zIndex] = sum / leftInput[xIndex];
} }
output[zIndex] = sum / leftInput[xIndex];
} }
} }
@ -109,8 +113,11 @@ namespace nd4j {
Nd4jLong* tadRightOffset, Nd4jLong* tadOutputShape, Nd4jLong* tadOutputOffset, Nd4jLong batchNum) { Nd4jLong* tadRightOffset, Nd4jLong* tadOutputShape, Nd4jLong* tadOutputOffset, Nd4jLong batchNum) {
__shared__ Nd4jLong rows; __shared__ Nd4jLong rows;
__shared__ Nd4jLong cols;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
rows = shape::sizeAt(leftPartShape, -2); rows = shape::sizeAt(leftPartShape, -2);
cols = shape::sizeAt(rightPartShape, -1);
} }
__syncthreads(); __syncthreads();
@ -123,9 +130,9 @@ namespace nd4j {
auto pRightPart = rightInput + tadRightOffset[i]; auto pRightPart = rightInput + tadRightOffset[i];
auto pOutputPart = output + tadOutputOffset[i]; auto pOutputPart = output + tadOutputOffset[i];
if (lower) { if (lower) {
lowerTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows); lowerTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols);
} else { } else {
upperTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows); upperTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols);
} }
} }
} }

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#ifndef __SOLVE__H_HELPERS__
#define __SOLVE__H_HELPERS__
#include <op_boilerplate.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
int solveFunctor(nd4j::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output);
void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output);
}
}
}
#endif

View File

@ -1541,6 +1541,166 @@ TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
delete []arr; delete []arr;
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_1) {
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f
});
auto b = NDArrayFactory::create<float>('c', {3, 1}, {
2.f, 4.f, 3.f
});
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {
7.625f, 3.25f, 5.f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("Solve of 3x3");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_2) {
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
1.f, 1.f, 1.f, 1.f,
0.f, 1.f, 1.f, 0.f,
0.f, 0.f, 2.f, 1.f,
0.f, 0.f, 0.f, 3.f,
});
auto b = NDArrayFactory::create<float>('c', {4, 1}, {
2.f, 4.f, 2.f, 4.f
});
auto exp = NDArrayFactory::create<float>('c', {4, 1}, {
-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("Solve 4x4");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_3) {
auto a = NDArrayFactory::create<float>('c', {2, 4, 4}, {
1.f, 1.f, 1.f, 1.f,
0.f, 1.f, 1.f, 0.f,
0.f, 0.f, 2.f, 1.f,
0.f, 0.f, 0.f, 3.f,
3.f, 0.f, 0.f, 0.f,
2.f, 1.f, 0.f, 0.f,
1.f, 0.f, 1.f, 0.f,
1.f, 1.f, 1.f, 1.f
});
auto b = NDArrayFactory::create<float>('c', {2, 4, 1}, {
2.f, 4.f, 2.f, 4.f,
4.f, 2.f, 4.f, 2.f
});
auto exp = NDArrayFactory::create<float>('c', {2, 4, 1}, {
-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f,
1.333333f, -0.6666667f, 2.6666667f, -1.3333333f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("Solve 4x4");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_4) {
auto a = NDArrayFactory::create<float>('c', {2, 2, 2}, {
0.7788f, 0.8012f, 0.7244f, 0.2309f,
0.7271f, 0.1804f, 0.5056f, 0.8925f
});
auto b = NDArrayFactory::create<float>('c', {2, 2, 2}, {
0.7717f, 0.9281f, 0.9846f, 0.4838f,
0.6433f, 0.6041f, 0.6501f, 0.7612f
});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f,
// 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f
1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f,
0.81915987f, 0.72049433f, 0.2643504f, 0.44472617f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printBuffer("4 Solve 4x4");
// exp.printBuffer("4 Expec 4x4");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_5) {
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f,
0.5056f, 0.8925f, 0.5461f
});
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
0.7717f, 0.9281f, 0.9846f,
0.4838f, 0.6433f, 0.6041f,
0.6501f, 0.7612f, 0.7605f
});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
1.5504692f, 1.8953944f, 2.2765768f,
0.03399149f, 0.2883001f, 0.5377323f,
-0.8774802f, -1.2155888f, -1.8049058f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b}, {true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
z->printBuffer("4 Solve 4x4");
exp.printBuffer("4 Expec 4x4");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) { TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {

View File

@ -3008,3 +3008,33 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete res; delete res;
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) {
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
5.f, 1.f, -3.f, 3.f,
0.f, 1.f, 1.f, -1.f,
0.f, 0.f, 2.f, -9.f,
0.f, 0.f, 0.f, 4.f
});
auto b = NDArrayFactory::create<float>('c', {4, 2}, {
5.f, 1.f, 2.f, 1.f, 0.f, 1.f, -3.f, 1.f
});
auto exp = NDArrayFactory::create<float>('c', {4, 2}, {
1.f,0.2f, 1.f,0.8f, 1.f,0.4f, 1.f,1.2f
});
nd4j::ops::triangular_solve op;
auto res = op.evaluate({&a, &b}, {}, {}, {false, true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
z->printIndexedBuffer("TriangularSolve with adjoint");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}

View File

@ -39,7 +39,7 @@ do
done done
CHIP="${CHIP:-cpu}" CHIP="${CHIP:-cpu}"
export GTEST_OUTPUT="xml:../target/surefire-reports/TEST-${CHIP}-results.xml" export GTEST_OUTPUT="xml:surefire-reports/TEST-${CHIP}-results.xml"
# On Mac, make sure it can find libraries for GCC # On Mac, make sure it can find libraries for GCC
export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/ export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/
@ -48,9 +48,11 @@ export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/
if [ -n "$BUILD_PATH" ]; then if [ -n "$BUILD_PATH" ]; then
if which cygpath; then if which cygpath; then
BUILD_PATH=$(cygpath -p $BUILD_PATH) BUILD_PATH=$(cygpath -p $BUILD_PATH)
export GTEST_OUTPUT="xml:'..\target\surefire-reports\TEST-${CHIP}-results.xml'"
fi fi
export PATH="$PATH:$BUILD_PATH" export PATH="$PATH:$BUILD_PATH"
fi fi
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests ../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests
# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion)
[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/

View File

@ -623,7 +623,8 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.Igammac.class, org.nd4j.linalg.api.ops.custom.Igammac.class,
org.nd4j.linalg.api.ops.custom.Digamma.class, org.nd4j.linalg.api.ops.custom.Digamma.class,
org.nd4j.linalg.api.ops.custom.Lu.class, org.nd4j.linalg.api.ops.custom.Lu.class,
org.nd4j.linalg.api.ops.custom.TriangularSolve.class org.nd4j.linalg.api.ops.custom.TriangularSolve.class,
org.nd4j.linalg.api.ops.custom.LinearSolve.class
); );
static { static {

View File

@ -16,26 +16,48 @@
package org.nd4j.linalg.api.buffer; package org.nd4j.linalg.api.buffer;
/**
* Enum lists supported data types.
*
*/
public enum DataType { public enum DataType {
DOUBLE, DOUBLE,
FLOAT, FLOAT,
/**
* @deprecated Replaced by {@link DataType#FLOAT16}, use that instead
*/
@Deprecated @Deprecated
HALF, HALF,
/**
* @deprecated Replaced by {@link DataType#INT64}, use that instead
*/
@Deprecated @Deprecated
LONG, LONG,
/**
* @deprecated Replaced by {@link DataType#INT32}, use that instead
*/
@Deprecated @Deprecated
INT, INT,
/**
* @deprecated Replaced by {@link DataType#INT16}, use that instead
*/
@Deprecated @Deprecated
SHORT, SHORT,
/**
* @deprecated Replaced by {@link DataType#UINT8}, use that instead
*/
@Deprecated @Deprecated
UBYTE, UBYTE,
/**
* @deprecated Replaced by {@link DataType#INT8}, use that instead
*/
@Deprecated @Deprecated
BYTE, BYTE,

View File

@ -0,0 +1,77 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@NoArgsConstructor
public class LinearSolve extends DynamicCustomOp {
public LinearSolve(INDArray a, INDArray b, boolean adjoint) {
addInputArgument(a, b);
addBArgument(adjoint);
}
public LinearSolve(INDArray a, INDArray b) {
this(a,b,false);
}
public LinearSolve(SameDiff sameDiff, SDVariable a, SDVariable b, SDVariable adjoint) {
super(sameDiff, new SDVariable[] {a, b, adjoint});
}
public LinearSolve(SameDiff sameDiff, SDVariable a, SDVariable b, boolean adjoint) {
super(sameDiff, new SDVariable[] {a, b});
addBArgument(adjoint);
}
@Override
public String opName() {
return "solve";
}
@Override
public String tensorflowName() {
return "MatrixSolve";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
boolean adjoint = attributesForNode.containsKey("adjoint") ? attributesForNode.get("adjoint").getB() : false;
addBArgument(adjoint);
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
int n = args().length;
Preconditions.checkState(dataTypes != null && dataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -1691,4 +1691,50 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(e, x); assertEquals(e, x);
} }
@Test
public void testLinearSolve() {
INDArray a = Nd4j.createFromArray(new float[]{
2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f
}).reshape(3, 3);
INDArray b = Nd4j.createFromArray(new float[]{
2.f, 4.f, 3.f
}).reshape(3, 1);
INDArray expected = Nd4j.createFromArray(new float[]{
7.625f, 3.25f, 5.f
}).reshape(3, 1);
val op = new LinearSolve(a, b);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Test
public void testLinearSolveAdjust() {
INDArray a = Nd4j.createFromArray(new float[]{
0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f,
0.5056f, 0.8925f, 0.5461f
}).reshape(3, 3);
INDArray b = Nd4j.createFromArray(new float[]{
0.7717f, 0.9281f, 0.9846f,
0.4838f, 0.6433f, 0.6041f,
0.6501f, 0.7612f, 0.7605f
}).reshape(3, 3);
INDArray expected = Nd4j.createFromArray(new float[]{
1.5504692f, 1.8953944f, 2.2765768f,
0.03399149f, 0.2883001f , 0.5377323f,
-0.8774802f, -1.2155888f, -1.8049058f
}).reshape(3, 3);
val op = new LinearSolve(a, b, true);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
} }