/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author Yurii Shyrma (iuriish@yahoo.com) // #include #include #include #include #include #include namespace sd { namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template static void sqrtmQuasiTrianDiag(const NDArray& matrixT, NDArray& sqrtT ) { const int rows = matrixT.sizeAt(0); for(int i = 0; i < rows; i++) { if (i == rows - 1 || matrixT.t(i+1, i) == (T)0) { const auto elemT = matrixT.t(i, i); if(elemT < (T)0) throw std::runtime_error("ops::helpers::Sqrtm::sqrtmQuasiTrianDiag: can't take sqrt of negative diagonal element of T matrix !"); sqrtT.r(i,i) = math::nd4j_sqrt(elemT); } else { EigenValsAndVecs es(matrixT({i,i+2, i,i+2}, true)); // es._Vecs {2,2,2}, es._Vals{2,2} const NDArray& vecs = es._Vecs; const NDArray& vals = es._Vals; const T& vecsReal00 = vecs.t(0,0,0); const T& vecsImag00 = vecs.t(0,0,1); const T& vecsReal01 = vecs.t(0,1,0); const T& vecsImag01 = vecs.t(0,1,1); const T& vecsReal10 = vecs.t(1,0,0); const T& vecsImag10 = vecs.t(1,0,1); const T& vecsReal11 = vecs.t(1,1,0); const T& vecsImag11 = vecs.t(1,1,1); // es.eigenvalues().cwiseSqrt().asDiagonal() T eigenValsSqrt[2][2]; eigenValsSqrt[0][0] = vals.t(0,0); eigenValsSqrt[0][1] = vals.t(0,1); eigenValsSqrt[1][0] = vals.t(1,0); eigenValsSqrt[1][1] = vals.t(1,1); EigenValsAndVecs::sqrtComplexNum(eigenValsSqrt[0][0], eigenValsSqrt[0][1]); EigenValsAndVecs::sqrtComplexNum(eigenValsSqrt[1][0], eigenValsSqrt[1][1]); // es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() T vecsElem[2][2][2]; EigenValsAndVecs::multiplyComplexNums(vecsReal00,vecsImag00, eigenValsSqrt[0][0],eigenValsSqrt[0][1], vecsElem[0][0][0],vecsElem[0][0][1]); EigenValsAndVecs::multiplyComplexNums(vecsReal01,vecsImag01, eigenValsSqrt[1][0],eigenValsSqrt[1][1], vecsElem[0][1][0],vecsElem[0][1][1]); EigenValsAndVecs::multiplyComplexNums(vecsReal10,vecsImag10, eigenValsSqrt[0][0],eigenValsSqrt[0][1], vecsElem[1][0][0],vecsElem[1][0][1]); EigenValsAndVecs::multiplyComplexNums(vecsReal11,vecsImag11, eigenValsSqrt[1][0],eigenValsSqrt[1][1], vecsElem[1][1][0],vecsElem[1][1][1]); // es.eigenvectors().inverse() T vecsElemInv[2][2][2]; T tempReal, tempImag, divisorReal, divisorImag; EigenValsAndVecs::multiplyComplexNums(vecsReal00,vecsImag00, vecsReal11,vecsImag11, divisorReal,divisorImag); EigenValsAndVecs::multiplyComplexNums(vecsReal01,vecsImag01, vecsReal10,vecsImag10, tempReal,tempImag); divisorReal -= tempReal; divisorImag -= tempImag; EigenValsAndVecs::divideComplexNums(vecsReal11,vecsImag11, divisorReal,divisorImag, vecsElemInv[0][0][0],vecsElemInv[0][0][1]); EigenValsAndVecs::divideComplexNums(-vecsReal01,-vecsImag01, divisorReal,divisorImag, vecsElemInv[0][1][0],vecsElemInv[0][1][1]); EigenValsAndVecs::divideComplexNums(-vecsReal10,-vecsImag10, divisorReal,divisorImag, vecsElemInv[1][0][0],vecsElemInv[1][0][1]); EigenValsAndVecs::divideComplexNums(vecsReal00,vecsImag00, divisorReal,divisorImag, vecsElemInv[1][1][0],vecsElemInv[1][1][1]); // result T result[2][2][2]; EigenValsAndVecs::multiplyComplexNums(vecsElem[0][0][0],vecsElem[0][0][1], vecsElemInv[0][0][0],vecsElemInv[0][0][1], tempReal,tempImag); EigenValsAndVecs::multiplyComplexNums(vecsElem[0][1][0],vecsElem[0][1][1], vecsElemInv[1][0][0],vecsElemInv[1][0][1], result[0][0][0],result[0][0][1]); result[0][0][0] += tempReal; EigenValsAndVecs::multiplyComplexNums(vecsElem[0][0][0],vecsElem[0][0][1], vecsElemInv[0][1][0],vecsElemInv[0][1][1], tempReal,tempImag); EigenValsAndVecs::multiplyComplexNums(vecsElem[0][1][0],vecsElem[0][1][1], vecsElemInv[1][1][0],vecsElemInv[1][1][1], result[0][1][0],result[0][1][1]); result[0][1][0] += tempReal; EigenValsAndVecs::multiplyComplexNums(vecsElem[1][0][0],vecsElem[1][0][1], vecsElemInv[0][0][0],vecsElemInv[0][0][1], tempReal,tempImag); EigenValsAndVecs::multiplyComplexNums(vecsElem[1][1][0],vecsElem[1][1][1], vecsElemInv[1][0][0],vecsElemInv[1][0][1], result[1][0][0],result[1][0][1]); result[1][0][0] += tempReal; EigenValsAndVecs::multiplyComplexNums(vecsElem[1][0][0],vecsElem[1][0][1], vecsElemInv[0][1][0],vecsElemInv[0][1][1], tempReal,tempImag); EigenValsAndVecs::multiplyComplexNums(vecsElem[1][1][0],vecsElem[1][1][1], vecsElemInv[1][1][0],vecsElemInv[1][1][1], result[1][1][0],result[1][1][1]); result[1][1][0] += tempReal; sqrtT.r(i,i) = result[0][0][0]; sqrtT.r(i,i+1) = result[0][1][0]; sqrtT.r(i+1,i) = result[1][0][0]; sqrtT.r(i+1,i+1) = result[1][1][0]; ++i; } } } ////////////////////////////////////////////////////////////////////////// // all matrices are {2,2} here template static void sqrtmQuasiTrianAuxEq(const NDArray& A, const NDArray& B, const NDArray& C, NDArray& X) { NDArray tempMatrix(A.ordering(), {4,4}, A.dataType(), A.getContext()); tempMatrix.r(0,0) = A.t(0,0) + B.t(0,0); tempMatrix.r(1,1) = A.t(0,0) + B.t(1,1); tempMatrix.r(2,2) = A.t(1,1) + B.t(0,0); tempMatrix.r(3,3) = A.t(1,1) + B.t(1,1); tempMatrix.r(0,1) = B.t(1,0); tempMatrix.r(0,2) = A.t(0,1); tempMatrix.r(1,0) = B.t(0,1); tempMatrix.r(1,3) = A.t(0,1); tempMatrix.r(2,0) = A.t(1,0); tempMatrix.r(2,3) = B.t(1,0); tempMatrix.r(3,1) = A.t(1,0); tempMatrix.r(3,2) = B.t(0,1); tempMatrix.r(0,3) = (T)0; tempMatrix.r(1,2) = (T)0; tempMatrix.r(2,1) = (T)0; tempMatrix.r(3,0) = (T)0; NDArray result(A.ordering(), {4,1}, A.dataType(), A.getContext()); result.r(0,0) = C.t(0,0); result.r(1,0) = C.t(0,1); result.r(2,0) = C.t(1,0); result.r(3,0) = C.t(1,1); FullPivLU::solve(tempMatrix, result, result); X.r(0,0) = result.t(0); X.r(0,1) = result.t(1); X.r(1,0) = result.t(2); X.r(1,1) = result.t(3); } ////////////////////////////////////////////////////////////////////////// template static void sqrtmQuasiTrianOffDiag(const NDArray& matrixT, NDArray& sqrtT ) { const int rows = matrixT.sizeAt(0); for (int j = 1; j < rows; j++) { if (matrixT.t(j, j-1) != (T)0) continue; for (int i = j - 1; i >= 0; i--) { if (i > 0 && matrixT.t(i, i-1) != (T)0) continue; const bool iBlockIs2x2 = (i < rows - 1) && (matrixT.t(i+1, i) != (T)0); const bool jBlockIs2x2 = (j < rows - 1) && (matrixT.t(j+1, j) != (T)0); if (iBlockIs2x2 && jBlockIs2x2) { NDArray A = sqrtT({i,i+2, i,i+2}, true); NDArray B = sqrtT({j,j+2, j,j+2}, true); NDArray X = matrixT({i,i+2, j,j+2}, true);//.dup(); if (j - i > 2) X -= mmul(sqrtT({i,i+2, i+2,j}, true), sqrtT({i+2,j, j,j+2}, true)); sqrtmQuasiTrianAuxEq(A, B, X, X); sqrtT.syncToDevice(); sqrtT({i,i+2, j,j+2}, true).assign(X); } else if (iBlockIs2x2 && !jBlockIs2x2) { NDArray rhs = matrixT({i,i+2, j,j+1}, true);//.dup(); if (j - i > 2) rhs -= mmul(sqrtT({i,i+2, i+2,j}, true), sqrtT({i+2,j, j,j+1}, true)); NDArray A(matrixT.ordering(), {2,2}, matrixT.dataType(), matrixT.getContext()); A.r(0,0) = A.r(1,1) = sqrtT.t(j,j); A.r(0,1) = A.r(1,0) = T(0); A += sqrtT({i,i+2, i,i+2}, true); FullPivLU::solve(A,rhs,rhs); // sqrtT.syncToDevice(); sqrtT({i,i+2, j,j+1}, true).assign(rhs); } else if (!iBlockIs2x2 && jBlockIs2x2) { NDArray rhs = matrixT({i,i+1, j,j+2}, true);//.dup(); if (j - i > 1) rhs -= mmul(sqrtT({i,i+1, i+1,j}, true), sqrtT({i+1,j, j,j+2}, true)); NDArray A(matrixT.ordering(), {2,2}, matrixT.dataType(), matrixT.getContext()); A.r(0,0) = A.r(1,1) = sqrtT.t(i,i); A.r(0,1) = A.r(1,0) = T(0); A += sqrtT({j,j+2, j,j+2}, true).transpose(); NDArray rhsT = rhs.transpose(); FullPivLU::solve(A,rhsT,rhsT); // sqrtT.syncToDevice(); sqrtT({i,i+1, j,j+2}, true).assign(rhs); } else if (!iBlockIs2x2 && !jBlockIs2x2) { T temp = mmul(sqrtT({i,i+1, i+1,j}), sqrtT({i+1,j, j,j+1})).t(0); // dot sqrtT.r(i,j) = (matrixT.t(i,j) - temp ) / (sqrtT.t(i,i) + sqrtT.t(j,j)); } } } } ////////////////////////////////////////////////////////////////////////// template void Sqrtm::calc(const NDArray& in, NDArray& out) { if(in.rankOf() != 2 || in.sizeAt(0) != in.sizeAt(1)) throw std::runtime_error("ops::helpers::Sqrtm::calc: input matrix must have rank 2 and be square !"); if(!out.isSameShape(in)) throw std::runtime_error("ops::helpers::Sqrtm::calc: output matrix must have the same shape as input one!"); if(in.lengthOf() == 1) { out.r(0) = math::nd4j_sqrt(in.t(0)); return; } ops::helpers::Schur schur(in); const NDArray& t1 = schur._T; const NDArray& t2 = schur._U; NDArray sqrtT = in.ulike(); sqrtT.nullify(); sqrtmQuasiTrianDiag(schur._T, sqrtT); sqrtmQuasiTrianOffDiag(schur._T, sqrtT); // out = U * sqrtT * U^T; NDArray temp = mmul(sqrtT, schur._U.transpose()); MmulHelper::mmul(&schur._U, &temp, &out); } template class ND4J_EXPORT Sqrtm; template class ND4J_EXPORT Sqrtm; template class ND4J_EXPORT Sqrtm; template class ND4J_EXPORT Sqrtm; } } }