cavis/libnd4j/include/helpers/cpu/hhColPivQR.cpp

172 lines
5.8 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by Yurii Shyrma on 11.01.2018
//
#include <helpers/hhColPivQR.h>
#include <helpers/householder.h>
#include <array/NDArrayFactory.h>
2019-06-06 14:21:15 +02:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
HHcolPivQR::HHcolPivQR(const NDArray& matrix) {
_qr = matrix;
_diagSize = math::nd4j_min<int>(matrix.sizeAt(0), matrix.sizeAt(1));
_coeffs = NDArrayFactory::create(matrix.ordering(), {1, _diagSize}, matrix.dataType(), matrix.getContext());
_permut = NDArrayFactory::create(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext());
evalData();
}
void HHcolPivQR::evalData() {
BUILD_SINGLE_SELECTOR(_qr.dataType(), _evalData, (), FLOAT_TYPES);
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
void HHcolPivQR::_evalData() {
int rows = _qr.sizeAt(0);
int cols = _qr.sizeAt(1);
auto transp = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext());
auto normsUpd = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext());
auto normsDir = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext());
int transpNum = 0;
for (int k = 0; k < cols; ++k) {
T norm = _qr({0,0, k,k+1}).reduceNumber(reduce::Norm2).e<T>(0);
normsDir.p<T>(k, norm);
normsUpd.p<T>(k, norm);
}
T normScaled = (normsUpd.reduceNumber(reduce::Max)).e<T>(0) * DataTypeUtils::eps<T>();
T threshold1 = normScaled * normScaled / (T)rows;
T threshold2 = math::nd4j_sqrt<T,T>(DataTypeUtils::eps<T>());
T nonZeroPivots = _diagSize;
T maxPivot = 0.;
for(int k = 0; k < _diagSize; ++k) {
int biggestColIndex = normsUpd({0,0, k,-1}).indexReduceNumber(indexreduce::IndexMax).e<int>(0);
T biggestColNorm = normsUpd({0,0, k,-1}).reduceNumber(reduce::Max).e<T>(0);
T biggestColSqNorm = biggestColNorm * biggestColNorm;
biggestColIndex += k;
if(nonZeroPivots == (T)_diagSize && biggestColSqNorm < threshold1 * (T)(rows-k))
nonZeroPivots = k;
transp.p<T>(k, (T)biggestColIndex);
if(k != biggestColIndex) {
auto temp1 = new NDArray(_qr({0,0, k,k+1}, true));
auto temp2 = new NDArray(_qr({0,0, biggestColIndex,biggestColIndex+1}, true));
auto temp3 = *temp1;
temp1->assign(temp2);
temp2->assign(temp3);
delete temp1;
delete temp2;
T e0 = normsUpd.e<T>(k);
T e1 = normsUpd.e<T>(biggestColIndex);
normsUpd.p(k, e1);
normsUpd.p(biggestColIndex, e0);
//math::nd4j_swap<T>(normsUpd(k), normsUpd(biggestColIndex));
e0 = normsDir.e<T>(k);
e1 = normsDir.e<T>(biggestColIndex);
normsDir.p(k, e1);
normsDir.p(biggestColIndex, e0);
//math::nd4j_swap<T>(normsDir(k), normsDir(biggestColIndex));
++transpNum;
}
T normX;
NDArray* qrBlock = new NDArray(_qr({k,rows, k,k+1}, true));
T c;
Householder<T>::evalHHmatrixDataI(*qrBlock, c, normX);
_coeffs.p<T>(k, c);
delete qrBlock;
_qr.p<T>(k,k, normX);
T max = math::nd4j_abs<T>(normX);
if(max > maxPivot)
maxPivot = max;
if(k < rows && (k+1) < cols) {
qrBlock = new NDArray(_qr({k, rows, k+1,cols}, true));
auto tail = new NDArray(_qr({k+1,rows, k, k+1}, true));
Householder<T>::mulLeft(*qrBlock, *tail, _coeffs.e<T>(k));
delete qrBlock;
delete tail;
}
for (int j = k + 1; j < cols; ++j) {
if (normsUpd.e<T>(j) != (T)0.f) {
T temp = math::nd4j_abs<T>(_qr.e<T>(k, j)) / normsUpd.e<T>(j);
temp = (1. + temp) * (1. - temp);
temp = temp < (T)0. ? (T)0. : temp;
T temp2 = temp * normsUpd.e<T>(j) * normsUpd.e<T>(j) / (normsDir.e<T>(j)*normsDir.e<T>(j));
if (temp2 <= threshold2) {
if(k+1 < rows && j < cols)
normsDir.p<T>(j, _qr({k+1,rows, j,j+1}).reduceNumber(reduce::Norm2).e<T>(0));
normsUpd.p<T>(j, normsDir.e<T>(j));
}
else
normsUpd.p<T>(j, normsUpd.e<T>(j) * math::nd4j_sqrt<T, T>(temp));
}
}
}
_permut.setIdentity();
for(int k = 0; k < _diagSize; ++k) {
int idx = transp.e<int>(k);
auto temp1 = new NDArray(_permut({0,0, k, k+1}, true));
auto temp2 = new NDArray(_permut({0,0, idx,idx+1}, true));
auto temp3 = *temp1;
temp1->assign(temp2);
temp2->assign(temp3);
delete temp1;
delete temp2;
}
}
BUILD_SINGLE_TEMPLATE(template void HHcolPivQR::_evalData, (), FLOAT_TYPES);
}
}
}