cavis/libnd4j/include/ops/declarable/helpers/cuda/lstsq.cu

116 lines
5.4 KiB
Plaintext
Raw Normal View History

Shugeo solve ls (#203) * lstsq op. Initial commit. Signed-off-by: shugeo <sgazeos@gmail.com> * Least squares linear problem solve op (lstsq). Cpu draft implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed shape routine and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for lstsq op. Signed-off-by: shugeo <sgazeos@gmail.com> * Rectification for lstsq op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected test to avoid numerical inconsistensy. Signed-off-by: shugeo <sgazeos@gmail.com> * Added prints for check computing. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected tests to use evalueate facility instead. Signed-off-by: shugeo <sgazeos@gmail.com> * CPU implementation of MatrixSolveLs op and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Added cuda implementation for helpers with lstsq op. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tests for lstsq op. Signed-off-by: shugeo <sgazeos@gmail.com> * Added processing for empty inputs. Signed-off-by: shugeo <sgazeos@gmail.com> * Merged tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored lstsq op for fast case. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed test. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored lstsq op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed some issues with solve. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed lstsq op to avoid erros. Signed-off-by: shugeo <sgazeos@gmail.com> * Added kernel for giagonal factor Signed-off-by: shugeo <sgazeos@gmail.com> * lstsq wrapper and triangular_solve fixed * Added proper processing empty inputs and test. Signed-off-by: shugeo <sgazeos@gmail.com> * SequenceMask test * Build fixed * Added proper processing of empty inputs with solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Mapping added * Added check of input shapes with solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a couple of tests for lstsq op and minor changes with cuda helper for one.' Signed-off-by: shugeo <sgazeos@gmail.com> * Tests on * Refactored test for lstsq op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed test * Added another approach for lstsq op aka solve_ls. Signed-off-by: shugeo <sgazeos@gmail.com> * Finished cpu part for solve_ls op helpers. * Added helper for low triangular matrix inversion. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored alternate solve_ls cpu implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Removed alternate approach for solve_ls op. Added multithreading with matrix inversion. Signed-off-by: shugeo <sgazeos@gmail.com> * Assert fixed * Refactored multithreading for inverse matricies. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-02-28 09:37:26 +01:00
/*******************************************************************************
* Copyright (c) 2020 Konduit, K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#include <NDArray.h>
#include <MmulHelper.h>
#include <ShapeUtils.h>
#include <ConstantTadHelper.h>
#include "../triangular_solve.h"
#include "../lup.h"
#include "../qr.h"
#include "../lstsq.h"
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static __global__ void fillRegularizerKernel(T* ioMatrixData, Nd4jLong* ioMatrixShape, Nd4jLong* ioMatrixTads, Nd4jLong* ioMatrixOffsets, Nd4jLong batchSize, Nd4jLong rows, T const value) {
for (auto x = blockIdx.x; x < batchSize; x += gridDim.x) {
auto z = ioMatrixData + ioMatrixOffsets[x];
for (auto r = threadIdx.x; r < rows; r += blockDim.x) {
Nd4jLong pos[] = {r,r};
auto zIndex = shape::getOffset(ioMatrixTads, pos);
z[zIndex] = value;
}
}
}
template <typename T>
static void fillRegularizer(nd4j::LaunchContext* context, NDArray& ioMatrix, double const value) {
auto lastDimsTads = ConstantTadHelper::getInstance()->tadForDimensions(ioMatrix.shapeInfo(), {-2, -1});
auto stream = context->getCudaStream();
auto rows = ioMatrix.sizeAt(-2);
//auto cols = ioMatrix.sizeAt(-1);
fillRegularizerKernel<T><<<256, 256, 128, *stream>>>(ioMatrix.dataBuffer()->specialAsT<T>(), ioMatrix.specialShapeInfo(), lastDimsTads.specialShapeInfo(), lastDimsTads.specialOffsets(), lastDimsTads.numberOfTads(), rows, (T)value);
}
template <typename T>
int leastSquaresSolveFunctor_(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) {
if (fast) { // Cholesky decomposition approach
// Equation for solve A^T * Ax = A^T * b, so
// 1. Computing A2:
auto tAtShape = ShapeUtils::evalShapeForMatmul(leftInput->getShapeInfo(), leftInput->getShapeInfo(), true, false);
//tAtShape[tAtShape.size() - 2] = output->sizeAt(-2);
NDArray leftOutput(leftInput->ordering(), tAtShape, output->dataType(), context);
MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, false); // Computing A2 = A^T * A
// 2. Computing B' = A^T * b
auto rightOutput = output->ulike();
MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, false); // Computing B' = A^T * b
// 3. Regularization ( indeed A' = A2 - l2Regularizer * I)
if (l2Regularizer != 0.0) {
auto regularizer = leftOutput.ulike(); regularizer.nullify();
fillRegularizer<T>(context, regularizer, (T)l2Regularizer);
leftOutput += regularizer;
}
// 4. Cholesky decomposition -- output matrix is square and lower triangular
helpers::cholesky(context, &leftOutput, &leftOutput, true); // inplace decomposition
// 5. Solve two triangular systems:
auto rightB = rightOutput.ulike(); rightB.nullify();
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, false, &rightB);
helpers::adjointMatrix(context, &leftOutput, true, &leftOutput);
helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, output);
// All done
}
else { // QR decomposition approach
// Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal matrix, and R - upper triangular
// 1. QR decomposition
auto qShape = leftInput->getShapeAsVector();
auto rShape = leftInput->getShapeAsVector();
qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2);
NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), context);// = leftInput->ulike();
NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), context); // = rightInput->ulike();
helpers::qr(context, leftInput, &Q, &R, true);
// 2. b` = Q^t * b:
auto rightOutput = rightInput->ulike();
MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false);
// 3. Solve triangular system
helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, output);
}
return Status::OK();
}
int leastSquaresSolveFunctor(nd4j::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) {
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return leastSquaresSolveFunctor_, (context, leftInput, rightInput, l2Regularizer, fast, output), FLOAT_TYPES);
}
}
}
}