/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author Yurii Shyrma (iuriish@yahoo.com) // #include #include namespace sd { namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template EigenValsAndVecs::EigenValsAndVecs(const NDArray& matrix) { if(matrix.rankOf() != 2) throw std::runtime_error("ops::helpers::EigenValsAndVecs constructor: input matrix must be 2D !"); if(matrix.sizeAt(0) != matrix.sizeAt(1)) throw std::runtime_error("ops::helpers::EigenValsAndVecs constructor: input array must be 2D square matrix !"); Schur schur(matrix); NDArray& schurMatrixU = schur._U; NDArray& schurMatrixT = schur._T; _Vecs = NDArray(matrix.ordering(), {schurMatrixU.sizeAt(1), schurMatrixU.sizeAt(1), 2}, matrix.dataType(), matrix.getContext()); _Vals = NDArray(matrix.ordering(), {matrix.sizeAt(1), 2}, matrix.dataType(), matrix.getContext()); // sequence of methods calls matters calcEigenVals(schurMatrixT); calcPseudoEigenVecs(schurMatrixT, schurMatrixU); // pseudo-eigenvectors are real and will be stored in schurMatrixU calcEigenVecs(schurMatrixU); } ////////////////////////////////////////////////////////////////////////// template void EigenValsAndVecs::calcEigenVals(const NDArray& schurMatrixT) { const int numOfCols = schurMatrixT.sizeAt(1); // calculate eigenvalues _Vals int i = 0; while (i < numOfCols) { if (i == numOfCols - 1 || schurMatrixT.t(i+1, i) == T(0.f)) { _Vals.r(i, 0) = schurMatrixT.t(i, i); // real part _Vals.r(i, 1) = T(0); // imaginary part if(!math::nd4j_isfin(_Vals.t(i, 0))) { throw std::runtime_error("ops::helpers::igenValsAndVec::calcEigenVals: got infinite eigen value !"); return; } ++i; } else { T p = T(0.5) * (schurMatrixT.t(i, i) - schurMatrixT.t(i+1, i+1)); T z; { T t0 = schurMatrixT.t(i+1, i); T t1 = schurMatrixT.t(i, i+1); T maxval = math::nd4j_max(math::nd4j_abs(p), math::nd4j_max(math::nd4j_abs(t0), math::nd4j_abs(t1))); t0 /= maxval; t1 /= maxval; T p0 = p / maxval; z = maxval * math::nd4j_sqrt(math::nd4j_abs(p0 * p0 + t0 * t1)); } _Vals.r(i, 0) = _Vals.r(i+1, 0) = schurMatrixT.t(i+1, i+1) + p; _Vals.r(i, 1) = z; _Vals.r(i+1,1) = -z; if(!(math::nd4j_isfin(_Vals.t(i,0)) && math::nd4j_isfin(_Vals.t(i+1,0)) && math::nd4j_isfin(_Vals.t(i,1))) && math::nd4j_isfin(_Vals.t(i+1,1))) { throw std::runtime_error("ops::helpers::igenValsAndVec::calcEigenVals: got infinite eigen value !"); return; } i += 2; } } } ////////////////////////////////////////////////////////////////////////// template void EigenValsAndVecs::calcPseudoEigenVecs(NDArray& schurMatrixT, NDArray& schurMatrixU) { const int numOfCols = schurMatrixU.sizeAt(1); T norm = 0; for (int j = 0; j < numOfCols; ++j) norm += schurMatrixT({j,j+1, math::nd4j_max(j-1, 0),numOfCols}).reduceNumber(reduce::ASum).template t(0); if (norm == T(0)) return; for (int n = numOfCols-1; n >= 0; n--) { T p = _Vals.t(n, 0); // real part T q = _Vals.t(n, 1); // imaginary part if(q == (T)0) { // not complex T lastr((T)0), lastw((T)0); int l = n; schurMatrixT.r(n, n) = T(1); for (int i = n-1; i >= 0; i--) { T w = schurMatrixT.t(i,i) - p; T r = mmul(schurMatrixT({i,i+1, l,n+1}, true), schurMatrixT({l,n+1, n,n+1}, true)).template t(0); // dot if (_Vals.t(i, 1) < T(0)) { lastw = w; lastr = r; } else { l = i; if (_Vals.t(i, 1) == T(0)) { if (w != T(0)) schurMatrixT.r(i, n) = -r / w; else schurMatrixT.r(i, n) = -r / (DataTypeUtils::eps() * norm); } else { T x = schurMatrixT.t(i, i+1); T y = schurMatrixT.t(i+1, i); T denom = (_Vals.t(i, 0) - p) * (_Vals.t(i, 0) - p) + _Vals.t(i, 1) * _Vals.t(i, 1); T t = (x * lastr - lastw * r) / denom; schurMatrixT.r(i, n) = t; if (math::nd4j_abs(x) > math::nd4j_abs(lastw)) schurMatrixT.r(i+1, n) = (-r - w * t) / x; else schurMatrixT.r(i+1, n) = (-lastr - y * t) / lastw; } T t = math::nd4j_abs(schurMatrixT.t(i, n)); if((DataTypeUtils::eps() * t) * t > T(1)) schurMatrixT({schurMatrixT.sizeAt(0)-numOfCols+i,-1, n,n+1}) /= t; } } } else if(q < T(0) && n > 0) { // complex T lastra(0), lastsa(0), lastw(0); int l = n - 1; if(math::nd4j_abs(schurMatrixT.t(n, n-1)) > math::nd4j_abs(schurMatrixT.t(n-1, n))) { schurMatrixT.r(n-1, n-1) = q / schurMatrixT.t(n, n-1); schurMatrixT.r(n-1, n) = -(schurMatrixT.t(n, n) - p) / schurMatrixT.t(n, n-1); } else { divideComplexNums(T(0),-schurMatrixT.t(n-1,n), schurMatrixT.t(n-1,n-1)-p,q, schurMatrixT.r(n-1,n-1),schurMatrixT.r(n-1,n)); } schurMatrixT.r(n,n-1) = T(0); schurMatrixT.r(n,n) = T(1); for (int i = n-2; i >= 0; i--) { T ra = mmul(schurMatrixT({i,i+1, l,n+1}, true), schurMatrixT({l,n+1, n-1,n}, true)).template t(0); // dot T sa = mmul(schurMatrixT({i,i+1, l,n+1}, true), schurMatrixT({l,n+1, n,n+1}, true)).template t(0); // dot T w = schurMatrixT.t(i,i) - p; if (_Vals.t(i, 1) < T(0)) { lastw = w; lastra = ra; lastsa = sa; } else { l = i; if (_Vals.t(i, 1) == T(0)) { divideComplexNums(-ra,-sa, w,q, schurMatrixT.r(i,n-1),schurMatrixT.r(i,n)); } else { T x = schurMatrixT.t(i,i+1); T y = schurMatrixT.t(i+1,i); T vr = (_Vals.t(i, 0) - p) * (_Vals.t(i, 0) - p) + _Vals.t(i, 1) * _Vals.t(i, 1) - q * q; T vi = (_Vals.t(i, 0) - p) * T(2) * q; if ((vr == T(0)) && (vi == T(0))) vr = DataTypeUtils::eps() * norm * (math::nd4j_abs(w) + math::nd4j_abs(q) + math::nd4j_abs(x) + math::nd4j_abs(y) + math::nd4j_abs(lastw)); divideComplexNums(x*lastra-lastw*ra+q*sa,x*lastsa-lastw*sa-q*ra, vr,vi, schurMatrixT.r(i,n-1),schurMatrixT.r(i,n)); if(math::nd4j_abs(x) > (math::nd4j_abs(lastw) + math::nd4j_abs(q))) { schurMatrixT.r(i+1,n-1) = (-ra - w * schurMatrixT.t(i,n-1) + q * schurMatrixT.t(i,n)) / x; schurMatrixT.r(i+1,n) = (-sa - w * schurMatrixT.t(i,n) - q * schurMatrixT.t(i,n-1)) / x; } else divideComplexNums(-lastra-y*schurMatrixT.t(i,n-1),-lastsa-y*schurMatrixT.t(i,n), lastw,q, schurMatrixT.r(i+1,n-1),schurMatrixT.r(i+1,n)); } T t = math::nd4j_max(math::nd4j_abs(schurMatrixT.t(i, n-1)), math::nd4j_abs(schurMatrixT.t(i,n))); if ((DataTypeUtils::eps() * t) * t > T(1)) schurMatrixT({i,numOfCols, n-1,n+1}) /= t; } } n--; } else throw std::runtime_error("ops::helpers::EigenValsAndVecs::calcEigenVecs: internal bug !"); } for (int j = numOfCols-1; j >= 0; j--) schurMatrixU({0,0, j,j+1}, true).assign( mmul(schurMatrixU({0,0, 0,j+1}, true), schurMatrixT({0,j+1, j,j+1}, true)) ); } ////////////////////////////////////////////////////////////////////////// template void EigenValsAndVecs::calcEigenVecs(const NDArray& schurMatrixU) { const T precision = T(2) * DataTypeUtils::eps(); const int numOfCols = schurMatrixU.sizeAt(1); for (int j = 0; j < numOfCols; ++j) { if(math::nd4j_abs(_Vals.t(j, 1)) <= math::nd4j_abs(_Vals.t(j, 0)) * precision || j+1 == numOfCols) { // real _Vecs.syncToDevice(); _Vecs({0,0, j,j+1, 0,1}).assign(schurMatrixU({0,0, j,j+1})); _Vecs({0,0, j,j+1, 1,2}) = (T)0; // normalize const T norm2 = _Vecs({0,0, j,j+1, 0,1}).reduceNumber(reduce::SquaredNorm).template t(0); if(norm2 > (T)0) _Vecs({0,0, j,j+1, 0,1}) /= math::nd4j_sqrt(norm2); } else { // complex for (int i = 0; i < numOfCols; ++i) { _Vecs.r(i, j, 0) = _Vecs.r(i, j+1, 0) = schurMatrixU.t(i, j); _Vecs.r(i, j, 1) = schurMatrixU.t(i, j+1); _Vecs.r(i, j+1, 1) = -schurMatrixU.t(i, j+1); } // normalize T norm2 = _Vecs({0,0, j,j+1, 0,0}).reduceNumber(reduce::SquaredNorm).template t(0); if(norm2 > (T)0) _Vecs({0,0, j,j+1, 0,0}) /= math::nd4j_sqrt(norm2); // normalize norm2 = _Vecs({0,0, j+1,j+2, 0,0}).reduceNumber(reduce::SquaredNorm).template t(0); if(norm2 > (T)0) _Vecs({0,0, j+1,j+2, 0,0}) /= math::nd4j_sqrt(norm2); ++j; } } } template class ND4J_EXPORT EigenValsAndVecs; template class ND4J_EXPORT EigenValsAndVecs; template class ND4J_EXPORT EigenValsAndVecs; template class ND4J_EXPORT EigenValsAndVecs; } } }