/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by Yurii Shyrma on 03.01.2018 // #ifndef LIBND4J_SVD_H #define LIBND4J_SVD_H #include #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { template class SVD { public: int _switchSize = 10; NDArray _m; NDArray _s; NDArray _u; NDArray _v; int _diagSize; bool _transp; bool _calcU; bool _calcV; bool _fullUV; /** * constructor */ SVD(const NDArray& matrix, const int switchSize, const bool calcV, const bool calcU, const bool fullUV); SVD(const NDArray& matrix, const int switchSize, const bool calcV, const bool calcU, const bool fullUV, const char t); void deflation1(int col1, int shift, int ind, int size); void deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, int ind2, int size); void deflation(int col1, int col2, int ind, int row1W, int col1W, int shift); // FIXME: proper T support required here T secularEq(const T diff, const NDArray& col0, const NDArray& diag, const NDArray &permut, const NDArray& diagShifted, const T shift); void calcSingVals(const NDArray& col0, const NDArray& diag, const NDArray& permut, NDArray& singVals, NDArray& shifts, NDArray& mus); void perturb(const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& zhat); void calcSingVecs(const NDArray& zhat, const NDArray& diag, const NDArray& perm, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& U, NDArray& V); void calcBlockSVD(int firstCol, int size, NDArray& U, NDArray& singVals, NDArray& V); void DivideAndConquer(int col1, int col2, int row1W, int col1W, int shift); void exchangeUV(const HHsequence& hhU, const HHsequence& hhV, const NDArray& U, const NDArray& V); void evalData(const NDArray& matrix); FORCEINLINE NDArray& getS(); FORCEINLINE NDArray& getU(); FORCEINLINE NDArray& getV(); }; ////////////////////////////////////////////////////////////////////////// template FORCEINLINE NDArray& SVD::getS() { return _s; } ////////////////////////////////////////////////////////////////////////// template FORCEINLINE NDArray& SVD::getU() { return _u; } ////////////////////////////////////////////////////////////////////////// template FORCEINLINE NDArray& SVD::getV() { return _v; } } } } #endif //LIBND4J_SVD_H