276 lines
11 KiB
C++
276 lines
11 KiB
C++
|
/*******************************************************************************
|
||
|
* Copyright (c) 2020 Konduit K.K.
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
******************************************************************************/
|
||
|
|
||
|
//
|
||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||
|
//
|
||
|
|
||
|
#include <helpers/Sqrtm.h>
|
||
|
#include <ops/declarable/helpers/lup.h>
|
||
|
#include <helpers/EigenValsAndVecs.h>
|
||
|
#include <helpers/HessenbergAndSchur.h>
|
||
|
#include <helpers/FullPivLU.h>
|
||
|
#include <helpers/MmulHelper.h>
|
||
|
|
||
|
|
||
|
namespace sd {
|
||
|
namespace ops {
|
||
|
namespace helpers {
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
template <typename T>
|
||
|
static void sqrtmQuasiTrianDiag(const NDArray& matrixT, NDArray& sqrtT ) {
|
||
|
|
||
|
const int rows = matrixT.sizeAt(0);
|
||
|
|
||
|
for(int i = 0; i < rows; i++) {
|
||
|
|
||
|
if (i == rows - 1 || matrixT.t<T>(i+1, i) == (T)0) {
|
||
|
const auto elemT = matrixT.t<T>(i, i);
|
||
|
if(elemT < (T)0)
|
||
|
throw std::runtime_error("ops::helpers::Sqrtm::sqrtmQuasiTrianDiag: can't take sqrt of negative diagonal element of T matrix !");
|
||
|
sqrtT.r<T>(i,i) = math::nd4j_sqrt<T,T>(elemT);
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
EigenValsAndVecs<T> es(matrixT({i,i+2, i,i+2}, true)); // es._Vecs {2,2,2}, es._Vals{2,2}
|
||
|
|
||
|
const NDArray& vecs = es._Vecs;
|
||
|
const NDArray& vals = es._Vals;
|
||
|
|
||
|
const T& vecsReal00 = vecs.t<T>(0,0,0);
|
||
|
const T& vecsImag00 = vecs.t<T>(0,0,1);
|
||
|
const T& vecsReal01 = vecs.t<T>(0,1,0);
|
||
|
const T& vecsImag01 = vecs.t<T>(0,1,1);
|
||
|
const T& vecsReal10 = vecs.t<T>(1,0,0);
|
||
|
const T& vecsImag10 = vecs.t<T>(1,0,1);
|
||
|
const T& vecsReal11 = vecs.t<T>(1,1,0);
|
||
|
const T& vecsImag11 = vecs.t<T>(1,1,1);
|
||
|
|
||
|
// es.eigenvalues().cwiseSqrt().asDiagonal()
|
||
|
T eigenValsSqrt[2][2];
|
||
|
eigenValsSqrt[0][0] = vals.t<T>(0,0);
|
||
|
eigenValsSqrt[0][1] = vals.t<T>(0,1);
|
||
|
eigenValsSqrt[1][0] = vals.t<T>(1,0);
|
||
|
eigenValsSqrt[1][1] = vals.t<T>(1,1);
|
||
|
EigenValsAndVecs<T>::sqrtComplexNum(eigenValsSqrt[0][0], eigenValsSqrt[0][1]);
|
||
|
EigenValsAndVecs<T>::sqrtComplexNum(eigenValsSqrt[1][0], eigenValsSqrt[1][1]);
|
||
|
|
||
|
// es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal()
|
||
|
T vecsElem[2][2][2];
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal00,vecsImag00, eigenValsSqrt[0][0],eigenValsSqrt[0][1], vecsElem[0][0][0],vecsElem[0][0][1]);
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal01,vecsImag01, eigenValsSqrt[1][0],eigenValsSqrt[1][1], vecsElem[0][1][0],vecsElem[0][1][1]);
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal10,vecsImag10, eigenValsSqrt[0][0],eigenValsSqrt[0][1], vecsElem[1][0][0],vecsElem[1][0][1]);
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal11,vecsImag11, eigenValsSqrt[1][0],eigenValsSqrt[1][1], vecsElem[1][1][0],vecsElem[1][1][1]);
|
||
|
|
||
|
// es.eigenvectors().inverse()
|
||
|
T vecsElemInv[2][2][2];
|
||
|
|
||
|
T tempReal, tempImag, divisorReal, divisorImag;
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal00,vecsImag00, vecsReal11,vecsImag11, divisorReal,divisorImag);
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal01,vecsImag01, vecsReal10,vecsImag10, tempReal,tempImag);
|
||
|
divisorReal -= tempReal;
|
||
|
divisorImag -= tempImag;
|
||
|
|
||
|
EigenValsAndVecs<T>::divideComplexNums(vecsReal11,vecsImag11, divisorReal,divisorImag, vecsElemInv[0][0][0],vecsElemInv[0][0][1]);
|
||
|
EigenValsAndVecs<T>::divideComplexNums(-vecsReal01,-vecsImag01, divisorReal,divisorImag, vecsElemInv[0][1][0],vecsElemInv[0][1][1]);
|
||
|
EigenValsAndVecs<T>::divideComplexNums(-vecsReal10,-vecsImag10, divisorReal,divisorImag, vecsElemInv[1][0][0],vecsElemInv[1][0][1]);
|
||
|
EigenValsAndVecs<T>::divideComplexNums(vecsReal00,vecsImag00, divisorReal,divisorImag, vecsElemInv[1][1][0],vecsElemInv[1][1][1]);
|
||
|
|
||
|
// result
|
||
|
T result[2][2][2];
|
||
|
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][0][0],vecsElem[0][0][1], vecsElemInv[0][0][0],vecsElemInv[0][0][1], tempReal,tempImag);
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][1][0],vecsElem[0][1][1], vecsElemInv[1][0][0],vecsElemInv[1][0][1], result[0][0][0],result[0][0][1]);
|
||
|
result[0][0][0] += tempReal;
|
||
|
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][0][0],vecsElem[0][0][1], vecsElemInv[0][1][0],vecsElemInv[0][1][1], tempReal,tempImag);
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][1][0],vecsElem[0][1][1], vecsElemInv[1][1][0],vecsElemInv[1][1][1], result[0][1][0],result[0][1][1]);
|
||
|
result[0][1][0] += tempReal;
|
||
|
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][0][0],vecsElem[1][0][1], vecsElemInv[0][0][0],vecsElemInv[0][0][1], tempReal,tempImag);
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][1][0],vecsElem[1][1][1], vecsElemInv[1][0][0],vecsElemInv[1][0][1], result[1][0][0],result[1][0][1]);
|
||
|
result[1][0][0] += tempReal;
|
||
|
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][0][0],vecsElem[1][0][1], vecsElemInv[0][1][0],vecsElemInv[0][1][1], tempReal,tempImag);
|
||
|
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][1][0],vecsElem[1][1][1], vecsElemInv[1][1][0],vecsElemInv[1][1][1], result[1][1][0],result[1][1][1]);
|
||
|
result[1][1][0] += tempReal;
|
||
|
|
||
|
sqrtT.r<T>(i,i) = result[0][0][0];
|
||
|
sqrtT.r<T>(i,i+1) = result[0][1][0];
|
||
|
sqrtT.r<T>(i+1,i) = result[1][0][0];
|
||
|
sqrtT.r<T>(i+1,i+1) = result[1][1][0];
|
||
|
|
||
|
++i;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
// all matrices are {2,2} here
|
||
|
template <typename T>
|
||
|
static void sqrtmQuasiTrianAuxEq(const NDArray& A, const NDArray& B, const NDArray& C, NDArray& X) {
|
||
|
|
||
|
NDArray tempMatrix(A.ordering(), {4,4}, A.dataType(), A.getContext());
|
||
|
|
||
|
tempMatrix.r<T>(0,0) = A.t<T>(0,0) + B.t<T>(0,0);
|
||
|
tempMatrix.r<T>(1,1) = A.t<T>(0,0) + B.t<T>(1,1);
|
||
|
tempMatrix.r<T>(2,2) = A.t<T>(1,1) + B.t<T>(0,0);
|
||
|
tempMatrix.r<T>(3,3) = A.t<T>(1,1) + B.t<T>(1,1);
|
||
|
tempMatrix.r<T>(0,1) = B.t<T>(1,0);
|
||
|
tempMatrix.r<T>(0,2) = A.t<T>(0,1);
|
||
|
tempMatrix.r<T>(1,0) = B.t<T>(0,1);
|
||
|
tempMatrix.r<T>(1,3) = A.t<T>(0,1);
|
||
|
tempMatrix.r<T>(2,0) = A.t<T>(1,0);
|
||
|
tempMatrix.r<T>(2,3) = B.t<T>(1,0);
|
||
|
tempMatrix.r<T>(3,1) = A.t<T>(1,0);
|
||
|
tempMatrix.r<T>(3,2) = B.t<T>(0,1);
|
||
|
tempMatrix.r<T>(0,3) = (T)0;
|
||
|
tempMatrix.r<T>(1,2) = (T)0;
|
||
|
tempMatrix.r<T>(2,1) = (T)0;
|
||
|
tempMatrix.r<T>(3,0) = (T)0;
|
||
|
|
||
|
NDArray result(A.ordering(), {4,1}, A.dataType(), A.getContext());
|
||
|
result.r<T>(0,0) = C.t<T>(0,0);
|
||
|
result.r<T>(1,0) = C.t<T>(0,1);
|
||
|
result.r<T>(2,0) = C.t<T>(1,0);
|
||
|
result.r<T>(3,0) = C.t<T>(1,1);
|
||
|
|
||
|
FullPivLU<T>::solve(tempMatrix, result, result);
|
||
|
|
||
|
X.r<T>(0,0) = result.t<T>(0);
|
||
|
X.r<T>(0,1) = result.t<T>(1);
|
||
|
X.r<T>(1,0) = result.t<T>(2);
|
||
|
X.r<T>(1,1) = result.t<T>(3);
|
||
|
}
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
template <typename T>
|
||
|
static void sqrtmQuasiTrianOffDiag(const NDArray& matrixT, NDArray& sqrtT ) {
|
||
|
|
||
|
const int rows = matrixT.sizeAt(0);
|
||
|
|
||
|
for (int j = 1; j < rows; j++) {
|
||
|
|
||
|
if (matrixT.t<T>(j, j-1) != (T)0)
|
||
|
continue;
|
||
|
|
||
|
for (int i = j - 1; i >= 0; i--) {
|
||
|
|
||
|
if (i > 0 && matrixT.t<T>(i, i-1) != (T)0)
|
||
|
continue;
|
||
|
|
||
|
const bool iBlockIs2x2 = (i < rows - 1) && (matrixT.t<T>(i+1, i) != (T)0);
|
||
|
const bool jBlockIs2x2 = (j < rows - 1) && (matrixT.t<T>(j+1, j) != (T)0);
|
||
|
|
||
|
if (iBlockIs2x2 && jBlockIs2x2) {
|
||
|
|
||
|
NDArray A = sqrtT({i,i+2, i,i+2}, true);
|
||
|
NDArray B = sqrtT({j,j+2, j,j+2}, true);
|
||
|
NDArray X = matrixT({i,i+2, j,j+2}, true);//.dup();
|
||
|
|
||
|
if (j - i > 2)
|
||
|
X -= mmul(sqrtT({i,i+2, i+2,j}, true), sqrtT({i+2,j, j,j+2}, true));
|
||
|
|
||
|
sqrtmQuasiTrianAuxEq<T>(A, B, X, X);
|
||
|
|
||
|
sqrtT.syncToDevice();
|
||
|
sqrtT({i,i+2, j,j+2}, true).assign(X);
|
||
|
}
|
||
|
else if (iBlockIs2x2 && !jBlockIs2x2) {
|
||
|
|
||
|
NDArray rhs = matrixT({i,i+2, j,j+1}, true);//.dup();
|
||
|
|
||
|
if (j - i > 2)
|
||
|
rhs -= mmul(sqrtT({i,i+2, i+2,j}, true), sqrtT({i+2,j, j,j+1}, true));
|
||
|
|
||
|
NDArray A(matrixT.ordering(), {2,2}, matrixT.dataType(), matrixT.getContext());
|
||
|
A.r<T>(0,0) = A.r<T>(1,1) = sqrtT.t<T>(j,j);
|
||
|
A.r<T>(0,1) = A.r<T>(1,0) = T(0);
|
||
|
A += sqrtT({i,i+2, i,i+2}, true);
|
||
|
|
||
|
FullPivLU<T>::solve(A,rhs,rhs);
|
||
|
|
||
|
// sqrtT.syncToDevice();
|
||
|
sqrtT({i,i+2, j,j+1}, true).assign(rhs);
|
||
|
}
|
||
|
else if (!iBlockIs2x2 && jBlockIs2x2) {
|
||
|
|
||
|
NDArray rhs = matrixT({i,i+1, j,j+2}, true);//.dup();
|
||
|
|
||
|
if (j - i > 1)
|
||
|
rhs -= mmul(sqrtT({i,i+1, i+1,j}, true), sqrtT({i+1,j, j,j+2}, true));
|
||
|
|
||
|
NDArray A(matrixT.ordering(), {2,2}, matrixT.dataType(), matrixT.getContext());
|
||
|
A.r<T>(0,0) = A.r<T>(1,1) = sqrtT.t<T>(i,i);
|
||
|
A.r<T>(0,1) = A.r<T>(1,0) = T(0);
|
||
|
A += sqrtT({j,j+2, j,j+2}, true).transpose();
|
||
|
|
||
|
NDArray rhsT = rhs.transpose();
|
||
|
FullPivLU<T>::solve(A,rhsT,rhsT);
|
||
|
|
||
|
// sqrtT.syncToDevice();
|
||
|
sqrtT({i,i+1, j,j+2}, true).assign(rhs);
|
||
|
}
|
||
|
else if (!iBlockIs2x2 && !jBlockIs2x2) {
|
||
|
|
||
|
T temp = mmul(sqrtT({i,i+1, i+1,j}), sqrtT({i+1,j, j,j+1})).t<T>(0); // dot
|
||
|
sqrtT.r<T>(i,j) = (matrixT.t<T>(i,j) - temp ) / (sqrtT.t<T>(i,i) + sqrtT.t<T>(j,j));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
template <typename T>
|
||
|
void Sqrtm<T>::calc(const NDArray& in, NDArray& out) {
|
||
|
|
||
|
if(in.rankOf() != 2 || in.sizeAt(0) != in.sizeAt(1))
|
||
|
throw std::runtime_error("ops::helpers::Sqrtm::calc: input matrix must have rank 2 and be square !");
|
||
|
if(!out.isSameShape(in))
|
||
|
throw std::runtime_error("ops::helpers::Sqrtm::calc: output matrix must have the same shape as input one!");
|
||
|
|
||
|
if(in.lengthOf() == 1) {
|
||
|
out.r<T>(0) = math::nd4j_sqrt<T,T>(in.t<T>(0));
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
ops::helpers::Schur<T> schur(in);
|
||
|
|
||
|
const NDArray& t1 = schur._T;
|
||
|
const NDArray& t2 = schur._U;
|
||
|
|
||
|
NDArray sqrtT = in.ulike();
|
||
|
sqrtT.nullify();
|
||
|
|
||
|
sqrtmQuasiTrianDiag<T>(schur._T, sqrtT);
|
||
|
sqrtmQuasiTrianOffDiag<T>(schur._T, sqrtT);
|
||
|
|
||
|
// out = U * sqrtT * U^T;
|
||
|
NDArray temp = mmul(sqrtT, schur._U.transpose());
|
||
|
MmulHelper::mmul(&schur._U, &temp, &out);
|
||
|
}
|
||
|
|
||
|
template class ND4J_EXPORT Sqrtm<float>;
|
||
|
template class ND4J_EXPORT Sqrtm<float16>;
|
||
|
template class ND4J_EXPORT Sqrtm<bfloat16>;
|
||
|
template class ND4J_EXPORT Sqrtm<double>;
|
||
|
|
||
|
|
||
|
}
|
||
|
}
|
||
|
}
|