cavis/libnd4j/include/helpers/impl/Sqrtm.cpp

278 lines
12 KiB
C++

/* ******************************************************************************
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <helpers/Sqrtm.h>
#include <ops/declarable/helpers/lup.h>
#include <helpers/EigenValsAndVecs.h>
#include <helpers/HessenbergAndSchur.h>
#include <helpers/FullPivLU.h>
#include <helpers/MmulHelper.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void sqrtmQuasiTrianDiag(const NDArray& matrixT, NDArray& sqrtT ) {
const int rows = matrixT.sizeAt(0);
for(int i = 0; i < rows; i++) {
if (i == rows - 1 || matrixT.t<T>(i+1, i) == (T)0) {
const auto elemT = matrixT.t<T>(i, i);
if(elemT < (T)0)
throw std::runtime_error("ops::helpers::Sqrtm::sqrtmQuasiTrianDiag: can't take sqrt of negative diagonal element of T matrix !");
sqrtT.r<T>(i,i) = math::nd4j_sqrt<T,T>(elemT);
}
else {
EigenValsAndVecs<T> es(matrixT({i,i+2, i,i+2}, true)); // es._Vecs {2,2,2}, es._Vals{2,2}
const NDArray& vecs = es._Vecs;
const NDArray& vals = es._Vals;
const T& vecsReal00 = vecs.t<T>(0,0,0);
const T& vecsImag00 = vecs.t<T>(0,0,1);
const T& vecsReal01 = vecs.t<T>(0,1,0);
const T& vecsImag01 = vecs.t<T>(0,1,1);
const T& vecsReal10 = vecs.t<T>(1,0,0);
const T& vecsImag10 = vecs.t<T>(1,0,1);
const T& vecsReal11 = vecs.t<T>(1,1,0);
const T& vecsImag11 = vecs.t<T>(1,1,1);
// es.eigenvalues().cwiseSqrt().asDiagonal()
T eigenValsSqrt[2][2];
eigenValsSqrt[0][0] = vals.t<T>(0,0);
eigenValsSqrt[0][1] = vals.t<T>(0,1);
eigenValsSqrt[1][0] = vals.t<T>(1,0);
eigenValsSqrt[1][1] = vals.t<T>(1,1);
EigenValsAndVecs<T>::sqrtComplexNum(eigenValsSqrt[0][0], eigenValsSqrt[0][1]);
EigenValsAndVecs<T>::sqrtComplexNum(eigenValsSqrt[1][0], eigenValsSqrt[1][1]);
// es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal()
T vecsElem[2][2][2];
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal00,vecsImag00, eigenValsSqrt[0][0],eigenValsSqrt[0][1], vecsElem[0][0][0],vecsElem[0][0][1]);
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal01,vecsImag01, eigenValsSqrt[1][0],eigenValsSqrt[1][1], vecsElem[0][1][0],vecsElem[0][1][1]);
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal10,vecsImag10, eigenValsSqrt[0][0],eigenValsSqrt[0][1], vecsElem[1][0][0],vecsElem[1][0][1]);
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal11,vecsImag11, eigenValsSqrt[1][0],eigenValsSqrt[1][1], vecsElem[1][1][0],vecsElem[1][1][1]);
// es.eigenvectors().inverse()
T vecsElemInv[2][2][2];
T tempReal, tempImag, divisorReal, divisorImag;
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal00,vecsImag00, vecsReal11,vecsImag11, divisorReal,divisorImag);
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal01,vecsImag01, vecsReal10,vecsImag10, tempReal,tempImag);
divisorReal -= tempReal;
divisorImag -= tempImag;
EigenValsAndVecs<T>::divideComplexNums(vecsReal11,vecsImag11, divisorReal,divisorImag, vecsElemInv[0][0][0],vecsElemInv[0][0][1]);
EigenValsAndVecs<T>::divideComplexNums(-vecsReal01,-vecsImag01, divisorReal,divisorImag, vecsElemInv[0][1][0],vecsElemInv[0][1][1]);
EigenValsAndVecs<T>::divideComplexNums(-vecsReal10,-vecsImag10, divisorReal,divisorImag, vecsElemInv[1][0][0],vecsElemInv[1][0][1]);
EigenValsAndVecs<T>::divideComplexNums(vecsReal00,vecsImag00, divisorReal,divisorImag, vecsElemInv[1][1][0],vecsElemInv[1][1][1]);
// result
T result[2][2][2];
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][0][0],vecsElem[0][0][1], vecsElemInv[0][0][0],vecsElemInv[0][0][1], tempReal,tempImag);
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][1][0],vecsElem[0][1][1], vecsElemInv[1][0][0],vecsElemInv[1][0][1], result[0][0][0],result[0][0][1]);
result[0][0][0] += tempReal;
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][0][0],vecsElem[0][0][1], vecsElemInv[0][1][0],vecsElemInv[0][1][1], tempReal,tempImag);
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][1][0],vecsElem[0][1][1], vecsElemInv[1][1][0],vecsElemInv[1][1][1], result[0][1][0],result[0][1][1]);
result[0][1][0] += tempReal;
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][0][0],vecsElem[1][0][1], vecsElemInv[0][0][0],vecsElemInv[0][0][1], tempReal,tempImag);
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][1][0],vecsElem[1][1][1], vecsElemInv[1][0][0],vecsElemInv[1][0][1], result[1][0][0],result[1][0][1]);
result[1][0][0] += tempReal;
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][0][0],vecsElem[1][0][1], vecsElemInv[0][1][0],vecsElemInv[0][1][1], tempReal,tempImag);
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][1][0],vecsElem[1][1][1], vecsElemInv[1][1][0],vecsElemInv[1][1][1], result[1][1][0],result[1][1][1]);
result[1][1][0] += tempReal;
sqrtT.r<T>(i,i) = result[0][0][0];
sqrtT.r<T>(i,i+1) = result[0][1][0];
sqrtT.r<T>(i+1,i) = result[1][0][0];
sqrtT.r<T>(i+1,i+1) = result[1][1][0];
++i;
}
}
}
//////////////////////////////////////////////////////////////////////////
// all matrices are {2,2} here
template <typename T>
static void sqrtmQuasiTrianAuxEq(const NDArray& A, const NDArray& B, const NDArray& C, NDArray& X) {
NDArray tempMatrix(A.ordering(), {4,4}, A.dataType(), A.getContext());
tempMatrix.r<T>(0,0) = A.t<T>(0,0) + B.t<T>(0,0);
tempMatrix.r<T>(1,1) = A.t<T>(0,0) + B.t<T>(1,1);
tempMatrix.r<T>(2,2) = A.t<T>(1,1) + B.t<T>(0,0);
tempMatrix.r<T>(3,3) = A.t<T>(1,1) + B.t<T>(1,1);
tempMatrix.r<T>(0,1) = B.t<T>(1,0);
tempMatrix.r<T>(0,2) = A.t<T>(0,1);
tempMatrix.r<T>(1,0) = B.t<T>(0,1);
tempMatrix.r<T>(1,3) = A.t<T>(0,1);
tempMatrix.r<T>(2,0) = A.t<T>(1,0);
tempMatrix.r<T>(2,3) = B.t<T>(1,0);
tempMatrix.r<T>(3,1) = A.t<T>(1,0);
tempMatrix.r<T>(3,2) = B.t<T>(0,1);
tempMatrix.r<T>(0,3) = (T)0;
tempMatrix.r<T>(1,2) = (T)0;
tempMatrix.r<T>(2,1) = (T)0;
tempMatrix.r<T>(3,0) = (T)0;
NDArray result(A.ordering(), {4,1}, A.dataType(), A.getContext());
result.r<T>(0,0) = C.t<T>(0,0);
result.r<T>(1,0) = C.t<T>(0,1);
result.r<T>(2,0) = C.t<T>(1,0);
result.r<T>(3,0) = C.t<T>(1,1);
FullPivLU<T>::solve(tempMatrix, result, result);
X.r<T>(0,0) = result.t<T>(0);
X.r<T>(0,1) = result.t<T>(1);
X.r<T>(1,0) = result.t<T>(2);
X.r<T>(1,1) = result.t<T>(3);
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void sqrtmQuasiTrianOffDiag(const NDArray& matrixT, NDArray& sqrtT ) {
const int rows = matrixT.sizeAt(0);
for (int j = 1; j < rows; j++) {
if (matrixT.t<T>(j, j-1) != (T)0)
continue;
for (int i = j - 1; i >= 0; i--) {
if (i > 0 && matrixT.t<T>(i, i-1) != (T)0)
continue;
const bool iBlockIs2x2 = (i < rows - 1) && (matrixT.t<T>(i+1, i) != (T)0);
const bool jBlockIs2x2 = (j < rows - 1) && (matrixT.t<T>(j+1, j) != (T)0);
if (iBlockIs2x2 && jBlockIs2x2) {
NDArray A = sqrtT({i,i+2, i,i+2}, true);
NDArray B = sqrtT({j,j+2, j,j+2}, true);
NDArray X = matrixT({i,i+2, j,j+2}, true);//.dup();
if (j - i > 2)
X -= mmul(sqrtT({i,i+2, i+2,j}, true), sqrtT({i+2,j, j,j+2}, true));
sqrtmQuasiTrianAuxEq<T>(A, B, X, X);
sqrtT.syncToDevice();
sqrtT({i,i+2, j,j+2}, true).assign(X);
}
else if (iBlockIs2x2 && !jBlockIs2x2) {
NDArray rhs = matrixT({i,i+2, j,j+1}, true);//.dup();
if (j - i > 2)
rhs -= mmul(sqrtT({i,i+2, i+2,j}, true), sqrtT({i+2,j, j,j+1}, true));
NDArray A(matrixT.ordering(), {2,2}, matrixT.dataType(), matrixT.getContext());
A.r<T>(0,0) = A.r<T>(1,1) = sqrtT.t<T>(j,j);
A.r<T>(0,1) = A.r<T>(1,0) = T(0);
A += sqrtT({i,i+2, i,i+2}, true);
FullPivLU<T>::solve(A,rhs,rhs);
// sqrtT.syncToDevice();
sqrtT({i,i+2, j,j+1}, true).assign(rhs);
}
else if (!iBlockIs2x2 && jBlockIs2x2) {
NDArray rhs = matrixT({i,i+1, j,j+2}, true);//.dup();
if (j - i > 1)
rhs -= mmul(sqrtT({i,i+1, i+1,j}, true), sqrtT({i+1,j, j,j+2}, true));
NDArray A(matrixT.ordering(), {2,2}, matrixT.dataType(), matrixT.getContext());
A.r<T>(0,0) = A.r<T>(1,1) = sqrtT.t<T>(i,i);
A.r<T>(0,1) = A.r<T>(1,0) = T(0);
A += sqrtT({j,j+2, j,j+2}, true).transpose();
NDArray rhsT = rhs.transpose();
FullPivLU<T>::solve(A,rhsT,rhsT);
// sqrtT.syncToDevice();
sqrtT({i,i+1, j,j+2}, true).assign(rhs);
}
else if (!iBlockIs2x2 && !jBlockIs2x2) {
T temp = mmul(sqrtT({i,i+1, i+1,j}), sqrtT({i+1,j, j,j+1})).t<T>(0); // dot
sqrtT.r<T>(i,j) = (matrixT.t<T>(i,j) - temp ) / (sqrtT.t<T>(i,i) + sqrtT.t<T>(j,j));
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
void Sqrtm<T>::calc(const NDArray& in, NDArray& out) {
if(in.rankOf() != 2 || in.sizeAt(0) != in.sizeAt(1))
throw std::runtime_error("ops::helpers::Sqrtm::calc: input matrix must have rank 2 and be square !");
if(!out.isSameShape(in))
throw std::runtime_error("ops::helpers::Sqrtm::calc: output matrix must have the same shape as input one!");
if(in.lengthOf() == 1) {
out.r<T>(0) = math::nd4j_sqrt<T,T>(in.t<T>(0));
return;
}
ops::helpers::Schur<T> schur(in);
const NDArray& t1 = schur._T;
const NDArray& t2 = schur._U;
NDArray sqrtT = in.ulike();
sqrtT.nullify();
sqrtmQuasiTrianDiag<T>(schur._T, sqrtT);
sqrtmQuasiTrianOffDiag<T>(schur._T, sqrtT);
// out = U * sqrtT * U^T;
NDArray temp = mmul(sqrtT, schur._U.transpose());
MmulHelper::mmul(&schur._U, &temp, &out);
}
template class ND4J_EXPORT Sqrtm<float>;
template class ND4J_EXPORT Sqrtm<float16>;
template class ND4J_EXPORT Sqrtm<bfloat16>;
template class ND4J_EXPORT Sqrtm<double>;
}
}
}