Shugeo qr (#153)

* Added qr op implementation. Initial version.

* Fixed doc for qr op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Implementation of QR decomposition. CPU platform version.

* Added a pair of tests for qr op testing.

Signed-off-by: shugeo <sgazeos@gmail.com>

* QR implementation.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected norm using.

* Properly calculated intermediate results with QR decomposition.

* Another step to implement QR algorithm by householder.

* Cpu implementatio for QR decomposition. The first working edition.

* Corrected test to QR decomposition.

* Added tad multithreading with QR implementation.

* Finished cpu implementation for QR decomposition helpers.

* Refactored tests and improved multithreading.

* Refactored QR cpu implementation and update cuda implementation helpers.

* Cuda QR helper implementation. The first working edition.

* Eliminated waste prints.

* Restore multithreading with cuda implementation.

* Ops names corrected

* Refactored qr op helpers to optimize.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Eliminated waste manual ticking.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored memory allocation to avoid waste memory usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored matrixMinor method both for cuda and cpu platforms.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored method of vmul to use raw buffers instead type conversion.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored temporary array of matricies.

Signed-off-by: shugeo <sgazeos@gmail.com>

Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
Co-authored-by: raver119 <raver119@gmail.com>
master
shugeo 2020-01-22 12:59:36 +02:00 committed by raver119
parent 815a2908af
commit 2717b25931
11 changed files with 588 additions and 16 deletions

View File

@ -0,0 +1,88 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by GS <sgazeos@gmail.com> at 12/20/2019
//
#include <op_boilerplate.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/qr.h>
#if NOT_EXCLUDED(OP_qr)
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(qr, 1, 2, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto outputQ = OUTPUT_VARIABLE(0);
auto outputR = OUTPUT_VARIABLE(1);
auto fullMatricies = false;
if (block.getBArguments()->size())
fullMatricies = B_ARG(0);
REQUIRE_TRUE(input->rankOf() >=2, 0, "qr: The rank of input array should not be less than 2, but %i is given", input->rankOf());
REQUIRE_TRUE((fullMatricies && outputQ->sizeAt(-1) == input->sizeAt(-2)) || (!fullMatricies && outputQ->isSameShape(input)), 0, "qr: The last dimmensions should be equal to result Q, but %i and %i are given", outputQ->sizeAt(-1), input->sizeAt(-2));
REQUIRE_TRUE((fullMatricies && outputR->sizeAt(-1) == input->sizeAt(-1)) || (!fullMatricies && outputR->sizeAt(-1) == outputR->sizeAt(-2)), 0, "qr: The last dimmensions should be equal to result R, but %i and %i are given", outputR->sizeAt(-1), input->sizeAt(-1));
helpers::qr(block.launchContext(), input, outputQ, outputR, fullMatricies);
return Status::OK();
}
DECLARE_SHAPE_FN(qr) {
auto inShape = inputShape->at(0);
Nd4jLong* shapeQ;
Nd4jLong* shapeR;
int targetRank = shape::rank(inShape); // last two dimensions will be reduced to scalar
auto fullMatricies = false;
if (block.getBArguments()->size())
fullMatricies = B_ARG(0);
auto shape = ShapeUtils::shapeAsVector(inShape);
if (!fullMatricies) { // outputs are: Q is MxN and R is NxN
shape[targetRank - 1] = shape::sizeAt(inShape, -1);
shape[targetRank - 2] = shape[targetRank - 1];
shapeQ = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape),
shape::order(inShape), targetRank,
shape::shapeOf(inShape));
shapeR = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape),
shape::order(inShape), shape);
}
else {// otherwise outputs are Q is MxM and R is MxN with zero filled rows
shape[targetRank - 1] = shape::sizeAt(inShape, -2);
shape[targetRank - 2] = shape[targetRank - 1];
shapeR = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape),
shape::order(inShape), targetRank,
shape::shapeOf(inShape));
shapeQ = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape),
shape::order(inShape), shape);
}
return SHAPELIST(shapeQ, shapeR);
}
DECLARE_TYPES(qr) {
getOpDescriptor()
->setAllowedInputTypes({ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS});
}
}
}
#endif

View File

