cavis/libnd4j/include/helpers/HessenbergAndSchur.h

103 lines
2.9 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#ifndef LIBND4J_HESSENBERGANDSCHUR_H
#define LIBND4J_HESSENBERGANDSCHUR_H
#include <array/NDArray.h>
namespace sd {
namespace ops {
namespace helpers {
// this class implements Hessenberg decomposition of square matrix using orthogonal similarity transformation
// A = Q H Q^T
// Q - orthogonal matrix
// H - Hessenberg matrix
template <typename T>
class Hessenberg {
// suppose we got input square NxN matrix
public:
NDArray _Q; // {N,N}
NDArray _H; // {N,N}
explicit Hessenberg(const NDArray& matrix);
private:
void evalData();
};
// this class implements real Schur decomposition of square matrix using orthogonal similarity transformation
// A = U T U^T
// T - real quasi-upper-triangular matrix - block upper triangular matrix where the blocks on the diagonal are 1×1 or 2×2 with complex eigenvalues
// U - real orthogonal matrix
template <typename T>
class Schur {
// suppose we got input square NxN matrix
public:
NDArray _T; // {N,N}
NDArray _U; // {N,N}
explicit Schur(const NDArray& matrix);
void splitTwoRows(const int ind, const T shift);
void calcShift(const int ind, const int iter, T& shift, NDArray& shiftInfo);
void initFrancisQR(const int ind1, const int ind2, const NDArray& shiftVec, int& ind3, NDArray& householderVec);
void doFrancisQR(const int ind1, const int ind2, const int ind3, const NDArray& householderVec);
void calcFromHessenberg();
private:
static const int _maxItersPerRow = 40;
void evalData(const NDArray& matrix);
//////////////////////////////////////////////////////////////////////////
FORCEINLINE int getSmallSubdiagEntry(const int inInd) {
int outInd = inInd;
while (outInd > 0) {
T factor = math::nd4j_abs<T>(_T.t<T>(outInd-1, outInd-1)) + math::nd4j_abs<T>(_T.t<T>(outInd, outInd));
if (math::nd4j_abs<T>(_T.t<T>(outInd, outInd-1)) <= DataTypeUtils::eps<T>() * factor)
break;
outInd--;
}
return outInd;
}
};
}
}
}
#endif //LIBND4J_HESSENBERGANDSCHUR_H