cavis/libnd4j/include/helpers/impl/hhColPivQR.cpp

150 lines
4.9 KiB
C++
Raw Normal View History

2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
Shyrma sqrtm (#429) * - start working on implementation of sqrtm op Signed-off-by: Yurii <iuriish@yahoo.com> * - improving householder procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - further polishing householder stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing hh pivoting qr procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing BiDiagonalUp procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing householder sequence class Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing jacobi svd class Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing svd stuff 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing svd stuff 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing class which performs Hessenberg decomposition of square matrix Signed-off-by: Yurii <iuriish@yahoo.com> * - add static method to JacobiSVD class which makes the continuous Givens rotation generation algorithm Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing auxiliary methods of Schur decomp class Signed-off-by: Yurii <iuriish@yahoo.com> * some references here and there Signed-off-by: raver119 <raver119@gmail.com> * - trying figure out difference between eigen and our Schur alg Signed-off-by: Yurii <iuriish@yahoo.com> * - testing fixing bugs in Schur decomposition op Signed-off-by: Yurii <iuriish@yahoo.com> * - start to implement class which performs calculation of eigen values and vectors Signed-off-by: Yurii <iuriish@yahoo.com> * - add to EigenValsAndVecs method which calculates complex eigen vectors Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in EigenValsAndVecs class Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing triangularSolver class Signed-off-by: Yurii <iuriish@yahoo.com> * Added a 2D routine for triangular systems solve. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored triangularSolve2D routine and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored another test for triangularSolve2D. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored test for triangularSolve for vector-bar case. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored triangularSolve2D routine and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * - implementation of FullPivLU class Signed-off-by: Yurii <iuriish@yahoo.com> * - fix bugs in FullPivLU::solve method Signed-off-by: Yurii <iuriish@yahoo.com> * - correct permutation vector in FullPivLU::solve Signed-off-by: Yurii <iuriish@yahoo.com> * - correct include headers Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation of Sqrtm class Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in Sqrtm class Signed-off-by: Yurii <iuriish@yahoo.com> * - include sqrtm classes to cuda folder, investigate in what places synchronization doesn't work Signed-off-by: Yurii <iuriish@yahoo.com> * Added implementation for cuda triangularSolve2D and also refactored triangularSolve2D for cpu. Signed-off-by: shugeo <sgazeos@gmail.com> * Eliminated waste implementations. Signed-off-by: shugeo <sgazeos@gmail.com> * - make offset calculation faster in t<> methods Signed-off-by: Yurii <iuriish@yahoo.com> * - rename refference T& NDArray::t<> method Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on cuda sqrtm Signed-off-by: Yurii <iuriish@yahoo.com> * - provide correct synchronization to device in Sqrtm class Signed-off-by: Yurii <iuriish@yahoo.com> * - add tests for sqrtm op Signed-off-by: Yurii <iuriish@yahoo.com> * - correct fails which appeared while testing on jenkins Signed-off-by: Yurii <iuriish@yahoo.com> * - trying to find out mistake in svd::deflation method Signed-off-by: Yurii <iuriish@yahoo.com> * Revert "- trying to find out mistake in svd::deflation method" This reverts commit 19d37baddbc509028e4bc67bc932fe7449becdb6. * Revert "- trying to find out mistake in svd::deflation method" This reverts commit 19d37baddbc509028e4bc67bc932fe7449becdb6. Signed-off-by: Yurii <iuriish@yahoo.com> * - change call semantic of r<> and t<> methods Signed-off-by: Yurii <iuriish@yahoo.com> * - ged rid of ambiguity in * operator overloads for windows buikd Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of ambiguity in * operator overloads for windows build 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of ambiguity in * operator overloads for windows build 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts with master Signed-off-by: Yurii <iuriish@yahoo.com> * cmakelists updated Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - minor fix in merge cpu helper - make use of reference getter Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com>
2020-05-14 17:06:13 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
Shyrma sqrtm (#429) * - start working on implementation of sqrtm op Signed-off-by: Yurii <iuriish@yahoo.com> * - improving householder procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - further polishing householder stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing hh pivoting qr procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing BiDiagonalUp procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing householder sequence class Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing jacobi svd class Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing svd stuff 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing svd stuff 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing class which performs Hessenberg decomposition of square matrix Signed-off-by: Yurii <iuriish@yahoo.com> * - add static method to JacobiSVD class which makes the continuous Givens rotation generation algorithm Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing auxiliary methods of Schur decomp class Signed-off-by: Yurii <iuriish@yahoo.com> * some references here and there Signed-off-by: raver119 <raver119@gmail.com> * - trying figure out difference between eigen and our Schur alg Signed-off-by: Yurii <iuriish@yahoo.com> * - testing fixing bugs in Schur decomposition op Signed-off-by: Yurii <iuriish@yahoo.com> * - start to implement class which performs calculation of eigen values and vectors Signed-off-by: Yurii <iuriish@yahoo.com> * - add to EigenValsAndVecs method which calculates complex eigen vectors Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in EigenValsAndVecs class Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing triangularSolver class Signed-off-by: Yurii <iuriish@yahoo.com> * Added a 2D routine for triangular systems solve. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored triangularSolve2D routine and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored another test for triangularSolve2D. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored test for triangularSolve for vector-bar case. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored triangularSolve2D routine and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * - implementation of FullPivLU class Signed-off-by: Yurii <iuriish@yahoo.com> * - fix bugs in FullPivLU::solve method Signed-off-by: Yurii <iuriish@yahoo.com> * - correct permutation vector in FullPivLU::solve Signed-off-by: Yurii <iuriish@yahoo.com> * - correct include headers Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation of Sqrtm class Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in Sqrtm class Signed-off-by: Yurii <iuriish@yahoo.com> * - include sqrtm classes to cuda folder, investigate in what places synchronization doesn't work Signed-off-by: Yurii <iuriish@yahoo.com> * Added implementation for cuda triangularSolve2D and also refactored triangularSolve2D for cpu. Signed-off-by: shugeo <sgazeos@gmail.com> * Eliminated waste implementations. Signed-off-by: shugeo <sgazeos@gmail.com> * - make offset calculation faster in t<> methods Signed-off-by: Yurii <iuriish@yahoo.com> * - rename refference T& NDArray::t<> method Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on cuda sqrtm Signed-off-by: Yurii <iuriish@yahoo.com> * - provide correct synchronization to device in Sqrtm class Signed-off-by: Yurii <iuriish@yahoo.com> * - add tests for sqrtm op Signed-off-by: Yurii <iuriish@yahoo.com> * - correct fails which appeared while testing on jenkins Signed-off-by: Yurii <iuriish@yahoo.com> * - trying to find out mistake in svd::deflation method Signed-off-by: Yurii <iuriish@yahoo.com> * Revert "- trying to find out mistake in svd::deflation method" This reverts commit 19d37baddbc509028e4bc67bc932fe7449becdb6. * Revert "- trying to find out mistake in svd::deflation method" This reverts commit 19d37baddbc509028e4bc67bc932fe7449becdb6. Signed-off-by: Yurii <iuriish@yahoo.com> * - change call semantic of r<> and t<> methods Signed-off-by: Yurii <iuriish@yahoo.com> * - ged rid of ambiguity in * operator overloads for windows buikd Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of ambiguity in * operator overloads for windows build 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of ambiguity in * operator overloads for windows build 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts with master Signed-off-by: Yurii <iuriish@yahoo.com> * cmakelists updated Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - minor fix in merge cpu helper - make use of reference getter Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com>
2020-05-14 17:06:13 +02:00
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by Yurii Shyrma on 11.01.2018
//
#include <helpers/hhColPivQR.h>
#include <helpers/householder.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
HHcolPivQR::HHcolPivQR(const NDArray& matrix) {
_qr = matrix.dup();
_diagSize = math::nd4j_min<int>(matrix.sizeAt(0), matrix.sizeAt(1));
_coeffs = NDArray(matrix.ordering(), {1, _diagSize}, matrix.dataType(), matrix.getContext());
_permut = NDArray(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext());
evalData();
}
void HHcolPivQR::evalData() {
BUILD_SINGLE_SELECTOR(_qr.dataType(), _evalData, (), FLOAT_TYPES);
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
void HHcolPivQR::_evalData() {
const int rows = _qr.sizeAt(0);
const int cols = _qr.sizeAt(1);
NDArray transp(_qr.ordering(), {cols}/*{1, cols}*/, _qr.dataType(), _qr.getContext());
NDArray normsUpd(_qr.ordering(), {cols}/*{1, cols}*/, _qr.dataType(), _qr.getContext());
NDArray normsDir(_qr.ordering(), {cols}/*{1, cols}*/, _qr.dataType(), _qr.getContext());
int transpNum = 0;
for (int k = 0; k < cols; ++k)
normsDir.r<T>(k) = normsUpd.r<T>(k) = _qr({0,0, k,k+1}).reduceNumber(reduce::Norm2).t<T>(0);
T normScaled = (normsUpd.reduceNumber(reduce::Max)).t<T>(0) * DataTypeUtils::eps<T>();
T threshold1 = normScaled * normScaled / (T)rows;
T threshold2 = math::nd4j_sqrt<T,T>(DataTypeUtils::eps<T>());
T nonZeroPivots = _diagSize;
T maxPivot = 0.;
for(int k = 0; k < _diagSize; ++k) {
int biggestColIndex = normsUpd({k,-1}).indexReduceNumber(indexreduce::IndexMax).e<int>(0);
T biggestColNorm = normsUpd({k,-1}).reduceNumber(reduce::Max).t<T>(0);
T biggestColSqNorm = biggestColNorm * biggestColNorm;
biggestColIndex += k;
if(nonZeroPivots == (T)_diagSize && biggestColSqNorm < threshold1 * (T)(rows-k))
nonZeroPivots = k;
transp.r<T>(k) = (T)biggestColIndex;
if(k != biggestColIndex) {
NDArray temp1(_qr({0,0, k,k+1}));
NDArray temp2(_qr({0,0, biggestColIndex,biggestColIndex+1}));
temp1.swapUnsafe(temp2);
math::nd4j_swap<T>(normsUpd.r<T>(k), normsUpd.r<T>(biggestColIndex));
math::nd4j_swap<T>(normsDir.r<T>(k), normsDir.r<T>(biggestColIndex));
++transpNum;
}
T normX, c;
NDArray qrBlock = _qr({k,rows, k,k+1});
Householder<T>::evalHHmatrixDataI(qrBlock, c, normX);
_coeffs.r<T>(k) = c;
_qr.r<T>(k,k) = normX;
T max = math::nd4j_abs<T>(normX);
if(max > maxPivot)
maxPivot = max;
if(k < rows && (k+1) < cols) {
NDArray qrBlock = _qr({k,rows, k+1,cols}, true);
NDArray tail = _qr({k+1,rows, k, k+1}, true);
Householder<T>::mulLeft(qrBlock, tail, _coeffs.t<T>(k));
}
for (int j = k + 1; j < cols; ++j) {
if (normsUpd.t<T>(j) != (T)0.f) {
T temp = math::nd4j_abs<T>(_qr.t<T>(k, j)) / normsUpd.t<T>(j);
temp = ((T)1. + temp) * ((T)1. - temp);
temp = temp < (T)0. ? (T)0. : temp;
T temp2 = temp * normsUpd.t<T>(j) * normsUpd.t<T>(j) / (normsDir.t<T>(j)*normsDir.t<T>(j));
if (temp2 <= threshold2) {
if(k+1 < rows && j < cols)
normsDir.r<T>(j) = _qr({k+1,rows, j,j+1}).reduceNumber(reduce::Norm2).t<T>(0);
normsUpd.r<T>(j) = normsDir.t<T>(j);
}
else
normsUpd.r<T>(j) = normsUpd.t<T>(j) * math::nd4j_sqrt<T, T>(temp);
}
}
}
_permut.setIdentity();
for(int k = 0; k < _diagSize; ++k) {
int idx = transp.e<int>(k);
NDArray temp1 = _permut({0,0, k, k+1});
NDArray temp2 = _permut({0,0, idx,idx+1});
temp1.swapUnsafe(temp2);
}
}
BUILD_SINGLE_TEMPLATE(template void HHcolPivQR::_evalData, (), FLOAT_TYPES);
}
}
}