@ -162,8 +162,24 @@ namespace nd4j {
* Input : batched tensor with rank >=2
* Output: tensor with rank lesser by 1 from input
*/
#if NOT_EXCLUDED(OP_matrix_diag_part)
DECLARE_CUSTOM_OP(matrix_diag_part, 1, 1, false, 0, 0);
#endif
/**
* QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper triangular.
* For A (MxN) Q is M x M and R is (NxN).
*
* Input :
* 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of float matricies
*
* Output:
* 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies {Qs}
* 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular matricies {Rs}
*/
#if NOT_EXCLUDED(OP_qr)
DECLARE_CUSTOM_OP(qr, 1, 2, false, 0, 0);
#endif
/**
* This operation takes 2 arrays: original values, and values to be excluded. And returns 2 arrays: values left after exclusion, and indices in original array for surivals.

View File

@ -0,0 +1,133 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <ops/declarable/helpers/qr.h>
#include <helpers/MmulHelper.h>
#include <execution/Threads.h>
#include <NDArrayFactory.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
NDArray matrixMinor(NDArray& in, Nd4jLong col) {
NDArray m = in.ulike();
m.setIdentity();
m({col, m.rows(), col, m.columns()}).assign(in({col, m.rows(), col, m.columns()}));
return m;
}
/* m = I - v v^T */
template <typename T>
NDArray vmul(NDArray const& v, int n)
{
NDArray res('c', {n,n}, v.dataType()); // x = matrix_new(n, n);
T const* vBuf = v.getDataBuffer()->primaryAsT<T>();
T* resBuf = res.dataBuffer()->primaryAsT<T>();
auto interloop = PRAGMA_THREADS_FOR_2D {
for (int i = start_x; i < n; i += inc_x)
for (int j = start_y; j < n; j += inc_y)
resBuf[i * n + j] = -2 * vBuf[i] * vBuf[j] + (i == j ? T(1) : T(0));
};
samediff::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1);
return res;
}
template <typename T>
void qrSingle(NDArray* matrix, NDArray* Q, NDArray* R, bool const fullMatricies) {
Nd4jLong M = matrix->sizeAt(-2);
Nd4jLong N = matrix->sizeAt(-1);
auto resQ = fullMatricies?Q->ulike():NDArrayFactory::create<T>(matrix->ordering(), {M,M}, Q->getContext());
auto resR = fullMatricies?R->ulike():matrix->ulike();
std::vector<NDArray> q(M);
NDArray z = *matrix;
NDArray e('c', {M}, DataTypeUtils::fromT<T>()); // two internal buffers and scalar for squared norm
for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
e.nullify();
z = matrixMinor<T>(z, k); // minor computing for current column with given matrix z (initally is a input matrix)
// z.printIndexedBuffer("Minor!!!");
auto currentColumn = z({0, 0, k, k + 1}); // retrieve k column from z to x buffer
auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0});
if (matrix->t<T>(k,k) > T(0.f)) // negate on positive matrix diagonal element
norm *= T(-1.f);//.applyTransform(transform::Neg, nullptr, nullptr); //t<T>(0) = -norm.t<T>(0);
//e.t<T>(k) = T(1.f); // e - is filled by 0 vector except diagonal element (filled by 1)
//auto tE = e;
//tE *= norm;
// norm.printIndexedBuffer("Norm!!!");
e.p(k, norm);
e += currentColumn;// e += tE; // e[i] = x[i] + a * e[i] for each i from 0 to n - 1
auto normE = e.reduceAlongDimension(reduce::Norm2, {0});
e /= normE;
q[k] = vmul<T>(e, M);
auto qQ = z.ulike();
MmulHelper::matmul(&q[k], &z, &qQ, false, false);
z = std::move(qQ);
}
resQ.assign(q[0]); //
// MmulHelper::matmul(&q[0], matrix, &resR, false, false);
for (int i = 1; i < N && i < M - 1; i++) {
auto tempResQ = resQ;
MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); // use mmulMxM?
resQ = std::move(tempResQ);
}
MmulHelper::matmul(&resQ, matrix, &resR, false, false);
// resR *= -1.f;
resQ.transposei();
if (fullMatricies) {
Q->assign(resQ);
R->assign(resR);
}
else {
Q->assign(resQ({0,0, 0, N}));
R->assign(resR({0,N, 0, 0}));
}
}
template <typename T>
void qr_(NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
Nd4jLong lastDim = input->rankOf() - 1;
Nd4jLong preLastDim = input->rankOf() - 2;
ResultSet listOutQ(outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}));
ResultSet listOutR(outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}));
ResultSet listInput(input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}));
auto batching = PRAGMA_THREADS_FOR {
for (auto batch = start; batch < stop; batch += increment) {
//qr here
qrSingle<T>(listInput.at(batch), listOutQ.at(batch), listOutR.at(batch), fullMatricies);
}
};
samediff::Threads::parallel_tad(batching, 0, listOutQ.size(), 1);
}
void qr(nd4j::LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (input, outputQ, outputR, fullMatricies), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,180 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <ops/declarable/helpers/qr.h>
#include <NDArrayFactory.h>
#include <MmulHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static __global__ void matrixMinorKernel(T* outBuffer, Nd4jLong* outShape, T* inBuffer, Nd4jLong* inShape, Nd4jLong column, Nd4jLong rows, Nd4jLong columns) {
// auto tid = threadIdx.x + blockDim.x * blockIdx.x;
// auto step = blockDim.x * gridDim.x;
// if (threadIdx.x == 0) {
// for (auto i = tid; i < column; i += step) {
// Nd4jLong diagPos[] = {i, i};
// auto zIndex = shape::getOffset(outShape, diagPos);
// outBuffer[zIndex] = T(1.f);
// }
// }
// __syncthreads();
for (auto i = blockIdx.x; i < rows; i += gridDim.x)
for (auto j = threadIdx.x; j < columns; j += blockDim.x) {
Nd4jLong pos[] = {i,j};
auto zIndex = shape::getOffset(outShape, pos);
auto xIndex = shape::getOffset(inShape, pos);
if (i < column || j < column) {
outBuffer[zIndex] = i != j?T(0.f):T(1.f);
}
else
outBuffer[zIndex] = inBuffer[xIndex]; //m.t<T>(i,j) = in.t<T>(i,j);
}
}
template <typename T>
NDArray matrixMinor(LaunchContext* context, NDArray& in, Nd4jLong col) {
NDArray m = in.ulike();
m.setIdentity();
m({col, m.rows(), col, m.columns()}).assign(in({col, m.rows(), col, m.columns()}));
// auto stream = context->getCudaStream();
// matrixMinorKernel<T><<<128, 128, 256, *stream>>>(m.dataBuffer()->specialAsT<T>(), m.specialShapeInfo(),
// matrixMinorKernel<T><<<128, 128, 256, *stream>>>(m.dataBuffer()->specialAsT<T>(), m.specialShapeInfo(),
// reinterpret_cast<T*>(in.specialBuffer()), in.specialShapeInfo(), col, in.rows(), in.columns());
//
m.tickWriteDevice();
return m;
}
/* m = I - v v^T */
template <typename T>
static __global__ void vmulKernel(T* resBuf, Nd4jLong* resShape, T const* vBuff, Nd4jLong const* vShape, Nd4jLong n) {
for (auto i = blockIdx.x; i < n; i += gridDim.x)
for (auto j = threadIdx.x; j < n; j += blockDim.x) {
Nd4jLong posR[] = {i, j};
auto indexR = shape::getOffset(resShape, posR);
auto indexX = shape::getIndexOffset(i, vShape);
auto indexY = shape::getIndexOffset(j, vShape);
resBuf[indexR] = T(-2.f) * vBuff[indexX] * vBuff[indexY] + (i != j?T(0.f):T(1.f));
}
}
template <typename T>
NDArray vmul(LaunchContext* context, NDArray const& v, int n)
{
NDArray res('c', {n,n}, v.dataType(), context); // x = matrix_new(n, n);
auto stream = context->getCudaStream();
vmulKernel<T><<<128, 128, 128, *stream>>>(res.dataBuffer()->specialAsT<T>(), res.specialShapeInfo(),
reinterpret_cast<T const*>(v.getSpecialBuffer()), v.getSpecialShapeInfo(), n);
return res;
}
template <typename T>
static bool diagonalIsPositive(NDArray* matrix, Nd4jLong k) {
T hVal;
Nd4jLong pos[] = {k, k};
auto shift = shape::getOffset(matrix->shapeInfo(), pos);
cudaMemcpy(&hVal, matrix->specialBuffer(), sizeof(T), cudaMemcpyDeviceToHost);
return hVal > T(0.f);
}
template <typename T>
void qrSingle(LaunchContext* context, NDArray* matrix, NDArray* Q, NDArray* R, bool const fullMatricies) {
Nd4jLong M = matrix->sizeAt(0);
Nd4jLong N = matrix->sizeAt(1);
auto resQ = fullMatricies?Q->ulike():NDArrayFactory::create<T>(matrix->ordering(), {M,M}, Q->getContext());
auto resR = fullMatricies?R->ulike():matrix->ulike();
std::vector<NDArray> q(M);
NDArray z = *matrix;
NDArray e('c', {M}, DataTypeUtils::fromT<T>()); // two internal buffers and scalar for squared norm
for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
e.nullify();
z = matrixMinor<T>(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix)
auto currentColumn = z({0, 0, k, k + 1}); // retrieve k column from z to x buffer
auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0});
if (diagonalIsPositive<T>(matrix, k)) //matrix->t<T>(k,k) > T(0.f)) // negate on positive matrix diagonal element
norm.applyTransform(transform::Neg, norm); // *= -1.f;//-norm.t<T>(0);
e.p(k, norm); // e - is filled by 0 vector except diagonal element (filled by 1)
e += currentColumn; // e[i] = x[i] + a * e[i] for each i from 0 to n - 1
auto normE = e.reduceAlongDimension(reduce::Norm2, {0});
e /= normE;
q[k] = vmul<T>(context, e, M);
auto qQ = z.ulike();
MmulHelper::matmul(&q[k], &z, &qQ, false, false);
z = std::move(qQ);
}
resQ.assign(q[0]); //
// MmulHelper::matmul(&q[0], matrix, &resR, false, false);
for (int i = 1; i < N && i < M - 1; i++) {
auto tempResQ = resQ;
MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false);
resQ = std::move(tempResQ);
}
MmulHelper::matmul(&resQ, matrix, &resR, false, false);
// resR *= -1.f;
resQ.transposei();
if (fullMatricies) {
Q->assign(resQ);
R->assign(resR);
}
else {
Q->assign(resQ({0, 0, 0, N}));
R->assign(resR({0, N, 0, 0}));
}
}
template <typename T>
void qr_(LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
Nd4jLong lastDim = input->rankOf() - 1;
Nd4jLong preLastDim = input->rankOf() - 2;
NDArray::prepareSpecialUse({outputQ, outputR}, {input});
ResultSet listOutQ(outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}));
ResultSet listOutR(outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}));
ResultSet listInput(input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}));
auto start = 0;
auto stop = listInput.size();
auto increment = 1;
for (auto batch = start; batch < stop; batch += increment) {
//qr here
qrSingle<T>(context, listInput.at(batch), listOutQ.at(batch), listOutR.at(batch), fullMatricies);
}
NDArray::registerSpecialUse({outputQ, outputR}, {input});
}
void qr(nd4j::LaunchContext* context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) {
BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (context, input, outputQ, outputR, fullMatricies), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,35 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#ifndef __QR__H_HELPERS__
#define __QR__H_HELPERS__
#include <op_boilerplate.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
void qr(nd4j::LaunchContext * context, NDArray* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies);
}
}
}
#endif

