Shyrma sqrtm (#429)
* - start working on implementation of sqrtm op Signed-off-by: Yurii <iuriish@yahoo.com> * - improving householder procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - further polishing householder stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing hh pivoting qr procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing BiDiagonalUp procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing householder sequence class Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing jacobi svd class Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing svd stuff 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing svd stuff 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing class which performs Hessenberg decomposition of square matrix Signed-off-by: Yurii <iuriish@yahoo.com> * - add static method to JacobiSVD class which makes the continuous Givens rotation generation algorithm Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing auxiliary methods of Schur decomp class Signed-off-by: Yurii <iuriish@yahoo.com> * some references here and there Signed-off-by: raver119 <raver119@gmail.com> * - trying figure out difference between eigen and our Schur alg Signed-off-by: Yurii <iuriish@yahoo.com> * - testing fixing bugs in Schur decomposition op Signed-off-by: Yurii <iuriish@yahoo.com> * - start to implement class which performs calculation of eigen values and vectors Signed-off-by: Yurii <iuriish@yahoo.com> * - add to EigenValsAndVecs method which calculates complex eigen vectors Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in EigenValsAndVecs class Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing triangularSolver class Signed-off-by: Yurii <iuriish@yahoo.com> * Added a 2D routine for triangular systems solve. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored triangularSolve2D routine and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored another test for triangularSolve2D. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored test for triangularSolve for vector-bar case. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored triangularSolve2D routine and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * - implementation of FullPivLU class Signed-off-by: Yurii <iuriish@yahoo.com> * - fix bugs in FullPivLU::solve method Signed-off-by: Yurii <iuriish@yahoo.com> * - correct permutation vector in FullPivLU::solve Signed-off-by: Yurii <iuriish@yahoo.com> * - correct include headers Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation of Sqrtm class Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in Sqrtm class Signed-off-by: Yurii <iuriish@yahoo.com> * - include sqrtm classes to cuda folder, investigate in what places synchronization doesn't work Signed-off-by: Yurii <iuriish@yahoo.com> * Added implementation for cuda triangularSolve2D and also refactored triangularSolve2D for cpu. Signed-off-by: shugeo <sgazeos@gmail.com> * Eliminated waste implementations. Signed-off-by: shugeo <sgazeos@gmail.com> * - make offset calculation faster in t<> methods Signed-off-by: Yurii <iuriish@yahoo.com> * - rename refference T& NDArray::t<> method Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on cuda sqrtm Signed-off-by: Yurii <iuriish@yahoo.com> * - provide correct synchronization to device in Sqrtm class Signed-off-by: Yurii <iuriish@yahoo.com> * - add tests for sqrtm op Signed-off-by: Yurii <iuriish@yahoo.com> * - correct fails which appeared while testing on jenkins Signed-off-by: Yurii <iuriish@yahoo.com> * - trying to find out mistake in svd::deflation method Signed-off-by: Yurii <iuriish@yahoo.com> * Revert "- trying to find out mistake in svd::deflation method" This reverts commit 19d37baddbc509028e4bc67bc932fe7449becdb6. * Revert "- trying to find out mistake in svd::deflation method" This reverts commit 19d37baddbc509028e4bc67bc932fe7449becdb6. Signed-off-by: Yurii <iuriish@yahoo.com> * - change call semantic of r<> and t<> methods Signed-off-by: Yurii <iuriish@yahoo.com> * - ged rid of ambiguity in * operator overloads for windows buikd Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of ambiguity in * operator overloads for windows build 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of ambiguity in * operator overloads for windows build 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts with master Signed-off-by: Yurii <iuriish@yahoo.com> * cmakelists updated Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - minor fix in merge cpu helper - make use of reference getter Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com>master
parent
2214175934
commit
753ce28a92
|
@ -1163,7 +1163,7 @@ namespace sd {
|
|||
|
||||
/**
|
||||
* fill target matrix with given value in one or two directions from main diagonal:
|
||||
* - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both)
|
||||
* - down from main diagonal starting at subdiagonal number "lower" if direction = 'l' (down) or 'b' (both)
|
||||
* - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both)
|
||||
* direction - in what direction to fill matrix. There are 3 possible directions:
|
||||
* 'u' - fill up, mathematically this corresponds to lower triangular matrix, subdiagonal "lower" unaffected
|
||||
|
@ -1230,14 +1230,13 @@ namespace sd {
|
|||
* returns reference on array element with given index
|
||||
*/
|
||||
template<typename T>
|
||||
FORCEINLINE T& t(const Nd4jLong index);
|
||||
|
||||
FORCEINLINE T& r(const Nd4jLong index);
|
||||
template<typename T>
|
||||
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j);
|
||||
FORCEINLINE T& r(const Nd4jLong i, const Nd4jLong j);
|
||||
template<typename T>
|
||||
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k);
|
||||
FORCEINLINE T& r(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k);
|
||||
template<typename T>
|
||||
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w);
|
||||
FORCEINLINE T& r(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w);
|
||||
|
||||
|
||||
/**
|
||||
|
@ -1246,7 +1245,6 @@ namespace sd {
|
|||
*/
|
||||
template<typename T>
|
||||
FORCEINLINE T t(const Nd4jLong i) const;
|
||||
|
||||
template<typename T>
|
||||
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const;
|
||||
template<typename T>
|
||||
|
@ -1778,70 +1776,60 @@ DataType NDArray::dataType() const {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
T& NDArray::t(const Nd4jLong i) {
|
||||
T& NDArray::r(const Nd4jLong i) {
|
||||
|
||||
// if (i >= _length)
|
||||
// throw std::invalid_argument("NDArray::t(i): input index is out of array length !");
|
||||
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||
throw std::invalid_argument("NDArray::t(i): type of array is not equal to template type T!");
|
||||
|
||||
if(!isActualOnHostSide())
|
||||
syncToHost();
|
||||
|
||||
tickWriteHost();
|
||||
|
||||
return *(reinterpret_cast<T*>(bufferWithOffset(getOffset(i))));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
T& NDArray::t(const Nd4jLong i, const Nd4jLong j) {
|
||||
T& NDArray::r(const Nd4jLong i, const Nd4jLong j) {
|
||||
|
||||
if (rankOf() != 2 || i >= sizeAt(0) || j >= sizeAt(1))
|
||||
throw std::invalid_argument("NDArray::t(i,j): one of input indexes is out of array length or rank!=2 !");
|
||||
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||
throw std::invalid_argument("NDArray::t(i,j): type of array is not equal to template type T!");
|
||||
|
||||
if(!isActualOnHostSide())
|
||||
syncToHost();
|
||||
|
||||
Nd4jLong coords[2] = {i, j};
|
||||
auto offset = shape::getOffset(shapeInfo(), coords);
|
||||
tickWriteHost();
|
||||
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
||||
|
||||
return *(reinterpret_cast<T*>(bufferWithOffset(i * strideAt(0) + j * strideAt(1))));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
|
||||
T& NDArray::r(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
|
||||
|
||||
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
|
||||
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!");
|
||||
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
|
||||
|
||||
if(!isActualOnHostSide())
|
||||
syncToHost();
|
||||
|
||||
Nd4jLong coords[3] = {i, j, k};
|
||||
auto offset = shape::getOffset(shapeInfo(), coords);
|
||||
tickWriteHost();
|
||||
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
||||
|
||||
return *(reinterpret_cast<T*>(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2))));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) {
|
||||
T& NDArray::r(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) {
|
||||
|
||||
if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3))
|
||||
throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4 !");
|
||||
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||
throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!");
|
||||
|
||||
if(!isActualOnHostSide())
|
||||
syncToHost();
|
||||
|
||||
Nd4jLong coords[4] = {i, j, k, w};
|
||||
auto offset = shape::getOffset(shapeInfo(), coords);
|
||||
tickWriteHost();
|
||||
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
||||
|
||||
return *(reinterpret_cast<T*>(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + w * strideAt(3))));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1853,10 +1841,8 @@ T NDArray::t(const Nd4jLong i) const {
|
|||
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||
throw std::invalid_argument("NDArray::t(i): type of array is not equal to template type T!");
|
||||
|
||||
if(!isActualOnHostSide())
|
||||
syncToHost();
|
||||
|
||||
tickReadHost();
|
||||
return *(reinterpret_cast<const T*>(bufferWithOffset(getOffset(i))));
|
||||
}
|
||||
|
||||
|
@ -1869,48 +1855,38 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
|
|||
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||
throw std::invalid_argument("NDArray::t(i,j): type of array is not equal to template type T!");
|
||||
|
||||
if(!isActualOnHostSide())
|
||||
syncToHost();
|
||||
|
||||
Nd4jLong coords[2] = {i, j};
|
||||
auto offset = shape::getOffset(shapeInfo(), coords);
|
||||
tickReadHost();
|
||||
return *(reinterpret_cast<const T*>(bufferWithOffset(offset)));
|
||||
return *(reinterpret_cast<const T*>(bufferWithOffset(i * strideAt(0) + j * strideAt(1))));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
|
||||
|
||||
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
|
||||
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!");
|
||||
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
|
||||
|
||||
if(!isActualOnHostSide())
|
||||
syncToHost();
|
||||
|
||||
Nd4jLong coords[3] = {i, j, k};
|
||||
auto offset = shape::getOffset(shapeInfo(), coords);
|
||||
tickReadHost();
|
||||
return *(reinterpret_cast<const T*>(bufferWithOffset(offset)));
|
||||
}
|
||||
return *(reinterpret_cast<const T*>(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2))));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const {
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const {
|
||||
|
||||
if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3))
|
||||
throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4!");
|
||||
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||
throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!");
|
||||
|
||||
if(!isActualOnHostSide())
|
||||
syncToHost();
|
||||
|
||||
Nd4jLong coords[4] = {i, j, k, w};
|
||||
auto offset = shape::getOffset(shapeInfo(), coords);
|
||||
tickReadHost();
|
||||
return *(reinterpret_cast<const T*>(bufferWithOffset(offset)));
|
||||
}
|
||||
return *(reinterpret_cast<const T*>(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + w * strideAt(3))));
|
||||
}
|
||||
|
||||
#ifndef __JAVACPP_HACK__
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -2170,7 +2170,7 @@ const std::string* ND4J_EXPORT NDArray::bufferAsT() const {
|
|||
template <typename T>
|
||||
const T* NDArray::bufferAsT() const {
|
||||
// FIXME: do we REALLY want sync here?
|
||||
syncToHost();
|
||||
// syncToHost();
|
||||
|
||||
return reinterpret_cast<const T*>(buffer());
|
||||
}
|
||||
|
@ -2597,11 +2597,9 @@ void NDArray::operator+=(const T value) {
|
|||
|
||||
auto other = NDArrayFactory::create(this->dataType(), value, getContext());
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&other});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execScalar(getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr);
|
||||
|
||||
NDArray::registerSpecialUse({this}, {});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
template ND4J_EXPORT void NDArray::operator+=(const double value);
|
||||
template ND4J_EXPORT void NDArray::operator+=(const float value);
|
||||
|
@ -2619,11 +2617,9 @@ void NDArray::operator-=(const T value) {
|
|||
|
||||
auto other = NDArrayFactory::create(dataType(), value, getContext());
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&other});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execScalar(getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr);
|
||||
|
||||
NDArray::registerSpecialUse({this}, {});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
template ND4J_EXPORT void NDArray::operator-=(const double value);
|
||||
template ND4J_EXPORT void NDArray::operator-=(const float value);
|
||||
|
@ -2640,10 +2636,9 @@ void NDArray::operator*=(const T scalar) {
|
|||
throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!");
|
||||
|
||||
auto other = NDArrayFactory::create(this->dataType(), scalar, getContext());
|
||||
NDArray::prepareSpecialUse({this}, {&other});
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execScalar(getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr);
|
||||
|
||||
NDArray::registerSpecialUse({this}, {});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
template ND4J_EXPORT void NDArray::operator*=(const double scalar);
|
||||
template ND4J_EXPORT void NDArray::operator*=(const float scalar);
|
||||
|
@ -2663,9 +2658,9 @@ void NDArray::operator/=(const T scalar) {
|
|||
throw std::runtime_error("NDArray::operator/=: you can't use this method on String array!");
|
||||
|
||||
auto other = NDArrayFactory::create(this->dataType(), scalar, getContext());
|
||||
NDArray::prepareSpecialUse({this}, {&other});
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execScalar(getContext(), sd::scalar::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr);
|
||||
NDArray::registerSpecialUse({this}, {});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
template ND4J_EXPORT void NDArray::operator/=(const double scalar);
|
||||
template ND4J_EXPORT void NDArray::operator/=(const float scalar);
|
||||
|
@ -3758,8 +3753,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j) const {
|
|||
if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1])
|
||||
throw std::invalid_argument("NDArray::e(i,j): one of input indexes is out of array length or rank!=2 !");
|
||||
|
||||
const Nd4jLong coords[2] = {i, j};
|
||||
const auto xOffset = shape::getOffset(shapeInfo(), coords);
|
||||
const auto xOffset = i * strideAt(0) + j * strideAt(1);
|
||||
|
||||
NDArray::preparePrimaryUse({}, {this});
|
||||
NDArray::registerPrimaryUse({}, {this});
|
||||
|
@ -3778,8 +3772,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
|
|||
if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2])
|
||||
throw std::invalid_argument("NDArray::e(i,j,k): one of input indexes is out of array length or rank!=3 !");
|
||||
|
||||
const Nd4jLong coords[3] = {i, j, k};
|
||||
const auto xOffset = shape::getOffset(shapeInfo(), coords);
|
||||
const auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2);
|
||||
|
||||
NDArray::preparePrimaryUse({}, {this});
|
||||
NDArray::registerPrimaryUse({}, {this});
|
||||
|
@ -3798,8 +3791,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLon
|
|||
if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2] || l >= shapeOf()[3])
|
||||
throw std::invalid_argument("NDArray::e(i,j,k,l): one of input indexes is out of array length or rank!=4 !");
|
||||
|
||||
const Nd4jLong coords[4] = {i, j, k, l};
|
||||
const auto xOffset = shape::getOffset(shapeInfo(), coords);
|
||||
const auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + l * strideAt(3);
|
||||
|
||||
NDArray::preparePrimaryUse({}, {this});
|
||||
NDArray::registerPrimaryUse({}, {this});
|
||||
|
@ -4411,8 +4403,7 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) {
|
|||
throw std::invalid_argument("NDArray:pe(i,j, value): one of input indexes is out of array length or rank!=2 !");
|
||||
|
||||
void *p = reinterpret_cast<void *>(const_cast<T *>(&value));
|
||||
Nd4jLong coords[2] = {i, j};
|
||||
auto xOffset = shape::getOffset(shapeInfo(), coords);
|
||||
auto xOffset = i * strideAt(0) + j * strideAt(1);
|
||||
|
||||
NDArray::preparePrimaryUse({this}, {}, true);
|
||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES);
|
||||
|
@ -4440,11 +4431,10 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T va
|
|||
if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2])
|
||||
throw std::invalid_argument("NDArray:pe(i,j,k, value): one of input indexes is out of array length or rank!=3 !");
|
||||
|
||||
NDArray::preparePrimaryUse({this}, {}, true);
|
||||
|
||||
void *p = reinterpret_cast<void *>(const_cast<T *>(&value));
|
||||
Nd4jLong coords[3] = {i, j, k};
|
||||
auto xOffset = shape::getOffset(shapeInfo(), coords);
|
||||
auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2);
|
||||
|
||||
NDArray::preparePrimaryUse({this}, {}, true);
|
||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES);
|
||||
NDArray::registerPrimaryUse({this}, {});
|
||||
}
|
||||
|
@ -4470,8 +4460,7 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4j
|
|||
throw std::invalid_argument("NDArray::p(i,j,k,l, value): one of input indexes is out of array length or rank!=4 !");
|
||||
|
||||
void *p = reinterpret_cast<void *>(const_cast<T *>(&value));
|
||||
Nd4jLong coords[4] = {i, j, k, l};
|
||||
auto xOffset = shape::getOffset(shapeInfo(), coords);
|
||||
auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + l * strideAt(3);
|
||||
|
||||
NDArray::preparePrimaryUse({this}, {}, true);
|
||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES);
|
||||
|
|
|
@ -153,21 +153,38 @@ void NDArray::setIdentity() {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void templatedSwap(void *xBuffer, void *yBuffer, Nd4jLong length) {
|
||||
static void templatedSwap(void *xBuffer, void *yBuffer, const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, Nd4jLong length) {
|
||||
auto x = reinterpret_cast<T *>(xBuffer);
|
||||
auto y = reinterpret_cast<T *>(yBuffer);
|
||||
|
||||
const bool isSameOrders = shape::order(xShapeInfo) == shape::order(xShapeInfo);
|
||||
|
||||
const auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
const auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto temp = x[i];
|
||||
x[i] = y[i];
|
||||
y[i] = temp;
|
||||
if(isSameOrders && xEws > 0 && yEws > 0) {
|
||||
for(auto i = start; i < stop; i++)
|
||||
sd::math::nd4j_swap(x[i*xEws], y[i*yEws]);
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||
for(auto i = start; i < stop; i++) {
|
||||
const auto ind = shape::getIndexOffset(i, xShapeInfo);
|
||||
sd::math::nd4j_swap(x[ind], y[ind]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for(auto i = start; i < stop; i++) {
|
||||
const auto xInd = shape::getIndexOffset(i, xShapeInfo);
|
||||
const auto yInd = shape::getIndexOffset(i, yShapeInfo);
|
||||
sd::math::nd4j_swap(x[xInd], y[yInd]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, length);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, Nd4jLong length), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, Nd4jLong length), LIBND4J_TYPES);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::swapUnsafe(NDArray& other) {
|
||||
|
@ -182,7 +199,7 @@ void NDArray::swapUnsafe(NDArray& other) {
|
|||
if(lengthOf() != other.lengthOf())
|
||||
throw std::runtime_error("NDArray::swapUnsafe method: input arrays should have the same length!");
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, templatedSwap, (buffer(), other.buffer(), this->lengthOf()), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(xType, templatedSwap, (buffer(), other.buffer(), shapeInfo(), other.shapeInfo(), this->lengthOf()), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -225,7 +225,13 @@ void NDArray::swapUnsafe(NDArray& other) {
|
|||
if(lengthOf() != other.lengthOf())
|
||||
throw std::runtime_error("NDArray::swapUnsafe method: input arrays should have the same length!");
|
||||
|
||||
PointersManager manager(getContext(), "NDArray::swapUnsafe");
|
||||
|
||||
prepareSpecialUse({&other, this}, {&other, this});
|
||||
BUILD_SINGLE_SELECTOR(xType, templatedSwapUnsafe, (specialBuffer(), specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), getContext()->getCudaStream()), LIBND4J_TYPES);
|
||||
registerSpecialUse({&other, this}, {&other, this});
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -546,21 +552,18 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
|
|||
if(specialBuffer() == nullptr || _length == 0)
|
||||
{ printf("NDArray::printSpecialBuffer: special buffer is nullptr !\n"); return; }
|
||||
|
||||
void* pHost = operator new(sizeof(T) * _length);
|
||||
const auto sizeOfBuffer = sizeOfT() * (getOffset(_length - 1) + 1);
|
||||
|
||||
if (ews() != 1) {
|
||||
for (uint i = 0; i < _length; i++)
|
||||
cudaMemcpyAsync(reinterpret_cast<T*>(pHost) + i, specialBufferWithOffset(i), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream()));
|
||||
}
|
||||
else
|
||||
cudaMemcpyAsync(pHost, specialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream());
|
||||
void* pHost = operator new(sizeOfBuffer);
|
||||
|
||||
cudaMemcpyAsync(pHost, specialBuffer(), sizeOfBuffer, cudaMemcpyDeviceToHost, *getContext()->getCudaStream());
|
||||
|
||||
cudaError_t cudaResult = cudaStreamSynchronize(*getContext()->getCudaStream());
|
||||
if(cudaResult != 0)
|
||||
throw std::runtime_error("NDArray::printSpecialBuffer: cudaStreamSynchronize failed!");
|
||||
|
||||
for (uint i = 0; i < _length; i++)
|
||||
printf("%.*f, ", precision, (double)reinterpret_cast<T*>(pHost)[i]);
|
||||
printf("%.*f, ", precision, (double)reinterpret_cast<T*>(pHost)[getOffset(i)]);
|
||||
printf("\n");
|
||||
|
||||
operator delete(pHost);
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_EIGENVALSANDVECS_H
|
||||
#define LIBND4J_EIGENVALSANDVECS_H
|
||||
|
||||
#include <array/NDArray.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
// this class calculates eigenvalues and eigenvectors of given input matrix
|
||||
template <typename T>
|
||||
class EigenValsAndVecs {
|
||||
|
||||
public:
|
||||
// suppose we got input square NxN matrix
|
||||
|
||||
NDArray _Vals; // {N,2} matrix of eigenvalues, 2 means real and imaginary part
|
||||
NDArray _Vecs; // {N,N,2} matrix, whose columns are the eigenvectors (complex), 2 means real and imaginary part
|
||||
|
||||
explicit EigenValsAndVecs(const NDArray& matrix);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
FORCEINLINE static void divideComplexNums(const T& a1, const T& b1, const T& a2, const T& b2, T& a3, T& b3) {
|
||||
|
||||
T norm2 = a2*a2 + b2*b2;
|
||||
|
||||
a3 = (a1*a2 + b1*b2) / norm2;
|
||||
b3 = (a2*b1 - a1*b2) / norm2;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
FORCEINLINE static void multiplyComplexNums(const T& a1, const T& b1, const T& a2, const T& b2, T& a3, T& b3) {
|
||||
|
||||
a3 = (a1*a2 - b1*b2);
|
||||
b3 = (a1*b2 + b1*a2);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
FORCEINLINE static void sqrtComplexNum(T& a, T& b) {
|
||||
|
||||
T norm = math::nd4j_sqrt<T,T>(a*a + b*b);
|
||||
|
||||
if(b < (T)0)
|
||||
b = -math::nd4j_sqrt<T,T>((T)0.5 * (norm - a));
|
||||
else
|
||||
b = math::nd4j_sqrt<T,T>((T)0.5 * (norm - a));
|
||||
a = math::nd4j_sqrt<T,T>((T)0.5 * (norm + a));
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
|
||||
void calcEigenVals(const NDArray& schurMatrixT); // calculates _Vals
|
||||
void calcPseudoEigenVecs(NDArray& schurMatrixT, NDArray& schurMatrixU); // makes changes both in schurMatrixT(NxN) and schurMatrixU(NxN), also calculates and stores pseudo-eigenvectors (real) in schurMatrixU columns
|
||||
void calcEigenVecs(const NDArray& schurMatrixU); // calculates _Vecs
|
||||
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif //LIBND4J_EIGENVALSANDVECS_H
|
|
@ -0,0 +1,52 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_FULLPIVLU_H
|
||||
#define LIBND4J_FULLPIVLU_H
|
||||
|
||||
#include <array/NDArray.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
// class solves equation A*x = b for x, by procedure of LU decomposition of input matrix A with complete pivoting
|
||||
// LU decomposition of a matrix is:
|
||||
// A = P^-1 * L * U * Q^-1
|
||||
// L is unit-lower-triangular,
|
||||
// U is upper-triangular,
|
||||
// and P and Q are permutation matrices for rows and columns correspondingly
|
||||
|
||||
template <typename T>
|
||||
class FullPivLU {
|
||||
|
||||
public:
|
||||
|
||||
// A{M,K} * x{K,N} = b{M,N}
|
||||
static void solve(const NDArray& A, const NDArray& b, NDArray& x);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif //LIBND4J_FULLPIVLU_H
|
|
@ -0,0 +1,102 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_HESSENBERGANDSCHUR_H
|
||||
#define LIBND4J_HESSENBERGANDSCHUR_H
|
||||
|
||||
#include <array/NDArray.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
// this class implements Hessenberg decomposition of square matrix using orthogonal similarity transformation
|
||||
// A = Q H Q^T
|
||||
// Q - orthogonal matrix
|
||||
// H - Hessenberg matrix
|
||||
template <typename T>
|
||||
class Hessenberg {
|
||||
// suppose we got input square NxN matrix
|
||||
|
||||
public:
|
||||
|
||||
NDArray _Q; // {N,N}
|
||||
NDArray _H; // {N,N}
|
||||
|
||||
explicit Hessenberg(const NDArray& matrix);
|
||||
|
||||
private:
|
||||
void evalData();
|
||||
};
|
||||
|
||||
|
||||
// this class implements real Schur decomposition of square matrix using orthogonal similarity transformation
|
||||
// A = U T U^T
|
||||
// T - real quasi-upper-triangular matrix - block upper triangular matrix where the blocks on the diagonal are 1×1 or 2×2 with complex eigenvalues
|
||||
// U - real orthogonal matrix
|
||||
|
||||
template <typename T>
|
||||
class Schur {
|
||||
// suppose we got input square NxN matrix
|
||||
|
||||
public:
|
||||
|
||||
NDArray _T; // {N,N}
|
||||
NDArray _U; // {N,N}
|
||||
|
||||
explicit Schur(const NDArray& matrix);
|
||||
|
||||
void splitTwoRows(const int ind, const T shift);
|
||||
|
||||
void calcShift(const int ind, const int iter, T& shift, NDArray& shiftInfo);
|
||||
|
||||
void initFrancisQR(const int ind1, const int ind2, const NDArray& shiftVec, int& ind3, NDArray& householderVec);
|
||||
|
||||
void doFrancisQR(const int ind1, const int ind2, const int ind3, const NDArray& householderVec);
|
||||
|
||||
void calcFromHessenberg();
|
||||
|
||||
private:
|
||||
|
||||
static const int _maxItersPerRow = 40;
|
||||
|
||||
void evalData(const NDArray& matrix);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
FORCEINLINE int getSmallSubdiagEntry(const int inInd) {
|
||||
|
||||
int outInd = inInd;
|
||||
while (outInd > 0) {
|
||||
T factor = math::nd4j_abs<T>(_T.t<T>(outInd-1, outInd-1)) + math::nd4j_abs<T>(_T.t<T>(outInd, outInd));
|
||||
if (math::nd4j_abs<T>(_T.t<T>(outInd, outInd-1)) <= DataTypeUtils::eps<T>() * factor)
|
||||
break;
|
||||
outInd--;
|
||||
}
|
||||
return outInd;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif //LIBND4J_HESSENBERGANDSCHUR_H
|
|
@ -0,0 +1,45 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_SQRTM_H
|
||||
#define LIBND4J_SQRTM_H
|
||||
|
||||
#include <array/NDArray.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
class Sqrtm {
|
||||
|
||||
|
||||
public:
|
||||
|
||||
static void calc(const NDArray& in, NDArray& out);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif //LIBND4J_SQRTM_H
|
|
@ -35,6 +35,7 @@ class BiDiagonalUp {
|
|||
|
||||
NDArray _HHmatrix; // 2D Householder matrix
|
||||
NDArray _HHbidiag; // vector which contains Householder coefficients
|
||||
NDArray _hhCoeffs; // vector of Householder coefficients
|
||||
|
||||
/**
|
||||
* constructor
|
||||
|
@ -63,9 +64,9 @@ class BiDiagonalUp {
|
|||
* type - type of sequence, type = 'u' (acting on columns) or type = 'v' (acting on rows)
|
||||
*/
|
||||
template <typename T>
|
||||
HHsequence makeHHsequence_(const char type) const;
|
||||
HHsequence makeHHsequence_(const char type);
|
||||
|
||||
HHsequence makeHHsequence(const char type) const;
|
||||
HHsequence makeHHsequence(const char type);
|
||||
|
||||
};
|
||||
|
||||
|
|
|
@ -1,180 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by Yurii Shyrma on 18.12.2017
|
||||
//
|
||||
|
||||
|
||||
#include <helpers/householder.h>
|
||||
#include <helpers/biDiagonalUp.h>
|
||||
#include <array/NDArrayFactory.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
BiDiagonalUp::BiDiagonalUp(const NDArray& matrix): _HHmatrix(sd::NDArrayFactory::create(matrix.ordering(), {matrix.sizeAt(0), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext())),
|
||||
_HHbidiag(sd::NDArrayFactory::create(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext())) {
|
||||
|
||||
// input validation
|
||||
if(matrix.rankOf() != 2 || matrix.isScalar())
|
||||
throw std::runtime_error("ops::helpers::biDiagonalizeUp constructor: input array must be 2D matrix !");
|
||||
|
||||
_HHmatrix.assign(&matrix);
|
||||
_HHbidiag.assign(0.);
|
||||
|
||||
evalData();
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BiDiagonalUp::_evalData() {
|
||||
|
||||
const auto rows = _HHmatrix.sizeAt(0);
|
||||
const auto cols = _HHmatrix.sizeAt(1);
|
||||
|
||||
if(rows < cols)
|
||||
throw std::runtime_error("ops::helpers::BiDiagonalizeUp::evalData method: this procedure is applicable only for input matrix with rows >= cols !");
|
||||
|
||||
NDArray* bottomRightCorner(nullptr), *column(nullptr), *row(nullptr);
|
||||
T coeff, normX;
|
||||
|
||||
T _x, _y;
|
||||
|
||||
for(Nd4jLong i = 0; i < cols-1; ++i ) {
|
||||
|
||||
// evaluate Householder matrix nullifying columns
|
||||
column = new NDArray(_HHmatrix({i,rows, i,i+1}, true));
|
||||
|
||||
_x = _HHmatrix.e<T>(i,i);
|
||||
_y = _HHbidiag.e<T>(i,i);
|
||||
|
||||
Householder<T>::evalHHmatrixDataI(*column, _x, _y);
|
||||
|
||||
_HHmatrix.p<T>(i, i, _x);
|
||||
_HHbidiag.p<T>(i, i, _y);
|
||||
|
||||
// multiply corresponding matrix block on householder matrix from the left: P * bottomRightCorner
|
||||
bottomRightCorner = new NDArray(_HHmatrix({i,rows, i+1,cols}, true)); // {i, cols}
|
||||
Householder<T>::mulLeft(*bottomRightCorner, _HHmatrix({i+1,rows, i,i+1}, true), _HHmatrix.e<T>(i,i));
|
||||
|
||||
delete bottomRightCorner;
|
||||
delete column;
|
||||
|
||||
if(i == cols-2)
|
||||
continue; // do not apply right multiplying at last iteration
|
||||
|
||||
// evaluate Householder matrix nullifying rows
|
||||
row = new NDArray(_HHmatrix({i,i+1, i+1,cols}, true));
|
||||
|
||||
_x = _HHmatrix.e<T>(i,i+1);
|
||||
_y = _HHbidiag.e<T>(i,i+1);
|
||||
|
||||
Householder<T>::evalHHmatrixDataI(*row, _x, _y);
|
||||
|
||||
_HHmatrix.p<T>(i, i+1, _x);
|
||||
_HHbidiag.p<T>(i, i+1, _y);
|
||||
|
||||
// multiply corresponding matrix block on householder matrix from the right: bottomRightCorner * P
|
||||
bottomRightCorner = new NDArray(_HHmatrix({i+1,rows, i+1,cols}, true)); // {i, rows}
|
||||
|
||||
Householder<T>::mulRight(*bottomRightCorner, _HHmatrix({i,i+1, i+2,cols}, true), _HHmatrix.e<T>(i,i+1));
|
||||
|
||||
delete bottomRightCorner;
|
||||
delete row;
|
||||
}
|
||||
|
||||
row = new NDArray(_HHmatrix({cols-2,cols-1, cols-1,cols}, true));
|
||||
|
||||
_x = _HHmatrix.e<T>(cols-2,cols-1);
|
||||
_y = _HHbidiag.e<T>(cols-2,cols-1);
|
||||
|
||||
Householder<T>::evalHHmatrixDataI(*row, _x, _y);
|
||||
|
||||
_HHmatrix.p<T>(cols-2,cols-1, _x);
|
||||
_HHbidiag.p<T>(cols-2,cols-1, _y);
|
||||
|
||||
delete row;
|
||||
|
||||
column = new NDArray(_HHmatrix({cols-1,rows, cols-1,cols}, true));
|
||||
|
||||
_x = _HHmatrix.e<T>(cols-1,cols-1);
|
||||
_y = _HHbidiag.e<T>(cols-1,cols-1);
|
||||
|
||||
Householder<T>::evalHHmatrixDataI(*column, _x, _y);
|
||||
|
||||
_HHmatrix.p<T>(cols-1, cols-1, _x);
|
||||
_HHbidiag.p<T>(cols-1, cols-1, _y);
|
||||
|
||||
delete column;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void BiDiagonalUp::evalData() {
|
||||
auto xType = _HHmatrix.dataType();
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, _evalData, ();, FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
HHsequence BiDiagonalUp::makeHHsequence_(const char type) const {
|
||||
|
||||
if(type == 'u') {
|
||||
|
||||
const int diagSize = _HHbidiag.sizeAt(0);
|
||||
auto colOfCoeffs = NDArrayFactory::create(_HHmatrix.ordering(), {diagSize, 1}, _HHmatrix.dataType(), _HHmatrix.getContext());
|
||||
|
||||
for(int i = 0; i < diagSize; ++i)
|
||||
colOfCoeffs.p(i, _HHmatrix.e<T>(i,i));
|
||||
|
||||
return HHsequence(_HHmatrix, colOfCoeffs, type);
|
||||
}
|
||||
else {
|
||||
|
||||
const int diagUpSize = _HHbidiag.sizeAt(0) - 1;
|
||||
NDArray colOfCoeffs = NDArrayFactory::create(_HHmatrix.ordering(), {diagUpSize, 1}, _HHmatrix.dataType(), _HHmatrix.getContext());
|
||||
|
||||
for(int i = 0; i < diagUpSize; ++i)
|
||||
colOfCoeffs.p(i, _HHmatrix.e<T>(i,i+1));
|
||||
|
||||
HHsequence result(_HHmatrix, colOfCoeffs, type);
|
||||
result._diagSize = diagUpSize;
|
||||
result._shift = 1;
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
HHsequence BiDiagonalUp::makeHHsequence(const char type) const {
|
||||
auto xType = _HHmatrix.dataType();
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, return makeHHsequence_, (type);, FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void BiDiagonalUp::_evalData, (), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template HHsequence BiDiagonalUp::makeHHsequence_, (const char type) const, FLOAT_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,171 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by Yurii Shyrma on 11.01.2018
|
||||
//
|
||||
|
||||
#include <helpers/hhColPivQR.h>
|
||||
#include <helpers/householder.h>
|
||||
#include <array/NDArrayFactory.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
HHcolPivQR::HHcolPivQR(const NDArray& matrix) {
|
||||
|
||||
_qr = matrix;
|
||||
_diagSize = math::nd4j_min<int>(matrix.sizeAt(0), matrix.sizeAt(1));
|
||||
_coeffs = NDArrayFactory::create(matrix.ordering(), {1, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
|
||||
_permut = NDArrayFactory::create(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext());
|
||||
|
||||
evalData();
|
||||
}
|
||||
|
||||
void HHcolPivQR::evalData() {
|
||||
BUILD_SINGLE_SELECTOR(_qr.dataType(), _evalData, (), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void HHcolPivQR::_evalData() {
|
||||
|
||||
int rows = _qr.sizeAt(0);
|
||||
int cols = _qr.sizeAt(1);
|
||||
|
||||
auto transp = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext());
|
||||
auto normsUpd = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext());
|
||||
auto normsDir = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext());
|
||||
|
||||
int transpNum = 0;
|
||||
|
||||
for (int k = 0; k < cols; ++k) {
|
||||
|
||||
T norm = _qr({0,0, k,k+1}).reduceNumber(reduce::Norm2).e<T>(0);
|
||||
normsDir.p<T>(k, norm);
|
||||
normsUpd.p<T>(k, norm);
|
||||
}
|
||||
|
||||
T normScaled = (normsUpd.reduceNumber(reduce::Max)).e<T>(0) * DataTypeUtils::eps<T>();
|
||||
T threshold1 = normScaled * normScaled / (T)rows;
|
||||
T threshold2 = math::nd4j_sqrt<T,T>(DataTypeUtils::eps<T>());
|
||||
|
||||
T nonZeroPivots = _diagSize;
|
||||
T maxPivot = 0.;
|
||||
|
||||
for(int k = 0; k < _diagSize; ++k) {
|
||||
|
||||
int biggestColIndex = normsUpd({0,0, k,-1}).indexReduceNumber(indexreduce::IndexMax).e<int>(0);
|
||||
T biggestColNorm = normsUpd({0,0, k,-1}).reduceNumber(reduce::Max).e<T>(0);
|
||||
T biggestColSqNorm = biggestColNorm * biggestColNorm;
|
||||
biggestColIndex += k;
|
||||
|
||||
if(nonZeroPivots == (T)_diagSize && biggestColSqNorm < threshold1 * (T)(rows-k))
|
||||
nonZeroPivots = k;
|
||||
|
||||
transp.p<T>(k, (T)biggestColIndex);
|
||||
|
||||
if(k != biggestColIndex) {
|
||||
|
||||
auto temp1 = new NDArray(_qr({0,0, k,k+1}, true));
|
||||
auto temp2 = new NDArray(_qr({0,0, biggestColIndex,biggestColIndex+1}, true));
|
||||
auto temp3 = *temp1;
|
||||
temp1->assign(temp2);
|
||||
temp2->assign(temp3);
|
||||
delete temp1;
|
||||
delete temp2;
|
||||
|
||||
T e0 = normsUpd.e<T>(k);
|
||||
T e1 = normsUpd.e<T>(biggestColIndex);
|
||||
normsUpd.p(k, e1);
|
||||
normsUpd.p(biggestColIndex, e0);
|
||||
//math::nd4j_swap<T>(normsUpd(k), normsUpd(biggestColIndex));
|
||||
|
||||
e0 = normsDir.e<T>(k);
|
||||
e1 = normsDir.e<T>(biggestColIndex);
|
||||
normsDir.p(k, e1);
|
||||
normsDir.p(biggestColIndex, e0);
|
||||
//math::nd4j_swap<T>(normsDir(k), normsDir(biggestColIndex));
|
||||
|
||||
++transpNum;
|
||||
}
|
||||
|
||||
T normX;
|
||||
NDArray* qrBlock = new NDArray(_qr({k,rows, k,k+1}, true));
|
||||
T c;
|
||||
Householder<T>::evalHHmatrixDataI(*qrBlock, c, normX);
|
||||
_coeffs.p<T>(k, c);
|
||||
delete qrBlock;
|
||||
|
||||
_qr.p<T>(k,k, normX);
|
||||
|
||||
T max = math::nd4j_abs<T>(normX);
|
||||
if(max > maxPivot)
|
||||
maxPivot = max;
|
||||
|
||||
if(k < rows && (k+1) < cols) {
|
||||
qrBlock = new NDArray(_qr({k, rows, k+1,cols}, true));
|
||||
auto tail = new NDArray(_qr({k+1,rows, k, k+1}, true));
|
||||
Householder<T>::mulLeft(*qrBlock, *tail, _coeffs.e<T>(k));
|
||||
delete qrBlock;
|
||||
delete tail;
|
||||
}
|
||||
|
||||
for (int j = k + 1; j < cols; ++j) {
|
||||
|
||||
if (normsUpd.e<T>(j) != (T)0.f) {
|
||||
T temp = math::nd4j_abs<T>(_qr.e<T>(k, j)) / normsUpd.e<T>(j);
|
||||
temp = (1. + temp) * (1. - temp);
|
||||
temp = temp < (T)0. ? (T)0. : temp;
|
||||
T temp2 = temp * normsUpd.e<T>(j) * normsUpd.e<T>(j) / (normsDir.e<T>(j)*normsDir.e<T>(j));
|
||||
|
||||
if (temp2 <= threshold2) {
|
||||
if(k+1 < rows && j < cols)
|
||||
normsDir.p<T>(j, _qr({k+1,rows, j,j+1}).reduceNumber(reduce::Norm2).e<T>(0));
|
||||
|
||||
normsUpd.p<T>(j, normsDir.e<T>(j));
|
||||
}
|
||||
else
|
||||
normsUpd.p<T>(j, normsUpd.e<T>(j) * math::nd4j_sqrt<T, T>(temp));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_permut.setIdentity();
|
||||
|
||||
for(int k = 0; k < _diagSize; ++k) {
|
||||
|
||||
int idx = transp.e<int>(k);
|
||||
auto temp1 = new NDArray(_permut({0,0, k, k+1}, true));
|
||||
auto temp2 = new NDArray(_permut({0,0, idx,idx+1}, true));
|
||||
auto temp3 = *temp1;
|
||||
temp1->assign(temp2);
|
||||
temp2->assign(temp3);
|
||||
delete temp1;
|
||||
delete temp2;
|
||||
}
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void HHcolPivQR::_evalData, (), FLOAT_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,221 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by Yurii Shyrma on 18.12.2017
|
||||
//
|
||||
|
||||
#include <helpers/householder.h>
|
||||
#include <array/NDArrayFactory.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
NDArray Householder<T>::evalHHmatrix(const NDArray& x) {
|
||||
|
||||
// input validation
|
||||
if(!x.isVector() && !x.isScalar())
|
||||
throw std::runtime_error("ops::helpers::Householder::evalHHmatrix method: input array must be vector or scalar!");
|
||||
|
||||
auto w = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), 1}, x.dataType(), x.getContext()); // column-vector
|
||||
auto wT = NDArrayFactory::create(x.ordering(), {1, (int)x.lengthOf()}, x.dataType(), x.getContext()); // row-vector (transposed w)
|
||||
|
||||
T coeff;
|
||||
T normX = x.reduceNumber(reduce::Norm2).e<T>(0);
|
||||
|
||||
if(normX*normX - x.e<T>(0) * x.e<T>(0) <= DataTypeUtils::min<T>() || x.lengthOf() == 1) {
|
||||
|
||||
normX = x.e<T>(0);
|
||||
coeff = 0.f;
|
||||
w = 0.f;
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
if(x.e<T>(0) >= (T)0.f)
|
||||
normX = -normX; // choose opposite sign to lessen roundoff error
|
||||
|
||||
T u0 = x.e<T>(0) - normX;
|
||||
coeff = -u0 / normX;
|
||||
w.assign(x / u0);
|
||||
}
|
||||
|
||||
w.p(Nd4jLong(0), 1.f);
|
||||
wT.assign(&w);
|
||||
|
||||
NDArray identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext());
|
||||
identity.setIdentity(); // identity matrix
|
||||
|
||||
return identity - mmul(w, wT) * coeff;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Householder<T>::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff, T& normX) {
|
||||
|
||||
// input validation
|
||||
if(!x.isVector() && !x.isScalar())
|
||||
throw std::runtime_error("ops::helpers::Householder::evalHHmatrixData method: input array must be vector or scalar!");
|
||||
|
||||
if(!x.isScalar() && x.lengthOf() != tail.lengthOf() + 1)
|
||||
throw std::runtime_error("ops::helpers::Householder::evalHHmatrixData method: input tail vector must have length less than unity compared to input x vector!");
|
||||
|
||||
normX = x.reduceNumber(reduce::Norm2, nullptr).e<T>(0);
|
||||
|
||||
if(normX*normX - x.e<T>(0) * x.e<T>(0) <= DataTypeUtils::min<T>() || x.lengthOf() == 1) {
|
||||
|
||||
normX = x.e<T>(0);
|
||||
coeff = (T)0.f;
|
||||
tail = (T)0.f;
|
||||
}
|
||||
else {
|
||||
|
||||
if(x.e<T>(0) >= (T)0.f)
|
||||
normX = -normX; // choose opposite sign to lessen roundoff error
|
||||
|
||||
T u0 = x.e<T>(0) - normX;
|
||||
coeff = -u0 / normX;
|
||||
|
||||
if(x.isRowVector())
|
||||
tail.assign(static_cast<const NDArray&>(x({0,0, 1,-1})) / u0);
|
||||
else
|
||||
tail.assign(static_cast<const NDArray&>(x({1,-1, 0,0,})) / u0);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Householder<T>::evalHHmatrixDataI(const NDArray& x, T& coeff, T& normX) {
|
||||
|
||||
int rows = (int)x.lengthOf()-1;
|
||||
int num = 1;
|
||||
|
||||
if(rows == 0) {
|
||||
rows = 1;
|
||||
num = 0;
|
||||
}
|
||||
|
||||
auto tail = NDArrayFactory::create(x.ordering(), {rows, 1}, x.dataType(), x.getContext());
|
||||
evalHHmatrixData(x, tail, coeff, normX);
|
||||
|
||||
if(x.isRowVector()) {
|
||||
auto temp = x({0,0, num, x.sizeAt(1)}, true);
|
||||
temp.assign(tail);
|
||||
}
|
||||
else {
|
||||
auto temp = x({num,x.sizeAt(0), 0,0}, true);
|
||||
temp.assign(tail);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff) {
|
||||
|
||||
// if(matrix.rankOf() != 2)
|
||||
// throw "ops::helpers::Householder::mulLeft method: input array must be 2D matrix !";
|
||||
|
||||
if(matrix.sizeAt(0) == 1) {
|
||||
matrix *= (T) 1.f - coeff;
|
||||
}
|
||||
else if(coeff != (T)0.f) {
|
||||
|
||||
auto bottomPart = new NDArray(matrix({1,matrix.sizeAt(0), 0,0}, true));
|
||||
auto bottomPartCopy = *bottomPart;
|
||||
|
||||
if(tail.isColumnVector()) {
|
||||
|
||||
auto column = tail;
|
||||
auto row = tail.transpose();
|
||||
auto resultingRow = mmul(row, bottomPartCopy);
|
||||
auto fistRow = matrix({0,1, 0,0}, true);
|
||||
resultingRow += fistRow;
|
||||
fistRow -= resultingRow * coeff;
|
||||
*bottomPart -= mmul(column, resultingRow) * coeff;
|
||||
}
|
||||
else {
|
||||
|
||||
auto row = tail;
|
||||
auto column = tail.transpose();
|
||||
auto resultingRow = mmul(row, bottomPartCopy);
|
||||
auto fistRow = matrix({0,1, 0,0}, true);
|
||||
resultingRow += fistRow;
|
||||
fistRow -= resultingRow * coeff;
|
||||
*bottomPart -= mmul(column, resultingRow) * coeff;
|
||||
}
|
||||
delete bottomPart;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Householder<T>::mulRight(NDArray& matrix, const NDArray& tail, const T coeff) {
|
||||
|
||||
// if(matrix.rankOf() != 2)
|
||||
// throw "ops::helpers::Householder::mulRight method: input array must be 2D matrix !";
|
||||
|
||||
if(matrix.sizeAt(1) == 1)
|
||||
matrix *= (T)1.f - coeff;
|
||||
|
||||
else if(coeff != (T)0.f) {
|
||||
|
||||
auto rightPart = new NDArray(matrix({0,0, 1,matrix.sizeAt(1)}, true));
|
||||
auto rightPartCopy = *rightPart;
|
||||
auto fistCol = new NDArray(matrix({0,0, 0,1}, true));
|
||||
|
||||
if(tail.isColumnVector()) {
|
||||
|
||||
auto column = tail;
|
||||
auto row = tail.transpose();
|
||||
auto resultingCol = mmul(rightPartCopy, column);
|
||||
resultingCol += *fistCol;
|
||||
*fistCol -= resultingCol * coeff;
|
||||
*rightPart -= mmul(resultingCol, row) * coeff;
|
||||
}
|
||||
else {
|
||||
|
||||
auto row = tail;
|
||||
auto column = tail.transpose();
|
||||
auto resultingCol = mmul(rightPartCopy, column);
|
||||
resultingCol += *fistCol;
|
||||
*fistCol -= resultingCol * coeff;
|
||||
*rightPart -= mmul(resultingCol, row) * coeff;
|
||||
}
|
||||
delete rightPart;
|
||||
delete fistCol;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template class ND4J_EXPORT Householder<float>;
|
||||
template class ND4J_EXPORT Householder<float16>;
|
||||
template class ND4J_EXPORT Householder<bfloat16>;
|
||||
template class ND4J_EXPORT Householder<double>;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -22,7 +22,6 @@
|
|||
#include <helpers/jacobiSVD.h>
|
||||
#include <helpers/biDiagonalUp.h>
|
||||
#include <array/ResultSet.h>
|
||||
#include <array/NDArrayFactory.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
|
@ -59,19 +58,19 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
|
|||
if (_transp)
|
||||
math::nd4j_swap<bool>(_calcU, _calcV);
|
||||
|
||||
_s = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize, 1}, matrix.getContext());
|
||||
_m = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.getContext());
|
||||
_m.assign(0.);
|
||||
_s = NDArray(matrix.ordering(), {_diagSize, 1}, matrix.dataType(), matrix.getContext());
|
||||
_m = NDArray(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
// _m.assign(0.);
|
||||
|
||||
if (_calcU)
|
||||
_u = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext());
|
||||
_u = NDArray(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.dataType(), matrix.getContext());
|
||||
else
|
||||
_u = NDArrayFactory::create<T>(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext());
|
||||
_u.assign(0.);
|
||||
_u = NDArray(matrix.ordering(), {2, _diagSize + 1}, matrix.dataType(), matrix.getContext());
|
||||
// _u.assign(0.);
|
||||
|
||||
if (_calcV) {
|
||||
_v = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize, _diagSize}, matrix.getContext());
|
||||
_v.assign(0.);
|
||||
_v = NDArray(matrix.ordering(), {_diagSize, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
// _v.assign(0.);
|
||||
}
|
||||
|
||||
evalData(matrix);
|
||||
|
@ -106,19 +105,19 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
|
|||
if (_transp)
|
||||
math::nd4j_swap<bool>(_calcU, _calcV);
|
||||
|
||||
_s = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize, 1}, matrix.getContext());
|
||||
_m = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.getContext());
|
||||
_m.assign(0.f);
|
||||
_s = NDArray(matrix.ordering(), {_diagSize, 1}, matrix.dataType(), matrix.getContext());
|
||||
_m = NDArray(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
// _m.assign(0.f);
|
||||
|
||||
if (_calcU)
|
||||
_u = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext());
|
||||
_u = NDArray(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.dataType(), matrix.getContext());
|
||||
else
|
||||
_u = NDArrayFactory::create<T>(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext());
|
||||
_u.assign(0.);
|
||||
_u = NDArray(matrix.ordering(), {2, _diagSize + 1}, matrix.dataType(), matrix.getContext());
|
||||
// _u.assign(0.);
|
||||
|
||||
if (_calcV) {
|
||||
_v = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize, _diagSize}, matrix.getContext());
|
||||
_v.assign(0.);
|
||||
_v = NDArray(matrix.ordering(), {_diagSize, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
// _v.assign(0.);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -131,28 +130,27 @@ void SVD<T>::deflation1(int col1, int shift, int ind, int size) {
|
|||
throw std::runtime_error("ops::helpers::SVD::deflation1 method: input int must satisfy condition ind > 0 !");
|
||||
|
||||
int first = col1 + shift;
|
||||
T cos = _m.e<T>(first, first);
|
||||
T sin = _m.e<T>(first+ind, first);
|
||||
T cos = _m.t<T>(first, first);
|
||||
T sin = _m.t<T>(first+ind, first);
|
||||
T denom = math::nd4j_sqrt<T, T>(cos*cos + sin*sin);
|
||||
|
||||
if (denom == (T)0.) {
|
||||
|
||||
_m.p(first+ind, first+ind, 0.f);
|
||||
_m.r<T>(first+ind, first+ind) = (T)0;
|
||||
return;
|
||||
}
|
||||
|
||||
cos /= denom;
|
||||
sin /= denom;
|
||||
|
||||
_m.p(first,first, denom);
|
||||
_m.p(first+ind, first, 0.f);
|
||||
_m.p(first+ind, first+ind, 0.f);
|
||||
_m.r<T>(first,first) = denom;
|
||||
_m.r<T>(first+ind, first) = (T)0;
|
||||
_m.r<T>(first+ind, first+ind) = (T)0;
|
||||
|
||||
auto rotation = NDArrayFactory::create<T>(_m.ordering(), {2, 2}, _m.getContext());
|
||||
rotation.p(0, 0, cos);
|
||||
rotation.p(0, 1, -sin);
|
||||
rotation.p(1, 0, sin);
|
||||
rotation.p(1, 1, cos);
|
||||
NDArray rotation(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext());
|
||||
|
||||
rotation.r<T>(0,0) = rotation.r<T>(1,1) = cos;
|
||||
rotation.r<T>(0,1) = -sin;
|
||||
rotation.r<T>(1,0) = sin;
|
||||
|
||||
if (_calcU) {
|
||||
auto temp = _u({col1,col1+size+1, 0,0}, true);
|
||||
|
@ -172,28 +170,26 @@ void SVD<T>::deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, i
|
|||
if(size <= 0)
|
||||
throw std::runtime_error("ops::helpers::SVD::deflation2 method: input size must satisfy condition size > 0 !");
|
||||
|
||||
T cos = _m.e<T>(col1M+ind1, col1M);
|
||||
T sin = _m.e<T>(col1M+ind2, col1M);
|
||||
T cos = _m.t<T>(col1M+ind1, col1M);
|
||||
T sin = _m.t<T>(col1M+ind2, col1M);
|
||||
T denom = math::nd4j_sqrt<T,T>(cos*cos + sin*sin);
|
||||
|
||||
if (denom == (T)0.) {
|
||||
|
||||
_m.p(col1M + ind1, col1M + ind1, _m.e<T>(col1M + ind2, col1M + ind2));
|
||||
_m.r<T>(col1M+ind1, col1M+ind1) = _m.t<T>(col1M+ind2, col1M+ind2);
|
||||
return;
|
||||
}
|
||||
|
||||
cos /= denom;
|
||||
sin /= denom;
|
||||
_m.p(col1M + ind1, col1M, denom);
|
||||
_m.p(col1M + ind2, col1M + ind2, _m.e<T>(col1M + ind1, col1M + ind1));
|
||||
_m.p(col1M + ind2, col1M, 0.f);
|
||||
_m.r<T>(col1M+ind1, col1M) = denom;
|
||||
_m.r<T>(col1M+ind2, col1M+ind2) = _m.t<T>(col1M+ind1, col1M+ind1);
|
||||
_m.r<T>(col1M+ind2, col1M) = (T)0;
|
||||
|
||||
auto rotation = NDArrayFactory::create<T>(_m.ordering(), {2, 2}, _m.getContext());
|
||||
rotation.p(0,0, cos);
|
||||
rotation.p(1,1, cos);
|
||||
NDArray rotation(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext());
|
||||
|
||||
rotation.p(0,1, -sin);
|
||||
rotation.p(1,0, sin);
|
||||
rotation.r<T>(0,0) = rotation.r<T>(1,1) = cos;
|
||||
rotation.r<T>(0,1) = -sin;
|
||||
rotation.r<T>(1,0) = sin;
|
||||
|
||||
if (_calcU) {
|
||||
auto temp = _u({col1U,col1U+size+1, 0,0}, true);
|
||||
|
@ -216,40 +212,40 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
|
|||
|
||||
const int len = col2 + 1 - col1;
|
||||
|
||||
auto colVec0 = new NDArray(_m({col1+shift,col1+shift+len, col1+shift,col1+shift+1}, true));
|
||||
NDArray colVec0 = _m({col1+shift,col1+shift+len, col1+shift,col1+shift+1}, true);
|
||||
|
||||
auto diagInterval = _m({col1+shift, col1+shift+len, col1+shift,col1+shift+len}, true).diagonal('c');
|
||||
NDArray diagInterval = _m({col1+shift,col1+shift+len, col1+shift,col1+shift+len}, true).diagonal('c');
|
||||
|
||||
const T almostZero = DataTypeUtils::min<T>();
|
||||
T maxElem;
|
||||
if(len == 1)
|
||||
maxElem = math::nd4j_abs<T>(diagInterval.template e<T>(0));
|
||||
maxElem = math::nd4j_abs<T>(diagInterval.template t<T>(0));
|
||||
else
|
||||
maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e<T>(0);
|
||||
T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e<T>(0);
|
||||
maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template t<T>(0);
|
||||
T maxElem0 = colVec0.reduceNumber(reduce::AMax).template t<T>(0);
|
||||
|
||||
T eps = math::nd4j_max<T>(almostZero, DataTypeUtils::eps<T>() * maxElem);
|
||||
T epsBig = (T)8. * DataTypeUtils::eps<T>() * math::nd4j_max<T>(maxElem0, maxElem);
|
||||
|
||||
if(diagInterval.template e<T>(0) < epsBig)
|
||||
diagInterval.p(Nd4jLong(0), epsBig);
|
||||
if(diagInterval.template t<T>(0) < epsBig)
|
||||
diagInterval.r<T>(0) = epsBig;
|
||||
|
||||
for(int i=1; i < len; ++i)
|
||||
if(math::nd4j_abs<T>(colVec0->template e<T>(i)) < eps)
|
||||
colVec0->p(i, 0.f);
|
||||
if(math::nd4j_abs<T>(colVec0.template t<T>(i)) < eps)
|
||||
colVec0.r<T>(i) = (T)0;
|
||||
|
||||
for(int i=1; i < len; i++)
|
||||
if(diagInterval.template e<T>(i) < epsBig) {
|
||||
if(diagInterval.template t<T>(i) < epsBig) {
|
||||
deflation1(col1, shift, i, len);
|
||||
for(int i = 0; i < len; ++i)
|
||||
diagInterval.p(i, _m.e<T>(col1+shift+i,col1+shift+i));
|
||||
diagInterval.r<T>(i) = _m.t<T>(col1+shift+i,col1+shift+i);
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
bool totDefl = true;
|
||||
for(int i=1; i < len; i++)
|
||||
if(colVec0->template e<T>(i) >= almostZero) {
|
||||
if(colVec0.template t<T>(i) >= almostZero) {
|
||||
totDefl = false;
|
||||
break;
|
||||
}
|
||||
|
@ -261,7 +257,7 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
|
|||
int p = 1;
|
||||
|
||||
for(int i=1; i<len; ++i)
|
||||
if(math::nd4j_abs<T>(diagInterval.template e<T>(i)) < almostZero)
|
||||
if(math::nd4j_abs<T>(diagInterval.template t<T>(i)) < almostZero)
|
||||
permut[p++] = i;
|
||||
|
||||
int k = 1, m = ind+1;
|
||||
|
@ -271,7 +267,7 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
|
|||
permut[p] = m++;
|
||||
else if(m >= len)
|
||||
permut[p] = k++;
|
||||
else if(diagInterval.template e<T>(k) < diagInterval.template e<T>(m))
|
||||
else if(diagInterval.template t<T>(k) < diagInterval.template t<T>(m))
|
||||
permut[p] = m++;
|
||||
else
|
||||
permut[p] = k++;
|
||||
|
@ -281,7 +277,7 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
|
|||
if(totDefl) {
|
||||
for(int i=1; i<len; ++i) {
|
||||
int ki = permut[i];
|
||||
if(math::nd4j_abs<T>(diagInterval.template e<T>(ki)) < almostZero || diagInterval.template e<T>(0) < diagInterval.template e<T>(ki))
|
||||
if(math::nd4j_abs<T>(diagInterval.template t<T>(ki)) < almostZero || diagInterval.template t<T>(0) < diagInterval.template t<T>(ki))
|
||||
permut[i-1] = permut[i];
|
||||
else {
|
||||
permut[i-1] = 0;
|
||||
|
@ -303,39 +299,26 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
|
|||
const int ki = permut[len - (totDefl ? i+1 : i)];
|
||||
const int jac = tCol[ki];
|
||||
|
||||
T _e0 = diagInterval.template e<T>(jac);
|
||||
//math::nd4j_swap<T>(diagInterval)(i), (*diagInterval)(jac));
|
||||
diagInterval.p(jac, diagInterval.template e<T>(i));
|
||||
diagInterval.p(i, _e0);
|
||||
math::nd4j_swap<T>(diagInterval.r<T>(i), diagInterval.r<T>(jac));
|
||||
|
||||
if(i!=0 && jac!=0) {
|
||||
_e0 = colVec0->template e<T>(jac);
|
||||
//math::nd4j_swap<T>((*colVec0)(i), (*colVec0)(jac));
|
||||
colVec0->p(jac, colVec0->template e<T>(i));
|
||||
colVec0->p(i, _e0);
|
||||
}
|
||||
if(i!=0 && jac!=0)
|
||||
math::nd4j_swap<T>(colVec0.r<T>(i), colVec0.r<T>(jac));
|
||||
|
||||
if (_calcU) {
|
||||
auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true);
|
||||
auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1});
|
||||
auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1});
|
||||
temp1.swapUnsafe(temp2);
|
||||
}
|
||||
else {
|
||||
auto temp1 = _u({0,2, col1+i, col1+i+1}, true);
|
||||
auto temp2 = _u({0,2, col1+jac, col1+jac+1}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
auto temp1 = _u({0,2, col1+i, col1+i+1});
|
||||
auto temp2 = _u({0,2, col1+jac, col1+jac+1});
|
||||
temp1.swapUnsafe(temp2);
|
||||
}
|
||||
|
||||
if(_calcV) {
|
||||
auto temp1 = _v({row1W,row1W+len, col1W+i, col1W+i+1}, true);
|
||||
auto temp2 = _v({row1W,row1W+len, col1W+jac, col1W+jac+1}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
auto temp1 = _v({row1W,row1W+len, col1W+i, col1W+i+1});
|
||||
auto temp2 = _v({row1W,row1W+len, col1W+jac, col1W+jac+1});
|
||||
temp1.swapUnsafe(temp2);
|
||||
}
|
||||
|
||||
const int tI = tInd[i];
|
||||
|
@ -351,19 +334,17 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
|
|||
{
|
||||
int i = len-1;
|
||||
|
||||
while(i > 0 && (math::nd4j_abs<T>(diagInterval.template e<T>(i)) < almostZero || math::nd4j_abs<T>(colVec0->template e<T>(i)) < almostZero))
|
||||
while(i > 0 && (math::nd4j_abs<T>(diagInterval.template t<T>(i)) < almostZero || math::nd4j_abs<T>(colVec0.template t<T>(i)) < almostZero))
|
||||
--i;
|
||||
|
||||
for(; i > 1; --i) {
|
||||
if( (diagInterval.template e<T>(i) - diagInterval.template e<T>(i-1)) < DataTypeUtils::eps<T>()*maxElem ) {
|
||||
if (math::nd4j_abs<T>(diagInterval.template e<T>(i) - diagInterval.template e<T>(i-1)) >= epsBig)
|
||||
if( (diagInterval.template t<T>(i) - diagInterval.template t<T>(i-1)) < DataTypeUtils::eps<T>()*maxElem ) {
|
||||
if (math::nd4j_abs<T>(diagInterval.template t<T>(i) - diagInterval.template t<T>(i-1)) >= epsBig)
|
||||
throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !");
|
||||
deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete colVec0;
|
||||
}
|
||||
|
||||
|
||||
|
@ -374,10 +355,10 @@ T SVD<T>::secularEq(const T diff, const NDArray& col0, const NDArray& diag, cons
|
|||
auto len = permut.lengthOf();
|
||||
T res = 1.;
|
||||
T item;
|
||||
for(Nd4jLong i=0; i<len; ++i) {
|
||||
auto j = permut.e<int>(i);
|
||||
item = col0.e<T>(j) / ((diagShifted.e<T>(j) - diff) * (diag.e<T>(j) + shift + diff));
|
||||
res += item * col0.e<T>(j);
|
||||
for(int i=0; i<len; ++i) {
|
||||
int j = (int)permut.t<T>(i);
|
||||
item = col0.t<T>(j) / ((diagShifted.t<T>(j) - diff) * (diag.t<T>(j) + shift + diff));
|
||||
res += item * col0.t<T>(j);
|
||||
}
|
||||
|
||||
return res;
|
||||
|
@ -390,34 +371,34 @@ void SVD<T>::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArra
|
|||
auto len = col0.lengthOf();
|
||||
auto curLen = len;
|
||||
|
||||
while(curLen > 1 && col0.e<T>(curLen-1) == (T)0.f)
|
||||
while(curLen > 1 && col0.t<T>(curLen-1) == (T)0.f)
|
||||
--curLen;
|
||||
|
||||
for (Nd4jLong k = 0; k < len; ++k) {
|
||||
|
||||
if (col0.e<T>(k) == (T)0.f || curLen==1) {
|
||||
if (col0.t<T>(k) == (T)0.f || curLen==1) {
|
||||
|
||||
singVals.p(k, k==0 ? col0.e<T>(0) : diag.e<T>(k));
|
||||
mus.p(k, 0.f);
|
||||
shifts.p(k, k==0 ? col0.e<T>(0) : diag.e<T>(k));
|
||||
singVals.r<T>(k) = k==0 ? col0.t<T>(0) : diag.t<T>(k);
|
||||
mus.r<T>(k) = (T)0;
|
||||
shifts.r<T>(k) = k==0 ? col0.t<T>(0) : diag.t<T>(k);
|
||||
continue;
|
||||
}
|
||||
|
||||
T left = diag.e<T>(k);
|
||||
T left = diag.t<T>(k);
|
||||
T right;
|
||||
|
||||
if(k==curLen-1)
|
||||
right = diag.e<T>(curLen-1) + col0.reduceNumber(reduce::Norm2).e<T>(0);
|
||||
right = diag.t<T>(curLen-1) + col0.reduceNumber(reduce::Norm2).t<T>(0);
|
||||
else {
|
||||
|
||||
int l = k+1;
|
||||
while(col0.e<T>(l) == (T)0.f) {
|
||||
while(col0.t<T>(l) == (T)0.f) {
|
||||
++l;
|
||||
if(l >= curLen)
|
||||
throw std::runtime_error("ops::helpers::SVD::calcSingVals method: l >= curLen !");
|
||||
}
|
||||
|
||||
right = diag.e<T>(l);
|
||||
right = diag.t<T>(l);
|
||||
}
|
||||
|
||||
T mid = left + (right - left) / (T)2.;
|
||||
|
@ -464,13 +445,12 @@ void SVD<T>::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArra
|
|||
|
||||
if (shift == left && (muCur < (T)0. || muCur > right - left))
|
||||
useBisection = true;
|
||||
if (shift == right && (muCur < -(right - left) || muCur > (T)0.))
|
||||
else if (shift == right && (muCur < -(right - left) || muCur > (T)0.))
|
||||
useBisection = true;
|
||||
if (math::nd4j_abs<T>(fCur) > math::nd4j_abs<T>(fPrev) && math::nd4j_abs<T>(fCur - fPrev) > (T)16. * DataTypeUtils::eps<T>())
|
||||
else if (math::nd4j_abs<T>(fCur) > math::nd4j_abs<T>(fPrev) && math::nd4j_abs<T>(fCur - fPrev) > (T)16. * DataTypeUtils::eps<T>())
|
||||
useBisection = true;
|
||||
}
|
||||
|
||||
|
||||
if (useBisection) {
|
||||
|
||||
T leftShifted, rightShifted;
|
||||
|
@ -479,7 +459,6 @@ void SVD<T>::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArra
|
|||
rightShifted = (k==curLen-1) ? right : ((right - left) * (T)0.6);
|
||||
}
|
||||
else {
|
||||
|
||||
leftShifted = -(right - left) * (T)0.6;
|
||||
rightShifted = -DataTypeUtils::min<T>();
|
||||
}
|
||||
|
@ -502,14 +481,12 @@ void SVD<T>::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArra
|
|||
}
|
||||
muCur = (leftShifted + rightShifted) / (T)2.;
|
||||
}
|
||||
singVals.p(k, shift + muCur);
|
||||
shifts.p(k, shift);
|
||||
mus.p(k, muCur);
|
||||
singVals.r<T>(k) = shift + muCur;
|
||||
shifts.r<T>(k) = shift;
|
||||
mus.r<T>(k) = muCur;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void SVD<T>::perturb(const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& zhat) {
|
||||
|
@ -517,29 +494,29 @@ void SVD<T>::perturb(const NDArray& col0, const NDArray& diag, const NDArray& pe
|
|||
int n = col0.lengthOf();
|
||||
int m = permut.lengthOf();
|
||||
if(m==0) {
|
||||
zhat.assign(0.);
|
||||
zhat.nullify();
|
||||
return;
|
||||
}
|
||||
|
||||
int last = permut.e<int>(m-1);
|
||||
int last = permut.t<T>(m-1);
|
||||
|
||||
for (int k = 0; k < n; ++k) {
|
||||
|
||||
if (col0.e<T>(k) == (T)0.f)
|
||||
zhat.p(k, (T)0.f);
|
||||
if (col0.t<T>(k) == (T)0.f)
|
||||
zhat.r<T>(k) = (T)0;
|
||||
else {
|
||||
T dk = diag.e<T>(k);
|
||||
T prod = (singVals.e<T>(last) + dk) * (mus.e<T>(last) + (shifts.e<T>(last) - dk));
|
||||
T dk = diag.t<T>(k);
|
||||
T prod = (singVals.t<T>(last) + dk) * (mus.t<T>(last) + (shifts.t<T>(last) - dk));
|
||||
|
||||
for(int l = 0; l<m; ++l) {
|
||||
int i = permut.e<int>(l);
|
||||
int i = (int)permut.t<T>(l);
|
||||
if(i!=k) {
|
||||
int j = i<k ? i : permut.e<int>(l-1);
|
||||
prod *= ((singVals.e<T>(j)+dk) / ((diag.e<T>(i)+dk))) * ((mus.e<T>(j)+(shifts.e<T>(j)-dk)) / ((diag.e<T>(i)-dk)));
|
||||
int j = i<k ? i : (int)permut.t<T>(l-1);
|
||||
prod *= ((singVals.t<T>(j)+dk) / ((diag.t<T>(i)+dk))) * ((mus.t<T>(j)+(shifts.t<T>(j)-dk)) / ((diag.t<T>(i)-dk)));
|
||||
}
|
||||
}
|
||||
T tmp = math::nd4j_sqrt<T,T>(prod);
|
||||
zhat.p(k, col0.e<T>(k) > (T)0.f ? tmp : -tmp);
|
||||
zhat.r<T>(k) = col0.t<T>(k) > (T)0 ? tmp : -tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -555,48 +532,46 @@ void SVD<T>::calcSingVecs(const NDArray& zhat, const NDArray& diag, const NDArra
|
|||
|
||||
for (int k = 0; k < n; ++k) {
|
||||
|
||||
auto colU = new NDArray(U({0,0, k,k+1}, true));
|
||||
*colU = 0.;
|
||||
NDArray* colV = nullptr;
|
||||
NDArray colU = U({0,0, k,k+1});
|
||||
colU.nullify();
|
||||
|
||||
NDArray colV;
|
||||
|
||||
if (_calcV) {
|
||||
colV = new NDArray(V({0,0, k,k+1}, true));
|
||||
*colV = 0.;
|
||||
colV = V({0,0, k,k+1});
|
||||
colV.nullify();
|
||||
}
|
||||
|
||||
if (zhat.e<T>(k) == (T)0.f) {
|
||||
colU->p(k, 1.f);
|
||||
if (zhat.t<T>(k) == (T)0.f) {
|
||||
colU.r<T>(k) = (T)1;
|
||||
|
||||
if (_calcV)
|
||||
colV->p(k, 1.f);
|
||||
colV.r<T>(k) = (T)1;
|
||||
}
|
||||
else {
|
||||
|
||||
for(int l = 0; l < m; ++l) {
|
||||
int i = perm.e<int>(l);
|
||||
U.p(i,k, zhat.e<T>(i)/(((diag.e<T>(i) - shifts.e<T>(k)) - mus.e<T>(k)) )/( (diag.e<T>(i) + singVals.e<T>(k))));
|
||||
int i = (int)perm.t<T>(l);
|
||||
U.r<T>(i,k) = zhat.t<T>(i)/(((diag.t<T>(i) - shifts.t<T>(k)) - mus.t<T>(k)) )/( (diag.t<T>(i) + singVals.t<T>(k)));
|
||||
}
|
||||
U.p(n,k, 0.f);
|
||||
*colU /= colU->reduceNumber(reduce::Norm2);
|
||||
U.r<T>(n,k) = (T)0;
|
||||
colU /= colU.reduceNumber(reduce::Norm2);
|
||||
|
||||
if (_calcV) {
|
||||
|
||||
for(int l = 1; l < m; ++l){
|
||||
int i = perm.e<T>(l);
|
||||
V.p(i,k, diag.e<T>(i) * zhat.e<T>(i) / (((diag.e<T>(i) - shifts.e<T>(k)) - mus.e<T>(k)) )/( (diag.e<T>(i) + singVals.e<T>(k))));
|
||||
int i = perm.t<T>(l);
|
||||
V.r<T>(i,k) = diag.t<T>(i) * zhat.t<T>(i) / (((diag.t<T>(i) - shifts.t<T>(k)) - mus.t<T>(k)) )/( (diag.t<T>(i) + singVals.t<T>(k)));
|
||||
}
|
||||
V.p(0,k, -1.f);
|
||||
*colV /= colV->reduceNumber(reduce::Norm2);
|
||||
V.r<T>(0,k) = (T)-1;
|
||||
colV /= colV.reduceNumber(reduce::Norm2);
|
||||
}
|
||||
}
|
||||
delete colU;
|
||||
if (_calcV)
|
||||
delete colV;
|
||||
}
|
||||
|
||||
auto colU = U({0,0, n,n+1}, true);
|
||||
colU = 0.;
|
||||
colU.p(n, 1.);
|
||||
NDArray colU = U({0,0, n,n+1});
|
||||
colU.nullify();
|
||||
colU.r<T>(n) = (T)1;
|
||||
}
|
||||
|
||||
|
||||
|
@ -608,26 +583,29 @@ void SVD<T>::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA
|
|||
auto col0 = _m({col1, col1+size, col1, col1+1}, true);
|
||||
auto diag = static_cast<const NDArray&>(_m({col1, col1+size, col1, col1+size}, true).diagonal('c'));
|
||||
|
||||
diag.p(Nd4jLong(0), T(0));
|
||||
singVals = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
|
||||
U = NDArrayFactory::create<T>(_u.ordering(), {size+1, size+1}, _u.getContext());
|
||||
diag.r<T>(0) = (T)0;
|
||||
singVals = NDArray(_m.ordering(), {size, 1}, _m.dataType(), _m.getContext());
|
||||
U = NDArray(_u.ordering(), {size+1, size+1}, _u.dataType(), _u.getContext());
|
||||
if (_calcV)
|
||||
V = NDArrayFactory::create<T>(_v.ordering(), {size, size}, _v.getContext());
|
||||
V = NDArray(_v.ordering(), {size, size}, _v.dataType(), _v.getContext());
|
||||
|
||||
int curSize = size;
|
||||
while(curSize > 1 && diag.template e<T>(curSize-1) == (T)0.f)
|
||||
while(curSize > 1 && diag.template t<T>(curSize-1) == (T)0.f)
|
||||
--curSize;
|
||||
|
||||
int m = 0;
|
||||
std::vector<T> indices;
|
||||
std::vector<int> indices;
|
||||
for(int k = 0; k < curSize; ++k)
|
||||
if(math::nd4j_abs<T>(col0.template e<T>(k)) > almostZero)
|
||||
indices.push_back((T)k);
|
||||
if(math::nd4j_abs<T>(col0.template t<T>(k)) > almostZero)
|
||||
indices.push_back(k);
|
||||
|
||||
auto permut = NDArrayFactory::create<T>(_m.ordering(), {1, (int)indices.size()}, indices, _m.getContext());
|
||||
auto shifts = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
|
||||
auto mus = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
|
||||
auto zhat = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
|
||||
NDArray permut(_m.ordering(), {(int)indices.size()}, _m.dataType(), _m.getContext());
|
||||
for(int k = 0; k < indices.size(); ++k)
|
||||
permut.r<T>(k) = (T)indices[k];
|
||||
|
||||
NDArray shifts(_m.ordering(), {size, 1}, _m.dataType(), _m.getContext());
|
||||
NDArray mus(_m.ordering(), {size, 1}, _m.dataType(), _m.getContext());
|
||||
NDArray zhat(_m.ordering(), {size, 1}, _m.dataType(), _m.getContext());
|
||||
|
||||
calcSingVals(col0, diag, permut, singVals, shifts, mus);
|
||||
perturb(col0, diag, permut, singVals, shifts, mus, zhat);
|
||||
|
@ -635,53 +613,39 @@ void SVD<T>::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA
|
|||
|
||||
for(int i=0; i<curSize-1; ++i) {
|
||||
|
||||
if(singVals.e<T>(i) > singVals.e<T>(i+1)) {
|
||||
T _e0 = singVals.e<T>(i);
|
||||
T _e1 = singVals.e<T>(i+1);
|
||||
//math::nd4j_swap<T>(singVals(i),singVals(i+1));
|
||||
singVals.p(i, _e1);
|
||||
singVals.p(i+1, _e0);
|
||||
if(singVals.t<T>(i) > singVals.t<T>(i+1)) {
|
||||
|
||||
auto temp1 = U({0,0, i,i+1}, true);
|
||||
auto temp2 = U({0,0, i+1,i+2}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
math::nd4j_swap<T>(singVals.r<T>(i), singVals.r<T>(i+1));
|
||||
|
||||
auto temp1 = U({0,0, i,i+1});
|
||||
auto temp2 = U({0,0, i+1,i+2});
|
||||
temp1.swapUnsafe(temp2);
|
||||
|
||||
if(_calcV) {
|
||||
auto temp1 = V({0,0, i,i+1}, true);
|
||||
auto temp2 = V({0,0, i+1,i+2}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
auto temp1 = V({0,0, i,i+1});
|
||||
auto temp2 = V({0,0, i+1,i+2});
|
||||
temp1.swapUnsafe(temp2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto temp1 = singVals({0,curSize, 0,0}, true);
|
||||
for (int e = 0; e < curSize / 2; ++e) {
|
||||
T tmp = temp1.e<T>(e);
|
||||
temp1.p(e, temp1.e<T>(curSize-1-e));
|
||||
temp1.p(curSize-1-e, tmp);
|
||||
}
|
||||
auto temp1 = singVals({0,curSize, 0,0});
|
||||
for (int e = 0; e < curSize / 2; ++e)
|
||||
math::nd4j_swap<T>(temp1.r<T>(e), temp1.r<T>(curSize-1-e));
|
||||
|
||||
auto temp2 = U({0,0, 0,curSize}, true);
|
||||
for(int i = 0; i < curSize/2; ++i) {
|
||||
auto temp3 = temp2({0,0, i,i+1}, true);
|
||||
auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true);
|
||||
auto temp5 = temp3;
|
||||
temp3.assign(temp4);
|
||||
temp4.assign(temp5);
|
||||
auto temp3 = temp2({0,0, i,i+1});
|
||||
auto temp4 = temp2({0,0, curSize-1-i,curSize-i});
|
||||
temp3.swapUnsafe(temp4);
|
||||
}
|
||||
|
||||
if (_calcV) {
|
||||
auto temp2 = V({0,0, 0,curSize}, true);
|
||||
for(int i = 0; i < curSize/2; ++i) {
|
||||
auto temp3 = temp2({0,0, i,i+1}, true);
|
||||
auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true);
|
||||
auto temp5 = temp3;
|
||||
temp3.assign(temp4);
|
||||
temp4.assign(temp5);
|
||||
auto temp3 = temp2({0,0, i,i+1});
|
||||
auto temp4 = temp2({0,0, curSize-1-i,curSize-i});
|
||||
temp3.swapUnsafe(temp4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -695,54 +659,45 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
|
|||
const int n = col2 - col1 + 1;
|
||||
const int k = n/2;
|
||||
const T almostZero = DataTypeUtils::min<T>();
|
||||
T alphaK;
|
||||
T betaK;
|
||||
T r0;
|
||||
T lambda, phi, c0, s0;
|
||||
auto l = NDArrayFactory::create<T>(_u.ordering(), {1, k}, _u.getContext());
|
||||
auto f = NDArrayFactory::create<T>(_u.ordering(), {1, n-k-1}, _u.getContext());
|
||||
T alphaK, betaK, r0, lambda, phi, c0, s0;
|
||||
|
||||
NDArray l(_u.ordering(), {1, k}, _u.dataType(), _u.getContext());
|
||||
NDArray f(_u.ordering(), {1, n-k-1}, _u.dataType(), _u.getContext());
|
||||
|
||||
if(n < _switchSize) {
|
||||
|
||||
JacobiSVD<T> jac(_m({col1,col1+n+1, col1,col1+n}, true), _calcU, _calcV, _fullUV);
|
||||
|
||||
if (_calcU) {
|
||||
auto temp = _u({col1,col1+n+1, col1,col1+n+1}, true);
|
||||
temp.assign(jac._u);
|
||||
}
|
||||
if (_calcU)
|
||||
_u({col1,col1+n+1, col1,col1+n+1}, true).assign(jac._u);
|
||||
else {
|
||||
auto temp1 = _u({0,1, col1,col1+n+1}, true);
|
||||
temp1.assign(jac._u({0,1, 0,0}, true));
|
||||
auto temp2 = _u({1,2, col1,col1+n+1}, true);
|
||||
temp2.assign(jac._u({n,n+1, 0,0}, true));
|
||||
_u({0,1, col1,col1+n+1}, true).assign(jac._u({0,1, 0,0}, true));
|
||||
_u({1,2, col1,col1+n+1}, true).assign(jac._u({n,n+1, 0,0}, true));
|
||||
}
|
||||
|
||||
if (_calcV) {
|
||||
auto temp = _v({row1W,row1W+n, col1W,col1W+n}, true);
|
||||
temp.assign(jac._v);
|
||||
}
|
||||
if (_calcV)
|
||||
_v({row1W,row1W+n, col1W,col1W+n}, true).assign(jac._v);
|
||||
|
||||
auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true);
|
||||
temp.assign(0.);
|
||||
_m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true).nullify();
|
||||
auto diag = _m.diagonal('c');
|
||||
diag({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
alphaK = _m.e<T>(col1 + k, col1 + k);
|
||||
betaK = _m.e<T>(col1 + k + 1, col1 + k);
|
||||
alphaK = _m.t<T>(col1 + k, col1 + k);
|
||||
betaK = _m.t<T>(col1 + k + 1, col1 + k);
|
||||
|
||||
DivideAndConquer(k + 1 + col1, col2, k + 1 + row1W, k + 1 + col1W, shift);
|
||||
DivideAndConquer(col1, k - 1 + col1, row1W, col1W + 1, shift + 1);
|
||||
|
||||
if (_calcU) {
|
||||
lambda = _u.e<T>(col1 + k, col1 + k);
|
||||
phi = _u.e<T>(col1 + k + 1, col2 + 1);
|
||||
lambda = _u.t<T>(col1 + k, col1 + k);
|
||||
phi = _u.t<T>(col1 + k + 1, col2 + 1);
|
||||
}
|
||||
else {
|
||||
lambda = _u.e<T>(1, col1 + k);
|
||||
phi = _u.e<T>(0, col2 + 1);
|
||||
lambda = _u.t<T>(1, col1 + k);
|
||||
phi = _u.t<T>(0, col2 + 1);
|
||||
}
|
||||
|
||||
r0 = math::nd4j_sqrt<T, T>((math::nd4j_abs<T>(alphaK * lambda) * math::nd4j_abs<T>(alphaK * lambda)) + math::nd4j_abs<T>(betaK * phi) * math::nd4j_abs<T>(betaK * phi));
|
||||
|
@ -757,7 +712,7 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
|
|||
}
|
||||
|
||||
if (_calcV)
|
||||
_v.p(row1W+k, col1W, 1.f);
|
||||
_v.r<T>(row1W+k, col1W) = (T)1;
|
||||
|
||||
if (r0 < almostZero){
|
||||
c0 = 1.;
|
||||
|
@ -770,39 +725,37 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
|
|||
|
||||
if (_calcU) {
|
||||
|
||||
auto temp = _u({col1,col1+k+1, col1+k,col1+k+1}, true);
|
||||
NDArray q1(temp);
|
||||
NDArray q1 = _u({col1,col1+k+1, col1+k,col1+k+1}, true).dup();
|
||||
|
||||
for (int i = col1 + k - 1; i >= col1; --i) {
|
||||
auto temp = _u({col1,col1+k+1, i+1,i+2}, true);
|
||||
temp.assign(_u({col1, col1+k+1, i, i+1}, true));
|
||||
}
|
||||
for (int i = col1 + k - 1; i >= col1; --i)
|
||||
_u({col1,col1+k+1, i+1,i+2}, true).assign(_u({col1,col1+k+1, i,i+1}, true));
|
||||
|
||||
NDArray temp1 = _u({col1+k+1,col1+n+1, col2+1,col2+2}, true);
|
||||
|
||||
_u({col1,col1+k+1, col1,col1+1}, true).assign(q1 * c0);
|
||||
_u({col1,col1+k+1, col2+1,col2+2}, true).assign(q1 * (-s0));
|
||||
_u({col1+k+1,col1+n+1, col1, col1+1}, true).assign(static_cast<const NDArray&>(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true)) * s0);
|
||||
_u({col1+k+1,col1+n+1, col2+1,col2+2}, true) *= c0;
|
||||
_u({col1+k+1,col1+n+1, col1,col1+1}, true).assign(temp1 * s0);
|
||||
temp1 *= c0;
|
||||
}
|
||||
else {
|
||||
|
||||
T q1 = _u.e<T>(0, col1 + k);
|
||||
T q1 = _u.t<T>(0, col1 + k);
|
||||
|
||||
for (int i = col1 + k - 1; i >= col1; --i)
|
||||
_u.p(0, i+1, _u.e<T>(0, i));
|
||||
_u.r<T>(0, i+1) = _u.r<T>(0, i);
|
||||
|
||||
_u.p(0, col1, q1 * c0);
|
||||
_u.p(0, col2+1, -q1*s0);
|
||||
_u.p(1, col1, _u.e<T>(1, col2+1) * s0);
|
||||
_u.p(1, col2 + 1, _u.e<T>(1, col2 + 1) * c0);
|
||||
_u({1,2, col1+1, col1+k+1}, true) = 0.f;
|
||||
_u({0,1, col1+k+1, col1+n}, true) = 0.f;
|
||||
_u.r<T>(0, col1) = q1 * c0;
|
||||
_u.r<T>(0, col2+1) = -q1*s0;
|
||||
_u.r<T>(1, col1) = _u.t<T>(1, col2+1) * s0;
|
||||
_u.r<T>(1, col2+1) = _u.t<T>(1, col2+1) * c0;
|
||||
_u({1,2, col1+1, col1+k+1}).nullify();
|
||||
_u({0,1, col1+k+1, col1+n}).nullify();
|
||||
}
|
||||
|
||||
_m.p(col1 + shift, col1 + shift, r0);
|
||||
auto temp1 = _m({col1+shift+1,col1+shift+k+1, col1+shift,col1+shift+1}, true);
|
||||
temp1.assign(l*alphaK);
|
||||
auto temp2 = _m({col1+shift+k+1,col1+shift+n, col1+shift,col1+shift+1}, true);
|
||||
temp2.assign(f*betaK);
|
||||
_m.r<T>(col1+shift, col1+shift) = r0;
|
||||
|
||||
_m({col1+shift+1,col1+shift+k+1, col1+shift,col1+shift+1}, true).assign(l*alphaK);
|
||||
_m({col1+shift+k+1,col1+shift+n, col1+shift,col1+shift+1}, true).assign(f*betaK);
|
||||
|
||||
deflation(col1, col2, k, row1W, col1W, shift);
|
||||
|
||||
|
@ -810,26 +763,22 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
|
|||
calcBlockSVD(col1 + shift, n, UofSVD, singVals, VofSVD);
|
||||
|
||||
if(_calcU) {
|
||||
auto pTemp = _u({col1, col1+n+1, col1,col1+n+1}, true);
|
||||
auto temp = pTemp;
|
||||
pTemp.assign(mmul(temp, UofSVD));
|
||||
auto temp = _u({col1, col1+n+1, col1,col1+n+1}, true);
|
||||
temp.assign(mmul(temp, UofSVD));
|
||||
}
|
||||
else {
|
||||
auto pTemp = _u({0,0, col1,col1+n+1}, true);
|
||||
auto temp = pTemp;
|
||||
pTemp.assign(mmul(temp, UofSVD));
|
||||
auto temp = _u({0,0, col1,col1+n+1}, true);
|
||||
temp.assign(mmul(temp, UofSVD));
|
||||
}
|
||||
|
||||
if (_calcV) {
|
||||
auto pTemp = _v({row1W,row1W+n, row1W,row1W+n}, true);
|
||||
auto temp = pTemp;
|
||||
pTemp.assign(mmul(temp, VofSVD));
|
||||
auto temp = _v({row1W,row1W+n, row1W,row1W+n}, true);
|
||||
temp.assign(mmul(temp, VofSVD));
|
||||
}
|
||||
|
||||
auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true);
|
||||
blockM = 0.f;
|
||||
auto diag = blockM.diagonal('c');
|
||||
diag.assign(singVals);
|
||||
blockM.nullify();
|
||||
blockM.diagonal('c').assign(singVals);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -839,24 +788,22 @@ void SVD<T>::exchangeUV(const HHsequence& hhU, const HHsequence& hhV, const NDAr
|
|||
if (_calcU) {
|
||||
|
||||
int colsU = _fullUV ? hhU.rows() : _diagSize;
|
||||
auto temp1 = NDArrayFactory::create<T>(_u.ordering(), {hhU.rows(), colsU}, _u.getContext());
|
||||
NDArray temp1(_u.ordering(), {hhU.rows(), colsU}, _u.dataType(), _u.getContext());
|
||||
temp1.setIdentity();
|
||||
_u = temp1;
|
||||
|
||||
auto temp2 = _u({0,_diagSize, 0,_diagSize}, true);
|
||||
temp2.assign(V({0,_diagSize, 0,_diagSize}, true));
|
||||
_u({0,_diagSize, 0,_diagSize}, true).assign(V({0,_diagSize, 0,_diagSize}, true));
|
||||
const_cast<HHsequence&>(hhU).mulLeft(_u);
|
||||
}
|
||||
|
||||
if (_calcV) {
|
||||
|
||||
int colsV = _fullUV ? hhV.rows() : _diagSize;
|
||||
auto temp1 = NDArrayFactory::create<T>(_v.ordering(), {hhV.rows(), colsV}, _v.getContext());
|
||||
NDArray temp1(_v.ordering(), {hhV.rows(), colsV}, _v.dataType(), _v.getContext());
|
||||
temp1.setIdentity();
|
||||
_v = temp1;
|
||||
|
||||
auto temp2 = _v({0,_diagSize, 0,_diagSize}, true);
|
||||
temp2.assign(U({0,_diagSize, 0,_diagSize}, true));
|
||||
_v({0,_diagSize, 0,_diagSize}, true).assign(U({0,_diagSize, 0,_diagSize}, true));
|
||||
const_cast<HHsequence&>(hhV).mulLeft(_v);
|
||||
}
|
||||
}
|
||||
|
@ -881,48 +828,40 @@ void SVD<T>::evalData(const NDArray& matrix) {
|
|||
return;
|
||||
}
|
||||
|
||||
T scale = matrix.reduceNumber(reduce::AMax).e<T>(0);
|
||||
T scale = matrix.reduceNumber(reduce::AMax).t<T>(0);
|
||||
|
||||
if(scale == (T)0.)
|
||||
scale = 1.;
|
||||
|
||||
NDArray copy;
|
||||
if(_transp)
|
||||
copy = matrix.transpose();
|
||||
else
|
||||
copy = matrix / scale;
|
||||
BiDiagonalUp biDiag(_transp ? matrix.transpose() : matrix / scale);
|
||||
|
||||
BiDiagonalUp biDiag(copy);
|
||||
_u.nullify();
|
||||
_v.nullify();
|
||||
|
||||
_u = 0.;
|
||||
_v = 0.;
|
||||
_m({0,_diagSize, 0,0}, true).assign(biDiag._HHbidiag.transpose());
|
||||
|
||||
auto temp1 = biDiag._HHbidiag.transpose();
|
||||
auto temp2 = _m({0,_diagSize, 0,0}, true);
|
||||
temp2.assign(temp1);
|
||||
|
||||
|
||||
auto temp3 = _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}, true);
|
||||
temp3.assign(0.);
|
||||
_m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}).nullify();
|
||||
|
||||
DivideAndConquer(0, _diagSize - 1, 0, 0, 0);
|
||||
|
||||
for (int i = 0; i < _diagSize; ++i) {
|
||||
T a = math::nd4j_abs<T>(_m.e<T>(i, i));
|
||||
_s.p(i, a * scale);
|
||||
T a = math::nd4j_abs<T>(_m.t<T>(i, i));
|
||||
_s.r<T>(i) = a * scale;
|
||||
if (a < almostZero) {
|
||||
auto temp = _s({i+1,_diagSize, 0,0}, true);
|
||||
temp.assign(0.);
|
||||
_s({i+1,_diagSize, 0,0}).nullify();
|
||||
break;
|
||||
}
|
||||
else if (i == _diagSize-1)
|
||||
break;
|
||||
}
|
||||
|
||||
HHsequence hhV = biDiag.makeHHsequence('v');
|
||||
HHsequence hhU = biDiag.makeHHsequence('u');
|
||||
|
||||
if(_transp)
|
||||
exchangeUV(biDiag.makeHHsequence('v'), biDiag.makeHHsequence('u'), _v, _u);
|
||||
exchangeUV(hhV, hhU, _v, _u);
|
||||
else
|
||||
exchangeUV(biDiag.makeHHsequence('u'), biDiag.makeHHsequence('v'), _u, _v);
|
||||
exchangeUV(hhU, hhV, _u, _v);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -35,12 +35,12 @@ class HHsequence {
|
|||
/*
|
||||
* matrix containing the Householder vectors
|
||||
*/
|
||||
NDArray _vectors;
|
||||
const NDArray& _vectors;
|
||||
|
||||
/*
|
||||
* vector containing the Householder coefficients
|
||||
*/
|
||||
NDArray _coeffs;
|
||||
const NDArray& _coeffs;
|
||||
|
||||
/*
|
||||
* shift of the Householder sequence
|
||||
|
@ -68,14 +68,14 @@ class HHsequence {
|
|||
* matrix - input matrix to be multiplied
|
||||
*/
|
||||
template <typename T>
|
||||
void _mulLeft(NDArray& matrix);
|
||||
void mulLeft_(NDArray& matrix);
|
||||
|
||||
void mulLeft(NDArray& matrix);
|
||||
|
||||
NDArray getTail(const int idx) const;
|
||||
|
||||
template <typename T>
|
||||
void _applyTo(NDArray& dest);
|
||||
void applyTo_(NDArray& dest);
|
||||
|
||||
void applyTo(NDArray& dest);
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ class Householder {
|
|||
*
|
||||
* x - input vector, remains unaffected
|
||||
*/
|
||||
static NDArray evalHHmatrix(const NDArray& x);
|
||||
// static NDArray evalHHmatrix(const NDArray& x);
|
||||
|
||||
/**
|
||||
* this method evaluates data required for calculation of Householder matrix P = identity_matrix - coeff * w * w^T
|
||||
|
@ -64,7 +64,7 @@ class Householder {
|
|||
*/
|
||||
static void evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff, T& normX);
|
||||
|
||||
static void evalHHmatrixDataI(const NDArray& x, T& coeff, T& normX);
|
||||
static void evalHHmatrixDataI(NDArray& x, T& coeff, T& normX); // in-place, x to be affected
|
||||
|
||||
/**
|
||||
* this method mathematically multiplies input matrix on Householder from the left P * matrix
|
||||
|
|
|
@ -0,0 +1,293 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <helpers/HessenbergAndSchur.h>
|
||||
#include <helpers/EigenValsAndVecs.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
EigenValsAndVecs<T>::EigenValsAndVecs(const NDArray& matrix) {
|
||||
|
||||
if(matrix.rankOf() != 2)
|
||||
throw std::runtime_error("ops::helpers::EigenValsAndVecs constructor: input matrix must be 2D !");
|
||||
|
||||
if(matrix.sizeAt(0) != matrix.sizeAt(1))
|
||||
throw std::runtime_error("ops::helpers::EigenValsAndVecs constructor: input array must be 2D square matrix !");
|
||||
|
||||
Schur<T> schur(matrix);
|
||||
|
||||
NDArray& schurMatrixU = schur._U;
|
||||
NDArray& schurMatrixT = schur._T;
|
||||
|
||||
_Vecs = NDArray(matrix.ordering(), {schurMatrixU.sizeAt(1), schurMatrixU.sizeAt(1), 2}, matrix.dataType(), matrix.getContext());
|
||||
_Vals = NDArray(matrix.ordering(), {matrix.sizeAt(1), 2}, matrix.dataType(), matrix.getContext());
|
||||
|
||||
// sequence of methods calls matters
|
||||
calcEigenVals(schurMatrixT);
|
||||
calcPseudoEigenVecs(schurMatrixT, schurMatrixU); // pseudo-eigenvectors are real and will be stored in schurMatrixU
|
||||
calcEigenVecs(schurMatrixU);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void EigenValsAndVecs<T>::calcEigenVals(const NDArray& schurMatrixT) {
|
||||
|
||||
const int numOfCols = schurMatrixT.sizeAt(1);
|
||||
|
||||
// calculate eigenvalues _Vals
|
||||
int i = 0;
|
||||
while (i < numOfCols) {
|
||||
|
||||
if (i == numOfCols - 1 || schurMatrixT.t<T>(i+1, i) == T(0.f)) {
|
||||
|
||||
_Vals.r<T>(i, 0) = schurMatrixT.t<T>(i, i); // real part
|
||||
_Vals.r<T>(i, 1) = T(0); // imaginary part
|
||||
|
||||
if(!math::nd4j_isfin<T>(_Vals.t<T>(i, 0))) {
|
||||
throw std::runtime_error("ops::helpers::igenValsAndVec::calcEigenVals: got infinite eigen value !");
|
||||
return;
|
||||
}
|
||||
|
||||
++i;
|
||||
}
|
||||
else {
|
||||
|
||||
T p = T(0.5) * (schurMatrixT.t<T>(i, i) - schurMatrixT.t<T>(i+1, i+1));
|
||||
T z;
|
||||
{
|
||||
T t0 = schurMatrixT.t<T>(i+1, i);
|
||||
T t1 = schurMatrixT.t<T>(i, i+1);
|
||||
T maxval = math::nd4j_max<T>(math::nd4j_abs<T>(p), math::nd4j_max<T>(math::nd4j_abs<T>(t0), math::nd4j_abs<T>(t1)));
|
||||
t0 /= maxval;
|
||||
t1 /= maxval;
|
||||
T p0 = p / maxval;
|
||||
z = maxval * math::nd4j_sqrt<T,T>(math::nd4j_abs<T>(p0 * p0 + t0 * t1));
|
||||
}
|
||||
|
||||
_Vals.r<T>(i, 0) = _Vals.r<T>(i+1, 0) = schurMatrixT.t<T>(i+1, i+1) + p;
|
||||
_Vals.r<T>(i, 1) = z;
|
||||
_Vals.r<T>(i+1,1) = -z;
|
||||
|
||||
if(!(math::nd4j_isfin<T>(_Vals.t<T>(i,0)) && math::nd4j_isfin<T>(_Vals.t<T>(i+1,0)) && math::nd4j_isfin<T>(_Vals.t<T>(i,1))) && math::nd4j_isfin<T>(_Vals.t<T>(i+1,1))) {
|
||||
throw std::runtime_error("ops::helpers::igenValsAndVec::calcEigenVals: got infinite eigen value !");
|
||||
return;
|
||||
}
|
||||
|
||||
i += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void EigenValsAndVecs<T>::calcPseudoEigenVecs(NDArray& schurMatrixT, NDArray& schurMatrixU) {
|
||||
|
||||
const int numOfCols = schurMatrixU.sizeAt(1);
|
||||
|
||||
T norm = 0;
|
||||
for (int j = 0; j < numOfCols; ++j)
|
||||
norm += schurMatrixT({j,j+1, math::nd4j_max<Nd4jLong>(j-1, 0),numOfCols}).reduceNumber(reduce::ASum).template t<T>(0);
|
||||
|
||||
if (norm == T(0))
|
||||
return;
|
||||
|
||||
for (int n = numOfCols-1; n >= 0; n--) {
|
||||
|
||||
T p = _Vals.t<T>(n, 0); // real part
|
||||
T q = _Vals.t<T>(n, 1); // imaginary part
|
||||
|
||||
if(q == (T)0) { // not complex
|
||||
|
||||
T lastr((T)0), lastw((T)0);
|
||||
int l = n;
|
||||
|
||||
schurMatrixT.r<T>(n, n) = T(1);
|
||||
|
||||
for (int i = n-1; i >= 0; i--) {
|
||||
|
||||
T w = schurMatrixT.t<T>(i,i) - p;
|
||||
T r = mmul(schurMatrixT({i,i+1, l,n+1}, true), schurMatrixT({l,n+1, n,n+1}, true)).template t<T>(0); // dot
|
||||
|
||||
if (_Vals.t<T>(i, 1) < T(0)) {
|
||||
lastw = w;
|
||||
lastr = r;
|
||||
}
|
||||
else {
|
||||
|
||||
l = i;
|
||||
if (_Vals.t<T>(i, 1) == T(0)) {
|
||||
|
||||
if (w != T(0))
|
||||
schurMatrixT.r<T>(i, n) = -r / w;
|
||||
else
|
||||
schurMatrixT.r<T>(i, n) = -r / (DataTypeUtils::eps<T>() * norm);
|
||||
}
|
||||
else {
|
||||
|
||||
T x = schurMatrixT.t<T>(i, i+1);
|
||||
T y = schurMatrixT.t<T>(i+1, i);
|
||||
T denom = (_Vals.t<T>(i, 0) - p) * (_Vals.t<T>(i, 0) - p) + _Vals.t<T>(i, 1) * _Vals.t<T>(i, 1);
|
||||
T t = (x * lastr - lastw * r) / denom;
|
||||
schurMatrixT.r<T>(i, n) = t;
|
||||
|
||||
if (math::nd4j_abs<T>(x) > math::nd4j_abs<T>(lastw))
|
||||
schurMatrixT.r<T>(i+1, n) = (-r - w * t) / x;
|
||||
else
|
||||
schurMatrixT.r<T>(i+1, n) = (-lastr - y * t) / lastw;
|
||||
}
|
||||
|
||||
|
||||
T t = math::nd4j_abs<T>(schurMatrixT.t<T>(i, n));
|
||||
if((DataTypeUtils::eps<T>() * t) * t > T(1))
|
||||
schurMatrixT({schurMatrixT.sizeAt(0)-numOfCols+i,-1, n,n+1}) /= t;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(q < T(0) && n > 0) { // complex
|
||||
|
||||
T lastra(0), lastsa(0), lastw(0);
|
||||
int l = n - 1;
|
||||
|
||||
if(math::nd4j_abs<T>(schurMatrixT.t<T>(n, n-1)) > math::nd4j_abs<T>(schurMatrixT.t<T>(n-1, n))) {
|
||||
|
||||
schurMatrixT.r<T>(n-1, n-1) = q / schurMatrixT.t<T>(n, n-1);
|
||||
schurMatrixT.r<T>(n-1, n) = -(schurMatrixT.t<T>(n, n) - p) / schurMatrixT.t<T>(n, n-1);
|
||||
}
|
||||
else {
|
||||
divideComplexNums(T(0),-schurMatrixT.t<T>(n-1,n), schurMatrixT.t<T>(n-1,n-1)-p,q, schurMatrixT.r<T>(n-1,n-1),schurMatrixT.r<T>(n-1,n));
|
||||
}
|
||||
|
||||
schurMatrixT.r<T>(n,n-1) = T(0);
|
||||
schurMatrixT.r<T>(n,n) = T(1);
|
||||
|
||||
for (int i = n-2; i >= 0; i--) {
|
||||
|
||||
T ra = mmul(schurMatrixT({i,i+1, l,n+1}, true), schurMatrixT({l,n+1, n-1,n}, true)).template t<T>(0); // dot
|
||||
T sa = mmul(schurMatrixT({i,i+1, l,n+1}, true), schurMatrixT({l,n+1, n,n+1}, true)).template t<T>(0); // dot
|
||||
|
||||
T w = schurMatrixT.t<T>(i,i) - p;
|
||||
|
||||
if (_Vals.t<T>(i, 1) < T(0)) {
|
||||
lastw = w;
|
||||
lastra = ra;
|
||||
lastsa = sa;
|
||||
}
|
||||
else {
|
||||
|
||||
l = i;
|
||||
|
||||
if (_Vals.t<T>(i, 1) == T(0)) {
|
||||
divideComplexNums(-ra,-sa, w,q, schurMatrixT.r<T>(i,n-1),schurMatrixT.r<T>(i,n));
|
||||
}
|
||||
else {
|
||||
|
||||
T x = schurMatrixT.t<T>(i,i+1);
|
||||
T y = schurMatrixT.t<T>(i+1,i);
|
||||
T vr = (_Vals.t<T>(i, 0) - p) * (_Vals.t<T>(i, 0) - p) + _Vals.t<T>(i, 1) * _Vals.t<T>(i, 1) - q * q;
|
||||
T vi = (_Vals.t<T>(i, 0) - p) * T(2) * q;
|
||||
|
||||
if ((vr == T(0)) && (vi == T(0)))
|
||||
vr = DataTypeUtils::eps<T>() * norm * (math::nd4j_abs<T>(w) + math::nd4j_abs<T>(q) + math::nd4j_abs<T>(x) + math::nd4j_abs<T>(y) + math::nd4j_abs<T>(lastw));
|
||||
|
||||
divideComplexNums(x*lastra-lastw*ra+q*sa,x*lastsa-lastw*sa-q*ra, vr,vi, schurMatrixT.r<T>(i,n-1),schurMatrixT.r<T>(i,n));
|
||||
|
||||
if(math::nd4j_abs<T>(x) > (math::nd4j_abs<T>(lastw) + math::nd4j_abs<T>(q))) {
|
||||
|
||||
schurMatrixT.r<T>(i+1,n-1) = (-ra - w * schurMatrixT.t<T>(i,n-1) + q * schurMatrixT.t<T>(i,n)) / x;
|
||||
schurMatrixT.r<T>(i+1,n) = (-sa - w * schurMatrixT.t<T>(i,n) - q * schurMatrixT.t<T>(i,n-1)) / x;
|
||||
}
|
||||
else
|
||||
divideComplexNums(-lastra-y*schurMatrixT.t<T>(i,n-1),-lastsa-y*schurMatrixT.t<T>(i,n), lastw,q, schurMatrixT.r<T>(i+1,n-1),schurMatrixT.r<T>(i+1,n));
|
||||
}
|
||||
|
||||
T t = math::nd4j_max<T>(math::nd4j_abs<T>(schurMatrixT.t<T>(i, n-1)), math::nd4j_abs<T>(schurMatrixT.t<T>(i,n)));
|
||||
if ((DataTypeUtils::eps<T>() * t) * t > T(1))
|
||||
schurMatrixT({i,numOfCols, n-1,n+1}) /= t;
|
||||
}
|
||||
}
|
||||
n--;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("ops::helpers::EigenValsAndVecs::calcEigenVecs: internal bug !");
|
||||
}
|
||||
|
||||
for (int j = numOfCols-1; j >= 0; j--)
|
||||
schurMatrixU({0,0, j,j+1}, true).assign( mmul(schurMatrixU({0,0, 0,j+1}, true), schurMatrixT({0,j+1, j,j+1}, true)) );
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void EigenValsAndVecs<T>::calcEigenVecs(const NDArray& schurMatrixU) {
|
||||
|
||||
const T precision = T(2) * DataTypeUtils::eps<T>();
|
||||
|
||||
const int numOfCols = schurMatrixU.sizeAt(1);
|
||||
|
||||
for (int j = 0; j < numOfCols; ++j) {
|
||||
|
||||
if(math::nd4j_abs<T>(_Vals.t<T>(j, 1)) <= math::nd4j_abs<T>(_Vals.t<T>(j, 0)) * precision || j+1 == numOfCols) { // real
|
||||
|
||||
_Vecs.syncToDevice();
|
||||
_Vecs({0,0, j,j+1, 0,1}).assign(schurMatrixU({0,0, j,j+1}));
|
||||
_Vecs({0,0, j,j+1, 1,2}) = (T)0;
|
||||
|
||||
// normalize
|
||||
const T norm2 = _Vecs({0,0, j,j+1, 0,1}).reduceNumber(reduce::SquaredNorm).template t<T>(0);
|
||||
if(norm2 > (T)0)
|
||||
_Vecs({0,0, j,j+1, 0,1}) /= math::nd4j_sqrt<T,T>(norm2);
|
||||
}
|
||||
else { // complex
|
||||
|
||||
for (int i = 0; i < numOfCols; ++i) {
|
||||
_Vecs.r<T>(i, j, 0) = _Vecs.r<T>(i, j+1, 0) = schurMatrixU.t<T>(i, j);
|
||||
_Vecs.r<T>(i, j, 1) = schurMatrixU.t<T>(i, j+1);
|
||||
_Vecs.r<T>(i, j+1, 1) = -schurMatrixU.t<T>(i, j+1);
|
||||
}
|
||||
|
||||
// normalize
|
||||
T norm2 = _Vecs({0,0, j,j+1, 0,0}).reduceNumber(reduce::SquaredNorm).template t<T>(0);
|
||||
if(norm2 > (T)0)
|
||||
_Vecs({0,0, j,j+1, 0,0}) /= math::nd4j_sqrt<T,T>(norm2);
|
||||
|
||||
// normalize
|
||||
norm2 = _Vecs({0,0, j+1,j+2, 0,0}).reduceNumber(reduce::SquaredNorm).template t<T>(0);
|
||||
if(norm2 > (T)0)
|
||||
_Vecs({0,0, j+1,j+2, 0,0}) /= math::nd4j_sqrt<T,T>(norm2);
|
||||
|
||||
++j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template class ND4J_EXPORT EigenValsAndVecs<float>;
|
||||
template class ND4J_EXPORT EigenValsAndVecs<float16>;
|
||||
template class ND4J_EXPORT EigenValsAndVecs<bfloat16>;
|
||||
template class ND4J_EXPORT EigenValsAndVecs<double>;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,170 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <helpers/FullPivLU.h>
|
||||
#include <ops/declarable/helpers/triangular_solve.h>
|
||||
#include <numeric>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// A{M,K} * x{K,N} = b{M,N}
|
||||
template <typename T>
|
||||
void FullPivLU<T>::solve(const NDArray& A, const NDArray& b, NDArray& x) {
|
||||
|
||||
if(A.rankOf() != 2)
|
||||
throw std::runtime_error("FullPivLU::solve: input matrix A must be 2D !");
|
||||
|
||||
if(A.sizeAt(0) != b.sizeAt(0))
|
||||
throw std::runtime_error("FullPivLU::solve: A and b must have the same number of rows !");
|
||||
|
||||
if(A.sizeAt(1) != x.sizeAt(0))
|
||||
throw std::runtime_error("FullPivLU::solve: number of A columns must be equal to number of x rows !");
|
||||
|
||||
NDArray LU = A.dup();
|
||||
|
||||
const int rows = LU.sizeAt(0);
|
||||
const int cols = LU.sizeAt(1);
|
||||
const int diagLen = math::nd4j_min<int>(rows, cols);
|
||||
|
||||
std::vector<int> rowsInds(rows), colsInds(cols);
|
||||
|
||||
int numOfTranspos = 0;
|
||||
int nonZeroPivots1 = diagLen;
|
||||
|
||||
T maxPivot = T(0);
|
||||
|
||||
for(int k = 0; k < diagLen; ++k) {
|
||||
|
||||
NDArray bottomRightCorner = LU({k,rows, k,cols}, true);
|
||||
const int indPivot = static_cast<int>(bottomRightCorner.indexReduceNumber(indexreduce::IndexAbsoluteMax).t<Nd4jLong>(0));
|
||||
|
||||
int colPivot = indPivot % (cols-k);
|
||||
int rowPivot = indPivot / (cols-k);
|
||||
|
||||
T currentMax = math::nd4j_abs<T>(bottomRightCorner.t<T>(rowPivot, colPivot));
|
||||
|
||||
// take into account that this was calculated in corner, not in whole LU
|
||||
rowPivot += k;
|
||||
colPivot += k;
|
||||
|
||||
if(currentMax == T(0)) {
|
||||
|
||||
nonZeroPivots1 = k;
|
||||
|
||||
for(int i = k; i < diagLen; ++i)
|
||||
rowsInds[i] = colsInds[i] = i;
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
if(currentMax > maxPivot)
|
||||
maxPivot = currentMax;
|
||||
|
||||
rowsInds[k] = rowPivot;
|
||||
colsInds[k] = colPivot;
|
||||
|
||||
if(k != rowPivot) {
|
||||
NDArray row1 = LU({k,k+1, 0,0}, true);
|
||||
NDArray row2 = LU({rowPivot,rowPivot+1, 0,0}, true);
|
||||
row1.swapUnsafe(row2);
|
||||
++numOfTranspos;
|
||||
}
|
||||
if(k != colPivot) {
|
||||
NDArray col1 = LU({0,0, k,k+1}, true);
|
||||
NDArray col2 = LU({0,0, colPivot,colPivot+1}, true);
|
||||
col1.swapUnsafe(col2);
|
||||
++numOfTranspos;
|
||||
}
|
||||
|
||||
if(k < rows-1)
|
||||
LU({k+1,rows, k,k+1}, true) /= LU.t<T>(k, k);
|
||||
|
||||
if(k < diagLen-1)
|
||||
LU({k+1,rows, k+1,cols},true) -= mmul(LU({k+1,rows, k,k+1},true), LU({k,k+1, k+1,cols},true));
|
||||
}
|
||||
|
||||
//***************************************************//
|
||||
|
||||
const T threshold = maxPivot * DataTypeUtils::eps<T>() * (T)diagLen;
|
||||
|
||||
int nonZeroPivots2 = 0;
|
||||
for(int i = 0; i < nonZeroPivots1; ++i)
|
||||
nonZeroPivots2 += static_cast<int>(math::nd4j_abs<T>(LU.t<T>(i,i)) > threshold);
|
||||
|
||||
if(nonZeroPivots2 == 0) {
|
||||
x.nullify();
|
||||
return;
|
||||
}
|
||||
|
||||
//***************************************************//
|
||||
|
||||
std::vector<int> rowsPermut1(rows), rowsPermut2(rows), colsPermut(cols);
|
||||
std::iota(rowsPermut1.begin(), rowsPermut1.end(), 0);
|
||||
std::iota(colsPermut.begin(), colsPermut.end(), 0);
|
||||
|
||||
for(int k = diagLen-1; k >= 0; --k)
|
||||
math::nd4j_swap<int>(rowsPermut1[k], rowsPermut1[rowsInds[k]]);
|
||||
|
||||
for(int k = 0; k < diagLen; ++k)
|
||||
math::nd4j_swap<int>(colsPermut[k], colsPermut[colsInds[k]]);
|
||||
|
||||
for(int i = 0; i < rows; ++i)
|
||||
for(int j = 0; j < rows; ++j)
|
||||
if(i == rowsPermut1[j]) { rowsPermut2[i] = j; break; }
|
||||
|
||||
//***************************************************//
|
||||
|
||||
NDArray c = b.ulike();
|
||||
|
||||
for (int i = 0; i < rows; ++i)
|
||||
c({i,i+1, 0,0}, true).assign(b({rowsPermut2[i],rowsPermut2[i]+1, 0,0}, true));
|
||||
|
||||
|
||||
NDArray cTopRows1 = c({0,diagLen, 0,0}, true);
|
||||
// TriangularSolver<T>::solve(LU({0,diagLen, 0,diagLen}, true), cTopRows1, true, true, cTopRows1);
|
||||
ops::helpers::triangularSolve2D<T>(nullptr, LU({0,diagLen, 0,diagLen}, true), cTopRows1,true,true, cTopRows1);
|
||||
|
||||
if(rows > cols)
|
||||
c({cols,-1, 0,0}, true) -= mmul(LU({cols,-1, 0,0},true), c({0,cols, 0,0}, true));
|
||||
|
||||
NDArray cTopRows2 = c({0,nonZeroPivots2, 0,0}, true);
|
||||
// TriangularSolver<T>::solve(LU({0,nonZeroPivots2, 0,nonZeroPivots2}, true), cTopRows2, false, false, cTopRows2);
|
||||
ops::helpers::triangularSolve2D<T>(nullptr, LU({0,nonZeroPivots2, 0,nonZeroPivots2}, true),cTopRows2,false,false, cTopRows2);
|
||||
|
||||
for(int i = 0; i < nonZeroPivots2; ++i)
|
||||
x({colsPermut[i],colsPermut[i]+1, 0,0}, true).assign(c({i,i+1, 0,0}, true));
|
||||
|
||||
for(int i = nonZeroPivots2; i < cols; ++i)
|
||||
x({colsPermut[i],colsPermut[i]+1, 0,0}, true).nullify();
|
||||
}
|
||||
|
||||
template class ND4J_EXPORT FullPivLU<float>;
|
||||
template class ND4J_EXPORT FullPivLU<float16>;
|
||||
template class ND4J_EXPORT FullPivLU<bfloat16>;
|
||||
template class ND4J_EXPORT FullPivLU<double>;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,383 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <helpers/HessenbergAndSchur.h>
|
||||
#include <helpers/householder.h>
|
||||
#include <helpers/hhSequence.h>
|
||||
#include <helpers/jacobiSVD.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
Hessenberg<T>::Hessenberg(const NDArray& matrix) {
|
||||
|
||||
if(matrix.rankOf() != 2)
|
||||
throw std::runtime_error("ops::helpers::Hessenberg constructor: input matrix must be 2D !");
|
||||
|
||||
if(matrix.sizeAt(0) == 1) {
|
||||
_Q = NDArray(matrix.ordering(), {1,1}, matrix.dataType(), matrix.getContext());
|
||||
_Q = 1;
|
||||
_H = matrix.dup();
|
||||
return;
|
||||
}
|
||||
|
||||
if(matrix.sizeAt(0) != matrix.sizeAt(1))
|
||||
throw std::runtime_error("ops::helpers::Hessenberg constructor: input array must be 2D square matrix !");
|
||||
|
||||
_H = matrix.dup();
|
||||
_Q = matrix.ulike();
|
||||
|
||||
evalData();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Hessenberg<T>::evalData() {
|
||||
|
||||
const int rows = _H.sizeAt(0);
|
||||
|
||||
NDArray hhCoeffs(_H.ordering(), {rows - 1}, _H.dataType(), _H.getContext());
|
||||
|
||||
// calculate _H
|
||||
for(uint i = 0; i < rows - 1; ++i) {
|
||||
|
||||
T coeff, norm;
|
||||
|
||||
NDArray tail1 = _H({i+1,-1, i,i+1});
|
||||
NDArray tail2 = _H({i+2,-1, i,i+1}, true);
|
||||
|
||||
Householder<T>::evalHHmatrixDataI(tail1, coeff, norm);
|
||||
|
||||
_H({0,0, i,i+1}). template r<T>(i+1) = norm;
|
||||
hhCoeffs. template r<T>(i) = coeff;
|
||||
|
||||
NDArray bottomRightCorner = _H({i+1,-1, i+1,-1}, true);
|
||||
Householder<T>::mulLeft(bottomRightCorner, tail2, coeff);
|
||||
|
||||
NDArray rightCols = _H({0,0, i+1,-1}, true);
|
||||
Householder<T>::mulRight(rightCols, tail2.transpose(), coeff);
|
||||
}
|
||||
|
||||
// calculate _Q
|
||||
HHsequence hhSeq(_H, hhCoeffs, 'u');
|
||||
hhSeq._diagSize = rows - 1;
|
||||
hhSeq._shift = 1;
|
||||
hhSeq.applyTo_<T>(_Q);
|
||||
|
||||
// fill down with zeros starting at first subdiagonal
|
||||
_H.fillAsTriangular<T>(0, -1, 0, _H, 'l');
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
Schur<T>::Schur(const NDArray& matrix) {
|
||||
|
||||
if(matrix.rankOf() != 2)
|
||||
throw std::runtime_error("ops::helpers::Schur constructor: input matrix must be 2D !");
|
||||
|
||||
if(matrix.sizeAt(0) != matrix.sizeAt(1))
|
||||
throw std::runtime_error("ops::helpers::Schur constructor: input array must be 2D square matrix !");
|
||||
|
||||
evalData(matrix);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Schur<T>::evalData(const NDArray& matrix) {
|
||||
|
||||
const T scale = matrix.reduceNumber(reduce::AMax).template t<T>(0);
|
||||
|
||||
const T almostZero = DataTypeUtils::min<T>();
|
||||
|
||||
if(scale < DataTypeUtils::min<T>()) {
|
||||
|
||||
_T = matrix.ulike();
|
||||
_U = matrix.ulike();
|
||||
|
||||
_T.nullify();
|
||||
_U.setIdentity();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// perform Hessenberg decomposition
|
||||
Hessenberg<T> hess(matrix / scale);
|
||||
|
||||
_T = std::move(hess._H);
|
||||
_U = std::move(hess._Q);
|
||||
|
||||
calcFromHessenberg();
|
||||
|
||||
_T *= scale;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void Schur<T>::splitTwoRows(const int ind, const T shift) {
|
||||
|
||||
const int numCols = _T.sizeAt(1);
|
||||
|
||||
T p = (T)0.5 * (_T.t<T>(ind-1, ind-1) - _T.t<T>(ind, ind));
|
||||
|
||||
T q = p*p + _T.t<T>(ind, ind-1) * _T.t<T>(ind-1, ind);
|
||||
|
||||
_T.r<T>(ind, ind) += shift;
|
||||
_T.r<T>(ind-1, ind-1) += shift;
|
||||
|
||||
if (q >= (T)0) {
|
||||
|
||||
T z = math::nd4j_sqrt<T,T>(math::nd4j_abs<T>(q));
|
||||
|
||||
NDArray rotation(_T.ordering(), {2, 2}, _T.dataType(), _T.getContext());
|
||||
|
||||
if (p >= (T)0)
|
||||
JacobiSVD<T>::createJacobiRotationGivens(p+z, _T.t<T>(ind, ind-1), rotation);
|
||||
else
|
||||
JacobiSVD<T>::createJacobiRotationGivens(p-z, _T.t<T>(ind, ind-1), rotation);
|
||||
|
||||
NDArray rightCols = _T({0,0, ind-1,-1});
|
||||
JacobiSVD<T>::mulRotationOnLeft(ind-1, ind, rightCols, rotation.transpose());
|
||||
|
||||
NDArray topRows = _T({0,ind+1, 0,0});
|
||||
JacobiSVD<T>::mulRotationOnRight(ind-1, ind, topRows, rotation);
|
||||
|
||||
JacobiSVD<T>::mulRotationOnRight(ind-1, ind, _U, rotation);
|
||||
|
||||
_T.r<T>(ind, ind-1) = (T)0;
|
||||
}
|
||||
|
||||
if (ind > 1)
|
||||
_T.r<T>(ind-1, ind-2) = (T)0;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void Schur<T>::calcShift(const int ind, const int iter, T& shift, NDArray& shiftVec) {
|
||||
|
||||
// shiftVec has length = 3
|
||||
|
||||
shiftVec.r<T>(0) = _T.t<T>(ind, ind);
|
||||
shiftVec.r<T>(1) = _T.t<T>(ind-1, ind-1);
|
||||
shiftVec.r<T>(2) = _T.t<T>(ind, ind-1) * _T.t<T>(ind-1, ind);
|
||||
|
||||
if (iter == 10) {
|
||||
shift += shiftVec.t<T>(0);
|
||||
|
||||
for (int i = 0; i <= ind; ++i)
|
||||
_T.r<T>(i,i) -= shiftVec.t<T>(0);
|
||||
|
||||
T s = math::nd4j_abs<T>(_T.t<T>(ind, ind-1)) + math::nd4j_abs<T>(_T.t<T>(ind-1, ind-2));
|
||||
|
||||
shiftVec.r<T>(0) = T(0.75) * s;
|
||||
shiftVec.r<T>(1) = T(0.75) * s;
|
||||
shiftVec.r<T>(2) = T(-0.4375) * s*s;
|
||||
}
|
||||
|
||||
if (iter == 30) {
|
||||
|
||||
T s = (shiftVec.t<T>(1) - shiftVec.t<T>(0)) / T(2.0);
|
||||
s = s*s + shiftVec.t<T>(2);
|
||||
|
||||
if (s > T(0)) {
|
||||
|
||||
s = math::nd4j_sqrt<T,T>(s);
|
||||
|
||||
if (shiftVec.t<T>(1) < shiftVec.t<T>(0))
|
||||
s = -s;
|
||||
|
||||
s = s + (shiftVec.t<T>(1) - shiftVec.t<T>(0)) / T(2.0);
|
||||
s = shiftVec.t<T>(0) - shiftVec.t<T>(2) / s;
|
||||
shift += s;
|
||||
|
||||
for (int i = 0; i <= ind; ++i)
|
||||
_T.r<T>(i,i) -= s;
|
||||
|
||||
shiftVec = T(0.964);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void Schur<T>::initFrancisQR(const int ind1, const int ind2, const NDArray& shiftVec, int& ind3, NDArray& householderVec) {
|
||||
|
||||
// shiftVec has length = 3
|
||||
|
||||
for (ind3 = ind2-2; ind3 >= ind1; --ind3) {
|
||||
|
||||
const T mm = _T.t<T>(ind3, ind3);
|
||||
const T r = shiftVec.t<T>(0) - mm;
|
||||
const T s = shiftVec.t<T>(1) - mm;
|
||||
|
||||
householderVec.r<T>(0) = (r * s - shiftVec.t<T>(2)) / _T.t<T>(ind3+1, ind3) + _T.t<T>(ind3, ind3+1);
|
||||
householderVec.r<T>(1) = _T.t<T>(ind3+1, ind3+1) - mm - r - s;
|
||||
householderVec.r<T>(2) = _T.t<T>(ind3+2, ind3+1);
|
||||
|
||||
if (ind3 == ind1)
|
||||
break;
|
||||
|
||||
const T lhs = _T.t<T>(ind3,ind3-1) * (math::nd4j_abs<T>(householderVec.t<T>(1)) + math::nd4j_abs<T>(householderVec.t<T>(2)));
|
||||
const T rhs = householderVec.t<T>(0) * (math::nd4j_abs<T>(_T.t<T>(ind3-1, ind3-1)) + math::nd4j_abs<T>(mm) + math::nd4j_abs<T>(_T.t<T>(ind3+1, ind3+1)));
|
||||
|
||||
if(math::nd4j_abs<T>(lhs) < DataTypeUtils::eps<T>() * rhs)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void Schur<T>::doFrancisQR(const int ind1, const int ind2, const int ind3, const NDArray& householderVec) {
|
||||
|
||||
if(!(ind2 >= ind1))
|
||||
throw std::runtime_error("ops::helpers::Schur:doFrancisQR: wrong input indexes, condition ind2 >= ind1 must be true !");
|
||||
if(!(ind2 <= ind3-2))
|
||||
throw std::runtime_error("ops::helpers::Schur:doFrancisQR: wrong input indexes, condition iind2 <= ind3-2 must be true !");
|
||||
|
||||
const int numCols = _T.sizeAt(1);
|
||||
|
||||
for (int k = ind2; k <= ind3-2; ++k) {
|
||||
|
||||
const bool firstIter = (k == ind2);
|
||||
|
||||
T coeff, normX;
|
||||
NDArray tail(_T.ordering(), {2, 1}, _T.dataType(), _T.getContext());
|
||||
Householder<T>::evalHHmatrixData(firstIter ? householderVec : _T({k,k+3, k-1,k}), tail, coeff, normX);
|
||||
|
||||
if (normX != T(0)) {
|
||||
|
||||
if (firstIter && k > ind1)
|
||||
_T.r<T>(k, k-1) = -_T.t<T>(k, k-1);
|
||||
else if (!firstIter)
|
||||
_T.r<T>(k, k-1) = normX;
|
||||
|
||||
NDArray block1 = _T({k,k+3, k,numCols}, true);
|
||||
Householder<T>::mulLeft(block1, tail, coeff);
|
||||
|
||||
NDArray block2 = _T({0,math::nd4j_min<int>(ind3,k+3)+1, k,k+3}, true);
|
||||
Householder<T>::mulRight(block2, tail, coeff);
|
||||
|
||||
NDArray block3 = _U({0,numCols, k,k+3}, true);
|
||||
Householder<T>::mulRight(block3, tail, coeff);
|
||||
}
|
||||
}
|
||||
|
||||
T coeff, normX;
|
||||
NDArray tail(_T.ordering(), {1, 1}, _T.dataType(), _T.getContext());
|
||||
Householder<T>::evalHHmatrixData(_T({ind3-1,ind3+1, ind3-2,ind3-1}), tail, coeff, normX);
|
||||
|
||||
if (normX != T(0)) {
|
||||
|
||||
_T.r<T>(ind3-1, ind3-2) = normX;
|
||||
|
||||
NDArray block1 = _T({ind3-1,ind3+1, ind3-1,numCols}, true);
|
||||
Householder<T>::mulLeft(block1, tail, coeff);
|
||||
|
||||
NDArray block2 = _T({0,ind3+1, ind3-1,ind3+1}, true);
|
||||
Householder<T>::mulRight(block2, tail, coeff);
|
||||
|
||||
NDArray block3 = _U({0,numCols, ind3-1,ind3+1}, true);
|
||||
Householder<T>::mulRight(block3, tail, coeff);
|
||||
}
|
||||
|
||||
for (int i = ind2+2; i <= ind3; ++i) {
|
||||
_T.r<T>(i, i-2) = T(0);
|
||||
if (i > ind2+2)
|
||||
_T.r<T>(i, i-3) = T(0);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void Schur<T>::calcFromHessenberg() {
|
||||
|
||||
const int maxIters = _maxItersPerRow * _T.sizeAt(0);
|
||||
|
||||
const int numCols = _T.sizeAt(1);
|
||||
int iu = numCols - 1;
|
||||
int iter = 0;
|
||||
int totalIter = 0;
|
||||
|
||||
T shift = T(0);
|
||||
|
||||
T norm = 0;
|
||||
for (int j = 0; j < numCols; ++j)
|
||||
norm += _T({0,math::nd4j_min<int>(numCols,j+2), j,j+1}).reduceNumber(reduce::ASum).template t<T>(0);
|
||||
|
||||
if(norm != T(0)) {
|
||||
|
||||
while (iu >= 0) {
|
||||
|
||||
const int il = getSmallSubdiagEntry(iu);
|
||||
|
||||
if (il == iu) {
|
||||
|
||||
_T.r<T>(iu,iu) = _T.t<T>(iu,iu) + shift;
|
||||
if (iu > 0)
|
||||
_T.r<T>(iu, iu-1) = T(0);
|
||||
iu--;
|
||||
iter = 0;
|
||||
|
||||
}
|
||||
else if (il == iu-1) {
|
||||
|
||||
splitTwoRows(iu, shift);
|
||||
iu -= 2;
|
||||
iter = 0;
|
||||
}
|
||||
else {
|
||||
|
||||
NDArray householderVec(_T.ordering(), {3}, _T.dataType(), _T.getContext());
|
||||
NDArray shiftVec (_T.ordering(), {3}, _T.dataType(), _T.getContext());
|
||||
|
||||
calcShift(iu, iter, shift, shiftVec);
|
||||
|
||||
++iter;
|
||||
++totalIter;
|
||||
|
||||
if (totalIter > maxIters)
|
||||
break;
|
||||
|
||||
int im;
|
||||
initFrancisQR(il, iu, shiftVec, im, householderVec);
|
||||
doFrancisQR(il, im, iu, householderVec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template class ND4J_EXPORT Hessenberg<float>;
|
||||
template class ND4J_EXPORT Hessenberg<float16>;
|
||||
template class ND4J_EXPORT Hessenberg<bfloat16>;
|
||||
template class ND4J_EXPORT Hessenberg<double>;
|
||||
|
||||
template class ND4J_EXPORT Schur<float>;
|
||||
template class ND4J_EXPORT Schur<float16>;
|
||||
template class ND4J_EXPORT Schur<bfloat16>;
|
||||
template class ND4J_EXPORT Schur<double>;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -207,7 +207,7 @@ sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::ND
|
|||
const bool isBVector = shape::isCommonVector(B->shapeInfo(), lenDim);
|
||||
|
||||
// dot product of 2 vectors
|
||||
if(isAVector && isBVector && (aRank != 2 || aRank == 2 && (A->isSameShape(B) || bRank == 1 && A->sizeAt(1) == 1))) // (1x1x1 * 1x1) or (1x4 * 1*4) or (4x1 * 4x1) or (4x1 * 4)
|
||||
if(A->lengthOf() == B->lengthOf() && isAVector && isBVector && (aRank != 2 || aRank == 2 && (A->isSameShape(B) || bRank == 1 && A->sizeAt(1) == 1))) // (1x1x1 * 1x1) or (1x4 * 1*4) or (4x1 * 4x1) or (4x1 * 4)
|
||||
return dot(A, B, C, alpha, beta);
|
||||
|
||||
// matrix x matrix
|
||||
|
|
|
@ -0,0 +1,276 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <helpers/Sqrtm.h>
|
||||
#include <ops/declarable/helpers/lup.h>
|
||||
#include <helpers/EigenValsAndVecs.h>
|
||||
#include <helpers/HessenbergAndSchur.h>
|
||||
#include <helpers/FullPivLU.h>
|
||||
#include <helpers/MmulHelper.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void sqrtmQuasiTrianDiag(const NDArray& matrixT, NDArray& sqrtT ) {
|
||||
|
||||
const int rows = matrixT.sizeAt(0);
|
||||
|
||||
for(int i = 0; i < rows; i++) {
|
||||
|
||||
if (i == rows - 1 || matrixT.t<T>(i+1, i) == (T)0) {
|
||||
const auto elemT = matrixT.t<T>(i, i);
|
||||
if(elemT < (T)0)
|
||||
throw std::runtime_error("ops::helpers::Sqrtm::sqrtmQuasiTrianDiag: can't take sqrt of negative diagonal element of T matrix !");
|
||||
sqrtT.r<T>(i,i) = math::nd4j_sqrt<T,T>(elemT);
|
||||
}
|
||||
else {
|
||||
|
||||
EigenValsAndVecs<T> es(matrixT({i,i+2, i,i+2}, true)); // es._Vecs {2,2,2}, es._Vals{2,2}
|
||||
|
||||
const NDArray& vecs = es._Vecs;
|
||||
const NDArray& vals = es._Vals;
|
||||
|
||||
const T& vecsReal00 = vecs.t<T>(0,0,0);
|
||||
const T& vecsImag00 = vecs.t<T>(0,0,1);
|
||||
const T& vecsReal01 = vecs.t<T>(0,1,0);
|
||||
const T& vecsImag01 = vecs.t<T>(0,1,1);
|
||||
const T& vecsReal10 = vecs.t<T>(1,0,0);
|
||||
const T& vecsImag10 = vecs.t<T>(1,0,1);
|
||||
const T& vecsReal11 = vecs.t<T>(1,1,0);
|
||||
const T& vecsImag11 = vecs.t<T>(1,1,1);
|
||||
|
||||
// es.eigenvalues().cwiseSqrt().asDiagonal()
|
||||
T eigenValsSqrt[2][2];
|
||||
eigenValsSqrt[0][0] = vals.t<T>(0,0);
|
||||
eigenValsSqrt[0][1] = vals.t<T>(0,1);
|
||||
eigenValsSqrt[1][0] = vals.t<T>(1,0);
|
||||
eigenValsSqrt[1][1] = vals.t<T>(1,1);
|
||||
EigenValsAndVecs<T>::sqrtComplexNum(eigenValsSqrt[0][0], eigenValsSqrt[0][1]);
|
||||
EigenValsAndVecs<T>::sqrtComplexNum(eigenValsSqrt[1][0], eigenValsSqrt[1][1]);
|
||||
|
||||
// es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal()
|
||||
T vecsElem[2][2][2];
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal00,vecsImag00, eigenValsSqrt[0][0],eigenValsSqrt[0][1], vecsElem[0][0][0],vecsElem[0][0][1]);
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal01,vecsImag01, eigenValsSqrt[1][0],eigenValsSqrt[1][1], vecsElem[0][1][0],vecsElem[0][1][1]);
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal10,vecsImag10, eigenValsSqrt[0][0],eigenValsSqrt[0][1], vecsElem[1][0][0],vecsElem[1][0][1]);
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal11,vecsImag11, eigenValsSqrt[1][0],eigenValsSqrt[1][1], vecsElem[1][1][0],vecsElem[1][1][1]);
|
||||
|
||||
// es.eigenvectors().inverse()
|
||||
T vecsElemInv[2][2][2];
|
||||
|
||||
T tempReal, tempImag, divisorReal, divisorImag;
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal00,vecsImag00, vecsReal11,vecsImag11, divisorReal,divisorImag);
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsReal01,vecsImag01, vecsReal10,vecsImag10, tempReal,tempImag);
|
||||
divisorReal -= tempReal;
|
||||
divisorImag -= tempImag;
|
||||
|
||||
EigenValsAndVecs<T>::divideComplexNums(vecsReal11,vecsImag11, divisorReal,divisorImag, vecsElemInv[0][0][0],vecsElemInv[0][0][1]);
|
||||
EigenValsAndVecs<T>::divideComplexNums(-vecsReal01,-vecsImag01, divisorReal,divisorImag, vecsElemInv[0][1][0],vecsElemInv[0][1][1]);
|
||||
EigenValsAndVecs<T>::divideComplexNums(-vecsReal10,-vecsImag10, divisorReal,divisorImag, vecsElemInv[1][0][0],vecsElemInv[1][0][1]);
|
||||
EigenValsAndVecs<T>::divideComplexNums(vecsReal00,vecsImag00, divisorReal,divisorImag, vecsElemInv[1][1][0],vecsElemInv[1][1][1]);
|
||||
|
||||
// result
|
||||
T result[2][2][2];
|
||||
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][0][0],vecsElem[0][0][1], vecsElemInv[0][0][0],vecsElemInv[0][0][1], tempReal,tempImag);
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][1][0],vecsElem[0][1][1], vecsElemInv[1][0][0],vecsElemInv[1][0][1], result[0][0][0],result[0][0][1]);
|
||||
result[0][0][0] += tempReal;
|
||||
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][0][0],vecsElem[0][0][1], vecsElemInv[0][1][0],vecsElemInv[0][1][1], tempReal,tempImag);
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[0][1][0],vecsElem[0][1][1], vecsElemInv[1][1][0],vecsElemInv[1][1][1], result[0][1][0],result[0][1][1]);
|
||||
result[0][1][0] += tempReal;
|
||||
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][0][0],vecsElem[1][0][1], vecsElemInv[0][0][0],vecsElemInv[0][0][1], tempReal,tempImag);
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][1][0],vecsElem[1][1][1], vecsElemInv[1][0][0],vecsElemInv[1][0][1], result[1][0][0],result[1][0][1]);
|
||||
result[1][0][0] += tempReal;
|
||||
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][0][0],vecsElem[1][0][1], vecsElemInv[0][1][0],vecsElemInv[0][1][1], tempReal,tempImag);
|
||||
EigenValsAndVecs<T>::multiplyComplexNums(vecsElem[1][1][0],vecsElem[1][1][1], vecsElemInv[1][1][0],vecsElemInv[1][1][1], result[1][1][0],result[1][1][1]);
|
||||
result[1][1][0] += tempReal;
|
||||
|
||||
sqrtT.r<T>(i,i) = result[0][0][0];
|
||||
sqrtT.r<T>(i,i+1) = result[0][1][0];
|
||||
sqrtT.r<T>(i+1,i) = result[1][0][0];
|
||||
sqrtT.r<T>(i+1,i+1) = result[1][1][0];
|
||||
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// all matrices are {2,2} here
|
||||
template <typename T>
|
||||
static void sqrtmQuasiTrianAuxEq(const NDArray& A, const NDArray& B, const NDArray& C, NDArray& X) {
|
||||
|
||||
NDArray tempMatrix(A.ordering(), {4,4}, A.dataType(), A.getContext());
|
||||
|
||||
tempMatrix.r<T>(0,0) = A.t<T>(0,0) + B.t<T>(0,0);
|
||||
tempMatrix.r<T>(1,1) = A.t<T>(0,0) + B.t<T>(1,1);
|
||||
tempMatrix.r<T>(2,2) = A.t<T>(1,1) + B.t<T>(0,0);
|
||||
tempMatrix.r<T>(3,3) = A.t<T>(1,1) + B.t<T>(1,1);
|
||||
tempMatrix.r<T>(0,1) = B.t<T>(1,0);
|
||||
tempMatrix.r<T>(0,2) = A.t<T>(0,1);
|
||||
tempMatrix.r<T>(1,0) = B.t<T>(0,1);
|
||||
tempMatrix.r<T>(1,3) = A.t<T>(0,1);
|
||||
tempMatrix.r<T>(2,0) = A.t<T>(1,0);
|
||||
tempMatrix.r<T>(2,3) = B.t<T>(1,0);
|
||||
tempMatrix.r<T>(3,1) = A.t<T>(1,0);
|
||||
tempMatrix.r<T>(3,2) = B.t<T>(0,1);
|
||||
tempMatrix.r<T>(0,3) = (T)0;
|
||||
tempMatrix.r<T>(1,2) = (T)0;
|
||||
tempMatrix.r<T>(2,1) = (T)0;
|
||||
tempMatrix.r<T>(3,0) = (T)0;
|
||||
|
||||
NDArray result(A.ordering(), {4,1}, A.dataType(), A.getContext());
|
||||
result.r<T>(0,0) = C.t<T>(0,0);
|
||||
result.r<T>(1,0) = C.t<T>(0,1);
|
||||
result.r<T>(2,0) = C.t<T>(1,0);
|
||||
result.r<T>(3,0) = C.t<T>(1,1);
|
||||
|
||||
FullPivLU<T>::solve(tempMatrix, result, result);
|
||||
|
||||
X.r<T>(0,0) = result.t<T>(0);
|
||||
X.r<T>(0,1) = result.t<T>(1);
|
||||
X.r<T>(1,0) = result.t<T>(2);
|
||||
X.r<T>(1,1) = result.t<T>(3);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void sqrtmQuasiTrianOffDiag(const NDArray& matrixT, NDArray& sqrtT ) {
|
||||
|
||||
const int rows = matrixT.sizeAt(0);
|
||||
|
||||
for (int j = 1; j < rows; j++) {
|
||||
|
||||
if (matrixT.t<T>(j, j-1) != (T)0)
|
||||
continue;
|
||||
|
||||
for (int i = j - 1; i >= 0; i--) {
|
||||
|
||||
if (i > 0 && matrixT.t<T>(i, i-1) != (T)0)
|
||||
continue;
|
||||
|
||||
const bool iBlockIs2x2 = (i < rows - 1) && (matrixT.t<T>(i+1, i) != (T)0);
|
||||
const bool jBlockIs2x2 = (j < rows - 1) && (matrixT.t<T>(j+1, j) != (T)0);
|
||||
|
||||
if (iBlockIs2x2 && jBlockIs2x2) {
|
||||
|
||||
NDArray A = sqrtT({i,i+2, i,i+2}, true);
|
||||
NDArray B = sqrtT({j,j+2, j,j+2}, true);
|
||||
NDArray X = matrixT({i,i+2, j,j+2}, true);//.dup();
|
||||
|
||||
if (j - i > 2)
|
||||
X -= mmul(sqrtT({i,i+2, i+2,j}, true), sqrtT({i+2,j, j,j+2}, true));
|
||||
|
||||
sqrtmQuasiTrianAuxEq<T>(A, B, X, X);
|
||||
|
||||
sqrtT.syncToDevice();
|
||||
sqrtT({i,i+2, j,j+2}, true).assign(X);
|
||||
}
|
||||
else if (iBlockIs2x2 && !jBlockIs2x2) {
|
||||
|
||||
NDArray rhs = matrixT({i,i+2, j,j+1}, true);//.dup();
|
||||
|
||||
if (j - i > 2)
|
||||
rhs -= mmul(sqrtT({i,i+2, i+2,j}, true), sqrtT({i+2,j, j,j+1}, true));
|
||||
|
||||
NDArray A(matrixT.ordering(), {2,2}, matrixT.dataType(), matrixT.getContext());
|
||||
A.r<T>(0,0) = A.r<T>(1,1) = sqrtT.t<T>(j,j);
|
||||
A.r<T>(0,1) = A.r<T>(1,0) = T(0);
|
||||
A += sqrtT({i,i+2, i,i+2}, true);
|
||||
|
||||
FullPivLU<T>::solve(A,rhs,rhs);
|
||||
|
||||
// sqrtT.syncToDevice();
|
||||
sqrtT({i,i+2, j,j+1}, true).assign(rhs);
|
||||
}
|
||||
else if (!iBlockIs2x2 && jBlockIs2x2) {
|
||||
|
||||
NDArray rhs = matrixT({i,i+1, j,j+2}, true);//.dup();
|
||||
|
||||
if (j - i > 1)
|
||||
rhs -= mmul(sqrtT({i,i+1, i+1,j}, true), sqrtT({i+1,j, j,j+2}, true));
|
||||
|
||||
NDArray A(matrixT.ordering(), {2,2}, matrixT.dataType(), matrixT.getContext());
|
||||
A.r<T>(0,0) = A.r<T>(1,1) = sqrtT.t<T>(i,i);
|
||||
A.r<T>(0,1) = A.r<T>(1,0) = T(0);
|
||||
A += sqrtT({j,j+2, j,j+2}, true).transpose();
|
||||
|
||||
NDArray rhsT = rhs.transpose();
|
||||
FullPivLU<T>::solve(A,rhsT,rhsT);
|
||||
|
||||
// sqrtT.syncToDevice();
|
||||
sqrtT({i,i+1, j,j+2}, true).assign(rhs);
|
||||
}
|
||||
else if (!iBlockIs2x2 && !jBlockIs2x2) {
|
||||
|
||||
T temp = mmul(sqrtT({i,i+1, i+1,j}), sqrtT({i+1,j, j,j+1})).t<T>(0); // dot
|
||||
sqrtT.r<T>(i,j) = (matrixT.t<T>(i,j) - temp ) / (sqrtT.t<T>(i,i) + sqrtT.t<T>(j,j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Sqrtm<T>::calc(const NDArray& in, NDArray& out) {
|
||||
|
||||
if(in.rankOf() != 2 || in.sizeAt(0) != in.sizeAt(1))
|
||||
throw std::runtime_error("ops::helpers::Sqrtm::calc: input matrix must have rank 2 and be square !");
|
||||
if(!out.isSameShape(in))
|
||||
throw std::runtime_error("ops::helpers::Sqrtm::calc: output matrix must have the same shape as input one!");
|
||||
|
||||
if(in.lengthOf() == 1) {
|
||||
out.r<T>(0) = math::nd4j_sqrt<T,T>(in.t<T>(0));
|
||||
return;
|
||||
}
|
||||
|
||||
ops::helpers::Schur<T> schur(in);
|
||||
|
||||
const NDArray& t1 = schur._T;
|
||||
const NDArray& t2 = schur._U;
|
||||
|
||||
NDArray sqrtT = in.ulike();
|
||||
sqrtT.nullify();
|
||||
|
||||
sqrtmQuasiTrianDiag<T>(schur._T, sqrtT);
|
||||
sqrtmQuasiTrianOffDiag<T>(schur._T, sqrtT);
|
||||
|
||||
// out = U * sqrtT * U^T;
|
||||
NDArray temp = mmul(sqrtT, schur._U.transpose());
|
||||
MmulHelper::mmul(&schur._U, &temp, &out);
|
||||
}
|
||||
|
||||
template class ND4J_EXPORT Sqrtm<float>;
|
||||
template class ND4J_EXPORT Sqrtm<float16>;
|
||||
template class ND4J_EXPORT Sqrtm<bfloat16>;
|
||||
template class ND4J_EXPORT Sqrtm<double>;
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,160 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by Yurii Shyrma on 18.12.2017
|
||||
//
|
||||
|
||||
|
||||
#include <helpers/householder.h>
|
||||
#include <helpers/biDiagonalUp.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
BiDiagonalUp::BiDiagonalUp(const NDArray& matrix): _HHmatrix(NDArray(matrix.ordering(), {matrix.sizeAt(0), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext())),
|
||||
_HHbidiag(NDArray(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext())) {
|
||||
|
||||
// input validation
|
||||
if(matrix.rankOf() != 2 || matrix.isScalar())
|
||||
throw std::runtime_error("ops::helpers::biDiagonalizeUp constructor: input array must be 2D matrix !");
|
||||
|
||||
_HHmatrix.assign(&matrix);
|
||||
_HHbidiag.assign(0.);
|
||||
|
||||
evalData();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BiDiagonalUp::_evalData() {
|
||||
|
||||
const auto rows = _HHmatrix.sizeAt(0);
|
||||
const auto cols = _HHmatrix.sizeAt(1);
|
||||
|
||||
if(rows < cols)
|
||||
throw std::runtime_error("ops::helpers::BiDiagonalizeUp::evalData method: this procedure is applicable only for input matrix with rows >= cols !");
|
||||
|
||||
T coeff, normX;
|
||||
|
||||
T x, y;
|
||||
|
||||
for(Nd4jLong i = 0; i < cols-1; ++i ) {
|
||||
|
||||
// evaluate Householder matrix nullifying columns
|
||||
NDArray column1 = _HHmatrix({i,rows, i,i+1});
|
||||
|
||||
x = _HHmatrix.t<T>(i,i);
|
||||
y = _HHbidiag.t<T>(i,i);
|
||||
|
||||
Householder<T>::evalHHmatrixDataI(column1, x, y);
|
||||
|
||||
_HHmatrix.r<T>(i, i) = x;
|
||||
_HHbidiag.r<T>(i, i) = y;
|
||||
|
||||
// multiply corresponding matrix block on householder matrix from the left: P * bottomRightCorner
|
||||
NDArray bottomRightCorner1 = _HHmatrix({i,rows, i+1,cols}, true); // {i, cols}
|
||||
Householder<T>::mulLeft(bottomRightCorner1, _HHmatrix({i+1,rows, i,i+1}, true), _HHmatrix.t<T>(i,i));
|
||||
|
||||
if(i == cols-2)
|
||||
continue; // do not apply right multiplying at last iteration
|
||||
|
||||
// evaluate Householder matrix nullifying rows
|
||||
NDArray row1 = _HHmatrix({i,i+1, i+1,cols});
|
||||
|
||||
x = _HHmatrix.t<T>(i,i+1);
|
||||
y = _HHbidiag.t<T>(i,i+1);
|
||||
|
||||
Householder<T>::evalHHmatrixDataI(row1, x, y);
|
||||
|
||||
_HHmatrix.r<T>(i, i+1) = x;
|
||||
_HHbidiag.r<T>(i, i+1) = y;
|
||||
|
||||
// multiply corresponding matrix block on householder matrix from the right: bottomRightCorner * P
|
||||
NDArray bottomRightCorner2 = _HHmatrix({i+1,rows, i+1,cols}, true); // {i, rows}
|
||||
|
||||
Householder<T>::mulRight(bottomRightCorner2, _HHmatrix({i,i+1, i+2,cols}, true), _HHmatrix.t<T>(i,i+1));
|
||||
}
|
||||
|
||||
NDArray row2 =_HHmatrix({cols-2,cols-1, cols-1,cols});
|
||||
|
||||
x = _HHmatrix.t<T>(cols-2,cols-1);
|
||||
y = _HHbidiag.t<T>(cols-2,cols-1);
|
||||
|
||||
Householder<T>::evalHHmatrixDataI(row2, x, y);
|
||||
|
||||
_HHmatrix.r<T>(cols-2,cols-1) = x;
|
||||
_HHbidiag.r<T>(cols-2,cols-1) = y;
|
||||
|
||||
NDArray column2 = _HHmatrix({cols-1,rows, cols-1,cols});
|
||||
|
||||
x = _HHmatrix.t<T>(cols-1,cols-1);
|
||||
y = _HHbidiag.t<T>(cols-1,cols-1);
|
||||
|
||||
Householder<T>::evalHHmatrixDataI(column2, x, y);
|
||||
|
||||
_HHmatrix.r<T>(cols-1, cols-1) = x;
|
||||
_HHbidiag.r<T>(cols-1, cols-1) = y;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void BiDiagonalUp::evalData() {
|
||||
auto xType = _HHmatrix.dataType();
|
||||
BUILD_SINGLE_SELECTOR(xType, _evalData, ();, FLOAT_TYPES);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
HHsequence BiDiagonalUp::makeHHsequence_(const char type) {
|
||||
|
||||
const int diagSize = type == 'u' ? _HHbidiag.sizeAt(0) : _HHbidiag.sizeAt(0) - 1;
|
||||
|
||||
_hhCoeffs = NDArray(_HHmatrix.ordering(), {diagSize}, _HHmatrix.dataType(), _HHmatrix.getContext());
|
||||
|
||||
if(type == 'u')
|
||||
for(int i = 0; i < diagSize; ++i)
|
||||
_hhCoeffs.r<T>(i) = _HHmatrix.t<T>(i,i);
|
||||
else
|
||||
for(int i = 0; i < diagSize; ++i)
|
||||
_hhCoeffs.r<T>(i) = _HHmatrix.t<T>(i,i+1);
|
||||
|
||||
HHsequence result(_HHmatrix, _hhCoeffs, type);
|
||||
|
||||
if(type != 'u') {
|
||||
result._diagSize = diagSize;
|
||||
result._shift = 1;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
HHsequence BiDiagonalUp::makeHHsequence(const char type) {
|
||||
auto xType = _HHmatrix.dataType();
|
||||
BUILD_SINGLE_SELECTOR(xType, return makeHHsequence_, (type);, FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void BiDiagonalUp::_evalData, (), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template HHsequence BiDiagonalUp::makeHHsequence_, (const char type), FLOAT_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by Yurii Shyrma on 11.01.2018
|
||||
//
|
||||
|
||||
#include <helpers/hhColPivQR.h>
|
||||
#include <helpers/householder.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
HHcolPivQR::HHcolPivQR(const NDArray& matrix) {
|
||||
|
||||
_qr = matrix.dup();
|
||||
_diagSize = math::nd4j_min<int>(matrix.sizeAt(0), matrix.sizeAt(1));
|
||||
_coeffs = NDArray(matrix.ordering(), {1, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
|
||||
_permut = NDArray(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext());
|
||||
|
||||
evalData();
|
||||
}
|
||||
|
||||
void HHcolPivQR::evalData() {
|
||||
BUILD_SINGLE_SELECTOR(_qr.dataType(), _evalData, (), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void HHcolPivQR::_evalData() {
|
||||
|
||||
const int rows = _qr.sizeAt(0);
|
||||
const int cols = _qr.sizeAt(1);
|
||||
|
||||
NDArray transp(_qr.ordering(), {cols}/*{1, cols}*/, _qr.dataType(), _qr.getContext());
|
||||
NDArray normsUpd(_qr.ordering(), {cols}/*{1, cols}*/, _qr.dataType(), _qr.getContext());
|
||||
NDArray normsDir(_qr.ordering(), {cols}/*{1, cols}*/, _qr.dataType(), _qr.getContext());
|
||||
|
||||
int transpNum = 0;
|
||||
|
||||
for (int k = 0; k < cols; ++k)
|
||||
normsDir.r<T>(k) = normsUpd.r<T>(k) = _qr({0,0, k,k+1}).reduceNumber(reduce::Norm2).t<T>(0);
|
||||
|
||||
T normScaled = (normsUpd.reduceNumber(reduce::Max)).t<T>(0) * DataTypeUtils::eps<T>();
|
||||
T threshold1 = normScaled * normScaled / (T)rows;
|
||||
T threshold2 = math::nd4j_sqrt<T,T>(DataTypeUtils::eps<T>());
|
||||
|
||||
T nonZeroPivots = _diagSize;
|
||||
T maxPivot = 0.;
|
||||
|
||||
for(int k = 0; k < _diagSize; ++k) {
|
||||
|
||||
int biggestColIndex = normsUpd({k,-1}).indexReduceNumber(indexreduce::IndexMax).e<int>(0);
|
||||
T biggestColNorm = normsUpd({k,-1}).reduceNumber(reduce::Max).t<T>(0);
|
||||
T biggestColSqNorm = biggestColNorm * biggestColNorm;
|
||||
biggestColIndex += k;
|
||||
|
||||
if(nonZeroPivots == (T)_diagSize && biggestColSqNorm < threshold1 * (T)(rows-k))
|
||||
nonZeroPivots = k;
|
||||
|
||||
transp.r<T>(k) = (T)biggestColIndex;
|
||||
|
||||
if(k != biggestColIndex) {
|
||||
|
||||
NDArray temp1(_qr({0,0, k,k+1}));
|
||||
NDArray temp2(_qr({0,0, biggestColIndex,biggestColIndex+1}));
|
||||
temp1.swapUnsafe(temp2);
|
||||
|
||||
math::nd4j_swap<T>(normsUpd.r<T>(k), normsUpd.r<T>(biggestColIndex));
|
||||
math::nd4j_swap<T>(normsDir.r<T>(k), normsDir.r<T>(biggestColIndex));
|
||||
|
||||
++transpNum;
|
||||
}
|
||||
|
||||
T normX, c;
|
||||
NDArray qrBlock = _qr({k,rows, k,k+1});
|
||||
Householder<T>::evalHHmatrixDataI(qrBlock, c, normX);
|
||||
|
||||
_coeffs.r<T>(k) = c;
|
||||
|
||||
_qr.r<T>(k,k) = normX;
|
||||
|
||||
T max = math::nd4j_abs<T>(normX);
|
||||
if(max > maxPivot)
|
||||
maxPivot = max;
|
||||
|
||||
if(k < rows && (k+1) < cols) {
|
||||
NDArray qrBlock = _qr({k,rows, k+1,cols}, true);
|
||||
NDArray tail = _qr({k+1,rows, k, k+1}, true);
|
||||
Householder<T>::mulLeft(qrBlock, tail, _coeffs.t<T>(k));
|
||||
}
|
||||
|
||||
for (int j = k + 1; j < cols; ++j) {
|
||||
|
||||
if (normsUpd.t<T>(j) != (T)0.f) {
|
||||
|
||||
T temp = math::nd4j_abs<T>(_qr.t<T>(k, j)) / normsUpd.t<T>(j);
|
||||
temp = ((T)1. + temp) * ((T)1. - temp);
|
||||
temp = temp < (T)0. ? (T)0. : temp;
|
||||
T temp2 = temp * normsUpd.t<T>(j) * normsUpd.t<T>(j) / (normsDir.t<T>(j)*normsDir.t<T>(j));
|
||||
|
||||
if (temp2 <= threshold2) {
|
||||
if(k+1 < rows && j < cols)
|
||||
normsDir.r<T>(j) = _qr({k+1,rows, j,j+1}).reduceNumber(reduce::Norm2).t<T>(0);
|
||||
|
||||
normsUpd.r<T>(j) = normsDir.t<T>(j);
|
||||
}
|
||||
else
|
||||
normsUpd.r<T>(j) = normsUpd.t<T>(j) * math::nd4j_sqrt<T, T>(temp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_permut.setIdentity();
|
||||
|
||||
for(int k = 0; k < _diagSize; ++k) {
|
||||
|
||||
int idx = transp.e<int>(k);
|
||||
NDArray temp1 = _permut({0,0, k, k+1});
|
||||
NDArray temp2 = _permut({0,0, idx,idx+1});
|
||||
temp1.swapUnsafe(temp2);
|
||||
}
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void HHcolPivQR::_evalData, (), FLOAT_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -20,7 +20,6 @@
|
|||
|
||||
#include <helpers/hhSequence.h>
|
||||
#include <helpers/householder.h>
|
||||
#include <array/NDArrayFactory.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
|
@ -37,32 +36,24 @@ HHsequence::HHsequence(const NDArray& vectors, const NDArray& coeffs, const char
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void HHsequence::_mulLeft(NDArray& matrix) {
|
||||
void HHsequence::mulLeft_(NDArray& matrix) {
|
||||
|
||||
const int rows = _vectors.sizeAt(0);
|
||||
const int cols = _vectors.sizeAt(1);
|
||||
const int inRows = matrix.sizeAt(0);
|
||||
|
||||
NDArray* block(nullptr);
|
||||
|
||||
for(int i = _diagSize - 1; i >= 0; --i) {
|
||||
|
||||
if(_type == 'u') {
|
||||
|
||||
block = new NDArray(matrix({inRows-rows+_shift+ i,inRows, 0,0}, true));
|
||||
T _x = _coeffs.e<T>(i);
|
||||
Householder<T>::mulLeft(*block, _vectors({i + 1 + _shift, rows, i, i+1}, true), _x);
|
||||
_coeffs.p<T>(i, _x);
|
||||
NDArray block = matrix({inRows-rows+_shift+ i,inRows, 0,0}, true);
|
||||
Householder<T>::mulLeft(block, _vectors({i + 1 + _shift, rows, i, i+1}, true), _coeffs.t<T>(i));
|
||||
}
|
||||
else {
|
||||
|
||||
block = new NDArray(matrix({inRows-cols+_shift+i,inRows, 0,0}, true));
|
||||
T _x = _coeffs.e<T>(i);
|
||||
Householder<T>::mulLeft(*block, _vectors({i, i+1, i + 1 + _shift, cols}, true), _x);
|
||||
_coeffs.p<T>(i, _x);
|
||||
NDArray block = matrix({inRows-cols+_shift+i,inRows, 0,0}, true);
|
||||
Householder<T>::mulLeft(block, _vectors({i, i+1, i + 1 + _shift, cols}, true), _coeffs.t<T>(i));
|
||||
}
|
||||
|
||||
delete block;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,15 +70,14 @@ NDArray HHsequence::getTail(const int idx) const {
|
|||
return _vectors({idx, idx+1, first, -1}, true);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void HHsequence::_applyTo(NDArray& dest) {
|
||||
void HHsequence::applyTo_(NDArray& dest) {
|
||||
|
||||
int size = _type == 'u' ? _vectors.sizeAt(0) : _vectors.sizeAt(1);
|
||||
|
||||
if(dest.rankOf() != 2 || (dest.sizeAt(0) != size && dest.sizeAt(1) != size))
|
||||
dest = NDArrayFactory::create(dest.ordering(), {size, size}, dest.dataType(), dest.getContext());
|
||||
dest = NDArray(dest.ordering(), {size, size}, dest.dataType(), dest.getContext());
|
||||
dest.setIdentity();
|
||||
|
||||
for(int k = _diagSize - 1; k >= 0; --k) {
|
||||
|
@ -96,29 +86,26 @@ void HHsequence::_applyTo(NDArray& dest) {
|
|||
if(curNum < 1 || (k + 1 + _shift) >= size )
|
||||
continue;
|
||||
auto block = dest({dest.sizeAt(0)-curNum,dest.sizeAt(0), dest.sizeAt(1)-curNum,dest.sizeAt(1)}, true);
|
||||
T _x = _coeffs.e<T>(k);
|
||||
|
||||
Householder<T>::mulLeft(block, getTail(k), _x);
|
||||
|
||||
_coeffs.p<T>(k, _x);
|
||||
Householder<T>::mulLeft(block, getTail(k), _coeffs.t<T>(k));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void HHsequence::applyTo(NDArray& dest) {
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void HHsequence::applyTo(NDArray& dest) {
|
||||
auto xType = _coeffs.dataType();
|
||||
BUILD_SINGLE_SELECTOR(xType, applyTo_, (dest), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, _applyTo, (dest), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
void HHsequence::mulLeft(NDArray& matrix) {
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void HHsequence::mulLeft(NDArray& matrix) {
|
||||
auto xType = _coeffs.dataType();
|
||||
BUILD_SINGLE_SELECTOR(xType, mulLeft_, (matrix), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, _mulLeft, (matrix), FLOAT_TYPES);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void HHsequence::applyTo_, (sd::NDArray &dest), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void HHsequence::mulLeft_, (NDArray& matrix), FLOAT_TYPES);
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void HHsequence::_applyTo, (sd::NDArray &dest), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void HHsequence::_mulLeft, (NDArray& matrix), FLOAT_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,218 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by Yurii Shyrma on 18.12.2017
|
||||
//
|
||||
|
||||
#include <helpers/householder.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// template <typename T>
|
||||
// NDArray Householder<T>::evalHHmatrix(const NDArray& x) {
|
||||
|
||||
// // input validation
|
||||
// if(x.rankOf() != 1 && !x.isScalar())
|
||||
// throw std::runtime_error("ops::helpers::Householder::evalHHmatrix method: iinput array must have rank = 1 or to be scalar!");
|
||||
|
||||
// const auto xLen = x.lengthOf();
|
||||
|
||||
// NDArray w(x.ordering(), {xLen, 1}, x.dataType(), x.getContext()); // column-vector
|
||||
|
||||
// NDArray xTail = xLen > 1 ? x({1,-1}) : NDArray();
|
||||
// T tailXnorm = xLen > 1 ? xTail.reduceNumber(reduce::SquaredNorm).t<T>(0) : (T)0;
|
||||
|
||||
// const auto xFirstElem = x.t<T>(0);
|
||||
|
||||
// T coeff, normX;
|
||||
|
||||
// if(tailXnorm <= DataTypeUtils::min<T>()) {
|
||||
|
||||
// normX = xFirstElem;
|
||||
// coeff = 0.f;
|
||||
// if(xLen > 1)
|
||||
// w({1,-1, 0,0}) = 0.f;
|
||||
// }
|
||||
// else {
|
||||
|
||||
// normX = math::nd4j_sqrt<T,T>(xFirstElem*xFirstElem + tailXnorm);
|
||||
|
||||
// if(xFirstElem >= (T)0.f)
|
||||
// normX = -normX; // choose opposite sign to lessen roundoff error
|
||||
|
||||
// coeff = (normX - xFirstElem) / normX;
|
||||
|
||||
// if(xLen > 1)
|
||||
// w({1,-1, 0,0}).assign(xTail / (xFirstElem - normX));
|
||||
// }
|
||||
|
||||
// w.t<T>(0) = (T)1;
|
||||
|
||||
// NDArray identity(x.ordering(), {xLen, xLen}, x.dataType(), x.getContext());
|
||||
// identity.setIdentity(); // identity matrix
|
||||
|
||||
// return identity - mmul(w, w.transpose()) * coeff;
|
||||
// }
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Householder<T>::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff, T& normX) {
|
||||
|
||||
// input validation
|
||||
if(x.rankOf() != 1 && !x.isScalar())
|
||||
throw std::runtime_error("ops::helpers::Householder::evalHHmatrixData method: input array must have rank = 1 or to be scalar!");
|
||||
|
||||
if(!x.isScalar() && x.lengthOf() != tail.lengthOf() + 1)
|
||||
throw std::runtime_error("ops::helpers::Householder::evalHHmatrixData method: input tail vector must have length less than unity compared to input x vector!");
|
||||
|
||||
const auto xLen = x.lengthOf();
|
||||
|
||||
const NDArray xTail = xLen > 1 ? x({1,-1}) : NDArray();
|
||||
|
||||
T tailXnorm = xLen > 1 ? xTail.reduceNumber(reduce::SquaredNorm).t<T>(0) : (T)0;
|
||||
|
||||
const auto xFirstElem = x.t<T>(0);
|
||||
|
||||
if(tailXnorm <= DataTypeUtils::min<T>()) {
|
||||
|
||||
normX = xFirstElem;
|
||||
coeff = (T)0.f;
|
||||
tail = (T)0.f;
|
||||
}
|
||||
else {
|
||||
|
||||
normX = math::nd4j_sqrt<T,T>(xFirstElem*xFirstElem + tailXnorm);
|
||||
|
||||
if(xFirstElem >= (T)0.f)
|
||||
normX = -normX; // choose opposite sign to lessen roundoff error
|
||||
|
||||
coeff = (normX - xFirstElem) / normX;
|
||||
|
||||
tail.assign(xTail / (xFirstElem - normX));
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Householder<T>::evalHHmatrixDataI(NDArray& x, T& coeff, T& normX) {
|
||||
|
||||
// input validation
|
||||
if(x.rankOf() != 1 && !x.isScalar())
|
||||
throw std::runtime_error("ops::helpers::Householder::evalHHmatrixDataI method: input array must have rank = 1 or to be scalar!");
|
||||
|
||||
int rows = (int)x.lengthOf()-1;
|
||||
int num = 1;
|
||||
|
||||
if(rows == 0) {
|
||||
rows = 1;
|
||||
num = 0;
|
||||
}
|
||||
|
||||
NDArray tail = x({num, -1});
|
||||
|
||||
evalHHmatrixData(x, tail, coeff, normX);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff) {
|
||||
|
||||
// if(matrix.rankOf() != 2)
|
||||
// throw "ops::helpers::Householder::mulLeft method: input array must be 2D matrix !";
|
||||
|
||||
if(matrix.sizeAt(0) == 1 && coeff != (T)0) {
|
||||
|
||||
matrix *= (T) 1.f - coeff;
|
||||
}
|
||||
else if(coeff != (T)0.f) {
|
||||
|
||||
NDArray bottomPart = matrix({1,matrix.sizeAt(0), 0,0}, true);
|
||||
NDArray fistRow = matrix({0,1, 0,0}, true);
|
||||
|
||||
if(tail.isColumnVector()) {
|
||||
|
||||
auto resultingRow = mmul(tail.transpose(), bottomPart);
|
||||
resultingRow += fistRow;
|
||||
resultingRow *= coeff;
|
||||
fistRow -= resultingRow;
|
||||
bottomPart -= mmul(tail, resultingRow);
|
||||
}
|
||||
else {
|
||||
|
||||
auto resultingRow = mmul(tail, bottomPart);
|
||||
resultingRow += fistRow;
|
||||
resultingRow *= coeff;
|
||||
fistRow -= resultingRow;
|
||||
bottomPart -= mmul(tail.transpose(), resultingRow);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Householder<T>::mulRight(NDArray& matrix, const NDArray& tail, const T coeff) {
|
||||
|
||||
// if(matrix.rankOf() != 2)
|
||||
// throw "ops::helpers::Householder::mulRight method: input array must be 2D matrix !";
|
||||
|
||||
if(matrix.sizeAt(1) == 1 && coeff != (T)0) {
|
||||
matrix *= (T)1.f - coeff;
|
||||
}
|
||||
else if(coeff != (T)0.f) {
|
||||
|
||||
NDArray rightPart = matrix({0,0, 1,matrix.sizeAt(1)}, true);
|
||||
NDArray fistCol = matrix({0,0, 0,1}, true);
|
||||
|
||||
if(tail.isColumnVector()) {
|
||||
|
||||
auto resultingCol = mmul(rightPart, tail);
|
||||
resultingCol += fistCol;
|
||||
resultingCol *= coeff;
|
||||
fistCol -= resultingCol;
|
||||
rightPart -= mmul(resultingCol, tail.transpose());
|
||||
}
|
||||
else {
|
||||
|
||||
auto resultingCol = mmul(rightPart, tail.transpose());
|
||||
resultingCol += fistCol;
|
||||
resultingCol *= coeff;
|
||||
fistCol -= resultingCol;
|
||||
rightPart -= mmul(resultingCol, tail);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template class ND4J_EXPORT Householder<float>;
|
||||
template class ND4J_EXPORT Householder<float16>;
|
||||
template class ND4J_EXPORT Householder<bfloat16>;
|
||||
template class ND4J_EXPORT Householder<double>;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,8 +20,7 @@
|
|||
|
||||
#include <helpers/jacobiSVD.h>
|
||||
#include <helpers/hhColPivQR.h>
|
||||
#include <array/NDArrayFactory.h>
|
||||
|
||||
#include <helpers/MmulHelper.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
|
@ -43,27 +42,27 @@ JacobiSVD<T>::JacobiSVD(const NDArray& matrix, const bool calcU, const bool calc
|
|||
_calcV = calcV;
|
||||
_fullUV = fullUV;
|
||||
|
||||
_s = NDArrayFactory::create(matrix.ordering(), {_diagSize, 1}, matrix.dataType(), matrix.getContext());
|
||||
_s = NDArray(matrix.ordering(), {_diagSize, 1}, matrix.dataType(), matrix.getContext());
|
||||
|
||||
if(_calcU) {
|
||||
if(_fullUV)
|
||||
_u = NDArrayFactory::create(matrix.ordering(), {_rows, _rows}, matrix.dataType(), matrix.getContext());
|
||||
_u = NDArray(matrix.ordering(), {_rows, _rows}, matrix.dataType(), matrix.getContext());
|
||||
else
|
||||
_u = NDArrayFactory::create(matrix.ordering(), {_rows, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
_u = NDArray(matrix.ordering(), {_rows, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
}
|
||||
else
|
||||
_u = NDArrayFactory::create(matrix.ordering(), {_rows, 1}, matrix.dataType(), matrix.getContext());
|
||||
_u = NDArray(matrix.ordering(), {_rows, 1}, matrix.dataType(), matrix.getContext());
|
||||
|
||||
if(_calcV) {
|
||||
if(_fullUV)
|
||||
_v = NDArrayFactory::create(matrix.ordering(), {_cols, _cols}, matrix.dataType(), matrix.getContext());
|
||||
_v = NDArray(matrix.ordering(), {_cols, _cols}, matrix.dataType(), matrix.getContext());
|
||||
else
|
||||
_v = NDArrayFactory::create(matrix.ordering(), {_cols, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
_v = NDArray(matrix.ordering(), {_cols, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
}
|
||||
else
|
||||
_v = NDArrayFactory::create(matrix.ordering(), {_cols, 1}, matrix.dataType(), matrix.getContext());
|
||||
_v = NDArray(matrix.ordering(), {_cols, 1}, matrix.dataType(), matrix.getContext());
|
||||
|
||||
_m = NDArrayFactory::create(matrix.ordering(), {_diagSize, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
_m = NDArray(matrix.ordering(), {_diagSize, _diagSize}, matrix.dataType(), matrix.getContext());
|
||||
|
||||
evalData(matrix);
|
||||
}
|
||||
|
@ -77,16 +76,19 @@ void JacobiSVD<T>::mulRotationOnLeft(const int i, const int j, NDArray& block, c
|
|||
if(j+1 > block.sizeAt(0))
|
||||
throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnLeft: second arguments is out of array row range !");
|
||||
|
||||
auto pTemp = block({i,j+1,j-i, 0,0,0}, true, true);
|
||||
auto temp = pTemp;
|
||||
pTemp.assign(mmul(rotation, temp));
|
||||
auto temp = block({i,j+1,j-i, 0,0,0}, true, true);
|
||||
temp.assign(mmul(rotation, temp));
|
||||
|
||||
// auto pTemp = block({i,j+1,j-i, 0,0,0}, true, true);
|
||||
// auto temp = pTemp.dup();
|
||||
// pTemp.assign(mmul(rotation, temp));
|
||||
}
|
||||
else {
|
||||
|
||||
if(j+1 > block.sizeAt(0) || i+1 > block.sizeAt(0))
|
||||
throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnLeft: some or both integer arguments are out of array row range !");
|
||||
|
||||
auto temp = NDArrayFactory::create(block.ordering(), {2, block.sizeAt(1)}, block.dataType(), block.getContext());
|
||||
NDArray temp(block.ordering(), {2, block.sizeAt(1)}, block.dataType(), block.getContext());
|
||||
auto row1 = block({i,i+1, 0,0}, true);
|
||||
auto row2 = block({j,j+1, 0,0}, true);
|
||||
auto rowTemp1 = temp({0,1, 0,0}, true);
|
||||
|
@ -108,16 +110,19 @@ void JacobiSVD<T>::mulRotationOnRight(const int i, const int j, NDArray& block,
|
|||
if(j+1 > block.sizeAt(1))
|
||||
throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnRight: second argument is out of array column range !");
|
||||
|
||||
auto pTemp = block({0,0,0, i,j+1,j-i}, true, true);
|
||||
auto temp = pTemp;
|
||||
pTemp.assign(mmul(temp, rotation));
|
||||
auto temp = block({0,0,0, i,j+1,j-i}, true, true);
|
||||
temp.assign(mmul(temp, rotation));
|
||||
|
||||
// auto pTemp = block({0,0,0, i,j+1,j-i}, true, true);
|
||||
// auto temp = pTemp.dup();
|
||||
// pTemp.assign(mmul(temp, rotation));
|
||||
}
|
||||
else {
|
||||
|
||||
if(j+1 > block.sizeAt(1) || i+1 > block.sizeAt(1))
|
||||
throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnRight: some or both integer arguments are out of array column range !");
|
||||
|
||||
auto temp = NDArrayFactory::create(block.ordering(), {block.sizeAt(0), 2}, block.dataType(), block.getContext());
|
||||
NDArray temp(block.ordering(), {block.sizeAt(0), 2}, block.dataType(), block.getContext());
|
||||
auto col1 = block({0,0, i,i+1}, true);
|
||||
auto col2 = block({0,0, j,j+1}, true);
|
||||
auto colTemp1 = temp({0,0, 0,1}, true);
|
||||
|
@ -134,123 +139,148 @@ void JacobiSVD<T>::mulRotationOnRight(const int i, const int j, NDArray& block,
|
|||
template <typename T>
|
||||
bool JacobiSVD<T>::isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem) {
|
||||
|
||||
auto rotation = NDArrayFactory::create(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext());
|
||||
T n = math::nd4j_sqrt<T,T>(block.e<T>(p,p) * block.e<T>(p,p) + block.e<T>(q,p) * block.e<T>(q,p));
|
||||
NDArray rotation(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext());
|
||||
|
||||
T n = math::nd4j_sqrt<T,T>(block.t<T>(p, p) * block.t<T>(p, p) + block.t<T>(q, p)*block.t<T>(q, p));
|
||||
|
||||
const T almostZero = DataTypeUtils::min<T>();
|
||||
const T precision = DataTypeUtils::eps<T>();
|
||||
|
||||
if(n == (T)0.f) {
|
||||
block.p(p, p, 0.f);
|
||||
block.p(q, p, 0.f);
|
||||
block.r<T>(p, p) = (T)0;
|
||||
block.r<T>(q, p) = (T)0;
|
||||
} else {
|
||||
T v = block.e<T>(p, p) / n;
|
||||
T v = block.t<T>(p, p) / n;
|
||||
|
||||
rotation.p(0, 0, v);
|
||||
rotation.p(1,1, v);
|
||||
rotation.r<T>(0,0) = rotation.r<T>(1,1) = v;
|
||||
|
||||
v = block.e<T>(q,p) / n;
|
||||
rotation.p(0, 1, v);
|
||||
v = block.t<T>(q, p) / n;
|
||||
rotation.r<T>(0,1) = v;
|
||||
|
||||
rotation.p(1,0, -rotation.template e<T>(0, 1));
|
||||
rotation.r<T>(1,0) = -rotation.template t<T>(0,1);
|
||||
mulRotationOnLeft(p, q, block, rotation);
|
||||
|
||||
if(_calcU) {
|
||||
auto temp2 = rotation.transpose();
|
||||
mulRotationOnRight(p, q, _u, temp2);
|
||||
}
|
||||
if(_calcU)
|
||||
mulRotationOnRight(p, q, _u, rotation.transpose());
|
||||
}
|
||||
|
||||
maxElem = math::nd4j_max<T>(maxElem, math::nd4j_max<T>(math::nd4j_abs<T>(block.e<T>(p,p)), math::nd4j_abs<T>(block.e<T>(q,q))));
|
||||
maxElem = math::nd4j_max<T>(maxElem, math::nd4j_max<T>(math::nd4j_abs<T>(block.t<T>(p, p)), math::nd4j_abs<T>(block.t<T>(q, q))));
|
||||
T threshold = math::nd4j_max<T>(almostZero, precision * maxElem);
|
||||
const bool condition1 = math::nd4j_abs<T>(block.e<T>(p,q)) > threshold;
|
||||
const bool condition2 = math::nd4j_abs<T>(block.e<T>(q,p)) > threshold;
|
||||
|
||||
return condition1 || condition2;
|
||||
return math::nd4j_abs<T>(block.t<T>(p, q)) > threshold || math::nd4j_abs<T>(block.t<T>(q, p)) > threshold;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
bool JacobiSVD<T>::createJacobiRotation(const T& x, const T& y, const T& z, NDArray& rotation) {
|
||||
|
||||
T denom = 2.* math::nd4j_abs<T>(y);
|
||||
T denom = (T)(2.f)* math::nd4j_abs<T>(y);
|
||||
|
||||
if(denom < DataTypeUtils::min<T>()) {
|
||||
|
||||
rotation.p(0,0, 1.f);
|
||||
rotation.p(1,1, 1.f);
|
||||
rotation.p(0,1, 0.f);
|
||||
rotation.p(1,0, 0.f);
|
||||
rotation.r<T>(0,0) = rotation.r<T>(1,1) = (T)1.f;
|
||||
rotation.r<T>(0,1) = rotation.r<T>(1,0) = (T)0.f;
|
||||
|
||||
return false;
|
||||
}
|
||||
else {
|
||||
|
||||
T tau = (x-z)/denom;
|
||||
T w = math::nd4j_sqrt<T,T>(tau*tau + 1.);
|
||||
T w = math::nd4j_sqrt<T,T>(tau*tau + (T)1.f);
|
||||
T t;
|
||||
|
||||
if(tau > (T)0.)
|
||||
t = 1. / (tau + w);
|
||||
t = (T)1.f / (tau + w);
|
||||
else
|
||||
t = 1. / (tau - w);
|
||||
t = (T)1.f / (tau - w);
|
||||
|
||||
T sign = t > (T)0. ? 1. : -1.;
|
||||
T n = 1. / math::nd4j_sqrt<T,T>(t*t + 1.f);
|
||||
rotation.p(0,0, n);
|
||||
rotation.p(1,1, n);
|
||||
T sign = t > (T)0. ? (T)1.f : (T)-1.f;
|
||||
|
||||
rotation.p(0,1, -sign * (y / math::nd4j_abs<T>(y)) * math::nd4j_abs<T>(t) * n);
|
||||
rotation.p(1,0, -rotation.e<T>(0,1));
|
||||
T cos = (T)1.f / math::nd4j_sqrt<T,T>(t*t + (T)1.f);
|
||||
T sin = -sign * (y / math::nd4j_abs<T>(y)) * math::nd4j_abs<T>(t) * cos;
|
||||
|
||||
rotation.r<T>(0,1) = sin;
|
||||
rotation.r<T>(1,0) = -sin;
|
||||
rotation.r<T>(0,0) = rotation.r<T>(1,1) = cos;
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void JacobiSVD<T>::createJacobiRotationGivens(const T& p, const T& q, NDArray& rotation) {
|
||||
|
||||
T cos, sin;
|
||||
|
||||
if(q == (T)0) {
|
||||
|
||||
cos = p < (T)0 ? (T)-1 : (T)1;
|
||||
sin = (T)0;
|
||||
}
|
||||
else if(p == (T)0) {
|
||||
|
||||
cos = (T)0;
|
||||
sin = q < (T)0 ? (T)1 : (T)-1;
|
||||
}
|
||||
else if(math::nd4j_abs<T>(p) > math::nd4j_abs<T>(q)) {
|
||||
|
||||
T t = q / p;
|
||||
T u = math::nd4j_sqrt<T,T>((T)1 + t*t);
|
||||
if(p < (T)0)
|
||||
u = -u;
|
||||
cos = (T)1 / u;
|
||||
sin = -t * cos;
|
||||
}
|
||||
else {
|
||||
T t = p / q;
|
||||
T u = math::nd4j_sqrt<T,T>((T)1 + t*t);
|
||||
if(q < (T)0)
|
||||
u = -u;
|
||||
sin = -(T)1 / u;
|
||||
cos = -t * sin;
|
||||
}
|
||||
|
||||
rotation.r<T>(0,1) = sin;
|
||||
rotation.r<T>(1,0) = -sin;
|
||||
rotation.r<T>(0,0) = rotation.r<T>(1,1) = cos;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void JacobiSVD<T>::svd2x2(const NDArray& block, int p, int q, NDArray& left, NDArray& right) {
|
||||
|
||||
auto m = NDArrayFactory::create(block.ordering(), {2, 2}, block.dataType(), block.getContext());
|
||||
m.p<T>(0,0, block.e<T>(p,p));
|
||||
m.p<T>(0,1, block.e<T>(p,q));
|
||||
m.p<T>(1,0, block.e<T>(q,p));
|
||||
m.p<T>(1,1, block.e<T>(q,q));
|
||||
NDArray m(block.ordering(), {2, 2}, block.dataType(), block.getContext());
|
||||
m.r<T>(0,0) = block.t<T>(p,p);
|
||||
m.r<T>(0,1) = block.t<T>(p,q);
|
||||
m.r<T>(1,0) = block.t<T>(q,p);
|
||||
m.r<T>(1,1) = block.t<T>(q,q);
|
||||
|
||||
auto rotation = NDArrayFactory::create(block.ordering(), {2, 2}, block.dataType(), block.getContext());
|
||||
T t = m.e<T>(0,0) + m.e<T>(1,1);
|
||||
T d = m.e<T>(1,0) - m.e<T>(0,1);
|
||||
NDArray rotation(block.ordering(), {2, 2}, block.dataType(), block.getContext());
|
||||
T t = m.t<T>(0,0) + m.t<T>(1,1);
|
||||
T d = m.t<T>(1,0) - m.t<T>(0,1);
|
||||
|
||||
if(math::nd4j_abs<T>(d) < DataTypeUtils::min<T>()) {
|
||||
|
||||
rotation.p(0,0, 1.f);
|
||||
rotation.p(1,1, 1.f);
|
||||
rotation.p(0,1, 0.f);
|
||||
rotation.p(1,0, 0.f);
|
||||
rotation.r<T>(0,0) = rotation.r<T>(1,1) = (T)1;
|
||||
rotation.r<T>(0,1) = rotation.r<T>(1,0) = (T)0;
|
||||
}
|
||||
else {
|
||||
|
||||
T u = t / d;
|
||||
T tmp = math::nd4j_sqrt<T,T>(1. + u*u);
|
||||
rotation.p(0,0, u / tmp);
|
||||
rotation.p(1,1, u / tmp);
|
||||
rotation.p(0,1, 1.f / tmp);
|
||||
rotation.p(1,0, -rotation.e<T>(0,1));
|
||||
T tmp = math::nd4j_sqrt<T,T>((T)1.f + u*u);
|
||||
rotation.r<T>(0,0) = rotation.r<T>(1,1) = u / tmp;
|
||||
rotation.r<T>(0,1) = (T)1.f / tmp;
|
||||
rotation.r<T>(1,0) = -rotation.t<T>(0,1);
|
||||
}
|
||||
|
||||
m.assign(mmul(rotation, m));
|
||||
|
||||
auto _x = m.e<T>(0,0);
|
||||
auto _y = m.e<T>(0,1);
|
||||
auto _z = m.e<T>(1,1);
|
||||
createJacobiRotation(m.t<T>(0,0), m.t<T>(0,1), m.t<T>(1,1), right);
|
||||
|
||||
createJacobiRotation(_x, _y, _z, right);
|
||||
|
||||
m.p<T>(0, 0, _x);
|
||||
m.p<T>(0, 1, _y);
|
||||
m.p<T>(1, 1, _z);
|
||||
|
||||
auto temp = right.transpose();
|
||||
left.assign(mmul(rotation, temp));
|
||||
left.assign(mmul(rotation, right.transpose()));
|
||||
}
|
||||
|
||||
|
||||
|
@ -261,7 +291,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
|||
const T precision = (T)2.f * DataTypeUtils::eps<T>();
|
||||
const T almostZero = DataTypeUtils::min<T>();
|
||||
|
||||
T scale = matrix.reduceNumber(reduce::AMax).e<T>(0);
|
||||
T scale = matrix.reduceNumber(reduce::AMax).template t<T>(0);
|
||||
if(scale== (T)0.f)
|
||||
scale = (T)1.f;
|
||||
|
||||
|
@ -285,8 +315,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
|||
}
|
||||
else if(_rows < _cols) {
|
||||
|
||||
auto matrixT = matrix.transpose();
|
||||
HHcolPivQR qr(matrixT / scale);
|
||||
HHcolPivQR qr(matrix.transpose() / scale);
|
||||
_m.assign(qr._qr({0,_rows, 0,_rows}));
|
||||
_m.fillAsTriangular<T>(0., 0, 0, _m, 'l');
|
||||
_m.transposei();
|
||||
|
@ -305,7 +334,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
|||
}
|
||||
else {
|
||||
|
||||
_m.assign(static_cast<const NDArray&>(matrix({0,_diagSize, 0,_diagSize})) / scale);
|
||||
_m.assign(matrix({0,_diagSize, 0,_diagSize}) / scale);
|
||||
|
||||
if(_calcU)
|
||||
_u.setIdentity();
|
||||
|
@ -316,7 +345,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
|||
|
||||
T maxDiagElem = 0.;
|
||||
for(int i = 0; i < _diagSize; ++i) {
|
||||
T current = math::nd4j_abs<T>(_m.e<T>(i,i));
|
||||
T current = math::nd4j_abs<T>(_m.t<T>(i,i));
|
||||
if(maxDiagElem < current )
|
||||
maxDiagElem = current;
|
||||
}
|
||||
|
@ -333,29 +362,27 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
|||
|
||||
T threshold = math::nd4j_max<T>(almostZero, precision * maxDiagElem);
|
||||
|
||||
if(math::nd4j_abs<T>(_m.e<T>(p,q)) > threshold || math::nd4j_abs<T>(_m.e<T>(q,p)) > threshold){
|
||||
if(math::nd4j_abs<T>(_m.t<T>(p,q)) > threshold || math::nd4j_abs<T>(_m.t<T>(q,p)) > threshold){
|
||||
|
||||
stop = false;
|
||||
|
||||
// if(isBlock2x2NotDiag(_m, p, q, maxDiagElem))
|
||||
{
|
||||
auto rotLeft = NDArrayFactory::create(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext());
|
||||
auto rotRight = NDArrayFactory::create(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext());
|
||||
NDArray rotLeft(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext());
|
||||
NDArray rotRight(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext());
|
||||
svd2x2(_m, p, q, rotLeft, rotRight);
|
||||
|
||||
mulRotationOnLeft(p, q, _m, rotLeft);
|
||||
|
||||
if(_calcU) {
|
||||
auto temp = rotLeft.transpose();
|
||||
mulRotationOnRight(p, q, _u, temp);
|
||||
}
|
||||
if(_calcU)
|
||||
mulRotationOnRight(p, q, _u, rotLeft.transpose());
|
||||
|
||||
mulRotationOnRight(p, q, _m, rotRight);
|
||||
|
||||
if(_calcV)
|
||||
mulRotationOnRight(p, q, _v, rotRight);
|
||||
|
||||
maxDiagElem = math::nd4j_max<T>(maxDiagElem, math::nd4j_max<T>(math::nd4j_abs<T>(_m.e<T>(p,p)), math::nd4j_abs<T>(_m.e<T>(q,q))));
|
||||
maxDiagElem = math::nd4j_max<T>(maxDiagElem, math::nd4j_max<T>(math::nd4j_abs<T>(_m.t<T>(p,p)), math::nd4j_abs<T>(_m.t<T>(q,q))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -363,8 +390,10 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
|||
}
|
||||
|
||||
for(int i = 0; i < _diagSize; ++i) {
|
||||
_s.p(i, math::nd4j_abs<T>(_m.e<T>(i,i)));
|
||||
if(_calcU && _m.e<T>(i,i) < (T)0.) {
|
||||
|
||||
_s.r<T>(i) = math::nd4j_abs<T>(_m.t<T>(i,i));
|
||||
|
||||
if(_calcU && _m.t<T>(i,i) < (T)0.) {
|
||||
auto temp = _u({0,0, i,i+1}, true);
|
||||
temp.applyTransform(transform::Neg, temp, nullptr);
|
||||
}
|
||||
|
@ -375,7 +404,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
|||
for(int i = 0; i < _diagSize; i++) {
|
||||
|
||||
int pos = (_s({i,-1, 0,0}).indexReduceNumber(indexreduce::IndexMax, nullptr)).template e<int>(0);
|
||||
T maxSingVal = _s({i,-1, 0,0}).reduceNumber(reduce::Max).template e<T>(0);
|
||||
T maxSingVal = _s({i,-1, 0,0}).reduceNumber(reduce::Max).template t<T>(0);
|
||||
|
||||
if(maxSingVal == (T)0.)
|
||||
break;
|
||||
|
@ -384,34 +413,24 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
|||
|
||||
pos += i;
|
||||
|
||||
T _e0 = _s.e<T>(i);
|
||||
T _e1 = _s.e<T>(pos);
|
||||
_s.p(pos, _e0);
|
||||
_s.p(i, _e1);
|
||||
//math::nd4j_swap<T>(_s(i), _s(pos));
|
||||
math::nd4j_swap<T>(_s.r<T>(i), _s.r<T>(pos));
|
||||
|
||||
if(_calcU) {
|
||||
auto temp1 = _u({0,0, pos,pos+1}, true);
|
||||
auto temp2 = _u({0,0, i,i+1}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
temp1.swapUnsafe(temp2);
|
||||
}
|
||||
|
||||
if(_calcV) {
|
||||
auto temp1 = _v({0,0, pos, pos+1}, true);
|
||||
auto temp2 = _v({0,0, i, i+1}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
temp1.swapUnsafe(temp2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template class ND4J_EXPORT JacobiSVD<float>;
|
||||
template class ND4J_EXPORT JacobiSVD<float16>;
|
||||
template class ND4J_EXPORT JacobiSVD<bfloat16>;
|
|
@ -52,6 +52,7 @@ class JacobiSVD {
|
|||
bool isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem);
|
||||
|
||||
static bool createJacobiRotation(const T& x, const T& y, const T& z, NDArray& rotation);
|
||||
static void createJacobiRotationGivens(const T& p, const T& q, NDArray& rotation);
|
||||
|
||||
static void svd2x2(const NDArray& block, int p, int q, NDArray& left, NDArray& right);
|
||||
|
||||
|
|
|
@ -528,7 +528,7 @@ namespace shape {
|
|||
* Returns the element wise stride for this information
|
||||
* buffer
|
||||
*/
|
||||
ND4J_EXPORT _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *buffer);
|
||||
ND4J_EXPORT _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *shapeInfo);
|
||||
|
||||
|
||||
/**
|
||||
|
|
|
@ -31,25 +31,39 @@ namespace sd {
|
|||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
__shared__ Nd4jLong resultLength;
|
||||
__shared__ Nd4jLong resultLength, xEws, yEws;
|
||||
__shared__ bool sameOffsets, sameOrders;
|
||||
__shared__ T* input;
|
||||
__shared__ T* output;
|
||||
|
||||
if (0 == threadIdx.x) {
|
||||
resultLength = shape::length(theFirstShape);
|
||||
input = reinterpret_cast<T*>(theSecondBuffer);
|
||||
output = reinterpret_cast<T*>(theFirstBuffer);
|
||||
|
||||
sameOffsets = shape::haveSameShapeAndStrides(theFirstShape, theSecondShape);
|
||||
sameOrders = shape::order(theFirstShape) == shape::order(theSecondShape);
|
||||
|
||||
xEws = shape::elementWiseStride(theFirstShape);
|
||||
yEws = shape::elementWiseStride(theSecondShape);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = tid; i < resultLength; i += totalThreads) {
|
||||
auto xEws = shape::order(theFirstShape) == 'c'? shape::elementWiseStride(theFirstShape) :1;
|
||||
auto yEws = shape::order(theSecondShape) == 'c'? shape::elementWiseStride(theSecondShape):1;
|
||||
|
||||
auto xOffset = shape::getIndexOffset(i * xEws, theFirstShape);
|
||||
auto yOffset = shape::getIndexOffset(i * yEws, theSecondShape);
|
||||
if(sameOrders && xEws > 0 && yEws > 0) {
|
||||
sd::math::nd4j_swap(output[i*xEws], input[i*yEws]);
|
||||
}
|
||||
else if(sameOffsets) {
|
||||
const auto offset = shape::getIndexOffset(i, theFirstShape);
|
||||
sd::math::nd4j_swap(output[offset], input[offset]);
|
||||
}
|
||||
else{
|
||||
const auto xOffset = shape::getIndexOffset(i, theFirstShape);
|
||||
const auto yOffset = shape::getIndexOffset(i, theSecondShape);
|
||||
sd::math::nd4j_swap(output[xOffset], input[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template __global__ void swapUnsafeKernel, (void* theFirstBuffer, Nd4jLong const* theFirstShape, void* theSecondBuffer, Nd4jLong const* theSecondShape), LIBND4J_TYPES);
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <system/op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_sqrtm)
|
||||
#include <ops/declarable/helpers/sqrtm.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
|
||||
CONFIGURABLE_OP_IMPL(sqrtm, 1, 1, false, 0, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() > 1, 0, "CONFIGURABLE_OP sqrtm: input array rank is required to be > 1, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(input->sizeAt(-2) == input->sizeAt(-1), 0, "CONFIGURABLE_OP sqrtm: two last dimensions of input array should be square matrices, but got such wrong shape instead: %s!", ShapeUtils::shapeAsString(input).c_str());
|
||||
|
||||
helpers::sqrtm(block.launchContext(), input, output);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(sqrtm) {
|
||||
getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -55,7 +55,7 @@ namespace sd {
|
|||
isLower = !isLower;
|
||||
};
|
||||
|
||||
auto res = helpers::triangularSolveFunctor(block.launchContext(), input, b, isLower, useAdjoint, z);
|
||||
auto res = helpers::triangularSolveFunctor(block.launchContext(), input, b, isLower, false, z);
|
||||
if (input != a)
|
||||
delete input;
|
||||
|
||||
|
|
|
@ -108,6 +108,20 @@ namespace sd {
|
|||
#if NOT_EXCLUDED(OP_svd)
|
||||
DECLARE_CUSTOM_OP(svd, 1, 1, false, 0, 3);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* calculates square root of matrix such that
|
||||
* x[..., M, M] = z[..., M, M] x z[..., M, M]
|
||||
*
|
||||
* Input array:
|
||||
* x[..., M, M], the necessary condition is: rank of x >= 2 and equality of last two dimensions
|
||||
*
|
||||
* Outputs arrays:
|
||||
* z - same shape as x
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_sqrtm)
|
||||
DECLARE_CONFIGURABLE_OP(sqrtm, 1, 1, false, 0, 0);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -118,7 +118,7 @@ static void betaIncForArray(sd::LaunchContext * context, const NDArray& a, const
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++)
|
||||
output.t<T>(i) = betaIncCore<T>(a.t<T>(i), b.t<T>(i), x.t<T>(i));
|
||||
output.r<T>(i) = betaIncCore<T>(a.t<T>(i), b.t<T>(i), x.t<T>(i));
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, xLen);
|
||||
|
|
|
@ -73,7 +73,7 @@ namespace helpers {
|
|||
bool setUp = (theSame && row >= 0 && col >= 0 && row < rowDim && col < colDim) ||
|
||||
(!theSame);
|
||||
if (setUp) {
|
||||
outMatrix->t<T>(i, j, pos) = patch->e<T>(row, col, pixel);
|
||||
outMatrix->r<T>(i, j, pos) = patch->e<T>(row, col, pixel);
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ namespace helpers {
|
|||
else if (val >= nudged_max)
|
||||
val = nudged_max;
|
||||
// quantization itself
|
||||
output->t<T>(e + i) = math::nd4j_floor<T,T>((val - nudged_min)/scale + T(0.5)) * scale + nudged_min;
|
||||
output->r<T>(e + i) = math::nd4j_floor<T,T>((val - nudged_min)/scale + T(0.5)) * scale + nudged_min;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -318,7 +318,7 @@ namespace helpers {
|
|||
}
|
||||
// copy pixel over all channels
|
||||
for (Nd4jLong e = 0; e < channels; e++)
|
||||
output->t<T>(b, y, x, e) = images->t<T>(b, inY, inX, e);
|
||||
output->r<T>(b, y, x, e) = images->t<T>(b, inY, inX, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ namespace helpers {
|
|||
|
||||
for (auto x = 0; x < lastDims.size(); x++) {
|
||||
for (auto r = 0; r < rows; r++) {
|
||||
lastDims[x]->t<T>(r,r) = (T)value;
|
||||
lastDims[x]->r<T>(r,r) = (T)value;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ namespace helpers {
|
|||
|
||||
if (theFirst != theSecond)
|
||||
for (int i = 0; i < matrix->columns(); i++) {
|
||||
math::nd4j_swap(matrix->t<T>(theFirst, i), matrix->t<T>(theSecond, i));
|
||||
math::nd4j_swap(matrix->r<T>(theFirst, i), matrix->r<T>(theSecond, i));
|
||||
}
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES);
|
||||
|
@ -71,12 +71,12 @@ namespace helpers {
|
|||
|
||||
auto invertDiagonals = PRAGMA_THREADS_FOR {
|
||||
for (int i = start; i < stop; i += increment)
|
||||
invertedMatrix->t<T>(i, i) /= inputMatrix->t<T>(i, i);
|
||||
invertedMatrix->r<T>(i, i) /= inputMatrix->t<T>(i, i);
|
||||
};
|
||||
|
||||
auto invertSubDiagonals = PRAGMA_THREADS_FOR {
|
||||
for (int i = start; i < stop; i += increment)
|
||||
invertedMatrix->t<T>(i, i - 1) -= (inputMatrix->t<T>(i, i - 1) * invertedMatrix->t<T>(i - 1, i - 1) / inputMatrix->t<T>(i, i));
|
||||
invertedMatrix->r<T>(i, i - 1) -= (inputMatrix->t<T>(i, i - 1) * invertedMatrix->t<T>(i - 1, i - 1) / inputMatrix->t<T>(i, i));
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(invertDiagonals, 0, n, 1);
|
||||
|
@ -86,7 +86,7 @@ namespace helpers {
|
|||
for (int i = 1; i < n; i++) {
|
||||
for (int j = 0; j < i - 1 ; j++)
|
||||
for (int k = 0; k < i; k++)
|
||||
invertedMatrix->t<T>(i, j) -= ((invertedMatrix->t<T>(k, j) * inputMatrix->t<T>(i, k) / inputMatrix->t<T>(i, i)));
|
||||
invertedMatrix->r<T>(i, j) -= ((invertedMatrix->t<T>(k, j) * inputMatrix->t<T>(i, k) / inputMatrix->t<T>(i, i)));
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -108,13 +108,13 @@ namespace helpers {
|
|||
|
||||
auto invertDiagonals = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment)
|
||||
invertedMatrix->t<T>(i, i) /= inputMatrix->t<T>(i, i);
|
||||
invertedMatrix->r<T>(i, i) /= inputMatrix->t<T>(i, i);
|
||||
};
|
||||
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
|
||||
auto invertUpDiagonals = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment)
|
||||
invertedMatrix->t<T>(i, i + 1) -= (inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) /
|
||||
invertedMatrix->r<T>(i, i + 1) -= (inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) /
|
||||
inputMatrix->t<T>(i, i));
|
||||
};
|
||||
|
||||
|
@ -125,7 +125,7 @@ namespace helpers {
|
|||
for (auto i = n - 2; i >= 0; i--) {
|
||||
for (auto j = i + 2; j < n; j++)
|
||||
for (auto k = i; k < n; k++)
|
||||
invertedMatrix->t<T>(i, j) -= ((invertedMatrix->t<T>(k, j) * inputMatrix->t<T>(i, k) / inputMatrix->t<T>(i, i)));
|
||||
invertedMatrix->r<T>(i, j) -= ((invertedMatrix->t<T>(k, j) * inputMatrix->t<T>(i, k) / inputMatrix->t<T>(i, i)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -169,10 +169,10 @@ namespace helpers {
|
|||
swapCount++;
|
||||
|
||||
for( int j = i + 1; j < rowNum; j++ ) {
|
||||
compoundMatrix.t<T>(j, i) /= compoundMatrix.t<T>(i, i);
|
||||
compoundMatrix.r<T>(j, i) /= compoundMatrix.t<T>(i, i);
|
||||
//PRAGMA_OMP_PARALLEL_FOR
|
||||
for( int k = i + 1; k < rowNum; k++ ) {
|
||||
compoundMatrix.t<T>(j, k) -= compoundMatrix.t<T>(j, i) * compoundMatrix.t<T>(i, k);
|
||||
compoundMatrix.r<T>(j, k) -= compoundMatrix.t<T>(j, i) * compoundMatrix.t<T>(i, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -190,7 +190,7 @@ namespace helpers {
|
|||
for (auto i = 0; i < rowNum; i++) {
|
||||
for (auto j = 0; j < columnNum; j++) {
|
||||
if (permutationMatrix.t<T>(i, j) != 0) {
|
||||
permutaionVector.template t<I>(i) = j;
|
||||
permutaionVector.template r<I>(i) = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -268,7 +268,7 @@ namespace helpers {
|
|||
sum += compound->t<T>(i,j) * compound->t<T>(j,k);
|
||||
|
||||
// Evaluating U(i, k)
|
||||
compound->t<T>(i, k) = input.t<T>(i, k) - sum;
|
||||
compound->r<T>(i, k) = input.t<T>(i, k) - sum;
|
||||
}
|
||||
|
||||
// Lower Triangular
|
||||
|
@ -279,7 +279,7 @@ namespace helpers {
|
|||
sum += compound->t<T>(k,j) * compound->t<T>(j, i);
|
||||
|
||||
// Evaluating L(k, i)
|
||||
compound->t<T>(k, i) = (input.t<T>(k, i) - sum) / compound->t<T>(i,i);
|
||||
compound->r<T>(k, i) = (input.t<T>(k, i) - sum) / compound->t<T>(i,i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -412,12 +412,12 @@ template <typename T>
|
|||
lowerMatrix.setIdentity(); // set up U to identity matrix
|
||||
for (int k = 1; k < n; k++) { // and then put all values under main diagonal on to it
|
||||
for (int j = 0; j < k; j++)
|
||||
lowerMatrix.template t<T>(k, j) = compound.template t<T>(k, j);
|
||||
lowerMatrix.template r<T>(k, j) = compound.template t<T>(k, j);
|
||||
}
|
||||
upperMatrix.setIdentity(); // set up U to identity matrix
|
||||
for (int k = 0; k < n; k++) { // and then put all values under main diagonal on to it
|
||||
for (int j = k; j < n; j++)
|
||||
upperMatrix.template t<T>(k, j) = compound.template e<T>(k, j);
|
||||
upperMatrix.template r<T>(k, j) = compound.template t<T>(k, j);
|
||||
}
|
||||
invertUpperMatrix(&upperMatrix, &matrix);
|
||||
|
||||
|
@ -426,7 +426,7 @@ template <typename T>
|
|||
sd::MmulHelper::mmul(&matrix, &upperMatrix, &compound, 1.0, 0.0);
|
||||
sd::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0);
|
||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||
output->t<T>(k) = matrix.template t<T>(row++);
|
||||
output->r<T>(k) = matrix.template t<T>(row++);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -470,7 +470,7 @@ template <typename T>
|
|||
invertLowerMatrix(&matrix, &lowerMatrix);
|
||||
|
||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||
output->t<T>(k) = lowerMatrix.template t<T>(row++);
|
||||
output->r<T>(k) = lowerMatrix.template t<T>(row++);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -597,7 +597,7 @@ template <typename T>
|
|||
|
||||
for (Nd4jLong e = 0; e < totalCount; e++) {
|
||||
for (size_t i = 0; i < n; ++i)
|
||||
output->t<T>(e) += sd::math::nd4j_log<T,T>(sd::math::nd4j_pow<T,T,T>(matricies.at(e)->t<T>(i, i), T(2)));
|
||||
output->r<T>(e) += sd::math::nd4j_log<T,T>(sd::math::nd4j_pow<T,T,T>(matricies.at(e)->t<T>(i, i), T(2)));
|
||||
}
|
||||
return ND4J_STATUS_OK;
|
||||
}
|
||||
|
|
|
@ -47,8 +47,8 @@ static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& o
|
|||
idx = static_cast<Z>(i);
|
||||
}
|
||||
}
|
||||
// FIXME, use .r<Z>(e)
|
||||
output.t<Z>(e) = static_cast<Z>(idx);
|
||||
|
||||
output.r<Z>(e) = static_cast<Z>(idx);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ namespace helpers {
|
|||
beta != nullptr ? copyBeta->t<T>(e) * u : u);
|
||||
}
|
||||
else {
|
||||
output->t<T>(pos + e) = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e),
|
||||
output->r<T>(pos + e) = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e),
|
||||
beta != nullptr ? copyBeta->t<T>(e) * u : u);
|
||||
}
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ namespace helpers {
|
|||
if (directOut)
|
||||
outputBuf[pos + e] = x;
|
||||
else
|
||||
output->t<T>(pos + e) = x;
|
||||
output->r<T>(pos + e) = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ namespace helpers {
|
|||
else {
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (Nd4jLong i = 0; i < output->lengthOf(); i++) {
|
||||
output->t<T>(i) = rng.relativeT<T>(i, minVal, maxVal);
|
||||
output->r<T>(i) = rng.relativeT<T>(i, minVal, maxVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,8 +54,8 @@ void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator&
|
|||
T t0 = input.t<T>(i);
|
||||
T t1 = input.t<T>(r);
|
||||
//math::nd4j_swap<T>(input(i), input(r));
|
||||
input.t<T>(i) = t1;
|
||||
input.t<T>(r) = t0;
|
||||
input.r<T>(i) = t1;
|
||||
input.r<T>(r) = t0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
@ -66,11 +66,11 @@ void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator&
|
|||
// FIXME: parallelism!!
|
||||
for(int i = firstDim-1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
output.t<T>(i) = input.t<T>(indices[r]);
|
||||
output.r<T>(i) = input.t<T>(indices[r]);
|
||||
if(i == r)
|
||||
continue;
|
||||
|
||||
output.t<T>(r) = input.t<T>(indices[i]);
|
||||
output.r<T>(r) = input.t<T>(indices[i]);
|
||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
}
|
||||
rng.rewindH(firstDim-1);
|
||||
|
|
|
@ -46,7 +46,7 @@ namespace helpers {
|
|||
idx = indices->e<Nd4jLong>(e);
|
||||
val = input->t<T>(e);
|
||||
}
|
||||
output->t<T>(idx) = val;
|
||||
output->r<T>(idx) = val;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
@ -65,7 +65,7 @@ namespace helpers {
|
|||
if (indices->e<int>(i) == idx) {
|
||||
|
||||
for (Nd4jLong e = 0; e < maxT->lengthOf(); e++) {
|
||||
maxT->t<T>(e) = sd::math::nd4j_max(maxT->t<T>(e), listOfTensors.at(i)->t<T>(e));
|
||||
maxT->r<T>(e) = sd::math::nd4j_max(maxT->t<T>(e), listOfTensors.at(i)->t<T>(e));
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
@ -96,7 +96,7 @@ namespace helpers {
|
|||
idx = indices->e<Nd4jLong>(e);
|
||||
val = input->t<T>(e);
|
||||
}
|
||||
output->t<T>(idx) = val;
|
||||
output->r<T>(idx) = val;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
@ -417,7 +417,7 @@ namespace helpers {
|
|||
for (size_t idx = 1; idx < fi->second.size(); ++idx) {
|
||||
val = sd::math::nd4j_min(val, input->t<T>(fi->second.at(idx)));
|
||||
}
|
||||
output->t<T>(fi->first) = val;
|
||||
output->r<T>(fi->first) = val;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
@ -436,7 +436,7 @@ namespace helpers {
|
|||
auto minT = listOfTensors.at(fi->second.at(idx));
|
||||
|
||||
for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) {
|
||||
outputT->t<T>(e) = sd::math::nd4j_min(minT->t<T>(e), outputT->t<T>(e));
|
||||
outputT->r<T>(e) = sd::math::nd4j_min(minT->t<T>(e), outputT->t<T>(e));
|
||||
}
|
||||
}
|
||||
//outputT->assign(maxT);
|
||||
|
@ -890,7 +890,7 @@ namespace helpers {
|
|||
for (auto e = start; e < stop; e++) {
|
||||
auto classNum = indices->e<Nd4jLong>(e);
|
||||
if (sd::math::nd4j_abs(tempRes.t<T>(classNum) - input->t<T>(e)) < 1.e-6)
|
||||
output->t<T>(e) = gradOut->t<T>(classNum);
|
||||
output->r<T>(e) = gradOut->t<T>(classNum);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -913,7 +913,7 @@ namespace helpers {
|
|||
|
||||
for (Nd4jLong e = 0; e < current->lengthOf(); e++) {
|
||||
if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->t<T>(e) - current->t<T>(e)) < 1.e-6)
|
||||
currentOut->t<T>(e) = currentGradOut->t<T>(e);
|
||||
currentOut->r<T>(e) = currentGradOut->t<T>(e);
|
||||
}
|
||||
}
|
||||
//};
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace helpers {
|
|||
for (auto i = start_x; i < stop_x; i += inc_x)
|
||||
for (auto k = start_y; k < stop_y; k += inc_y)
|
||||
if (i < input->t<I>(k))
|
||||
output->t<B>(k * maxIndex + i) = B(true); //, T(1.0f));
|
||||
output->r<B>(k * maxIndex + i) = B(true); //, T(1.0f));
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, maxIndex, 1, 0, input->lengthOf(), 1);
|
||||
|
|
|
@ -43,7 +43,7 @@ namespace helpers {
|
|||
for (auto batch = start; batch < stop; batch++) {
|
||||
for (Nd4jLong r = 0; r < rows; r++) {
|
||||
for (Nd4jLong c = 0; c < r; c++) {
|
||||
math::nd4j_swap(outputPart[batch]->t<T>(r, c) , outputPart[batch]->t<T>(c, r));
|
||||
math::nd4j_swap(outputPart[batch]->r<T>(r, c) , outputPart[batch]->r<T>(c, r));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ namespace helpers {
|
|||
|
||||
for (auto batch = 0; batch < permutationsPart.size(); ++batch) {
|
||||
for (Nd4jLong row = 0; row < PPart[batch]->rows(); ++row) {
|
||||
PPart[batch]->t<T>(row, permutationsPart[batch]->t<int>(row)) = T(1.f);
|
||||
PPart[batch]->r<T>(row, permutationsPart[batch]->t<int>(row)) = T(1.f);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,7 +78,7 @@ namespace helpers {
|
|||
ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1});
|
||||
for (auto i = 0; i < leftLowerPart.size(); i++) {
|
||||
for (Nd4jLong r = 0; r < leftLowerPart[i]->rows(); r++)
|
||||
leftLowerPart[i]->t<T>(r,r) = (T)1.f;
|
||||
leftLowerPart[i]->r<T>(r,r) = (T)1.f;
|
||||
}
|
||||
// stage 2: triangularSolveFunctor for Lower with given b
|
||||
helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput);
|
||||
|
|
|
@ -27,911 +27,6 @@ namespace sd {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const bool calcV, const bool fullUV ) {
|
||||
|
||||
if(matrix.rankOf() != 2 || matrix.isScalar())
|
||||
throw std::runtime_error("ops::helpers::SVD constructor: input array must be 2D matrix !");
|
||||
|
||||
const int rows = matrix.sizeAt(0);
|
||||
const int cols = matrix.sizeAt(1);
|
||||
|
||||
if(cols > rows) {
|
||||
|
||||
_transp = true;
|
||||
_diagSize = rows;
|
||||
}
|
||||
else {
|
||||
|
||||
_transp = false;
|
||||
_diagSize = cols;
|
||||
}
|
||||
|
||||
_switchSize = switchSize;
|
||||
_calcU = calcU;
|
||||
_calcV = calcV;
|
||||
_fullUV = fullUV;
|
||||
|
||||
if (_transp)
|
||||
math::nd4j_swap<bool>(_calcU, _calcV);
|
||||
|
||||
_s = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize, 1}, matrix.getContext());
|
||||
_m = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.getContext());
|
||||
_m.assign(0.);
|
||||
|
||||
if (_calcU)
|
||||
_u = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext());
|
||||
else
|
||||
_u = NDArrayFactory::create<T>(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext());
|
||||
_u.assign(0.);
|
||||
|
||||
if (_calcV) {
|
||||
_v = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize, _diagSize}, matrix.getContext());
|
||||
_v.assign(0.);
|
||||
}
|
||||
|
||||
evalData(matrix);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const bool calcV, const bool fullUV, const char t) {
|
||||
|
||||
if(matrix.rankOf() != 2 || matrix.isScalar())
|
||||
throw std::runtime_error("ops::helpers::SVD constructor: input array must be 2D matrix !");
|
||||
|
||||
const int rows = matrix.sizeAt(0);
|
||||
const int cols = matrix.sizeAt(1);
|
||||
|
||||
if(cols > rows) {
|
||||
|
||||
_transp = true;
|
||||
_diagSize = rows;
|
||||
}
|
||||
else {
|
||||
|
||||
_transp = false;
|
||||
_diagSize = cols;
|
||||
}
|
||||
|
||||
_switchSize = switchSize;
|
||||
_calcU = calcU;
|
||||
_calcV = calcV;
|
||||
_fullUV = fullUV;
|
||||
|
||||
if (_transp)
|
||||
math::nd4j_swap<bool>(_calcU, _calcV);
|
||||
|
||||
_s = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize, 1}, matrix.getContext());
|
||||
_m = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.getContext());
|
||||
_m.assign(0.f);
|
||||
|
||||
if (_calcU)
|
||||
_u = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext());
|
||||
else
|
||||
_u = NDArrayFactory::create<T>(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext());
|
||||
_u.assign(0.);
|
||||
|
||||
if (_calcV) {
|
||||
_v = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize, _diagSize}, matrix.getContext());
|
||||
_v.assign(0.);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void SVD<T>::deflation1(int col1, int shift, int ind, int size) {
|
||||
|
||||
if(ind <= 0)
|
||||
throw std::runtime_error("ops::helpers::SVD::deflation1 method: input int must satisfy condition ind > 0 !");
|
||||
|
||||
int first = col1 + shift;
|
||||
T cos = _m.e<T>(first, first);
|
||||
T sin = _m.e<T>(first+ind, first);
|
||||
T denom = math::nd4j_sqrt<T, T>(cos*cos + sin*sin);
|
||||
|
||||
if (denom == (T)0.) {
|
||||
|
||||
_m.p(first+ind, first+ind, 0.f);
|
||||
return;
|
||||
}
|
||||
|
||||
cos /= denom;
|
||||
sin /= denom;
|
||||
|
||||
_m.p(first,first, denom);
|
||||
_m.p(first+ind, first, 0.f);
|
||||
_m.p(first+ind, first+ind, 0.f);
|
||||
|
||||
auto rotation = NDArrayFactory::create<T>(_m.ordering(), {2, 2}, _m.getContext());
|
||||
rotation.p(0, 0, cos);
|
||||
rotation.p(0, 1, -sin);
|
||||
rotation.p(1, 0, sin);
|
||||
rotation.p(1, 1, cos);
|
||||
|
||||
if (_calcU) {
|
||||
auto temp = _u({col1,col1+size+1, 0,0}, true);
|
||||
JacobiSVD<T>::mulRotationOnRight(col1, col1+ind, temp, rotation);
|
||||
}
|
||||
else
|
||||
JacobiSVD<T>::mulRotationOnRight(col1, col1+ind, _u, rotation);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void SVD<T>::deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, int ind2, int size) {
|
||||
|
||||
if(ind1 >= ind2)
|
||||
throw std::runtime_error("ops::helpers::SVD::deflation2 method: input intes must satisfy condition ind1 < ind2 !");
|
||||
|
||||
if(size <= 0)
|
||||
throw std::runtime_error("ops::helpers::SVD::deflation2 method: input size must satisfy condition size > 0 !");
|
||||
|
||||
T cos = _m.e<T>(col1M+ind1, col1M);
|
||||
T sin = _m.e<T>(col1M+ind2, col1M);
|
||||
T denom = math::nd4j_sqrt<T,T>(cos*cos + sin*sin);
|
||||
|
||||
if (denom == (T)0.) {
|
||||
|
||||
_m.p(col1M + ind1, col1M + ind1, _m.e<T>(col1M + ind2, col1M + ind2));
|
||||
return;
|
||||
}
|
||||
|
||||
cos /= denom;
|
||||
sin /= denom;
|
||||
_m.p(col1M + ind1, col1M, denom);
|
||||
_m.p(col1M + ind2, col1M + ind2, _m.e<T>(col1M + ind1, col1M + ind1));
|
||||
_m.p(col1M + ind2, col1M, 0.f);
|
||||
|
||||
auto rotation = NDArrayFactory::create<T>(_m.ordering(), {2, 2}, _m.getContext());
|
||||
rotation.p(0,0, cos);
|
||||
rotation.p(1,1, cos);
|
||||
|
||||
rotation.p(0,1, -sin);
|
||||
rotation.p(1,0, sin);
|
||||
|
||||
if (_calcU) {
|
||||
auto temp = _u({col1U,col1U+size+1, 0,0}, true);
|
||||
JacobiSVD<T>::mulRotationOnRight(col1U+ind1, col1U+ind2, temp, rotation);
|
||||
}
|
||||
else
|
||||
JacobiSVD<T>::mulRotationOnRight(col1U+ind1, col1U+ind2, _u, rotation);
|
||||
|
||||
if (_calcV) {
|
||||
auto temp = _v({row1W,row1W+size, 0,0}, true);
|
||||
JacobiSVD<T>::mulRotationOnRight(col1W+ind1, col1W+ind2, temp, rotation);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// has effect on block from (col1+shift, col1+shift) to (col2+shift, col2+shift) inclusively
|
||||
template <typename T>
|
||||
void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int shift)
|
||||
{
|
||||
|
||||
const int len = col2 + 1 - col1;
|
||||
|
||||
auto colVec0 = new NDArray(_m({col1+shift,col1+shift+len, col1+shift,col1+shift+1}, true));
|
||||
|
||||
auto diagInterval = _m({col1+shift, col1+shift+len, col1+shift,col1+shift+len}, true).diagonal('c');
|
||||
|
||||
const T almostZero = DataTypeUtils::min<T>();
|
||||
T maxElem;
|
||||
if(len == 1)
|
||||
maxElem = math::nd4j_abs<T>(diagInterval.template e<T>(0));
|
||||
else
|
||||
maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e<T>(0);
|
||||
T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e<T>(0);
|
||||
|
||||
T eps = math::nd4j_max<T>(almostZero, DataTypeUtils::eps<T>() * maxElem);
|
||||
T epsBig = (T)8. * DataTypeUtils::eps<T>() * math::nd4j_max<T>(maxElem0, maxElem);
|
||||
|
||||
if(diagInterval.template e<T>(0) < epsBig)
|
||||
diagInterval.p(Nd4jLong(0), epsBig);
|
||||
|
||||
for(int i=1; i < len; ++i)
|
||||
if(math::nd4j_abs<T>(colVec0->template e<T>(i)) < eps)
|
||||
colVec0->p(i, 0.f);
|
||||
|
||||
for(int i=1; i < len; i++)
|
||||
if(diagInterval.template e<T>(i) < epsBig) {
|
||||
deflation1(col1, shift, i, len);
|
||||
for(int i = 0; i < len; ++i)
|
||||
diagInterval.p(i, _m.e<T>(col1+shift+i,col1+shift+i));
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
bool totDefl = true;
|
||||
for(int i=1; i < len; i++)
|
||||
if(colVec0->template e<T>(i) >= almostZero) {
|
||||
totDefl = false;
|
||||
break;
|
||||
}
|
||||
|
||||
int* permut = nullptr;
|
||||
ALLOCATE(permut, _m.getContext()->getWorkspace(), 3*_diagSize, int);
|
||||
{
|
||||
permut[0] = 0;
|
||||
int p = 1;
|
||||
|
||||
for(int i=1; i<len; ++i)
|
||||
if(math::nd4j_abs<T>(diagInterval.template e<T>(i)) < almostZero)
|
||||
permut[p++] = i;
|
||||
|
||||
int k = 1, m = ind+1;
|
||||
|
||||
for( ; p < len; ++p) {
|
||||
if(k > ind)
|
||||
permut[p] = m++;
|
||||
else if(m >= len)
|
||||
permut[p] = k++;
|
||||
else if(diagInterval.template e<T>(k) < diagInterval.template e<T>(m))
|
||||
permut[p] = m++;
|
||||
else
|
||||
permut[p] = k++;
|
||||
}
|
||||
}
|
||||
|
||||
if(totDefl) {
|
||||
for(int i=1; i<len; ++i) {
|
||||
int ki = permut[i];
|
||||
if(math::nd4j_abs<T>(diagInterval.template e<T>(ki)) < almostZero || diagInterval.template e<T>(0) < diagInterval.template e<T>(ki))
|
||||
permut[i-1] = permut[i];
|
||||
else {
|
||||
permut[i-1] = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int *tInd = permut + len;
|
||||
int *tCol = permut + 2*len;
|
||||
|
||||
for(int m = 0; m < len; m++) {
|
||||
tCol[m] = m;
|
||||
tInd[m] = m;
|
||||
}
|
||||
|
||||
for(int i = totDefl ? 0 : 1; i < len; i++) {
|
||||
|
||||
const int ki = permut[len - (totDefl ? i+1 : i)];
|
||||
const int jac = tCol[ki];
|
||||
|
||||
T _e0 = diagInterval.template e<T>(jac);
|
||||
//math::nd4j_swap<T>(diagInterval)(i), (*diagInterval)(jac));
|
||||
diagInterval.p(jac, diagInterval.template e<T>(i));
|
||||
diagInterval.p(i, _e0);
|
||||
|
||||
if(i!=0 && jac!=0) {
|
||||
_e0 = colVec0->template e<T>(jac);
|
||||
//math::nd4j_swap<T>((*colVec0)(i), (*colVec0)(jac));
|
||||
colVec0->p(jac, colVec0->template e<T>(i));
|
||||
colVec0->p(i, _e0);
|
||||
}
|
||||
|
||||
if (_calcU) {
|
||||
auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true);
|
||||
auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
}
|
||||
else {
|
||||
auto temp1 = _u({0,2, col1+i, col1+i+1}, true);
|
||||
auto temp2 = _u({0,2, col1+jac, col1+jac+1}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
}
|
||||
|
||||
if(_calcV) {
|
||||
auto temp1 = _v({row1W,row1W+len, col1W+i, col1W+i+1}, true);
|
||||
auto temp2 = _v({row1W,row1W+len, col1W+jac, col1W+jac+1}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
}
|
||||
|
||||
const int tI = tInd[i];
|
||||
tCol[tI] = jac;
|
||||
tCol[ki] = i;
|
||||
tInd[jac] = tI;
|
||||
tInd[i] = ki;
|
||||
}
|
||||
|
||||
RELEASE(permut, _m.getContext());
|
||||
}
|
||||
|
||||
{
|
||||
int i = len-1;
|
||||
|
||||
while(i > 0 && (math::nd4j_abs<T>(diagInterval.template e<T>(i)) < almostZero || math::nd4j_abs<T>(colVec0->template e<T>(i)) < almostZero))
|
||||
--i;
|
||||
|
||||
for(; i > 1; --i) {
|
||||
if( (diagInterval.template e<T>(i) - diagInterval.template e<T>(i-1)) < DataTypeUtils::eps<T>()*maxElem ) {
|
||||
if (math::nd4j_abs<T>(diagInterval.template e<T>(i) - diagInterval.template e<T>(i-1)) >= epsBig)
|
||||
throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !");
|
||||
deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete colVec0;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
T SVD<T>::secularEq(const T diff, const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& diagShifted, const T shift) {
|
||||
|
||||
auto len = permut.lengthOf();
|
||||
T res = 1.;
|
||||
T item;
|
||||
for(int i=0; i<len; ++i) {
|
||||
auto j = permut.e<int>(i);
|
||||
item = col0.e<T>(j) / ((diagShifted.e<T>(j) - diff) * (diag.e<T>(j) + shift + diff));
|
||||
res += item * col0.e<T>(j);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void SVD<T>::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArray& permut, NDArray& singVals, NDArray& shifts, NDArray& mus) {
|
||||
|
||||
auto len = col0.lengthOf();
|
||||
auto curLen = len;
|
||||
|
||||
while(curLen > 1 && col0.e<T>(curLen-1) == (T)0.f)
|
||||
--curLen;
|
||||
|
||||
for (int k = 0; k < len; ++k) {
|
||||
|
||||
if (col0.e<T>(k) == (T)0.f || curLen==1) {
|
||||
|
||||
singVals.p(k, k==0 ? col0.e<T>(0) : diag.e<T>(k));
|
||||
mus.p(k, 0.f);
|
||||
shifts.p(k, k==0 ? col0.e<T>(0) : diag.e<T>(k));
|
||||
continue;
|
||||
}
|
||||
|
||||
T left = diag.e<T>(k);
|
||||
T right;
|
||||
|
||||
if(k==curLen-1)
|
||||
right = diag.e<T>(curLen-1) + col0.reduceNumber(reduce::Norm2).e<T>(0);
|
||||
else {
|
||||
|
||||
int l = k+1;
|
||||
while(col0.e<T>(l) == (T)0.f) {
|
||||
++l;
|
||||
if(l >= curLen)
|
||||
throw std::runtime_error("ops::helpers::SVD::calcSingVals method: l >= curLen !");
|
||||
}
|
||||
|
||||
right = diag.e<T>(l);
|
||||
}
|
||||
|
||||
T mid = left + (right - left) / (T)2.;
|
||||
T fMid = secularEq(mid, col0, diag, permut, diag, 0.);
|
||||
T shift = (k == curLen-1 || fMid > (T)0.) ? left : right;
|
||||
|
||||
auto diagShifted = diag - shift;
|
||||
|
||||
T muPrev, muCur;
|
||||
if (shift == left) {
|
||||
muPrev = (right - left) * 0.1;
|
||||
if (k == curLen-1)
|
||||
muCur = right - left;
|
||||
else
|
||||
muCur = (right - left) * 0.5;
|
||||
}
|
||||
else {
|
||||
muPrev = -(right - left) * 0.1;
|
||||
muCur = -(right - left) * 0.5;
|
||||
}
|
||||
|
||||
T fPrev = secularEq(muPrev, col0, diag, permut, diagShifted, shift);
|
||||
T fCur = secularEq(muCur, col0, diag, permut, diagShifted, shift);
|
||||
|
||||
if (math::nd4j_abs<T>(fPrev) < math::nd4j_abs<T>(fCur)) {
|
||||
math::nd4j_swap<T>(fPrev, fCur);
|
||||
math::nd4j_swap<T>(muPrev, muCur);
|
||||
}
|
||||
|
||||
bool useBisection = fPrev * fCur > (T)0.;
|
||||
while (fCur != (T).0 &&
|
||||
math::nd4j_abs<T>(muCur - muPrev) > (T)8. * DataTypeUtils::eps<T>() * math::nd4j_max<T>(math::nd4j_abs<T>(muCur), math::nd4j_abs<T>(muPrev))
|
||||
&& math::nd4j_abs<T>(fCur - fPrev) > DataTypeUtils::eps<T>() && !useBisection) {
|
||||
|
||||
T a = (fCur - fPrev) / ((T)1./muCur - (T)1./muPrev);
|
||||
T jac = fCur - a / muCur;
|
||||
T muZero = -a/jac;
|
||||
T fZero = secularEq(muZero, col0, diag, permut, diagShifted, shift);
|
||||
|
||||
muPrev = muCur;
|
||||
fPrev = fCur;
|
||||
muCur = muZero;
|
||||
fCur = fZero;
|
||||
|
||||
if (shift == left && (muCur < (T)0. || muCur > right - left))
|
||||
useBisection = true;
|
||||
if (shift == right && (muCur < -(right - left) || muCur > (T)0.))
|
||||
useBisection = true;
|
||||
if (math::nd4j_abs<T>(fCur) > math::nd4j_abs<T>(fPrev) && math::nd4j_abs<T>(fCur - fPrev) > (T)16. * DataTypeUtils::eps<T>())
|
||||
useBisection = true;
|
||||
}
|
||||
|
||||
|
||||
if (useBisection) {
|
||||
|
||||
T leftShifted, rightShifted;
|
||||
if (shift == left) {
|
||||
leftShifted = DataTypeUtils::min<T>();
|
||||
rightShifted = (k==curLen-1) ? right : ((right - left) * (T)0.6);
|
||||
}
|
||||
else {
|
||||
|
||||
leftShifted = -(right - left) * (T)0.6;
|
||||
rightShifted = -DataTypeUtils::min<T>();
|
||||
}
|
||||
|
||||
T fLeft = secularEq(leftShifted, col0, diag, permut, diagShifted, shift);
|
||||
T fRight = secularEq(rightShifted, col0, diag, permut, diagShifted, shift);
|
||||
// if(fLeft * fRight >= (T)0.)
|
||||
// throw "ops::helpers::SVD::calcSingVals method: fLeft * fRight >= (T)0. !";
|
||||
|
||||
while (rightShifted - leftShifted > (T)2.f * DataTypeUtils::eps<T>() * math::nd4j_max<T>(math::nd4j_abs<T>(leftShifted), math::nd4j_abs<T>(rightShifted))) {
|
||||
|
||||
T midShifted = (leftShifted + rightShifted) / (T)2.;
|
||||
fMid = secularEq(midShifted, col0, diag, permut, diagShifted, shift);
|
||||
if (fLeft * fMid < (T)0.)
|
||||
rightShifted = midShifted;
|
||||
else {
|
||||
leftShifted = midShifted;
|
||||
fLeft = fMid;
|
||||
}
|
||||
}
|
||||
muCur = (leftShifted + rightShifted) / (T)2.;
|
||||
}
|
||||
singVals.p(k, shift + muCur);
|
||||
shifts.p(k, shift);
|
||||
mus.p(k, muCur);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void SVD<T>::perturb(const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& zhat) {
|
||||
|
||||
int n = col0.lengthOf();
|
||||
int m = permut.lengthOf();
|
||||
if(m==0) {
|
||||
zhat.assign(0.);
|
||||
return;
|
||||
}
|
||||
|
||||
int last = permut.e<int>(m-1);
|
||||
|
||||
for (int k = 0; k < n; ++k) {
|
||||
|
||||
if (col0.e<T>(k) == (T)0.f)
|
||||
zhat.p(k, (T)0.f);
|
||||
else {
|
||||
T dk = diag.e<T>(k);
|
||||
T prod = (singVals.e<T>(last) + dk) * (mus.e<T>(last) + (shifts.e<T>(last) - dk));
|
||||
|
||||
for(int l = 0; l<m; ++l) {
|
||||
int i = permut.e<int>(l);
|
||||
if(i!=k) {
|
||||
int j = i<k ? i : permut.e<int>(l-1);
|
||||
prod *= ((singVals.e<T>(j)+dk) / ((diag.e<T>(i)+dk))) * ((mus.e<T>(j)+(shifts.e<T>(j)-dk)) / ((diag.e<T>(i)-dk)));
|
||||
}
|
||||
}
|
||||
T tmp = math::nd4j_sqrt<T,T>(prod);
|
||||
zhat.p(k, col0.e<T>(k) > (T)0.f ? tmp : -tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void SVD<T>::calcSingVecs(const NDArray& zhat, const NDArray& diag, const NDArray& perm, const NDArray& singVals,
|
||||
const NDArray& shifts, const NDArray& mus, NDArray& U, NDArray& V) {
|
||||
|
||||
int n = zhat.lengthOf();
|
||||
int m = perm.lengthOf();
|
||||
|
||||
for (int k = 0; k < n; ++k) {
|
||||
|
||||
auto colU = new NDArray(U({0,0, k,k+1}, true));
|
||||
*colU = 0.;
|
||||
NDArray* colV = nullptr;
|
||||
|
||||
if (_calcV) {
|
||||
colV = new NDArray(V({0,0, k,k+1}, true));
|
||||
*colV = 0.;
|
||||
}
|
||||
|
||||
if (zhat.e<T>(k) == (T)0.f) {
|
||||
colU->p(k, 1.f);
|
||||
|
||||
if (_calcV)
|
||||
colV->p(k, 1.f);
|
||||
}
|
||||
else {
|
||||
|
||||
for(int l = 0; l < m; ++l) {
|
||||
int i = perm.e<int>(l);
|
||||
U.p(i,k, zhat.e<T>(i)/(((diag.e<T>(i) - shifts.e<T>(k)) - mus.e<T>(k)) )/( (diag.e<T>(i) + singVals.e<T>(k))));
|
||||
}
|
||||
U.p(n,k, 0.f);
|
||||
*colU /= colU->reduceNumber(reduce::Norm2);
|
||||
|
||||
if (_calcV) {
|
||||
|
||||
for(int l = 1; l < m; ++l){
|
||||
int i = perm.e<T>(l);
|
||||
V.p(i,k, diag.e<T>(i) * zhat.e<T>(i) / (((diag.e<T>(i) - shifts.e<T>(k)) - mus.e<T>(k)) )/( (diag.e<T>(i) + singVals.e<T>(k))));
|
||||
}
|
||||
V.p(0,k, -1.f);
|
||||
*colV /= colV->reduceNumber(reduce::Norm2);
|
||||
}
|
||||
}
|
||||
delete colU;
|
||||
if (_calcV)
|
||||
delete colV;
|
||||
}
|
||||
|
||||
auto colU = U({0,0, n,n+1}, true);
|
||||
colU = 0.;
|
||||
colU.p(n, 1.);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void SVD<T>::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDArray& V) {
|
||||
|
||||
const T almostZero = DataTypeUtils::min<T>();
|
||||
auto col0 = _m({col1, col1+size, col1, col1+1}, true);
|
||||
auto diag = static_cast<const NDArray&>(_m({col1, col1+size, col1, col1+size}, true).diagonal('c'));
|
||||
|
||||
diag.p(Nd4jLong(0), T(0));
|
||||
singVals = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
|
||||
U = NDArrayFactory::create<T>(_u.ordering(), {size+1, size+1}, _u.getContext());
|
||||
if (_calcV)
|
||||
V = NDArrayFactory::create<T>(_v.ordering(), {size, size}, _v.getContext());
|
||||
|
||||
int curSize = size;
|
||||
while(curSize > 1 && diag.template e<T>(curSize-1) == (T)0.f)
|
||||
--curSize;
|
||||
|
||||
int m = 0;
|
||||
std::vector<T> indices;
|
||||
for(int k = 0; k < curSize; ++k)
|
||||
if(math::nd4j_abs<T>(col0.template e<T>(k)) > almostZero)
|
||||
indices.push_back((T)k);
|
||||
|
||||
auto permut = NDArrayFactory::create<T>(_m.ordering(), {1, (int)indices.size()}, indices, _m.getContext());
|
||||
auto shifts = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
|
||||
auto mus = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
|
||||
auto zhat = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
|
||||
|
||||
calcSingVals(col0, diag, permut, singVals, shifts, mus);
|
||||
perturb(col0, diag, permut, singVals, shifts, mus, zhat);
|
||||
calcSingVecs(zhat, diag, permut, singVals, shifts, mus, U, V);
|
||||
|
||||
for(int i=0; i<curSize-1; ++i) {
|
||||
|
||||
if(singVals.e<T>(i) > singVals.e<T>(i+1)) {
|
||||
T _e0 = singVals.e<T>(i);
|
||||
T _e1 = singVals.e<T>(i+1);
|
||||
//math::nd4j_swap<T>(singVals(i),singVals(i+1));
|
||||
singVals.p(i, _e1);
|
||||
singVals.p(i+1, _e0);
|
||||
|
||||
auto temp1 = U({0,0, i,i+1}, true);
|
||||
auto temp2 = U({0,0, i+1,i+2}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
|
||||
if(_calcV) {
|
||||
auto temp1 = V({0,0, i,i+1}, true);
|
||||
auto temp2 = V({0,0, i+1,i+2}, true);
|
||||
auto temp3 = temp1;
|
||||
temp1.assign(temp2);
|
||||
temp2.assign(temp3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto temp1 = singVals({0,curSize, 0,0}, true);
|
||||
for (int e = 0; e < curSize / 2; ++e) {
|
||||
T tmp = temp1.e<T>(e);
|
||||
temp1.p(e, temp1.e<T>(curSize-1-e));
|
||||
temp1.p(curSize-1-e, tmp);
|
||||
}
|
||||
|
||||
auto temp2 = U({0,0, 0,curSize}, true);
|
||||
for(int i = 0; i < curSize/2; ++i) {
|
||||
auto temp3 = temp2({0,0, i,i+1}, true);
|
||||
auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true);
|
||||
auto temp5 = temp3;
|
||||
temp3.assign(temp4);
|
||||
temp4.assign(temp5);
|
||||
}
|
||||
|
||||
if (_calcV) {
|
||||
auto temp2 = V({0,0, 0,curSize}, true);
|
||||
for(int i = 0; i < curSize/2; ++i) {
|
||||
auto temp3 = temp2({0,0, i,i+1}, true);
|
||||
auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true);
|
||||
auto temp5 = temp3;
|
||||
temp3.assign(temp4);
|
||||
temp4.assign(temp5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shift) {
|
||||
|
||||
// requires rows = cols + 1;
|
||||
const int n = col2 - col1 + 1;
|
||||
const int k = n/2;
|
||||
const T almostZero = DataTypeUtils::min<T>();
|
||||
T alphaK;
|
||||
T betaK;
|
||||
T r0;
|
||||
T lambda, phi, c0, s0;
|
||||
auto l = NDArrayFactory::create<T>(_u.ordering(), {1, k}, _u.getContext());
|
||||
auto f = NDArrayFactory::create<T>(_u.ordering(), {1, n-k-1}, _u.getContext());
|
||||
|
||||
if(n < _switchSize) {
|
||||
|
||||
JacobiSVD<T> jac(_m({col1,col1+n+1, col1,col1+n}, true), _calcU, _calcV, _fullUV);
|
||||
|
||||
if (_calcU) {
|
||||
auto temp = _u({col1,col1+n+1, col1,col1+n+1}, true);
|
||||
temp.assign(jac._u);
|
||||
}
|
||||
else {
|
||||
auto temp1 = _u({0,1, col1,col1+n+1}, true);
|
||||
temp1.assign(jac._u({0,1, 0,0}, true));
|
||||
auto temp2 = _u({1,2, col1,col1+n+1}, true);
|
||||
temp2.assign(jac._u({n,n+1, 0,0}, true));
|
||||
}
|
||||
|
||||
if (_calcV) {
|
||||
auto temp = _v({row1W,row1W+n, col1W,col1W+n}, true);
|
||||
temp.assign(jac._v);
|
||||
}
|
||||
|
||||
auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true);
|
||||
temp.assign(0.);
|
||||
auto diag = _m.diagonal('c');
|
||||
diag({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
alphaK = _m.e<T>(col1 + k, col1 + k);
|
||||
betaK = _m.e<T>(col1 + k + 1, col1 + k);
|
||||
|
||||
DivideAndConquer(k + 1 + col1, col2, k + 1 + row1W, k + 1 + col1W, shift);
|
||||
DivideAndConquer(col1, k - 1 + col1, row1W, col1W + 1, shift + 1);
|
||||
|
||||
if (_calcU) {
|
||||
lambda = _u.e<T>(col1 + k, col1 + k);
|
||||
phi = _u.e<T>(col1 + k + 1, col2 + 1);
|
||||
}
|
||||
else {
|
||||
lambda = _u.e<T>(1, col1 + k);
|
||||
phi = _u.e<T>(0, col2 + 1);
|
||||
}
|
||||
|
||||
r0 = math::nd4j_sqrt<T, T>((math::nd4j_abs<T>(alphaK * lambda) * math::nd4j_abs<T>(alphaK * lambda)) + math::nd4j_abs<T>(betaK * phi) * math::nd4j_abs<T>(betaK * phi));
|
||||
|
||||
if(_calcU) {
|
||||
l.assign(_u({col1+k, col1+k+1, col1,col1+k}, true));
|
||||
f.assign(_u({col1+k+1,col1+k+2, col1+k+1,col1+n}, true));
|
||||
}
|
||||
else {
|
||||
l.assign(_u({1,2, col1, col1+k}, true));
|
||||
f.assign(_u({0,1, col1+k+1, col1+n}, true));
|
||||
}
|
||||
|
||||
// UofSVD.printIndexedBuffer();
|
||||
// VofSVD.printIndexedBuffer();
|
||||
// singVals.printIndexedBuffer();
|
||||
// printf("!! \n");
|
||||
|
||||
if (_calcV)
|
||||
_v.p(row1W+k, col1W, 1.f);
|
||||
|
||||
if (r0 < almostZero){
|
||||
c0 = 1.;
|
||||
s0 = 0.;
|
||||
}
|
||||
else {
|
||||
c0 = alphaK * lambda / r0;
|
||||
s0 = betaK * phi / r0;
|
||||
}
|
||||
|
||||
if (_calcU) {
|
||||
|
||||
auto temp = _u({col1,col1+k+1, col1+k,col1+k+1}, true);
|
||||
NDArray q1(temp);
|
||||
|
||||
for (int i = col1 + k - 1; i >= col1; --i) {
|
||||
auto temp = _u({col1,col1+k+1, i+1,i+2}, true);
|
||||
temp.assign(_u({col1, col1+k+1, i, i+1}, true));
|
||||
}
|
||||
|
||||
_u({col1,col1+k+1, col1,col1+1}, true).assign(q1 * c0);
|
||||
_u({col1,col1+k+1, col2+1,col2+2}, true).assign(q1 * (-s0));
|
||||
_u({col1+k+1,col1+n+1, col1, col1+1}, true).assign(static_cast<const NDArray&>(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true)) * s0);
|
||||
_u({col1+k+1,col1+n+1, col2+1,col2+2}, true) *= c0;
|
||||
}
|
||||
else {
|
||||
|
||||
T q1 = _u.e<T>(0, col1 + k);
|
||||
|
||||
for (int i = col1 + k - 1; i >= col1; --i)
|
||||
_u.p(0, i+1, _u.e<T>(0, i));
|
||||
|
||||
_u.p(0, col1, q1 * c0);
|
||||
_u.p(0, col2+1, -q1*s0);
|
||||
_u.p(1, col1, _u.e<T>(1, col2+1) * s0);
|
||||
_u.p(1, col2 + 1, _u.e<T>(1, col2 + 1) * c0);
|
||||
_u({1,2, col1+1, col1+k+1}, true) = 0.f;
|
||||
_u({0,1, col1+k+1, col1+n}, true) = 0.f;
|
||||
}
|
||||
|
||||
_m.p(col1 + shift, col1 + shift, r0);
|
||||
auto temp1 = _m({col1+shift+1,col1+shift+k+1, col1+shift,col1+shift+1}, true);
|
||||
temp1.assign(l*alphaK);
|
||||
auto temp2 = _m({col1+shift+k+1,col1+shift+n, col1+shift,col1+shift+1}, true);
|
||||
temp2.assign(f*betaK);
|
||||
|
||||
deflation(col1, col2, k, row1W, col1W, shift);
|
||||
|
||||
NDArray UofSVD, VofSVD, singVals;
|
||||
calcBlockSVD(col1 + shift, n, UofSVD, singVals, VofSVD);
|
||||
|
||||
if(_calcU) {
|
||||
auto pTemp = _u({col1, col1+n+1, col1,col1+n+1}, true);
|
||||
auto temp = pTemp;
|
||||
pTemp.assign(mmul(temp, UofSVD));
|
||||
}
|
||||
else {
|
||||
auto pTemp = _u({0,0, col1,col1+n+1}, true);
|
||||
auto temp = pTemp;
|
||||
pTemp.assign(mmul(temp, UofSVD));
|
||||
}
|
||||
|
||||
if (_calcV) {
|
||||
auto pTemp = _v({row1W,row1W+n, row1W,row1W+n}, true);
|
||||
auto temp = pTemp;
|
||||
pTemp.assign(mmul(temp, VofSVD));
|
||||
}
|
||||
|
||||
auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true);
|
||||
blockM = 0.f;
|
||||
auto diag = blockM.diagonal('c');
|
||||
diag.assign(singVals);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void SVD<T>::exchangeUV(const HHsequence& hhU, const HHsequence& hhV, const NDArray& U, const NDArray& V) {
|
||||
|
||||
if (_calcU) {
|
||||
|
||||
int colsU = _fullUV ? hhU.rows() : _diagSize;
|
||||
auto temp1 = NDArrayFactory::create<T>(_u.ordering(), {hhU.rows(), colsU}, _u.getContext());
|
||||
temp1.setIdentity();
|
||||
_u = temp1;
|
||||
|
||||
auto temp2 = _u({0,_diagSize, 0,_diagSize}, true);
|
||||
temp2.assign(V({0,_diagSize, 0,_diagSize}, true));
|
||||
const_cast<HHsequence&>(hhU).mulLeft(_u);
|
||||
}
|
||||
|
||||
if (_calcV) {
|
||||
|
||||
int colsV = _fullUV ? hhV.rows() : _diagSize;
|
||||
auto temp1 = NDArrayFactory::create<T>(_v.ordering(), {hhV.rows(), colsV}, _v.getContext());
|
||||
temp1.setIdentity();
|
||||
_v = temp1;
|
||||
|
||||
auto temp2 = _v({0,_diagSize, 0,_diagSize}, true);
|
||||
temp2.assign(U({0,_diagSize, 0,_diagSize}, true));
|
||||
const_cast<HHsequence&>(hhV).mulLeft(_v);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void SVD<T>::evalData(const NDArray& matrix) {
|
||||
|
||||
const T almostZero = DataTypeUtils::min<T>();
|
||||
|
||||
if(matrix.sizeAt(1) < _switchSize) {
|
||||
|
||||
JacobiSVD<T> jac(matrix, _calcU, _calcV, _fullUV);
|
||||
|
||||
if(_calcU)
|
||||
_u = jac._u;
|
||||
if(_calcV)
|
||||
_v = jac._v;
|
||||
|
||||
_s.assign(jac._s);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
T scale = matrix.reduceNumber(reduce::AMax).e<T>(0);
|
||||
|
||||
if(scale == (T)0.)
|
||||
scale = 1.;
|
||||
|
||||
NDArray copy;
|
||||
if(_transp)
|
||||
copy = matrix.transpose();
|
||||
else
|
||||
copy = matrix / scale;
|
||||
|
||||
BiDiagonalUp biDiag(copy);
|
||||
|
||||
_u = 0.;
|
||||
_v = 0.;
|
||||
|
||||
auto temp1 = biDiag._HHbidiag.transpose();
|
||||
auto temp2 = _m({0,_diagSize, 0,0}, true);
|
||||
temp2.assign(temp1);
|
||||
|
||||
auto temp3 = _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}, true);
|
||||
temp3.assign(0.);
|
||||
|
||||
DivideAndConquer(0, _diagSize - 1, 0, 0, 0);
|
||||
|
||||
for (int i = 0; i < _diagSize; ++i) {
|
||||
T a = math::nd4j_abs<T>(_m.e<T>(i, i));
|
||||
_s.p(i, a * scale);
|
||||
if (a < almostZero) {
|
||||
auto temp = _s({i+1,_diagSize, 0,0}, true);
|
||||
temp.assign(0.);
|
||||
break;
|
||||
}
|
||||
else if (i == _diagSize-1)
|
||||
break;
|
||||
}
|
||||
|
||||
if(_transp)
|
||||
exchangeUV(biDiag.makeHHsequence('v'), biDiag.makeHHsequence('u'), _v, _u);
|
||||
else
|
||||
exchangeUV(biDiag.makeHHsequence('u'), biDiag.makeHHsequence('v'), _u, _v);
|
||||
}
|
||||
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SVD,,FLOAT_TYPES);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// svd operation, this function is not method of SVD class, it is standalone function
|
||||
template <typename T>
|
||||
|
@ -972,9 +67,10 @@ static void svd_(const NDArray* x, const std::vector<NDArray*>& outArrs, const b
|
|||
}
|
||||
}
|
||||
|
||||
void svd(sd::LaunchContext * context, const NDArray* x, const std::vector<NDArray*>& outArrs, const bool fullUV, const bool calcUV, const int switchNum) {
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void svd(sd::LaunchContext * context, const NDArray* x, const std::vector<NDArray*>& outArrs, const bool fullUV, const bool calcUV, const int switchNum) {
|
||||
BUILD_SINGLE_SELECTOR(x->dataType(), svd_, (x, outArrs, fullUV, calcUV, switchNum), FLOAT_TYPES);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -73,8 +73,8 @@ namespace helpers {
|
|||
NDArray sortedVals = NDArrayFactory::create<T>('c', {k}, input->getContext());
|
||||
NDArray topIndices = NDArrayFactory::create<Nd4jLong>('c', {k}, input->getContext());
|
||||
for (uint pos = 0; pos < k; ++pos) {
|
||||
topIndices.t<Nd4jLong>(pos) = pos;
|
||||
topValues.t<T>(pos) = trial.t<T>(pos);
|
||||
topIndices.r<Nd4jLong>(pos) = pos;
|
||||
topValues.r<T>(pos) = trial.t<T>(pos);
|
||||
}
|
||||
//std::vector<T> sortedVals(topValues);
|
||||
sortedVals.assign(topValues);// = NDArrayFactory::create<T>('c', {k});
|
||||
|
@ -93,9 +93,9 @@ namespace helpers {
|
|||
T* topBegin = reinterpret_cast<T*>(topValues.buffer());
|
||||
T* topEnd = topBegin + k;
|
||||
auto exchangePos = std::distance(topBegin, std::find(topBegin, topEnd, sortedVals.t<T>(0)));
|
||||
topValues.t<T>(exchangePos) = val; //*exchangeIt = val;
|
||||
topIndices.t<Nd4jLong>(exchangePos) = i;
|
||||
sortedVals.t<T>(0) = val; // suppress in sorted
|
||||
topValues.r<T>(exchangePos) = val; //*exchangeIt = val;
|
||||
topIndices.r<Nd4jLong>(exchangePos) = i;
|
||||
sortedVals.r<T>(0) = val; // suppress in sorted
|
||||
//std::sort(sortedVals.begin(), sortedVals.end()); // sorted in ascending order
|
||||
SpecialMethods<T>::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), false);
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ namespace helpers {
|
|||
for (Nd4jLong j = 0; j < width; j++)
|
||||
for (uint pos = 0; pos < k; ++pos)
|
||||
if (topValues.t<T>(pos) == trial.t<T>(j))
|
||||
topIndices.t<Nd4jLong>(pos) = j;
|
||||
topIndices.r<Nd4jLong>(pos) = j;
|
||||
}
|
||||
else { // else sort by indices
|
||||
std::map<Nd4jLong, T> sortValsMap;
|
||||
|
@ -121,8 +121,8 @@ namespace helpers {
|
|||
//});
|
||||
Nd4jLong e = 0;
|
||||
for (auto it = sortValsMap.begin(); it != sortValsMap.end(); ++it, e++) {
|
||||
topIndices.t<Nd4jLong>(e) = it->first;
|
||||
topValues.t<T>(e) = it->second;
|
||||
topIndices.r<Nd4jLong>(e) = it->first;
|
||||
topValues.r<T>(e) = it->second;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -39,17 +39,17 @@ namespace helpers {
|
|||
*
|
||||
* */
|
||||
template <typename T>
|
||||
static void lowerTriangularSolve(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
||||
static void lowerTriangularSolve(sd::LaunchContext * context, NDArray const * leftInput, NDArray const* rightInput, bool const unitsOnDiag, NDArray* output) {
|
||||
auto rows = leftInput->rows();
|
||||
auto cols = rightInput->columns();
|
||||
//output->t<T>(0,0) = rightInput->t<T>(0,0) / leftInput->t<T>(0,0);
|
||||
//output->r<T>(0,0) = rightInput->t<T>(0,0) / leftInput->t<T>(0,0);
|
||||
for (Nd4jLong r = 0; r < rows; r++) {
|
||||
for (Nd4jLong j = 0; j < cols; j++) {
|
||||
auto sum = rightInput->t<T>(r, j);
|
||||
for (Nd4jLong c = 0; c < r; c++) {
|
||||
sum -= leftInput->t<T>(r, c) * output->t<T>(c, j);
|
||||
}
|
||||
output->t<T>(r, j) = sum / leftInput->t<T>(r, r);
|
||||
output->r<T>(r, j) = unitsOnDiag?sum: sum / leftInput->t<T>(r, r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ namespace helpers {
|
|||
* */
|
||||
|
||||
template <typename T>
|
||||
static void upperTriangularSolve(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
||||
static void upperTriangularSolve(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, bool const unitsOnDiag, NDArray* output) {
|
||||
auto rows = leftInput->rows();
|
||||
auto cols = rightInput->columns();
|
||||
for (Nd4jLong r = rows; r > 0; r--) {
|
||||
|
@ -78,11 +78,31 @@ namespace helpers {
|
|||
for (Nd4jLong c = r; c < rows; c++) {
|
||||
sum -= leftInput->t<T>(r - 1, c) * output->t<T>(c, j);
|
||||
}
|
||||
output->t<T>(r - 1, j) = sum / leftInput->t<T>(r - 1, r - 1);
|
||||
output->r<T>(r - 1, j) = unitsOnDiag? sum : sum / leftInput->t<T>(r - 1, r - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// triangularSolve2D - 2D implementation of triangularSolveFunctor
|
||||
/// \tparam T - type of NDArray output
|
||||
/// \param context - launch context pointer
|
||||
/// \param leftInput - T matrix of equation Tx = b
|
||||
/// \param rightInput - b vector of equation Tx = b
|
||||
/// \param lower - lower or upper triangular matrix
|
||||
/// \param unitsOnDiag - solve for case when only units (1.0) on diagonal is assumed
|
||||
/// \param output - output vector (x on equation Tx = b)
|
||||
///
|
||||
template <typename T>
|
||||
void triangularSolve2D(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output) {
|
||||
if (lower) {
|
||||
lowerTriangularSolve<T>(context, &leftInput, &rightInput, unitsOnDiag, &output);
|
||||
}
|
||||
else {
|
||||
upperTriangularSolve<T>(context, &leftInput, &rightInput, unitsOnDiag, &output);
|
||||
}
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void triangularSolve2D, (sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output), FLOAT_TYPES);
|
||||
|
||||
template <typename T>
|
||||
static int triangularSolveFunctor_(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) {
|
||||
auto leftPart = leftInput->allTensorsAlongDimension({-2, -1});
|
||||
|
@ -92,9 +112,9 @@ namespace helpers {
|
|||
auto batchLoop = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
if (lower) {
|
||||
lowerTriangularSolve<T>(context, leftPart[i], rightPart[i], adjoint, outputPart[i]);
|
||||
lowerTriangularSolve<T>(context, leftPart[i], rightPart[i], false, outputPart[i]);
|
||||
} else {
|
||||
upperTriangularSolve<T>(context, leftPart[i], rightPart[i], adjoint, outputPart[i]);
|
||||
upperTriangularSolve<T>(context, leftPart[i], rightPart[i], false, outputPart[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -116,13 +136,13 @@ namespace helpers {
|
|||
if (!lower) {
|
||||
for (Nd4jLong r = 0; r < rows; r++) {
|
||||
for (Nd4jLong c = 0; c <= r; c++) {
|
||||
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
||||
outputPart[batch]->r<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (Nd4jLong r = 0; r < rows; r++) {
|
||||
for (Nd4jLong c = r; c < cols; c++) {
|
||||
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
||||
outputPart[batch]->r<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ static void triuBP_(sd::LaunchContext * context, const NDArray& input, const NDA
|
|||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
if (dOdI.t<T>(i) != static_cast<T>(0.f))
|
||||
dOdI.t<T>(i) = static_cast<T>(1.f);
|
||||
dOdI.r<T>(i) = static_cast<T>(1.f);
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0, dLen);
|
||||
|
|
|
@ -41,9 +41,9 @@ namespace sd {
|
|||
*
|
||||
* */
|
||||
template <typename T>
|
||||
static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
|
||||
static _CUDA_HD void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
|
||||
T const* rightInput, Nd4jLong const* rightInputShape,
|
||||
bool const adjoint, T* output, Nd4jLong const* outputShape,
|
||||
bool const unitOnDiag, T* output, const Nd4jLong* outputShape,
|
||||
Nd4jLong rows, Nd4jLong cols) {
|
||||
|
||||
for (auto r = 0; r < rows; r++) {
|
||||
|
@ -62,7 +62,7 @@ namespace sd {
|
|||
auto zcIndex = shape::getOffset(outputShape, posZ, 0);
|
||||
sum -= leftInput[xcIndex] * output[zcIndex];
|
||||
}
|
||||
output[zIndex] = sum / leftInput[xIndex];
|
||||
output[zIndex] = unitOnDiag?sum:sum / leftInput[xIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -82,9 +82,9 @@ namespace sd {
|
|||
* */
|
||||
|
||||
template <typename T>
|
||||
static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
|
||||
T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output,
|
||||
Nd4jLong const* outputShape, Nd4jLong rows, Nd4jLong cols) {
|
||||
static _CUDA_HD void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
|
||||
T const* rightInput, Nd4jLong const* rightInputShape, bool const unitOnDiag, T* output,
|
||||
const Nd4jLong* outputShape, Nd4jLong rows, Nd4jLong cols) {
|
||||
|
||||
for (auto r = rows; r > 0; r--) {
|
||||
for (auto j = 0; j < cols; j++) {
|
||||
|
@ -101,16 +101,16 @@ namespace sd {
|
|||
auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
|
||||
sum -= leftInput[xcIndex] * output[zcIndex];
|
||||
}
|
||||
output[zIndex] = sum / leftInput[xIndex];
|
||||
output[zIndex] = unitOnDiag?sum:sum / leftInput[xIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void triangularSolveKernel(T const* leftInput, Nd4jLong const* leftPartShape,
|
||||
T const* rightInput, Nd4jLong const* rightPartShape, bool const lower, bool const adjoint, T* output,
|
||||
Nd4jLong const* outputShape, Nd4jLong const* tadLeftShape, Nd4jLong const* tadLeftOffset, Nd4jLong const* tadRightShape,
|
||||
Nd4jLong const* tadRightOffset, Nd4jLong const* tadOutputShape, Nd4jLong const* tadOutputOffset, Nd4jLong batchNum) {
|
||||
T const* rightInput, Nd4jLong const* rightPartShape, bool const lower, bool const unitsOnDiag, T* output,
|
||||
const Nd4jLong* outputShape, const Nd4jLong* tadLeftShape, const Nd4jLong* tadLeftOffset, const Nd4jLong* tadRightShape,
|
||||
const Nd4jLong* tadRightOffset, const Nd4jLong* tadOutputShape, const Nd4jLong* tadOutputOffset, Nd4jLong batchNum) {
|
||||
|
||||
__shared__ Nd4jLong rows;
|
||||
__shared__ Nd4jLong cols;
|
||||
|
@ -130,16 +130,16 @@ namespace sd {
|
|||
auto pRightPart = rightInput + tadRightOffset[i];
|
||||
auto pOutputPart = output + tadOutputOffset[i];
|
||||
if (lower) {
|
||||
lowerTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols);
|
||||
lowerTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, unitsOnDiag, pOutputPart, tadOutputShape, rows, cols);
|
||||
} else {
|
||||
upperTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols);
|
||||
upperTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, unitsOnDiag, pOutputPart, tadOutputShape, rows, cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int triangularSolveFunctor_(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput,
|
||||
bool lower, bool adjoint, NDArray* output) {
|
||||
bool lower, bool unitsOnDiag, NDArray* output) {
|
||||
NDArray::prepareSpecialUse({output}, {leftInput, rightInput});
|
||||
auto leftTads = ConstantTadHelper::getInstance()->tadForDimensions(leftInput->shapeInfo(), {-2, -1});
|
||||
auto rightTads = ConstantTadHelper::getInstance()->tadForDimensions(rightInput->shapeInfo(), {-2, -1});
|
||||
|
@ -150,7 +150,7 @@ namespace sd {
|
|||
T const* rightBuf = reinterpret_cast<T const*>(rightInput->specialBuffer());
|
||||
T* outputBuf = reinterpret_cast<T*>(output->specialBuffer());
|
||||
triangularSolveKernel<T><<<128, 128, 256, *stream>>>(leftBuf, leftInput->specialShapeInfo(),
|
||||
rightBuf, rightInput->specialShapeInfo(), lower, adjoint, outputBuf, output->specialShapeInfo(),
|
||||
rightBuf, rightInput->specialShapeInfo(), lower, unitsOnDiag, outputBuf, output->specialShapeInfo(),
|
||||
leftTads.specialShapeInfo(), leftTads.specialOffsets(), rightTads.specialShapeInfo(),
|
||||
rightTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets(),
|
||||
leftTads.numberOfTads());
|
||||
|
@ -161,8 +161,41 @@ namespace sd {
|
|||
|
||||
}
|
||||
|
||||
int triangularSolveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return triangularSolveFunctor_, (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE);
|
||||
/// triangularSolve2D - 2D implementation of triangularSolveFunctor
|
||||
/// \tparam T - type of NDArray output
|
||||
/// \param context - launch context pointer
|
||||
/// \param leftInput - T matrix of equation Tx = b
|
||||
/// \param rightInput - b vector of equation Tx = b
|
||||
/// \param lower - lower or upper triangular matrix
|
||||
/// \param unitsOnDiag - solve for case when only units (1.0) on diagonal is assumed
|
||||
/// \param output - output vector (x on equation Tx = b)
|
||||
///
|
||||
template <typename T>
|
||||
void triangularSolve2D(sd::LaunchContext* context, const NDArray& leftInput, const NDArray& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output) {
|
||||
|
||||
triangularSolveFunctor_<T>(context, const_cast<NDArray*>(&leftInput), const_cast<NDArray*>(&rightInput), lower, unitsOnDiag, &output);
|
||||
|
||||
// leftInput.syncToHost(); rightInput.syncToHost(); output.syncToHost();
|
||||
// T const* pLeftPart = (T const*)leftInput.getBuffer();
|
||||
// T const* pRightPart = (T const*)rightInput.getBuffer();
|
||||
// T* pOutputPart = (T*)output.buffer();
|
||||
// auto rows = leftInput.rows();
|
||||
// auto cols = leftInput.columns();
|
||||
// if (lower) {
|
||||
// lowerTriangularSolve<T>(pLeftPart, leftInput.shapeInfo(), pRightPart, rightInput.shapeInfo(), unitsOnDiag, pOutputPart, output.shapeInfo(), rows, cols);
|
||||
// } else {
|
||||
// upperTriangularSolve<T>(pLeftPart, leftInput.shapeInfo(), pRightPart, rightInput.shapeInfo(), unitsOnDiag, pOutputPart, output.shapeInfo(), rows, cols);
|
||||
// }
|
||||
// output.syncToDevice();
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void triangularSolve2D, (sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output), FLOAT_TYPES);
|
||||
// template void triangularSolve2D<float>(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output);
|
||||
// template void triangularSolve2D<bfloat16>(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output);
|
||||
// template void triangularSolve2D<float16>(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output);
|
||||
// template void triangularSolve2D<double>(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output);
|
||||
|
||||
int triangularSolveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool unitsOnDiag, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return triangularSolveFunctor_, (context, leftInput, rightInput, lower, unitsOnDiag, output), FLOAT_NATIVE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -229,6 +262,76 @@ namespace sd {
|
|||
BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, (context, input, lower, output), FLOAT_NATIVE);
|
||||
}
|
||||
|
||||
/*
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void triangularSolve2D(sd::LaunchContext* context, NDArray const& A, NDArray const& b, bool const lower, bool const unitsOnDiag, NDArray& x) {
|
||||
|
||||
if(A.rankOf() != 2)
|
||||
throw std::runtime_error("triangularSolve2D: input matrix A must be 2D !");
|
||||
|
||||
int temp;
|
||||
|
||||
const bool isBvector = b.isCommonVector(temp);
|
||||
const bool isXvector = x.isCommonVector(temp);
|
||||
|
||||
if(A.sizeAt(0) != (isBvector ? b.lengthOf() : b.sizeAt(0)))
|
||||
throw std::runtime_error("triangularSolve2D: A and b must have the same number of rows !");
|
||||
|
||||
if(A.sizeAt(1) != (isXvector ? x.lengthOf() : x.sizeAt(0)))
|
||||
throw std::runtime_error("triangularSolve2D: columns number of array A must be equal to rows number of array x !");
|
||||
|
||||
if(isBvector) {
|
||||
|
||||
if(lower) {
|
||||
|
||||
for (int i = 0; i < A.sizeAt(0); ++i) {
|
||||
T sum = b.t<T>(i);
|
||||
for (int j = 0; j < i; ++j)
|
||||
sum -= A.t<T>(i,j) * x.t<T>(j);
|
||||
x.r<T>(i) = unitsOnDiag ? sum : sum / A.t<T>(i,i);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
for (int i = A.sizeAt(0) - 1; i >= 0; --i) {
|
||||
T sum = b.t<T>(i);
|
||||
for (int j = i + 1; j < A.sizeAt(1); ++j)
|
||||
sum -= A.t<T>(i,j) * x.t<T>(j);
|
||||
x.r<T>(i) = unitsOnDiag ? sum : sum / A.t<T>(i,i);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
if(lower) {
|
||||
|
||||
for (int bCol = 0; bCol < b.sizeAt(1); ++bCol) {
|
||||
for (int i = 0; i < A.sizeAt(0); ++i) {
|
||||
T sum = b.t<T>(i, bCol);
|
||||
for (int j = 0; j < i; ++j)
|
||||
sum -= A.t<T>(i,j) * x.t<T>(j, bCol);
|
||||
x.r<T>(i, bCol) = unitsOnDiag ? sum : sum / A.t<T>(i,i);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
for (int bCol = 0; bCol < b.sizeAt(1); ++bCol) {
|
||||
for (int i = A.sizeAt(0) - 1; i >= 0; --i) {
|
||||
T sum = b.t<T>(i, bCol);
|
||||
for (int j = i + 1; j < A.sizeAt(1); ++j)
|
||||
sum -= A.t<T>(i,j) * x.t<T>(j, bCol);
|
||||
x.r<T>(i, bCol) = unitsOnDiag ? sum : sum / A.t<T>(i,i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void triangularSolve2D, (sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output), FLOAT_TYPES);
|
||||
*/
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ namespace sd {
|
|||
// make sure host buffer is updated
|
||||
values.syncToHost();
|
||||
indices.syncToHost();
|
||||
output.syncToHost();
|
||||
|
||||
auto rank = output.rankOf();
|
||||
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* ThnIn program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which nIn available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* dnIntributed under the License nIn dnIntributed on an "AS nIn" BASnIn, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permnInsions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <helpers/Sqrtm.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void sqrtm_(const NDArray* x, NDArray* z) {
|
||||
|
||||
|
||||
if(x->rankOf() == 2) {
|
||||
|
||||
ops::helpers::Sqrtm<T>::calc(*x, *z);
|
||||
}
|
||||
else {
|
||||
|
||||
auto listX = x->allTensorsAlongDimension({-2, -1});
|
||||
auto listZ = z->allTensorsAlongDimension({-2, -1});
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
||||
for (auto i = start; i < stop; i++)
|
||||
ops::helpers::Sqrtm<T>::calc(*listX.at(i), *listZ.at(i));
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_tad(func, 0, listX.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void sqrtm(sd::LaunchContext* context, const NDArray* x, NDArray* z) {
|
||||
|
||||
x->syncToHost();
|
||||
BUILD_SINGLE_SELECTOR(z->dataType(), sqrtm_, (x, z), FLOAT_TYPES);
|
||||
z->syncToDevice();
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_SQRTM_HELPER_H
|
||||
#define LIBND4J_SQRTM_HELPER_H
|
||||
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include "array/NDArray.h"
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void sqrtm(sd::LaunchContext* context, const NDArray* x, NDArray* z);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif //LIBND4J_SQRTM_HELPER_H
|
|
@ -26,7 +26,9 @@ namespace sd {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output);
|
||||
int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool lower, bool unitsOnDiag, NDArray* output);
|
||||
template <typename T>
|
||||
void triangularSolve2D(sd::LaunchContext* context, const NDArray& leftInput, const NDArray& rightInput, const bool lower, const bool unitsOnDiag, NDArray& output);
|
||||
void adjointMatrix(sd::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -392,10 +392,10 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test12) {
|
|||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
weights.assign(0.5);
|
||||
weights.t<double>(0) = 0.;
|
||||
weights.t<double>(1) = 0.;
|
||||
weights.t<double>(2) = 0.;
|
||||
weights.t<double>(3) = 0.;
|
||||
weights.r<double>(0) = 0.;
|
||||
weights.r<double>(1) = 0.;
|
||||
weights.r<double>(2) = 0.;
|
||||
weights.r<double>(3) = 0.;
|
||||
|
||||
|
||||
sd::ops::log_loss_grad op;
|
||||
|
@ -431,9 +431,9 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test13) {
|
|||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
weights.assign(0.5);
|
||||
weights.t<double>(0) = 0.;
|
||||
weights.t<double>(1) = 0.;
|
||||
weights.t<double>(2) = 0.;
|
||||
weights.r<double>(0) = 0.;
|
||||
weights.r<double>(1) = 0.;
|
||||
weights.r<double>(2) = 0.;
|
||||
|
||||
sd::ops::log_loss_grad op;
|
||||
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
|
||||
|
@ -2399,10 +2399,10 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) {
|
|||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
weights.assign(0.5);
|
||||
weights.t<double>(0) = 0.;
|
||||
weights.t<double>(1) = 0.;
|
||||
weights.t<double>(2) = 0.;
|
||||
weights.t<double>(3) = 0.;
|
||||
weights.r<double>(0) = 0.;
|
||||
weights.r<double>(1) = 0.;
|
||||
weights.r<double>(2) = 0.;
|
||||
weights.r<double>(3) = 0.;
|
||||
|
||||
sd::ops::mean_sqerr_loss_grad op;
|
||||
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||
|
@ -2436,9 +2436,9 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) {
|
|||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
weights.assign(0.5);
|
||||
weights.t<double>(0) = 0.;
|
||||
weights.t<double>(1) = 0.;
|
||||
weights.t<double>(2) = 0.;
|
||||
weights.r<double>(0) = 0.;
|
||||
weights.r<double>(1) = 0.;
|
||||
weights.r<double>(2) = 0.;
|
||||
|
||||
sd::ops::mean_sqerr_loss_grad op;
|
||||
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||
|
@ -2830,10 +2830,10 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) {
|
|||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
weights.assign(0.5);
|
||||
weights.t<double>(0) = 0.;
|
||||
weights.t<double>(1) = 0.;
|
||||
weights.t<double>(2) = 0.;
|
||||
weights.t<double>(3) = 0.;
|
||||
weights.r<double>(0) = 0.;
|
||||
weights.r<double>(1) = 0.;
|
||||
weights.r<double>(2) = 0.;
|
||||
weights.r<double>(3) = 0.;
|
||||
|
||||
sd::ops::absolute_difference_loss_grad op;
|
||||
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||
|
@ -2867,9 +2867,9 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) {
|
|||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
weights.assign(0.5);
|
||||
weights.t<double>(0) = 0.;
|
||||
weights.t<double>(1) = 0.;
|
||||
weights.t<double>(2) = 0.;
|
||||
weights.r<double>(0) = 0.;
|
||||
weights.r<double>(1) = 0.;
|
||||
weights.r<double>(2) = 0.;
|
||||
|
||||
sd::ops::absolute_difference_loss_grad op;
|
||||
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
||||
|
@ -3305,10 +3305,10 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) {
|
|||
logits.linspace(-0.08, 0.04);
|
||||
labels.linspace(1);
|
||||
weights.assign(0.5);
|
||||
weights.t<double>(0) = 0.;
|
||||
weights.t<double>(1) = 0.;
|
||||
weights.t<double>(2) = 0.;
|
||||
weights.t<double>(3) = 0.;
|
||||
weights.r<double>(0) = 0.;
|
||||
weights.r<double>(1) = 0.;
|
||||
weights.r<double>(2) = 0.;
|
||||
weights.r<double>(3) = 0.;
|
||||
|
||||
|
||||
sd::ops::sigm_cross_entropy_loss_grad op;
|
||||
|
@ -3344,9 +3344,9 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) {
|
|||
logits.linspace(-0.08, 0.04);
|
||||
labels.linspace(1);
|
||||
weights.assign(0.5);
|
||||
weights.t<double>(0) = 0.;
|
||||
weights.t<double>(1) = 0.;
|
||||
weights.t<double>(2) = 0.;
|
||||
weights.r<double>(0) = 0.;
|
||||
weights.r<double>(1) = 0.;
|
||||
weights.r<double>(2) = 0.;
|
||||
|
||||
sd::ops::sigm_cross_entropy_loss_grad op;
|
||||
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
|
||||
|
|
|
@ -2065,500 +2065,6 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) {
|
|||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 3;
|
||||
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 3;
|
||||
|
||||
const int dataFormat = 1; // [bS,sL,nIn]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = false; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) {
|
||||
|
||||
const int sL = 4;
|
||||
const int bS = 3;
|
||||
const int nIn = 3;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 2; // [bS, nIn, sL]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {2,0,4}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 3;
|
||||
|
||||
const int dataFormat = 1; // [bS,sL,nIn]
|
||||
const int directionMode = 1; // backward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 2; // [bS, nIn, sL]
|
||||
const int directionMode = 1; // backward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 2; // [bS, nIn, sL]
|
||||
const int directionMode = 2; // bidirectional sum
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 1; // [bS,sL,nIn]
|
||||
const int directionMode = 3; // bidirectional concat
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 3; // [sL, bS, nIn]
|
||||
const int directionMode = 4; // bidirectional extra output dim
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_test1) {
|
||||
|
||||
|
|
|
@ -1923,7 +1923,6 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) {
|
|||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, gru_1) {
|
||||
|
||||
|
@ -1960,31 +1959,67 @@ TEST_F(DeclarableOpsTests15, gru_1) {
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, gru_bp_1) {
|
||||
TEST_F(DeclarableOpsTests15, sqrtm_1) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 5;
|
||||
const int nOut = 4;
|
||||
NDArray x1('c', {1,1}, {4.}, sd::DataType::DOUBLE);
|
||||
NDArray x2('c', {2,2}, {1.3,2,0.3,.5}, sd::DataType::DOUBLE);
|
||||
NDArray x3('c', {3,3}, {0.5 ,-0.4 ,1.2 ,-2.8 ,-0.2 ,-2.1 ,-2.4 ,-2.0 ,1.1}, sd::DataType::DOUBLE);
|
||||
NDArray x4('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||
NDArray x5('c', {5,5}, {2.4 ,0.3 ,0.0 ,1.1 ,1.8 ,0.1 ,1.7 ,2.7 ,1.5 ,2.6 ,0.6 ,2.1 ,2.2 ,1.0 ,0.2 ,1.2 ,2.8 ,1.9 ,0.8 ,2.0 ,0.5 ,1.6 ,0.9 ,1.4 ,2.5}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray exp1('c', {1,1}, {2.}, sd::DataType::DOUBLE);
|
||||
NDArray exp2('c', {2,2}, {1.0163674, 1.3341597,0.200124, 0.4827035}, sd::DataType::DOUBLE);
|
||||
NDArray exp3('c', {3,3}, {6.5692188, 2.6273616,-0.1387864,-16.8404762,-7.0296495, 0.9204148,-11.4664296,-5.834273 , 2.2087478}, sd::DataType::DOUBLE);
|
||||
NDArray exp4('c', {4,4}, {1.161387 ,-1.9343154, 0.230372 , 0.8660897,0.80588 , 3.4045446,-1.0152824,-2.0369467,2.2589629, 1.9674252, 1.5109997,-1.4283141,0.0226356, 1.3032279,-1.00396 , 1.8278487}, sd::DataType::DOUBLE);
|
||||
NDArray exp5('c', {5,5}, {1.4175046,-0.4425298, 0.1846149, 0.3166522, 0.9140631,-0.1929139, 0.2889113, 1.4045273, 0.2600026, 1.552021 , 0.1372758, 0.5703854, 1.3336126, 0.3869317,-0.082492 ,
|
||||
0.8607272, 3.1792474,-0.9499947, 0.8541668,-1.4243879, 0.0081136,-0.0622248, 0.4534325, 0.4641865, 1.8132138}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
sd::ops::sqrtm op;
|
||||
|
||||
NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE);
|
||||
auto results = op.evaluate({&x1}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
ASSERT_TRUE(exp1.isSameShape(results.at(0)));
|
||||
ASSERT_TRUE(exp1.equalsTo(results.at(0)));
|
||||
|
||||
Wx.linspace(1,-0.1);
|
||||
Wh.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
results = op.evaluate({&x2}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
ASSERT_TRUE(exp2.isSameShape(results.at(0)));
|
||||
ASSERT_TRUE(exp2.equalsTo(results.at(0)));
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &hI, &Wx, &Wh, &b}, {}, {});
|
||||
const OpArgsHolder argsHolderBP({&x, &hI, &Wx, &Wh, &b, &dLdh}, {}, {});
|
||||
results = op.evaluate({&x3}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
ASSERT_TRUE(exp3.isSameShape(results.at(0)));
|
||||
ASSERT_TRUE(exp3.equalsTo(results.at(0)));
|
||||
|
||||
sd::ops::gru opFF;
|
||||
sd::ops::gru_bp opBP;
|
||||
results = op.evaluate({&x4}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
ASSERT_TRUE(exp4.isSameShape(results.at(0)));
|
||||
ASSERT_TRUE(exp4.equalsTo(results.at(0)));
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
results = op.evaluate({&x5}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
ASSERT_TRUE(exp5.isSameShape(results.at(0)));
|
||||
ASSERT_TRUE(exp5.equalsTo(results.at(0)));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, sqrtm_2) {
|
||||
|
||||
NDArray x('c', {10,10}, {-0.3 ,2.7 ,4.9 ,7.0 ,7.3 ,-1.3 ,0.5 ,9.9 ,-9.4 ,8.4 ,2.2 ,5.2 ,7.6 ,1.2 ,2.0 ,-3.8 ,2.1 ,6.1 ,1.6 ,6.9 ,5.1 ,5.3 ,6.4 ,8.7 ,0.1 ,8.5 ,
|
||||
3.3 ,1.0 ,6.8 ,0.4 ,0.7 ,3.2 ,7.4 ,6.7 ,1.1 ,7.2 ,6.0 ,7.5 ,9.7 ,5.4 ,9.0 ,6.3 ,0.0 ,4.5 ,8.3 ,7.9 ,3.0 ,6.5 ,0.6 ,8.0 ,9.5 ,3.6 ,1.9 ,6.2 ,0.9 ,4.0 ,4.1 ,
|
||||
8.1 ,3.9 ,4.3 ,4.7 ,3.7 ,3.4 ,5.8 ,10.0 ,8.6 ,9.3 ,9.1 ,4.6 ,1.4 ,7.8 ,1.5 ,7.7 ,4.2 ,9.6 ,8.2 ,-7.1 ,5.7 ,5.5 ,2.6 ,8.8 ,2.9 ,0.2 ,5.6 ,-2.5 ,8.9 ,2.8 ,0.8 ,1.5 ,3.1 ,3.5 ,4.4 ,2.4 ,9.2 ,-4.8 ,1.7 ,6.6 ,9.8 ,1.8 ,5.9}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray expZ('c', {10,10}, {1.2779038, 0.0333321, 0.8215617, 0.5736392, 1.3973911, -1.1757741,0.1990005, 1.5893778, -3.0159568, 2.5829108,0.5692253, 2.219431 , 1.022612 , -0.3131795, -0.1957848, -1.7805065,
|
||||
0.6668489, 1.1968921, 0.9781974, 1.2007764,0.7028634, 0.7496937, 2.2511438, 2.1945378, 0.2559353, 2.8948612,-0.4306994, -0.9922216, 0.3884369, -1.4174481,
|
||||
-1.6060233, 0.1571057, 1.432471 , 0.4508346, 0.0618069, -2.4511742,2.0641709, 2.4751085, 1.84787 , 3.4146313,0.7774219, 0.768369 , -0.1417226, -0.3970577, 2.9512879, 0.5474537,
|
||||
0.4991412, 0.7604095, 0.4523091, 1.7813704,2.5998339, 0.9402402, -0.82775 , 2.3637147, -0.6394584, 4.6181937,-0.1762181, -0.2820475, 0.9280713, -2.1876918,
|
||||
0.1576249, 0.336376 , 0.2017592, 0.851786 , 1.3542577, 1.2752901,2.9718476, 1.1102557, 0.0067319, -0.2652283,0.8839235, -0.2637131, 1.5687876, 0.5156139, 1.9015886, 0.9087172,
|
||||
-1.5607482, 2.4216275, 1.0399745, -0.4930439,1.3044354, 0.1690006, 0.2106909, -0.2683631, -0.4193939, 1.0233265,0.4571777, -0.2024148, 2.3564855, 1.0442339,
|
||||
1.1073322, 1.0728525, -0.5917566, 2.2267418, -1.6096582, 2.0685315,0.6800798, 0.4451858, -0.4048465, 1.2347676}, sd::DataType::DOUBLE);
|
||||
sd::ops::sqrtm op;
|
||||
|
||||
auto results = op.evaluate({&x}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
ASSERT_TRUE(expZ.isSameShape(results.at(0)));
|
||||
ASSERT_TRUE(expZ.equalsTo(results.at(0)));
|
||||
}
|
||||
|
|
|
@ -241,6 +241,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) {
|
|||
ASSERT_EQ(exp, initial);
|
||||
}
|
||||
|
||||
#ifdef _RELEASE
|
||||
TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
|
||||
// [2,1,135079944,1,1,8192,1,99]
|
||||
auto initial = NDArrayFactory::create<float>('c', {1, 135079944});
|
||||
|
@ -287,6 +288,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
|
|||
|
||||
ASSERT_EQ(exp, initial);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,426 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
#include "testlayers.h"
|
||||
#include <helpers/HessenbergAndSchur.h>
|
||||
#include <helpers/EigenValsAndVecs.h>
|
||||
#include <helpers/FullPivLU.h>
|
||||
#include <ops/declarable/helpers/triangular_solve.h>
|
||||
#include <helpers/Sqrtm.h>
|
||||
|
||||
using namespace sd;
|
||||
|
||||
class HelpersTests2 : public testing::Test {
|
||||
public:
|
||||
|
||||
HelpersTests2() {
|
||||
|
||||
std::cout<<std::endl<<std::flush;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// #ifndef __CUDABLAS__
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, Hessenberg_1) {
|
||||
|
||||
|
||||
NDArray x1('c', {1,4}, {14,17,3,1}, sd::DataType::DOUBLE);
|
||||
NDArray x2('c', {1,1}, {14}, sd::DataType::DOUBLE);
|
||||
NDArray expQ('c', {1,1}, {1}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::Hessenberg<double> hess1(x1);
|
||||
ASSERT_TRUE(hess1._H.isSameShape(&x1));
|
||||
ASSERT_TRUE(hess1._H.equalsTo(&x1));
|
||||
ASSERT_TRUE(hess1._Q.isSameShape(&expQ));
|
||||
ASSERT_TRUE(hess1._Q.equalsTo(&expQ));
|
||||
|
||||
ops::helpers::Hessenberg<double> hess2(x2);
|
||||
ASSERT_TRUE(hess2._H.isSameShape(&x2));
|
||||
ASSERT_TRUE(hess2._H.equalsTo(&x2));
|
||||
ASSERT_TRUE(hess2._Q.isSameShape(&expQ));
|
||||
ASSERT_TRUE(hess2._Q.equalsTo(&expQ));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, Hessenberg_2) {
|
||||
|
||||
NDArray x('c', {2,2}, {1.5,-2,17,5}, sd::DataType::DOUBLE);
|
||||
NDArray expQ('c', {2,2}, {1,0,0,1}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::Hessenberg<double> hess(x);
|
||||
|
||||
// hess._H.printBuffer();
|
||||
|
||||
ASSERT_TRUE(hess._H.isSameShape(&x));
|
||||
ASSERT_TRUE(hess._H.equalsTo(&x));
|
||||
|
||||
ASSERT_TRUE(hess._Q.isSameShape(&expQ));
|
||||
ASSERT_TRUE(hess._Q.equalsTo(&expQ));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, Hessenberg_3) {
|
||||
|
||||
NDArray x('c', {3,3}, {33,24,-48,57,12.5,-3,1.1,10,-5.2}, sd::DataType::DOUBLE);
|
||||
NDArray expH('c', {3,3}, {33, -23.06939, -48.45414, -57.01061, 12.62845, 3.344058, 0, -9.655942, -5.328448}, sd::DataType::DOUBLE);
|
||||
NDArray expQ('c', {3,3}, {1,0,0,0, -0.99981, -0.019295, 0, -0.019295,0.99981}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::Hessenberg<double> hess(x);
|
||||
|
||||
ASSERT_TRUE(hess._H.isSameShape(&expH));
|
||||
ASSERT_TRUE(hess._H.equalsTo(&expH));
|
||||
|
||||
ASSERT_TRUE(hess._Q.isSameShape(&expQ));
|
||||
ASSERT_TRUE(hess._Q.equalsTo(&expQ));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, Hessenberg_4) {
|
||||
|
||||
NDArray x('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||
NDArray expH('c', {4,4}, {0.33, 0.4961181, 3.51599, 9.017665, -7.792702, 4.190221, 6.500328, 5.438888, 0, 3.646734, 0.4641911, -7.635502, 0,0, 5.873535, 5.105588}, sd::DataType::DOUBLE);
|
||||
NDArray expQ('c', {4,4}, {1,0,0,0, 0,-0.171956, 0.336675, -0.925787, 0,-0.973988,0.0826795, 0.210976, 0, 0.147574, 0.937984,0.3137}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::Hessenberg<double> hess(x);
|
||||
|
||||
ASSERT_TRUE(hess._H.isSameShape(&expH));
|
||||
ASSERT_TRUE(hess._H.equalsTo(&expH));
|
||||
|
||||
ASSERT_TRUE(hess._Q.isSameShape(&expQ));
|
||||
ASSERT_TRUE(hess._Q.equalsTo(&expQ));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, Hessenberg_5) {
|
||||
|
||||
NDArray x('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 ,
|
||||
-6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 ,
|
||||
0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 ,
|
||||
6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE);
|
||||
NDArray expH('c', {10,10}, {6.9, 6.125208, -8.070945, 7.219828, -9.363308, 2.181236, 5.995414, 3.892612, 4.982657, -2.088574,-12.6412, 1.212547, -6.449684, 5.162879, 0.4341714, -5.278079, -2.624011, -2.03615, 11.39619, -3.034842,
|
||||
0, -12.71931, 10.1146, 6.494434, -1.062934, 5.668906, -4.672953, -9.319893, -2.023392, 6.090341,0,0, 7.800521, -1.46286, 1.484626, -10.58252, -3.492978, 2.42187, 5.470045, 1.877265,
|
||||
0,0,0, 14.78259,-0.3147726, -5.74874, -0.377823, 3.310056, 2.242614, -5.111574,0,0,0,0, -9.709131, 3.885072, 6.762626, 4.509144, 2.390195, -4.991013,
|
||||
0,0,0,0,0, 8.126269, -12.32529, 9.030151, 1.390931, 0.8634045,0,0,0,0,0,0, -12.99477, 9.574299,-0.3098022, 4.910835,0,0,0,0,0,0,0, 14.75256, 18.95723, -5.054717,0,0,0,0,0,0,0,0, -4.577715, -5.440827,}, sd::DataType::DOUBLE);
|
||||
NDArray expQ('c', {10,10}, {1,0,0,0,0,0,0,0,0,0,0,-0.0079106,-0.38175,-0.39287,-0.26002,-0.44102,-0.071516,0.12118,0.64392,0.057562,
|
||||
0,0.28478,0.0058784,0.3837,-0.47888,0.39477,0.0036847,-0.24678,0.3229,0.47042,0,-0.031643,-0.61277,0.087648,0.12014,0.47648,-0.5288,0.060599,0.021434,-0.30102,
|
||||
0,0.23732,-0.17801,-0.31809,-0.31267,0.27595,0.30134,0.64555,-0.33392,0.13363,0,-0.023732,-0.40236,0.43089,-0.38692,-0.5178,-0.03957,-0.081667,-0.47515,-0.0077949,
|
||||
0,0.20568,-0.0169,0.36962,0.49669,-0.22475,-0.22199,0.50075,0.10454,0.46112,0,0.41926,0.30243,-0.3714,-0.16795,-0.12969,-0.67572,-0.1205,-0.26047,0.10407,
|
||||
0,-0.41135,-0.28357,-0.33858,0.18836,0.083822,-0.0068213,-0.30161,-0.24956,0.66327,0,0.68823,-0.33616,-0.12129,0.36163,-0.063256,0.34198,-0.37564,-0.048196,-0.058948}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::Hessenberg<double> hess(x);
|
||||
|
||||
ASSERT_TRUE(hess._H.isSameShape(&expH));
|
||||
ASSERT_TRUE(hess._H.equalsTo(&expH));
|
||||
|
||||
ASSERT_TRUE(hess._Q.isSameShape(&expQ));
|
||||
ASSERT_TRUE(hess._Q.equalsTo(&expQ));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, Schur_1) {
|
||||
|
||||
NDArray x('c', {3,3}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray expT('c', {3,3}, {-2.5, -2, 1, 0, 1.5, -2, 3, 4, 5}, sd::DataType::DOUBLE);
|
||||
NDArray expU('c', {3,3}, {0.3, 0.2,-0.1, 0,-0.1, 0.2, -0.3,-0.4, 0.5}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::Schur<double> schur(x);
|
||||
schur._T.linspace(-3, 1);
|
||||
schur._U.linspace(-0.3, 0.1);
|
||||
|
||||
schur.splitTwoRows(1, 0.5);
|
||||
|
||||
ASSERT_TRUE(schur._T.isSameShape(&expT));
|
||||
ASSERT_TRUE(schur._T.equalsTo(&expT));
|
||||
|
||||
ASSERT_TRUE(schur._U.isSameShape(&expU));
|
||||
ASSERT_TRUE(schur._U.equalsTo(&expU));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, Schur_2) {
|
||||
|
||||
NDArray x('c', {3,3}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray shift('c', {3}, sd::DataType::DOUBLE);
|
||||
NDArray exp1('c', {3}, {1,-3,0}, sd::DataType::DOUBLE);
|
||||
NDArray exp2('c', {3}, {3, 3,-7}, sd::DataType::DOUBLE);
|
||||
NDArray exp3('c', {3}, {0.964,0.964,0.964}, sd::DataType::DOUBLE);
|
||||
NDArray exp1T('c', {3,3}, {-3,-2,-1,0,1,2,3,4,5}, sd::DataType::DOUBLE);
|
||||
NDArray exp2T('c', {3,3}, {-8,-2,-1,0,-4,2,3,4,0}, sd::DataType::DOUBLE);
|
||||
NDArray exp3T('c', {3,3}, {-9.464102,-2,-1,0,-5.464102,2,3,4,-1.464102,}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::Schur<double> schur(x);
|
||||
// schur._U.linspace(-0.3, 0.1); // doesn't matter
|
||||
|
||||
schur._T.linspace(-3, 1);
|
||||
double expShift =0;
|
||||
schur.calcShift(1, 5, expShift, shift);
|
||||
ASSERT_TRUE(schur._T.equalsTo(&exp1T));
|
||||
ASSERT_TRUE(shift.isSameShape(&exp1));
|
||||
ASSERT_TRUE(shift.equalsTo(&exp1));
|
||||
ASSERT_TRUE(expShift == 0);
|
||||
|
||||
schur._T.linspace(-3, 1);
|
||||
expShift = 0;
|
||||
schur.calcShift(2, 10, expShift, shift);
|
||||
ASSERT_TRUE(schur._T.equalsTo(&exp2T));
|
||||
ASSERT_TRUE(shift.isSameShape(&exp2));
|
||||
ASSERT_TRUE(shift.equalsTo(&exp2));
|
||||
ASSERT_TRUE(expShift == 5);
|
||||
|
||||
schur._T.linspace(-3, 1);
|
||||
expShift = 0;
|
||||
schur.calcShift(2, 30, expShift, shift);
|
||||
ASSERT_TRUE(schur._T.equalsTo(&exp3T));
|
||||
ASSERT_TRUE(shift.isSameShape(&exp3));
|
||||
ASSERT_TRUE(shift.equalsTo(&exp3));
|
||||
ASSERT_TRUE((6.4641-0.00001) < expShift && expShift < (6.4641+0.00001));
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, Schur_3) {
|
||||
|
||||
NDArray x('c', {2,2}, {1.5,-2,17,5}, sd::DataType::DOUBLE);
|
||||
NDArray expU('c', {2,2}, {1,0,0,1}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::Schur<double> schur(x);
|
||||
|
||||
ASSERT_TRUE(schur._T.isSameShape(&x));
|
||||
ASSERT_TRUE(schur._T.equalsTo(&x));
|
||||
|
||||
ASSERT_TRUE(schur._U.isSameShape(&expU));
|
||||
ASSERT_TRUE(schur._U.equalsTo(&expU));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, Schur_4) {
|
||||
|
||||
NDArray x('c', {3,3}, {33,24,-48,57,12.5,-3,1.1,10,-5.2}, sd::DataType::DOUBLE);
|
||||
NDArray expT('c', {3,3}, {53.73337,-20.21406,-50.44809,0,-27.51557, 26.74307,0,0,14.0822}, sd::DataType::DOUBLE);
|
||||
NDArray expU('c', {3,3}, {-0.5848506, 0.7185352, 0.3763734,-0.7978391,-0.5932709,-0.1071558,-0.1462962, 0.3629555,-0.9202504}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::Schur<double> schur(x);
|
||||
|
||||
ASSERT_TRUE(schur._T.isSameShape(&expT));
|
||||
ASSERT_TRUE(schur._T.equalsTo(&expT));
|
||||
|
||||
ASSERT_TRUE(schur._U.isSameShape(&expU));
|
||||
ASSERT_TRUE(schur._U.equalsTo(&expU));
|
||||
}
|
||||
|
||||
/*
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, Schur_5) {
|
||||
|
||||
NDArray x('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||
NDArray expT('c', {4,4}, {6.940177,7.201107,2.523849,-8.534745,-3.109643,5.289615,-2.940507,9.330303, 0,0,-0.1740346, 7.19851,0,0, -2.870214, -1.965758}, sd::DataType::DOUBLE);
|
||||
NDArray expU('c', {4,4}, {-0.2602141, 0.8077556,-0.3352316,-0.4091935,0.3285353,-0.4395489,-0.4714875,-0.6903338,0.7536921, 0.3005626,-0.3910435, 0.4343908,-0.5062621, -0.252962,-0.7158242, 0.4090287}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::Schur<double> schur(x);
|
||||
|
||||
ASSERT_TRUE(schur._T.isSameShape(&expT));
|
||||
ASSERT_TRUE(schur._T.equalsTo(&expT));
|
||||
|
||||
ASSERT_TRUE(schur._U.isSameShape(&expU));
|
||||
ASSERT_TRUE(schur._U.equalsTo(&expU));
|
||||
}
|
||||
*/
|
||||
/*
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, Schur_6) {
|
||||
|
||||
NDArray x('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 ,
|
||||
-6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 ,
|
||||
0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 ,
|
||||
6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE);
|
||||
NDArray expT('c', {10,10}, {-13.78982, 6.072464, 0.3021194, -8.455495,-0.3047058, 4.033153, 2.610364, 2.80607, -2.735616, 0.3040549,-2.188506, -12.38324, -1.167179, -4.539672, -19.08546, 1.752401,-0.1354974,-0.2747422,-0.3270464, -5.070936,
|
||||
0,0,0.5067366, 7.930223,-0.6465996, 8.659522, 1.283713, 4.551415, 12.7736, 3.4812,0,0,-9.858142, -2.905068, -6.474159, -6.247967, 0.4720073, -10.49523, 3.617189, -4.941627,
|
||||
0,0,0,0,9.461626, -4.896166, 9.339704, 4.640336, 16.8626, 2.056027,0,0,0,0,6.479812, 8.462862, 7.386285, -4.123457, -5.817095, -2.633641,0,0,0,0,0,0,13.46667, -4.907281, 4.602204, 5.198035,
|
||||
0,0,0,0,0,0, 7.176822, 16.93311, 2.195036, 1.346086,0,0,0,0,0,0,0,0, 16.86979, -3.052473,0,0,0,0,0,0,0,0,0, -5.52268}, sd::DataType::DOUBLE);
|
||||
|
||||
// NDArray expT('c', {10,10}, {-13.78982, 6.072464, 0.1926198, -8.458698,-0.3047363, 4.033151, 2.610336, 2.806096, -2.735616, 0.3040549,-2.188506, -12.38324, -1.225857, -4.52418, -19.08548, 1.752257,-0.1354946,-0.2747435,-0.3270464, -5.070936,
|
||||
// 0,0, 0.4812058, 7.886377,-0.7304318, 8.577898, 1.289673, 4.415163, 12.81936, 3.416929,0,0, -9.901988, -2.879537, -6.465196, -6.359608, 0.455452, -10.55328, 3.451505, -4.986284,
|
||||
// 0,0,0,0, 9.461614, -4.896159, 9.339602, 4.64046, 16.86265, 2.056047,0,0,0,0, 6.47982, 8.462874, 7.386396, -4.123349, -5.816967, -2.633626,
|
||||
// 0,0,0,0,0,0, 13.46665, -4.907315, 4.602182, 5.198022,0,0,0,0,0,0, 7.176788, 16.93313, 2.195081, 1.346137,0,0,0,0,0,0,0,0, 16.86979, -3.052473,0,0,0,0,0,0,0,0,0, -5.52268}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray expU('c', {10,10}, {0.1964177, 0.2165192, -0.2138164, 0.4083154, -0.1872303, -0.5087223, 0.5529025, -0.2996174,-0.08772947, 0.07126534,-0.1906247, -0.223588, 0.3574755, 0.4245914, -0.3885589,-0.07328949, -0.4176507, -0.1885168, -0.4476957, 0.1971104,
|
||||
-0.2219015, 0.3084187, 0.1069209, -0.4905009, -0.3517786, 0.1446875, 0.121738, -0.3772941, 0.1232591, 0.5353205,-0.4766346, 0.6158252, -0.1529085, 0.04780914, 0.1274182, -0.1219211, -0.3123289, -0.2219282,-0.07613826, -0.429201,
|
||||
0.2577533, -0.3356205, -0.225358, -0.1540796, 0.3155174, -0.1904664, -0.3567101, -0.6831458, 0.1244646, 0.03383783, -0.45597, -0.3350697, 0.06824276, -0.2861978,-0.06724917, -0.7046481, 0.01664764, 0.2270567, 0.2003283,-0.01544937,
|
||||
0.122865, 0.1516775, -0.4446453, -0.2338583, 0.1633447, -0.193498, -0.198088, 0.3170272, -0.5869794, 0.4013553, 0.347383, 0.3666581, 0.6890763,-0.05797414, 0.3630058, -0.319958, -0.1071812, 0.06162044, 0.03171228, 0.1275262,
|
||||
-0.2986812, 0.05382598, -0.1484276, 0.4936468, 0.362756, 0.05858297, -0.1055183, 0.1090384, 0.4217073, 0.5534347, 0.3864388, 0.2085926, -0.204135, 0.05230855, -0.5290207, -0.1548485, -0.4670302, 0.2205726, 0.4380318,-0.01626632}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::Schur<double> schur(x);
|
||||
|
||||
ASSERT_TRUE(schur._T.isSameShape(&expT));
|
||||
ASSERT_TRUE(schur._T.equalsTo(&expT, 1e-3));
|
||||
|
||||
ASSERT_TRUE(schur._U.isSameShape(&expU));
|
||||
ASSERT_TRUE(schur._U.equalsTo(&expU));
|
||||
}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, EigenValsAndVecs_1) {
|
||||
|
||||
NDArray x('c', {2,2}, {1.5,-2,17,5}, sd::DataType::DOUBLE);
|
||||
NDArray expVals('c', {2,2}, {3.25,5.562149, 3.25,-5.562149}, sd::DataType::DOUBLE);
|
||||
NDArray expVecs('c', {2,2,2}, {-0.3094862,-0.0973726, -0.3094862,0.0973726,0,0.9459053, 0,-0.9459053}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::EigenValsAndVecs<double> eig(x);
|
||||
|
||||
ASSERT_TRUE(eig._Vals.isSameShape(&expVals));
|
||||
ASSERT_TRUE(eig._Vals.equalsTo(&expVals));
|
||||
|
||||
ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs));
|
||||
ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, EigenValsAndVecs_2) {
|
||||
|
||||
NDArray x('c', {3,3}, {33,24,-48,57,12.5,-3,1.1,10,-5.2}, sd::DataType::DOUBLE);
|
||||
NDArray expVals('c', {3,2}, {53.73337,0, -27.51557,0, 14.0822,0}, sd::DataType::DOUBLE);
|
||||
NDArray expVecs('c', {3,3,2}, {-0.5848506,0,0.5560778,0,-0.04889745,0,-0.7978391,0,-0.7683444,0,-0.8855156,0,-0.1462962,0,0.3168979,0,-0.4620293,0}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::EigenValsAndVecs<double> eig(x);
|
||||
|
||||
ASSERT_TRUE(eig._Vals.isSameShape(&expVals));
|
||||
ASSERT_TRUE(eig._Vals.equalsTo(&expVals));
|
||||
|
||||
ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs));
|
||||
ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, EigenValsAndVecs_3) {
|
||||
|
||||
NDArray x('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||
NDArray expVals('c', {4,2}, {6.114896,4.659591,6.114896,-4.659591, -1.069896,4.45631,-1.069896,-4.45631}, sd::DataType::DOUBLE);
|
||||
NDArray expVecs('c', {4,4,2}, {-0.2141303,0.4815241,-0.2141303,-0.4815241, 0.1035092,-0.4270603, 0.1035092,0.4270603, 0.2703519,-0.2892722, 0.2703519,0.2892722, -0.5256817,0.044061, -0.5256817,-0.044061,
|
||||
0.6202137,0.05521234,0.6202137,-0.05521234, -0.5756007,0.3932209,-0.5756007,-0.3932209,-0.4166034,-0.0651337, -0.4166034,0.0651337, -0.1723716,0.1138941,-0.1723716,-0.1138941}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::EigenValsAndVecs<double> eig(x);
|
||||
|
||||
ASSERT_TRUE(eig._Vals.isSameShape(&expVals));
|
||||
ASSERT_TRUE(eig._Vals.equalsTo(&expVals));
|
||||
|
||||
ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs));
|
||||
ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs));
|
||||
}
|
||||
|
||||
/*
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, EigenValsAndVecs_4) {
|
||||
|
||||
NDArray x('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 ,
|
||||
-6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 ,
|
||||
0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 ,
|
||||
6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE);
|
||||
NDArray expVals('c', {10,2}, { -13.08653,3.577011,-13.08653,-3.577011, -1.199166,8.675665,-1.199166,-8.675665,8.962244,
|
||||
5.610424, 8.962244,-5.610424, 15.19989,5.675794, 15.19989,-5.675794,16.86979,0,-5.52268,0}, sd::DataType::DOUBLE);
|
||||
NDArray expVecs('c', {10,10,2}, {0.1652385,0.1439317, 0.1652385,-0.1439317, -0.198272,0.207306, -0.198272,-0.207306, 0.1861466,-0.4599919, 0.1861466,0.4599919, 0.09384053,-0.4889922, 0.09384053,0.4889922, -0.6153314,0, -0.2180209,0,
|
||||
-0.1603652,-0.1466119, -0.1603652,0.1466119, 0.2817409,0.3301842, 0.2817409,-0.3301842, 0.09747303,-0.2218182, 0.09747303,0.2218182, 0.2318273,-0.3355113, 0.2318273,0.3355113, -0.4828878,0, -0.1451126,0,
|
||||
-0.1866771,0.1220412, -0.1866771,-0.1220412, 0.08937842,-0.3025104, 0.08937842,0.3025104, 0.2783766,0.2258364, 0.2783766,-0.2258364, -0.1413997,-0.09596012, -0.1413997,0.09596012, -0.2286925,0, 0.3290011,0,
|
||||
-0.4009741,0.238131, -0.4009741,-0.238131, -0.02772353,0.1338458, -0.02772353,-0.1338458, 0.09030543,-0.2222453, 0.09030543,0.2222453, 0.2565825,-0.2275446, 0.2565825,0.2275446, -0.2855937,0, -0.3950544,0,
|
||||
0.2168379,-0.1301121, 0.2168379,0.1301121, -0.165433,-0.1220125, -0.165433,0.1220125, -0.2685605,0.008133055,-0.2685605,-0.008133055, 0.1929395,-0.1194659, 0.1929395,0.1194659, 0.2206467,0, 0.3289105,0,
|
||||
-0.3835898,-0.2478813, -0.3835898,0.2478813, 0.1923005,-0.01036433, 0.1923005,0.01036433, -0.1711637,-0.3548358, -0.1711637,0.3548358, 0.2888441,0.09625169, 0.2888441,-0.09625169, 0.2595426,0, -0.1288072,0,
|
||||
0.1033616,0.09839151, 0.1033616,-0.09839151, -0.3080167,-0.1624564, -0.3080167,0.1624564,-0.03972293,-0.03967309, -0.03972293,0.03967309, 0.1965443,0.3025898, 0.1965443,-0.3025898, 0.04587166,0, 0.499261,0,
|
||||
0.2922398,0.2461792, 0.2922398,-0.2461792, 0.2769633,-0.2745029, 0.2769633,0.2745029, 0.1034687,-0.002947149, 0.1034687,0.002947149, -0.02611308,0.1658046, -0.02611308,-0.1658046, 0.2351063,0, -0.3787892,0,
|
||||
-0.2512689,-0.02169855, -0.2512689,0.02169855, -0.01481625,0.4376404, -0.01481625,-0.4376404, -0.2298635,-0.2360671, -0.2298635,0.2360671, 0.11004,-0.1467444, 0.11004,0.1467444, 0.1501568,0, 0.340117,0,
|
||||
0.325096,0.1712822, 0.325096,-0.1712822, -0.2412035,-0.09236849, -0.2412035,0.09236849, 0.3894343,-0.08673087, 0.3894343,0.08673087, 0.3125305,0.07128152, 0.3125305,-0.07128152, -0.2415555,0, 0.1841298,0,}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::EigenValsAndVecs<double> eig(x);
|
||||
|
||||
ASSERT_TRUE(eig._Vals.isSameShape(&expVals));
|
||||
ASSERT_TRUE(eig._Vals.equalsTo(&expVals));
|
||||
|
||||
ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs));
|
||||
ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs));
|
||||
}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, fullPivLU_1) {
|
||||
|
||||
NDArray a('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4,1}, {-5.,10,9,1}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray x = b.ulike();
|
||||
|
||||
NDArray expX('c', {4,1}, {0.8527251, -0.2545784, -1.076495, -0.8526268}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::FullPivLU<double>::solve(a,b,x);
|
||||
|
||||
ASSERT_TRUE(x.equalsTo(&expX));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, fullPivLU_2) {
|
||||
|
||||
NDArray a('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4,2}, {-5.,10,9,1,1.5,-2,17,5}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray x = b.ulike();
|
||||
|
||||
NDArray expX('c', {4,2}, {1.462913, 1.835338, 0.4083664, -2.163816, -3.344481, -3.739225, 0.5156383,0.01624954}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::FullPivLU<double>::solve(a,b,x);
|
||||
|
||||
ASSERT_TRUE(x.equalsTo(&expX));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, fullPivLU_3) {
|
||||
|
||||
NDArray a1('c', {4,3}, {0.33 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,2.24 ,-6.82 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||
NDArray a2('c', {3,4}, {0.33 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,2.24 ,-6.82 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||
NDArray b1('c', {4,2}, {-5.,10,9,1,1.5,-2,17,5}, sd::DataType::DOUBLE);
|
||||
NDArray b2('c', {3,2}, {-5.,10,9,1,1.5,-2}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray expX1('c', {3,2}, {0.9344955,-0.5841325, 0.8768102, 1.029137, -1.098021, 1.360152}, sd::DataType::DOUBLE);
|
||||
NDArray expX2('c', {4,2}, {0.3536033,0.5270184,0,0,-0.8292221,0.967515,0.01827441,2.856337}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray x1 = expX1.ulike();
|
||||
ops::helpers::FullPivLU<double>::solve(a1,b1,x1);
|
||||
ASSERT_TRUE(x1.equalsTo(&expX1));
|
||||
|
||||
NDArray x2 = expX2.ulike();
|
||||
ops::helpers::FullPivLU<double>::solve(a2,b2,x2);
|
||||
ASSERT_TRUE(x2.equalsTo(&expX2));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests2, fullPivLU_4) {
|
||||
|
||||
NDArray a('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 ,
|
||||
-6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 ,
|
||||
0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 ,
|
||||
6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {10,2}, {-5.,10,9,1,1.5,-2,17,5,3.6,0.12, -3.1,2.27,-0.5,27.3,8.9,5,-7,8,-9,10}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray x = b.ulike();
|
||||
|
||||
NDArray expX('c', {10,2}, {-0.697127, 2.58257, 2.109721,3.160622,-2.217796, -3.275736,-0.5752479, 2.475356,1.996841, -1.928947,
|
||||
2.213154,3.541014, 0.7104885, -1.981451,-3.297972,-0.4720612, 3.672657, 0.9161028, -2.322383, -1.784493}, sd::DataType::DOUBLE);
|
||||
|
||||
ops::helpers::FullPivLU<double>::solve(a,b,x);
|
||||
|
||||
ASSERT_TRUE(x.equalsTo(&expX));
|
||||
}
|
|
@ -90,6 +90,9 @@ TEST_F(NDArrayTest, NDArrayOrder1) {
|
|||
auto arrayF = new NDArray(arrayC->dup('f'));
|
||||
auto arrayC2 = new NDArray(arrayF->dup('c'));
|
||||
|
||||
arrayF->syncToHost();
|
||||
arrayC2->syncToHost();
|
||||
|
||||
ASSERT_EQ('c', arrayC->ordering());
|
||||
ASSERT_EQ('f', arrayF->ordering());
|
||||
ASSERT_EQ('c', arrayC2->ordering());
|
||||
|
|
|
@ -251,7 +251,7 @@ TEST_F(NativeOpsTests, ExecPairwise_2) {
|
|||
auto exp = NDArrayFactory::create<bool>('c', {5, 5});
|
||||
x.assign(true);
|
||||
y.assign(false);
|
||||
y.t<bool>(5) = true;
|
||||
y.r<bool>(5) = true;
|
||||
#ifdef __CUDABLAS__
|
||||
printf("Unsupported for cuda now.\n");
|
||||
#else
|
||||
|
|
|
@ -1168,6 +1168,529 @@ TEST_F(PlaygroundTests, lstmLayerCellBp_1) {
|
|||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 3;
|
||||
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 3;
|
||||
|
||||
const int dataFormat = 1; // [bS,sL,nIn]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = false; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) {
|
||||
|
||||
const int sL = 4;
|
||||
const int bS = 3;
|
||||
const int nIn = 3;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 2; // [bS, nIn, sL]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {2,0,4}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 3;
|
||||
|
||||
const int dataFormat = 1; // [bS,sL,nIn]
|
||||
const int directionMode = 1; // backward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 2; // [bS, nIn, sL]
|
||||
const int directionMode = 1; // backward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 2; // [bS, nIn, sL]
|
||||
const int directionMode = 2; // bidirectional sum
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 1; // [bS,sL,nIn]
|
||||
const int directionMode = 3; // bidirectional concat
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 3; // [sL, bS, nIn]
|
||||
const int directionMode = 4; // bidirectional extra output dim
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, gru_bp_1) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 5;
|
||||
const int nOut = 4;
|
||||
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
Wx.linspace(1,-0.1);
|
||||
Wh.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &hI, &Wx, &Wh, &b}, {}, {});
|
||||
const OpArgsHolder argsHolderBP({&x, &hI, &Wx, &Wh, &b, &dLdh}, {}, {});
|
||||
|
||||
sd::ops::gru opFF;
|
||||
sd::ops::gru_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ TEST_F(RNGTests, TestGenerator_SGA_1) {
|
|||
for (auto idx = 0; idx < array.lengthOf(); idx++) {
|
||||
float x = generator.relativeT(idx, -sd::DataTypeUtils::template max<float>() / 10,
|
||||
sd::DataTypeUtils::template max<float>() / 10);
|
||||
array.t<float>(idx) = x;
|
||||
array.r<float>(idx) = x;
|
||||
}
|
||||
auto minimum = array.reduceNumber(reduce::AMin);
|
||||
minimum.printBuffer("Randomly float min on 1M array");
|
||||
|
|
|
@ -115,7 +115,7 @@ elseif(WIN32)
|
|||
set(CMAKE_CXX_FLAGS " -g -fPIC -std=c++11 -Wa,-mbig-obj")
|
||||
endif()
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -DLINUX_BUILD=true")
|
||||
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -ffast-math -DFFAST_MATH=true -DLINUX_BUILD=true")
|
||||
|
||||
if ("${_RELEASE}" OR CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||
message("Release build for tests")
|
||||
|
@ -225,6 +225,17 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE))
|
|||
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE COMPILATION_UNITS false ../../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in)
|
||||
foreach(FL_ITEM ${COMPILATION_UNITS})
|
||||
string(REGEX MATCH "^(.*)\\.cpp\.in$" dummy ${FL_ITEM})
|
||||
set(FL_ITEM_WLE ${CMAKE_MATCH_1})
|
||||
foreach(FL_TYPE_INDEX RANGE 0 9)
|
||||
#message( "${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp")
|
||||
configure_file( "${FL_ITEM}" "${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp" @ONLY)
|
||||
LIST(APPEND CUSTOMOPS_GENERIC_SOURCES ${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp )
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
|
||||
# this function strips path from file name, basically making up short file name, i.e. file.cpp
|
||||
function(SHORTNAME LONG_NAME OUTPUT)
|
||||
|
|
Loading…
Reference in New Issue