/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by Yurii Shyrma on 02.01.2018 // #ifndef LIBND4J_HHSEQUENCE_H #define LIBND4J_HHSEQUENCE_H #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { class HHsequence { public: /* * matrix containing the Householder vectors */ const NDArray& _vectors; /* * vector containing the Householder coefficients */ const NDArray& _coeffs; /* * shift of the Householder sequence */ int _shift; /* * length of the Householder sequence */ int _diagSize; /* * type of sequence, type = 'u' (acting on columns, left) or type = 'v' (acting on rows, right) */ char _type; /* * constructor */ HHsequence(const NDArray& vectors, const NDArray& coeffs, const char type); /** * this method mathematically multiplies input matrix on Householder sequence from the left H0*H1*...Hn * matrix * * matrix - input matrix to be multiplied */ template void mulLeft_(NDArray& matrix); void mulLeft(NDArray& matrix); NDArray getTail(const int idx) const; template void applyTo_(NDArray& dest); void applyTo(NDArray& dest); FORCEINLINE int rows() const; }; ////////////////////////////////////////////////////////////////////////// FORCEINLINE int HHsequence::rows() const { return _type == 'u' ? _vectors.sizeAt(0) : _vectors.sizeAt(1); } } } } #endif //LIBND4J_HHSEQUENCE_H