View File

@ -2684,13 +2684,14 @@ TEST_F(DeclarableOpsTests12, LU_Test_3_3) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, LU_Test_4_1) {
auto in = NDArrayFactory::create<float>('c', {2, 2,2}, {0.7788f, 0.8012f,
0.7244f, 0.2309f,
0.7271f, 0.1804f,
0.5056f, 0.8925f});
auto in = NDArrayFactory::create<float>('c', {2, 2,2}, {
0.7788f, 0.8012f, 0.7244f, 0.2309f,
0.7271f, 0.1804f, 0.5056f, 0.8925f
});
auto expLU = NDArrayFactory::create<float>('c', {2, 2,2}, {
0.7788f, 0.8012f, 0.930149f, -0.514335f,
0.7271f, 0.1804f, 0.695365f, 0.767056f
0.7788f, 0.8012f, 0.930149f, -0.514335f,
0.7271f, 0.1804f, 0.695365f, 0.767056f
});
auto expP = NDArrayFactory::create<int>('c', {2,2}, {0, 1, 0, 1});
@ -2711,10 +2712,11 @@ TEST_F(DeclarableOpsTests12, LU_Test_4_1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, LU_Test_4_2) {
auto in = NDArrayFactory::create<float>('c', {2, 2,2}, {0.7788f, 0.8012f,
0.7244f, 0.2309f,
0.7271f, 0.1804f,
0.5056f, 0.8925f});
auto in = NDArrayFactory::create<float>('c', {2, 2,2}, {
0.7788f, 0.8012f, 0.7244f, 0.2309f,
0.7271f, 0.1804f, 0.5056f, 0.8925f
});
auto expLU = NDArrayFactory::create<float>('c', {2, 2,2}, {
0.7788f, 0.8012f, 0.930149f, -0.514335f,
0.7271f, 0.1804f, 0.695365f, 0.767056f
@ -2735,6 +2737,124 @@ TEST_F(DeclarableOpsTests12, LU_Test_4_2) {
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, QR_Test_1) {
auto in = NDArrayFactory::create<double>('c', {5,3}, {
12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.
});
auto expQ = NDArrayFactory::create<double>('c', {5, 5}, {
0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485
});
auto expR = NDArrayFactory::create<double>('c', {5,3}, {
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. });
nd4j::ops::qr op;
auto res = op.execute({&in}, {}, {}, {true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto q = res->at(0);
auto r = res->at(1);
// q->printIndexedBuffer("Orthogonal 5x5");
// expQ.printBuffer("Orthogonal Exp");
// r->printIndexedBuffer("Upper triangular 5x3");
// expR.printBuffer("Upper triangular Exp");
// q->printShapeInfo("Q shape");
// r->printShapeInfo("R shape");
nd4j::ops::matmul opMul;
auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false);
auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
ASSERT_TRUE(exp->isSameShape(in));
// ASSERT_TRUE(q->isSameShape(expQ));
//ASSERT_TRUE(expQ.equalsTo(q));
ASSERT_TRUE(exp->equalsTo(in));
delete res2;
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, QR_Test_1_1) {
auto in = NDArrayFactory::create<double>('c', {4, 5, 3}, {
12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.,
12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.,
12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.,
12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.
});
auto expQ = NDArrayFactory::create<double>('c', {4, 5, 5}, {
0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485,
0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485,
0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485,
0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485
});
auto expR = NDArrayFactory::create<double>('c', {4, 5,3}, {
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0.,
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0.,
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0.,
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0.
});
nd4j::ops::qr op;
auto res = op.execute({&in}, {}, {}, {true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto q = res->at(0);
auto r = res->at(1);
// q->printIndexedBuffer("Orthogonal 5x5");
// expQ.printBuffer("Orthogonal Exp");
// r->printIndexedBuffer("Upper triangular 5x3");
// expR.printBuffer("Upper triangular Exp");
// q->printShapeInfo("Q shape");
// r->printShapeInfo("R shape");
nd4j::ops::matmul opMul;
auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false);
auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
ASSERT_TRUE(exp->isSameShape(in));
// ASSERT_TRUE(q->isSameShape(expQ));
//ASSERT_TRUE(expQ.equalsTo(q));
ASSERT_TRUE(exp->equalsTo(in));
delete res2;
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, QR_Test_2) {
auto in = NDArrayFactory::create<double>('c', {5,3}, {
12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.
});
auto expQ = NDArrayFactory::create<double>('c', {5, 3}, {
0.8464148, 0.3912908, -0.3431241, -0.42320737, -0.9040873, 0.02927014, 0.28213826, -0.17042054, -0.93285596, 0.07053456, -0.01404065, 0.00109937, -0.14106913, 0.0166551, 0.10577161
});
auto expR = NDArrayFactory::create<double>('c', {3,3}, {
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546
});
nd4j::ops::qr op;
auto res = op.execute({&in}, {}, {}, {false});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto q = res->at(0);
auto r = res->at(1);
ASSERT_TRUE(q->isSameShape(expQ));
ASSERT_TRUE(r->isSameShape(expR));
// q->printIndexedBuffer("Orthogonal 5x5");
// r->printIndexedBuffer("Upper triangular 5x3");
nd4j::ops::matmul opMul;
auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false);
auto exp = res2->at(0);//->printIndexedBuffer("Result as result");
ASSERT_TRUE(exp->isSameShape(in));
ASSERT_TRUE(exp->equalsTo(in));
delete res2;
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, TriangularSolve_Test_1) {
@ -2883,7 +3003,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) {
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
z->printIndexedBuffer("TriangularSolve with adjoint");
// z->printIndexedBuffer("TriangularSolve with adjoint");
ASSERT_TRUE(exp.equalsTo(z));
delete res;

View File

@ -83,7 +83,7 @@ public class BroadcastMax extends BaseBroadcastOp {
@Override
public String tensorflowName() {
return "max";
return "Max";
}
@Override

View File

@ -83,7 +83,7 @@ public class BroadcastMin extends BaseBroadcastOp {
@Override
public String tensorflowName() {
return "min";
return "Min";
}
@Override

View File

@ -77,7 +77,7 @@ public class BroadcastMulOp extends BaseBroadcastOp {
@Override
public String tensorflowName() {
return "mul";
return "Mul";
}
@Override

View File

@ -326,6 +326,6 @@ public class TensorMmul extends DynamicCustomOp {
@Override
public String tensorflowName() {
return "matmul";
return "MatMul";
}
}

View File

@ -63,7 +63,7 @@ public class RSubOp extends BaseDynamicTransformOp {
@Override
public String tensorflowName() {
return "sub";
return "Sub";
}
public RSubOp( INDArray[] inputs, INDArray[] outputs) {