/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// Created by Yurii Shyrma on 11.01.2018
//

#include <hhColPivQR.h>
#include <householder.h>
#include <NDArrayFactory.h>

namespace nd4j {
namespace ops {
namespace helpers {


//////////////////////////////////////////////////////////////////////////
HHcolPivQR::HHcolPivQR(const NDArray& matrix) {

    _qr = matrix;
    _diagSize = math::nd4j_min<int>(matrix.sizeAt(0), matrix.sizeAt(1));    
    _coeffs = NDArrayFactory::create(matrix.ordering(), {1, _diagSize}, matrix.dataType(), matrix.getContext());
    
    _permut = NDArrayFactory::create(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext());

    evalData();    
}

    void HHcolPivQR::evalData() {
        BUILD_SINGLE_SELECTOR(_qr.dataType(), _evalData, (), FLOAT_TYPES);
    }

//////////////////////////////////////////////////////////////////////////
template <typename T>
void HHcolPivQR::_evalData() {

    int rows = _qr.sizeAt(0);
    int cols = _qr.sizeAt(1);    
        
    auto transp = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext());
    auto normsUpd = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext());
    auto normsDir = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext());
          
    int transpNum = 0;

    for (int k = 0; k < cols; ++k) {
        
        T norm = _qr({0,0, k,k+1}).reduceNumber(reduce::Norm2).e<T>(0);
        normsDir.p<T>(k, norm);
        normsUpd.p<T>(k, norm);
    }

    T normScaled = (normsUpd.reduceNumber(reduce::Max)).e<T>(0) * DataTypeUtils::eps<T>();
    T threshold1 = normScaled * normScaled / (T)rows;     
    T threshold2 = math::nd4j_sqrt<T,T>(DataTypeUtils::eps<T>());

    T nonZeroPivots = _diagSize; 
    T maxPivot = 0.;

    for(int k = 0; k < _diagSize; ++k) {

        int biggestColIndex = normsUpd({0,0, k,-1}).indexReduceNumber(indexreduce::IndexMax).e<int>(0);
        T biggestColNorm = normsUpd({0,0, k,-1}).reduceNumber(reduce::Max).e<T>(0);
        T biggestColSqNorm = biggestColNorm * biggestColNorm;
        biggestColIndex += k;
    
        if(nonZeroPivots == (T)_diagSize && biggestColSqNorm < threshold1 * (T)(rows-k))
            nonZeroPivots = k;
        
        transp.p<T>(k, (T)biggestColIndex);

        if(k != biggestColIndex) {
        
            auto temp1 = new NDArray(_qr({0,0, k,k+1}, true));
            auto temp2 = new NDArray(_qr({0,0, biggestColIndex,biggestColIndex+1}, true));
            auto temp3 = *temp1;
            temp1->assign(temp2);
            temp2->assign(temp3);
            delete temp1;
            delete temp2;

            T e0 = normsUpd.e<T>(k);
            T e1 = normsUpd.e<T>(biggestColIndex);
            normsUpd.p(k, e1);
            normsUpd.p(biggestColIndex, e0);
            //math::nd4j_swap<T>(normsUpd(k), normsUpd(biggestColIndex));

            e0 = normsDir.e<T>(k);
            e1 = normsDir.e<T>(biggestColIndex);
            normsDir.p(k, e1);
            normsDir.p(biggestColIndex, e0);
            //math::nd4j_swap<T>(normsDir(k), normsDir(biggestColIndex));
                        
            ++transpNum;
        }
        
        T normX;
        NDArray* qrBlock = new NDArray(_qr({k,rows, k,k+1}, true));
        T c;
        Householder<T>::evalHHmatrixDataI(*qrBlock, c, normX);
        _coeffs.p<T>(k, c);
        delete qrBlock;        

        _qr.p<T>(k,k, normX);
        
        T max = math::nd4j_abs<T>(normX);
        if(max > maxPivot) 
            maxPivot = max;
        
        if(k < rows && (k+1) < cols) {
            qrBlock   = new NDArray(_qr({k,  rows,  k+1,cols}, true));
            auto tail = new NDArray(_qr({k+1,rows,  k, k+1},   true));
            Householder<T>::mulLeft(*qrBlock, *tail, _coeffs.e<T>(k));
            delete qrBlock;
            delete tail;
        }

        for (int j = k + 1; j < cols; ++j) {            
            
            if (normsUpd.e<T>(j) != (T)0.f) {
                T temp = math::nd4j_abs<T>(_qr.e<T>(k, j)) / normsUpd.e<T>(j);
                temp = (1. + temp) * (1. - temp);
                temp = temp < (T)0. ? (T)0. : temp;
                T temp2 = temp * normsUpd.e<T>(j) * normsUpd.e<T>(j) / (normsDir.e<T>(j)*normsDir.e<T>(j));
                
                if (temp2 <= threshold2) {          
                    if(k+1 < rows && j < cols)
                        normsDir.p<T>(j, _qr({k+1,rows, j,j+1}).reduceNumber(reduce::Norm2).e<T>(0));

                    normsUpd.p<T>(j, normsDir.e<T>(j));
                } 
                else 
                    normsUpd.p<T>(j, normsUpd.e<T>(j) * math::nd4j_sqrt<T, T>(temp));
            }
        }
    }

    _permut.setIdentity();
    
    for(int k = 0; k < _diagSize; ++k) {

        int idx = transp.e<int>(k);
        auto temp1 = new NDArray(_permut({0,0, k, k+1},    true));
        auto temp2 = new NDArray(_permut({0,0, idx,idx+1}, true));
        auto  temp3 = *temp1;
        temp1->assign(temp2);
        temp2->assign(temp3);
        delete temp1;
        delete temp2;
    }    
}

    BUILD_SINGLE_TEMPLATE(template void HHcolPivQR::_evalData, (), FLOAT_TYPES);

}
}
}