[WIP] Shugeo lup (#126)

* Added infrastructure for implementation op lu for both cuda and cpu platforms.

* Added implementation of helpers with lu op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored LU decomposition to use vector of permutations instead.

* Refactored helpers for lu op.

* Fixed crash with determinant op.

* Refactored cpu LU op heleper.

* Added implementation for lu op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed issue with argmax on column.

* Added multithreaded behaviour for lu op helper.

* Fixed multithreaded cpu implementation helpers for lu op.

* Added cuda implementation for lu op helper.

* Finished lu helper implementation for cuda platform.

* Eliminated waste prints and comments.

* Fixed race condition and multithreading issues.

* Fixed memory leak with shape construction.

* Corrected test for lu op to avoid near zero elements on the main diagonal."

Signed-off-by: shugeo <sgazeos@gmail.com>

* Improved test for adjust_constast op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed issues with cuda implementation of resize_bicubic helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>
master
shugeo 2019-12-20 16:56:28 +02:00 committed by raver119
parent 6d8a063c9b
commit 67d8199165
9 changed files with 797 additions and 579 deletions

View File

@ -0,0 +1,59 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by GS <sgazeos@gmail.com> at 12/10/2019
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_matrix_inverse)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/lup.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(lu, 1, 2, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
auto p = OUTPUT_VARIABLE(1);
REQUIRE_TRUE(input->rankOf() >=2, 0, "matrix_inverse: The rank of input array should not less than 2, but %i is given", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "matrix_inverse: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
helpers::lu(block.launchContext(), input, z, p);
return Status::OK();
}
DECLARE_SHAPE_FN(lu) {
auto in = inputShape->at(0);
auto shapeVector = ShapeUtils::shapeAsVector(in);
auto luShape = ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace());
auto luP = ShapeBuilders::createShapeInfo(nd4j::DataType::INT32, shape::order(in), shapeVector.size() - 1,
shapeVector.data(), block.workspace());
return SHAPELIST(CONSTANT(luShape), CONSTANT(luP));
}
DECLARE_TYPES(lu) {
getOpDescriptor()
->setAllowedInputTypes({ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS})
->setAllowedOutputTypes(1, {nd4j::DataType::INT32, nd4j::DataType::INT64})
->setSameMode(false);
}
}
}
#endif

View File

@ -1027,6 +1027,24 @@ namespace nd4j {
DECLARE_OP(matrix_inverse, 1, 1, true); DECLARE_OP(matrix_inverse, 1, 1, true);
#endif #endif
/**
* lu op. - make LUP decomposition of given batch of 2D square matricies
*
* input params:
* 0 - float tensor with dimension (x * y * z * ::: * M * M)
*
* return value:
* 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M matricies in it
* 1 - int (32 or 64) batched vector of permutations with length M - shape (x * y * z * ::: * M)
*
* int argument:
* 0 - data type of output permutaion vector (int32 or int64), optional, default INT32
*/
#if NOT_EXCLUDED(OP_matrix_inverse)
DECLARE_CUSTOM_OP(lu, 1, 2, false, 0, 0);
#endif
/** /**
* sequence_mask op. - make mask for given tensor filled by (j > x[i_1, i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j]
* *

View File

@ -22,6 +22,8 @@
#include <MmulHelper.h> #include <MmulHelper.h>
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
#include <Status.h> #include <Status.h>
#include <execution/Threads.h>
#include <execution/Threads.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -32,15 +34,30 @@ namespace helpers {
if (theFirst != theSecond) if (theFirst != theSecond)
for (int i = 0; i < matrix->columns(); i++) { for (int i = 0; i < matrix->columns(); i++) {
T e0 = matrix->e<T>(theFirst, i); math::nd4j_swap(matrix->t<T>(theFirst, i), matrix->t<T>(theSecond, i));
T e1 = matrix->e<T>(theSecond, i);
matrix->p<T>(theFirst, i, e1);
matrix->p<T>(theSecond, i, e0);
} }
} }
BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES);
template <typename T>
static void swapRows(T* matrixBuf, Nd4jLong* matrixShape, Nd4jLong theFirst, Nd4jLong theSecond) {
if (theFirst != theSecond) {
auto n = shape::sizeAt(matrixShape, -1);
auto loop = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
Nd4jLong theFirstPos[] = {theFirst, i};
Nd4jLong theSecondPos[] = {theSecond, i};
auto theFirstIndex = shape::getOffset(matrixShape, theFirstPos, 0);
auto theSecondIndex = shape::getOffset(matrixShape, theSecondPos, 0);
math::nd4j_swap(matrixBuf[theFirstIndex], matrixBuf[theSecondIndex]);
}
};
samediff::Threads::parallel_tad(loop, 0, n, 1);
}
}
void swapRows(NDArray* matrix, int theFirst, int theSecond) { void swapRows(NDArray* matrix, int theFirst, int theSecond) {
BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES);
} }
@ -106,7 +123,7 @@ namespace helpers {
} }
template <typename T> template <typename T, typename I>
static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) { static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) {
const int rowNum = input->rows(); const int rowNum = input->rows();
@ -132,7 +149,7 @@ namespace helpers {
} }
} }
if( pivotValue > T(0.00001)) { if( pivotValue > DataTypeUtils::min<T>()) {
swapRows(&compoundMatrix, pivot, i); swapRows(&compoundMatrix, pivot, i);
swapRows(&permutationMatrix, pivot, i); swapRows(&permutationMatrix, pivot, i);
if (pivot != i) if (pivot != i)
@ -155,14 +172,113 @@ namespace helpers {
if (swapCount % 2) determinant = -determinant; if (swapCount % 2) determinant = -determinant;
if (compound != nullptr) if (compound != nullptr)
compound->assign(compoundMatrix); compound->assign(compoundMatrix);
if (permutation != nullptr) if (permutation != nullptr) {
auto permutaionVector = NDArrayFactory::create('c', {rowNum}, DataTypeUtils::fromT<I>(), input->getContext());
for (auto i = 0; i < rowNum; i++) {
for (auto j = 0; j < columnNum; j++) {
if (permutationMatrix.t<T>(i, j) != 0) {
permutaionVector.template t<I>(i) = j;
}
}
}
if (permutationMatrix.isSameShape(permutation))
permutation->assign(permutationMatrix); permutation->assign(permutationMatrix);
else if (permutation->isSameShape(permutaionVector)) {
permutation->assign(permutaionVector);
}
}
return determinant; return determinant;
} }
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); BUILD_DOUBLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES);
/*
* lu decomposition with naive algorithm with partial pivoting
* */
template <typename T, typename I>
static I argmaxCol(I column, T* compoundBuffer, Nd4jLong* compoundShape) {
auto rowNum = shape::sizeAt(compoundShape, 0);
Nd4jLong xInitial[] = {column, column};
auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0);
auto maxValue = T(0); //nd4j::math::nd4j_abs(compoundBuffer[xInitialIndex]);
auto result = -1;
//auto loop = PRAGMA_THREADS_FOR {
auto start = column, stop = rowNum, increment = 1;
for (auto rowCounter = start; rowCounter < stop; rowCounter += increment) {
Nd4jLong xPos[] = {rowCounter, column};
auto xIndex = shape::getOffset(compoundShape, xPos, 0);
if (nd4j::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) {
maxValue = nd4j::math::nd4j_max(maxValue, nd4j::math::nd4j_abs(compoundBuffer[xIndex]));
result = rowCounter;
}
}
//};
//samediff::Threads::parallel_for(loop, column, rowNum, 1);
return result;
}
template <typename T>
void processColumns(int currentRow, int rowNum, T* compoundBuf, Nd4jLong* compoundShape) {
Nd4jLong xDiag[] = {currentRow, currentRow};
auto diagIndex = shape::getOffset(compoundShape, xDiag, 0);
auto loop = PRAGMA_THREADS_FOR {
for (int j = start; j < stop; j += increment) {
Nd4jLong xRow[] = {j, currentRow};
auto rowIndex = shape::getOffset(compoundShape, xRow, 0);
compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t<T>(i, i);
for (int k = currentRow + 1; k < rowNum; k++) {
Nd4jLong yRow[] = {j, k};
Nd4jLong yCol[] = {currentRow, k};
auto rowIndexY = shape::getOffset(compoundShape, yRow, 0);
auto colIndex = shape::getOffset(compoundShape, yCol, 0);
compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex];
}
}
};
samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
}
template <typename T, typename I>
static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) {
//const int rowNum = compound->rows();
// const int columnNum = output->columns();
permutation->linspace(0);
auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>();
auto compoundBuf = compound->bufferAsT<T>();
auto compoundShape = compound->shapeInfo();
auto permutationShape = permutation->shapeInfo();
for (auto i = 0; i < rowNum - 1; i++) {
auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape);
if (pivotIndex < 0) {
throw std::runtime_error("helpers::luNN_: input matrix is singular.");
}
math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]);
swapRows(compoundBuf, compoundShape, i, pivotIndex);
processColumns(i, rowNum, compoundBuf, compoundShape);
}
}
template <typename T, typename I>
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
auto n = input->sizeAt(-1);
output->assign(input); // fill up output tensor with zeros
std::unique_ptr<ResultSet> outputs(output->allTensorsAlongDimension({-2, -1}));
std::unique_ptr<ResultSet> permutations(permutationVectors->allTensorsAlongDimension({-1}));
auto loop = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
luNN_<T, I>(context, outputs->at(i), permutations->at(i), n);
}
};
samediff::Threads::parallel_for(loop, 0, outputs->size(), 1);
}
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES);
}
// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES);
template <typename T> template <typename T>
static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) { static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) {
@ -175,7 +291,7 @@ namespace helpers {
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
matrix.p(row, input->e<T>(k)); matrix.p(row, input->e<T>(k));
output->p(e, lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr)); output->p(e, lup_<T, int>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr));
} }
return Status::OK(); return Status::OK();
@ -196,7 +312,7 @@ template <typename T>
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
matrix.p(row, input->e<T>(k)); matrix.p(row, input->e<T>(k));
} }
NDArray det = lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); NDArray det = lup_<T, int>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr);
if (det.e<T>(0) != 0.f) if (det.e<T>(0) != 0.f)
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0)))); output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
} }
@ -229,7 +345,7 @@ template <typename T>
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
matrix.p(row++, input->e<T>(k)); matrix.p(row++, input->e<T>(k));
} }
T det = lup_<T>(context, &matrix, &compound, &permutation).template e<T>(0); T det = lup_<T, int>(context, &matrix, &compound, &permutation).template e<T>(0);
// FIXME: and how this is going to work on float16? // FIXME: and how this is going to work on float16?
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) { if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
@ -274,7 +390,7 @@ template <typename T>
// check for symmetric // check for symmetric
for (Nd4jLong r = 0; r < thisMatrix->rows(); r++) for (Nd4jLong r = 0; r < thisMatrix->rows(); r++)
for (Nd4jLong c = 0; c < thisMatrix->columns(); c++) for (Nd4jLong c = 0; c < thisMatrix->columns(); c++)
if (nd4j::math::nd4j_abs(thisMatrix->e<T>(r, c) - lastMatrixList->at(i)->e<T>(c,r)) > T(1.e-6f)) return false; if (nd4j::math::nd4j_abs(thisMatrix->e<T>(r, c) - lastMatrixList->at(i)->e<T>(c,r)) > DataTypeUtils::min<T>()) return false;
NDArray output = NDArrayFactory::create<T>(0., context); NDArray output = NDArrayFactory::create<T>(0., context);
if (ND4J_STATUS_OK != determinant(context, thisMatrix, &output)) return false; if (ND4J_STATUS_OK != determinant(context, thisMatrix, &output)) return false;
@ -366,6 +482,11 @@ template <typename T>
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES);
} }
int lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation) {
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_, (context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES);
return Status::OK();
}
} }
} }
} }

View File

@ -689,11 +689,17 @@ namespace helpers {
} }
template <typename T> template <typename T>
static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, float* cachedValue, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) { static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) {
// auto numChannels = pResizerState->channels; // auto numChannels = pResizerState->channels;
for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) { for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) {
auto pInput = inputPtr + b * inBatchWidth; auto pInput = inputPtr + b * inBatchWidth;
float* cachedValue;
for (Nd4jLong y = threadIdx.x; y < pResizerState->outHeight; y += blockDim.x) { for (Nd4jLong y = threadIdx.x; y < pResizerState->outHeight; y += blockDim.x) {
if (threadIdx.x == 0) {
extern __shared__ char sharedChar[];
cachedValue = reinterpret_cast<float*>(sharedChar);
}
auto pos = (b * pResizerState->outHeight + y) * pResizerState->outWidth * pResizerState->channels; auto pos = (b * pResizerState->outHeight + y) * pResizerState->outWidth * pResizerState->channels;
auto pOutput = &outputPtr[pos]; auto pOutput = &outputPtr[pos];
struct WeightsAndIndices yWai; struct WeightsAndIndices yWai;
@ -846,20 +852,20 @@ namespace helpers {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot set up memory for resizerState", err); throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot set up memory for resizerState", err);
} }
float* cachedValue = nullptr; // float* cachedValue = nullptr;
size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels); // size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels);
if (cachedSize) { // if (cachedSize) {
err = cudaMalloc(reinterpret_cast<void**>(&cachedValue), cachedSize); // err = cudaMalloc(reinterpret_cast<void**>(&cachedValue), cachedSize);
if (err != 0) { // if (err != 0) {
throw cuda_exception::build( // throw cuda_exception::build(
"helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err); // "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err);
} // }
err = cudaMemset(cachedValue, 0, cachedSize); // err = cudaMemset(cachedValue, 0, cachedSize);
if (err != 0) { // if (err != 0) {
throw cuda_exception::build( // throw cuda_exception::build(
"helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err); // "helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err);
} // }
} // }
WeightsAndIndices* xWais; //(resizerState.outWidth); WeightsAndIndices* xWais; //(resizerState.outWidth);
err = cudaMalloc(&xWais, sizeof(WeightsAndIndices) * resizerState.outWidth); err = cudaMalloc(&xWais, sizeof(WeightsAndIndices) * resizerState.outWidth);
@ -878,7 +884,7 @@ namespace helpers {
} }
const T* pInput = image->getDataBuffer()->specialAsT<T>(); const T* pInput = image->getDataBuffer()->specialAsT<T>();
float* pOutput = output->dataBuffer()->specialAsT<float>(); //_data.data(); float* pOutput = output->dataBuffer()->specialAsT<float>(); //_data.data();
bicubicInterpolateWithCachingKernel<T><<<128, 1, 512, *stream>>>(coeffsTable, cachedValue, pInput, bicubicInterpolateWithCachingKernel<T><<<128, 1, 512, *stream>>>(coeffsTable, pInput,
resizerStateD, xWais, halfPixelCenters, inBatchWidth, inRowWidth, pOutput); resizerStateD, xWais, halfPixelCenters, inBatchWidth, inRowWidth, pOutput);
err = cudaStreamSynchronize(*stream); err = cudaStreamSynchronize(*stream);
if (err != 0) { if (err != 0) {
@ -889,11 +895,11 @@ namespace helpers {
if (err != 0) { if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for resizerState", err); throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for resizerState", err);
} }
if (cachedSize) // if (cachedSize)
err = cudaFree(cachedValue); // err = cudaFree(cachedValue);
if (err != 0) { // if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err); // throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err);
} // }
err = cudaFree(xWais); err = cudaFree(xWais);
if (err != 0) { if (err != 0) {

View File

@ -24,6 +24,7 @@
#include <Status.h> #include <Status.h>
#include <ConstantTadHelper.h> #include <ConstantTadHelper.h>
#include <ShapeUtils.h> #include <ShapeUtils.h>
//#include <ops/declarable/generic/helpers/BroadcastHelper.h>
#include <cusolverDn.h> #include <cusolverDn.h>
#include <cuda_exception.h> #include <cuda_exception.h>
@ -336,7 +337,7 @@ namespace helpers {
// //
// input - A matrix nxn // input - A matrix nxn
// compound - C matrix L + U - I, or main diagonal and lower - L matrix, from the 2nd diagonal - U matrix // compound - C matrix L + U - I, or main diagonal and lower - L matrix, from the 2nd diagonal - U matrix
template<typename T> template<typename T, typename I>
static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
auto n = input->rows(); auto n = input->rows();
@ -383,7 +384,7 @@ namespace helpers {
err); err);
} }
if (permutation == nullptr) if (permutation == nullptr) {
status = cusolverDnDgetrf( status = cusolverDnDgetrf(
cusolverH, cusolverH,
n, n,
@ -393,9 +394,15 @@ namespace helpers {
d_work, d_work,
nullptr, nullptr,
d_info); d_info);
if (status != CUSOLVER_STATUS_SUCCESS) {
throw cuda_exception::build("helpers::lup_: LU factorization is failed due ",
status);
}
}
else { else {
NDArray permutVector('c', {n}, nd4j::DataType::INT32, context); NDArray permutVector('c', {n}, nd4j::DataType::INT32, context);
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer()); int* permutationBuf = permutVector.dataBuffer()->specialAsT<int>();
status = cusolverDnDgetrf( status = cusolverDnDgetrf(
cusolverH, cusolverH,
n, n,
@ -405,9 +412,21 @@ namespace helpers {
d_work, d_work,
permutationBuf, permutationBuf,
d_info); d_info);
if (status != CUSOLVER_STATUS_SUCCESS) {
throw cuda_exception::build("helpers::lup_: LU factorization is failed due ",
status);
}
if (permutation->rankOf() == 2) {
fillUpPermutation<double> <<< n, n, 1024, *stream >>> fillUpPermutation<double> <<< n, n, 1024, *stream >>>
(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
permutation->tickWriteDevice(); }
else {
permutVector.tickWriteDevice();
input->tickWriteDevice();
compound->assign(input);
permutation->assign(permutVector);
}
} }
err = cudaFree(d_work); err = cudaFree(d_work);
if (err) { if (err) {
@ -448,7 +467,7 @@ namespace helpers {
nullptr, nullptr,
d_info); d_info);
else { else {
NDArray permutVector('c', {n}, nd4j::DataType::INT32, context); NDArray permutVector('c', {n}, DataType::INT32, context);
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer()); int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
status = cusolverDnSgetrf( status = cusolverDnSgetrf(
cusolverH, cusolverH,
@ -459,10 +478,17 @@ namespace helpers {
d_work, d_work,
permutationBuf, permutationBuf,
d_info); d_info);
fillUpPermutation<T> <<< n, n, 128, *stream >> > if (permutation->rankOf() == 2) {
fillUpPermutation<I> <<< n, n, 128, *stream >>>
(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
permutation->tickWriteDevice(); permutation->tickWriteDevice();
} }
else {
input->tickWriteDevice();
compound->assign(input);
permutation->assign(permutVector);
}
}
err = cudaFree(d_work); err = cudaFree(d_work);
if (err) { if (err) {
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer",
@ -484,8 +510,116 @@ namespace helpers {
} }
// ------------------------------------------------------------------------------------------------------------------ // // ------------------------------------------------------------------------------------------------------------------ //
BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE); BUILD_DOUBLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE, INDEXING_TYPES);
template <typename T>
static __device__ void swapRows(T* matrix, Nd4jLong* shape, Nd4jLong theFirst, Nd4jLong theSecond, Nd4jLong n) {
if (theFirst != theSecond) {
for (auto i = 0; i < n; i++) {
Nd4jLong theFirstPos[] = {theFirst, i};
Nd4jLong theSecondPos[] = {theSecond, i};
auto theFirstIndex = shape::getOffset(shape, theFirstPos, 0);
auto theSecondIndex = shape::getOffset(shape, theSecondPos, 0);
math::nd4j_swap(matrix[theFirstIndex], matrix[theSecondIndex]);
}
}
}
template <typename T>
static __device__ void processColumns(Nd4jLong currentRow, Nd4jLong rowNum, T* compoundBuf, Nd4jLong* compoundShape) {
Nd4jLong xDiag[] = {currentRow, currentRow};
auto diagIndex = shape::getOffset(compoundShape, xDiag, 0);
for (auto j = currentRow + 1; j < rowNum; j++) {
Nd4jLong xRow[] = {j, currentRow};
auto rowIndex = shape::getOffset(compoundShape, xRow, 0);
compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t<T>(i, i);
for (auto k = currentRow + 1; k < rowNum; k++) {
Nd4jLong yRow[] = {j, k};
Nd4jLong yCol[] = {currentRow, k};
auto rowIndexY = shape::getOffset(compoundShape, yRow, 0);
auto colIndex = shape::getOffset(compoundShape, yCol, 0);
compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex];
}
}
}
template <typename T>
__device__ Nd4jLong argmaxCol(Nd4jLong column, T* compoundBuffer, Nd4jLong* compoundShape) {
auto rowNum = shape::sizeAt(compoundShape, 0);
Nd4jLong xInitial[] = {column, column};
auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0);
auto maxValue = T(0); //nd4j::math::nd4j_abs(compoundBuffer[xInitialIndex]);
auto result = -1LL;
for (auto rowCounter = column; rowCounter < rowNum; rowCounter++) {
Nd4jLong xPos[] = {rowCounter, column};
auto xIndex = shape::getOffset(compoundShape, xPos, 0);
if (nd4j::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) {
maxValue = nd4j::math::nd4j_max(maxValue, nd4j::math::nd4j_abs(compoundBuffer[xIndex]));
result = rowCounter;
}
}
return result;
}
template <typename T, typename I>
static __device__ int luNN(T* matrix, Nd4jLong* shape, I* permutation, Nd4jLong* permuShape, Nd4jLong n) {
for (auto i = 0; i < n - 1; i++) {
auto pivotIndex = argmaxCol(i, matrix, shape);
if (pivotIndex < 0) {
return -1;//throw std::runtime_error("helpers::luNN_: input matrix is singular.");
}
math::nd4j_swap(permutation[shape::getIndexOffset(i, permuShape)], permutation[shape::getIndexOffset(pivotIndex, permuShape)]);
swapRows(matrix, shape, (Nd4jLong)i, pivotIndex, n);
processColumns(i, n, matrix, shape);
}
return 0;
}
template <typename T, typename I>
static __global__ void luBatchedKernel(T* outputBuf, Nd4jLong* outputShape, I* permutations, Nd4jLong* permuShape,
Nd4jLong* outputTadShape, Nd4jLong* outputTadOffsets, Nd4jLong* permuTadShape, Nd4jLong* permuTadOffsets,
Nd4jLong batchNum) {
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (auto b = start; b < batchNum; b += step) {
T* matrix = outputBuf + outputTadOffsets[b];
I* permutation = permutations + permuTadOffsets[b];
if (0 != luNN(matrix, outputTadShape, permutation, permuTadShape, shape::length(permuTadShape))) break;
}
}
template <typename T, typename I>
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
auto n = input->sizeAt(-1);
auto stream = context->getCudaStream();
auto iota = NDArrayFactory::create<int>('c', {n});
iota.linspace(0); iota.syncToDevice();
output->assign(input); // fill up output tensor with zeros
output->tickWriteDevice();
permutationVectors->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &iota, permutationVectors, true, nullptr);
permutationVectors->tickWriteDevice();
auto tads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1});
auto permutaionTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-1});
auto batchNum = tads.numberOfTads();
luBatchedKernel<T,I><<<batchNum, 256, 1024, *stream>>>(reinterpret_cast<T*>(output->platformBuffer()),
output->specialShapeInfo(), reinterpret_cast<I*>(permutationVectors->platformBuffer()),
permutationVectors->specialShapeInfo(), tads.specialShapeInfo(), tads.specialOffsets(),
permutaionTads.specialShapeInfo(), permutaionTads.specialOffsets(), batchNum);
}
void lu(LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutations) {
NDArray::prepareSpecialUse({output, permutations}, {input});
BUILD_DOUBLE_SELECTOR(input->dataType(), permutations->dataType(), lu_, (context, input, output, permutations), FLOAT_NATIVE, INDEXING_TYPES);
NDArray::registerSpecialUse({output, permutations}, {input});
}
// ------------------------------------------------------------------------------------------------------------------ // // ------------------------------------------------------------------------------------------------------------------ //
template<typename T> template<typename T>
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
@ -509,7 +643,7 @@ namespace helpers {
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else // else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
lup_<T>(context, &matrix, nullptr, nullptr); lup_<T, int>(context, &matrix, nullptr, nullptr);
// else // else
// lup_<float>(context, &matrix, nullptr, nullptr); // lup_<float>(context, &matrix, nullptr, nullptr);
auto offset = shape::getIndexOffset(e, output->shapeInfo()); auto offset = shape::getIndexOffset(e, output->shapeInfo());
@ -557,7 +691,7 @@ namespace helpers {
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
lup_<T>(context, &matrix, nullptr, nullptr); lup_<T, int>(context, &matrix, nullptr, nullptr);
// else // else
// lup_<float>(context, &matrix, nullptr, nullptr); // lup_<float>(context, &matrix, nullptr, nullptr);
auto offset = shape::getIndexOffset(e, output->shapeInfo()); auto offset = shape::getIndexOffset(e, output->shapeInfo());
@ -638,7 +772,7 @@ namespace helpers {
matrix.tickWriteDevice(); matrix.tickWriteDevice();
//compound.assign(matrix); //compound.assign(matrix);
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
lup_<T>(context, &matrix, nullptr, nullptr); lup_<T, int>(context, &matrix, nullptr, nullptr);
fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n); fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n);
lower.tickWriteDevice(); lower.tickWriteDevice();
upper.tickWriteDevice(); upper.tickWriteDevice();
@ -861,6 +995,14 @@ namespace helpers {
BUILD_SINGLE_SELECTOR(output->dataType(), return logdetFunctor_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(output->dataType(), return logdetFunctor_, (context, input, output), FLOAT_NATIVE);
} }
/*
* lup - batched input, batched outputs
* */
int lup(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) {
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_,(context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES);
return Status::OK();
}
// BUILD_SINGLE_TEMPLATE(template int logdetFunctor_, // BUILD_SINGLE_TEMPLATE(template int logdetFunctor_,
// (nd4j::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE); // (nd4j::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE);
} }

View File

@ -26,9 +26,8 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> int lup(nd4j::LaunchContext* context, NDArray* input, NDArray* lu, NDArray* permutation);
T lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation); void lu(nd4j::LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation);
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output); int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output); int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);

View File

@ -1050,7 +1050,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
auto testData = NDArrayFactory::create<float>('c', {2,9,9,1}, { auto testData = NDArrayFactory::create<float>('c', {2,9,9,1}, {
0.230286f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f, 0.230286514f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f,
0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f, 0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f,
0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f, 0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f,
0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f, 0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f,
@ -1081,7 +1081,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
NDArray* result = results->at(0); NDArray* result = results->at(0);
// result->printBuffer("Resized to 9x9"); // result->printBuffer("Resized to 9x9");
// expected.printBuffer("Expect for 9x9"); // testData.printBuffer("Expect for 9x9");
ASSERT_TRUE(testData.isSameShape(result)); ASSERT_TRUE(testData.isSameShape(result));
ASSERT_TRUE(testData.equalsTo(result)); ASSERT_TRUE(testData.equalsTo(result));
delete results; delete results;

View File

@ -2424,3 +2424,256 @@ TEST_F(DeclarableOpsTests12, Pad_SGO_Test_1) {
ASSERT_TRUE(exp.equalsTo(res->at(0))); ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res; delete res;
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, LU_Test_1) {
auto in = NDArrayFactory::create<double>('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.});
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7});
auto pExp = NDArrayFactory::create<int>('c', {3}, {0, 1, 2});
nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
auto p = res->at(1);
// z->printIndexedBuffer("Triangulars");
// p->printIndexedBuffer("Permutaions");
ASSERT_TRUE(exp.equalsTo(z));
ASSERT_TRUE(pExp.equalsTo(p));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, LU_Test_2) {
auto in = NDArrayFactory::create<double>('c', {3,3}, {1, 0, 0, 2, 3, 0, 4, 5, 6});
auto expLU = NDArrayFactory::create<double>('c', {3,3}, {4., 5., 6., 0.25, -1.25, -1.5, 0.5, -0.4, -3.6});
auto expP = NDArrayFactory::create<int>({2, 0, 1});
nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
auto p = res->at(1);
// z->printIndexedBuffer("Triangulars2");
// p->printIndexedBuffer("Permutaions2");
ASSERT_TRUE(expLU.equalsTo(z));
ASSERT_TRUE(expP.equalsTo(p));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, LU_Test_3) {
auto in = NDArrayFactory::create<double>('c', {3,3}, {1,2,3,4,7,9, 11, 12, 13});
auto expLU = NDArrayFactory::create<double>('c', {3,3}, {
11., 12., 13.,
0.36363637, 2.6363635, 4.272727,
0.09090909, 0.3448276, 0.34482753});
auto expP = NDArrayFactory::create<int>({2, 1, 0});
nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
auto p = res->at(1);
// z->printIndexedBuffer("Triangulars3");
// p->printIndexedBuffer("Permutaions3");
ASSERT_TRUE(expLU.equalsTo(z));
ASSERT_TRUE(expP.equalsTo(p));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, LU_Test_4) {
auto in = NDArrayFactory::create<double>('c', {10,10}, {
1., 2., 3., 4., 5., 6., 7., 8., 1., 15.,
5., 1., 13., 4., 15., 1., 17., 9., 11., 25.,
1., 9., 1., 4., 5., 2., 13., 10, 21., 15.,
3., 9., 4., 1., 5., 3., 7., 1, 1., 5.,
2., 3., 2., 5., 4., 4., 7., 3, 3., 4.,
0., 1., 3., 3., 5., 1., 3., 1, 31., 15.,
2., 1., 4., 3., 1., 5., 1., 2, 31., 35.,
3., 4., 3., 3., 4., 4., 4., 1., 3., 1.,
1., 1., 1., 1., 5., 6., 5., 4., 3., 2.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
auto expLU = NDArrayFactory::create<double>('c', {10,10}, {
5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0,
0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0,
0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636,
0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957,
0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323,
0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387,
0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300,
0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119,
0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178,
0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695
});
auto expP = NDArrayFactory::create<int>({1, 2, 7, 3, 6, 8, 5, 4, 0, 9});
nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
auto p = res->at(1);
// z->printBuffer("Triangulars4");
// expLU.printBuffer("TriangulExp4");
// p->printBuffer("Permutaions4");
ASSERT_TRUE(expLU.equalsTo(z));
ASSERT_TRUE(expP.equalsTo(p));
delete res;
}
TEST_F(DeclarableOpsTests12, LU_Test_5) {
auto in = NDArrayFactory::create<double>('c', {2, 10,10}, {
1., 2., 3., 4., 5., 6., 7., 8., 1., 15.,
5., 1., 13., 4., 15., 1., 17., 9., 11., 25.,
1., 9., 1., 4., 5., 2., 13., 10, 21., 15.,
3., 9., 4., 1., 5., 3., 7., 1, 1., 5.,
2., 3., 2., 5., 4., 4., 7., 3, 3., 4.,
0., 1., 3., 3., 5., 1., 3., 1, 31., 15.,
2., 1., 4., 3., 1., 5., 1., 2, 31., 35.,
3., 4., 3., 3., 4., 4., 4., 1., 3., 1.,
1., 1., 1., 1., 5., 6., 5., 4., 3., 2.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 2., 3., 4., 5., 6., 7., 8., 1., 15.,
5., 1., 13., 4., 15., 1., 17., 9., 11., 25.,
1., 9., 1., 4., 5., 2., 13., 10, 21., 15.,
3., 9., 4., 1., 5., 3., 7., 1, 1., 5.,
2., 3., 2., 5., 4., 4., 7., 3, 3., 4.,
0., 1., 3., 3., 5., 1., 3., 1, 31., 15.,
2., 1., 4., 3., 1., 5., 1., 2, 31., 35.,
3., 4., 3., 3., 4., 4., 4., 1., 3., 1.,
1., 1., 1., 1., 5., 6., 5., 4., 3., 2.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
});
auto expLU = NDArrayFactory::create<double>('c', {2, 10,10}, {
5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0,
0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0,
0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636,
0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957,
0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323,
0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387,
0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300,
0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119,
0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178,
0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695,
5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0,
0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0,
0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636,
0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957,
0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323,
0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387,
0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300,
0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119,
0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178,
0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695
});
auto expP = NDArrayFactory::create<int>('c', {2, 10}, {
1, 2, 7, 3, 6, 8, 5, 4, 0, 9,
1, 2, 7, 3, 6, 8, 5, 4, 0, 9
});
nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
auto p = res->at(1);
// z->printBuffer("Triangulars5");
// expLU.printBuffer("TriangulExp5");
// p->printBuffer("Permutaions5");
ASSERT_TRUE(expLU.equalsTo(z));
ASSERT_TRUE(expP.equalsTo(p));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, LU_Test_1_2) {
auto in = NDArrayFactory::create<double>('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.,1., 2., 3., 0., 2., 3., 0., 0., 7.});
auto exp = NDArrayFactory::create<double>('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7, 1., 2., 3., 0., 2., 3., 0., 0., 7.});
nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
auto p = res->at(1);
// z->printIndexedBuffer("Triangulars (2,3,3)");
// p->printIndexedBuffer("Permutaions (2,3,3)");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, LU_Test_3_2) {
auto in = NDArrayFactory::create<double>('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,1,2,3,4,7,9, 11, 12, 13});
auto expLU = NDArrayFactory::create<double>('c', {2, 3,3}, {
11., 12., 13.,
0.36363637, 2.6363635, 4.272727,
0.09090909, 0.3448276, 0.34482753,
11., 12., 13.,
0.36363637, 2.6363635, 4.272727,
0.09090909, 0.3448276, 0.34482753
});
auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 2, 1, 0});
nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
auto p = res->at(1);
// z->printIndexedBuffer("Triangulars3_2");
// p->printIndexedBuffer("Permutaions3_2");
ASSERT_TRUE(expLU.equalsTo(z));
ASSERT_TRUE(expP.equalsTo(p));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, LU_Test_3_3) {
auto in = NDArrayFactory::create<double>('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,13,2,3,4,7,9, 11, 12, 1});
auto expLU = NDArrayFactory::create<double>('c', {2, 3,3}, {
11., 12., 13.,
0.36363637, 2.6363635, 4.272727,
0.09090909, 0.3448276, 0.34482753,
13., 2., 3.,
0.84615386, 10.307693, -1.5384617,
0.30769232, 0.619403, 9.029851});
auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 0, 2, 1});
nd4j::ops::lu op;
auto res = op.execute({&in}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
auto p = res->at(1);
// z->printIndexedBuffer("Triangulars3_3");
// p->printIndexedBuffer("Permutaions3_3");
ASSERT_TRUE(expLU.equalsTo(z));
ASSERT_TRUE(expP.equalsTo(p));
delete res;
}

View File

@ -293,268 +293,77 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) {
.7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f,
0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f});
auto e = NDArrayFactory::create<float>('c', {8, 8, 3, 1}, { auto e = NDArrayFactory::create<float>('c', {8, 8, 3, 1}, {
1.0218375f, 1.0218375f, 1.0666375f, 0.9130375f,
1.0666375f, -0.07396251f, 0.91843754f, -0.17496246f,
0.9130375f, 0.47543746f, 1.2492375f, 0.55643755f,
1.3110375f, -0.36456245f, 1.0518374f,
-0.07396251f, 0.7824375f, 0.57523745f, -0.21656245f,
0.91843754f, 0.0816375f, -0.2261625f, 0.40323752f,
-0.17496246f, 1.4520376f, 0.6868375f, 0.81723756f,
-0.17576247f, 0.81423753f, -0.08656245f,
0.47543746f,
1.2492375f, -0.36249164f, 0.45590833f, 1.1925083f,
0.55643755f, 0.00650835f, 1.4861084f, 1.2079083f,
0.05270836f, 0.37350836f, 0.94130826f,
1.3110375f, 1.0715083f, 0.6103083f, 0.9825083f,
-0.36456245f, 0.07370833f, -0.4518917f, -0.39889166f,
1.0518374f, -0.3354917f, 1.2213084f, 1.0345083f,
-0.3132917f, 0.78470826f, 0.23390833f,
0.7824375f, 0.6943083f, 0.68170834f, -0.09989169f,
0.57523745f,
-0.21656245f, 0.8352709f, 1.3798709f, 0.15507084f,
0.26607084f, -0.10792917f, 1.2302709f,
0.0816375f, 0.6448709f, -0.29992914f, 1.3534708f,
-0.2261625f, 0.86607087f, 0.37607086f, 0.04027084f,
0.40323752f, 0.40087086f, 0.59507084f, 0.9416709f,
0.53127086f, -0.01712915f, 1.4610709f,
1.4520376f, -0.17152917f, -0.13992918f, 0.6242708f,
0.6868375f, -0.42192918f, 0.38387084f, -0.15752912f,
0.81723756f,
0.3311833f, 0.00618333f, 0.17538333f,
-0.17576247f, 0.10418332f, 0.8365834f, 0.27098334f,
0.81423753f, 1.2421833f, -0.1114167f, 1.0153834f,
-0.08656245f, 0.9523833f, 0.8317833f, 0.9633833f,
0.6501833f, 0.04258335f, 0.9999833f,
-0.40181667f, 0.11418331f, 0.47938335f,
-0.36249164f, 1.1057833f, -0.29761666f, 1.0779834f,
0.45590833f, 0.5243833f, -0.32181668f, 1.1833833f,
1.1925083f,
0.73157084f, 0.4317708f, 0.7283708f,
0.00650835f, 1.2297708f, 0.4307708f, 0.85377085f,
1.4861084f, 0.05977082f, -0.09282917f, 0.33957082f,
1.2079083f, 1.0751709f, 0.2119708f, 0.51897085f,
-0.25302917f, 1.1723708f, -0.12562919f,
0.05270836f, 1.1993709f, 0.5257708f, 0.40517086f,
0.37350836f, 0.53197086f, 0.8441708f, 0.02617085f,
0.94130826f, -0.0208292f, 0.8711709f, 0.04137081f,
1.0715083f, 0.74936247f, 0.6085625f, 0.8997625f,
0.6103083f, -0.08743751f, 0.18576252f, -0.17563748f,
0.9825083f, 0.5991625f, -0.0038375f, 0.07576251f,
0.42536253f, -0.22823751f, 0.36296248f,
0.07370833f, 0.81456256f, -0.16183749f, 0.5161625f,
-0.4518917f, -0.21183747f, 0.7429625f, 0.6217625f,
-0.39889166f, 0.17656249f, 0.02616251f, -0.17923748f,
1.4659625f, 0.40016252f, 0.28356248f,
-0.3354917f,
1.2213084f, 0.4195791f, 0.8745791f, 0.36637908f,
1.0345083f, 0.50597906f, -0.17942089f, 0.16917908f,
1.0235791f, 1.3699791f, -0.11382091f,
-0.3132917f, -0.0918209f, 0.7757791f, 0.09017909f,
0.78470826f, 1.3807791f, -0.15202093f, 1.3875791f,
0.23390833f, -0.1712209f, 1.3989791f, 0.43777913f,
0.7855791f, 0.1423791f, 1.4711791f,
0.6943083f, 0.6455791f, 0.6211791f, -0.48062086f,
0.68170834f,
-0.09989169f, 0.10189578f, 0.5628958f, 0.68909574f,
0.96649575f, -0.09370419f, 1.3466958f,
1.4584957f, 1.3544958f, -0.3829042f,
0.8352709f, 0.11269578f, -0.47890422f, 1.0436958f,
1.3798709f, 0.6128957f, 0.27209583f, 0.2714958f,
0.15507084f, 0.21889582f, 0.08789578f, 1.1296958f,
0.4596958f, 0.39309582f, 0.8344958f,
0.26607084f, 0.71149576f, -0.4799042f, 0.4880958f
-0.10792917f,
1.2302709f,
0.6448709f,
-0.29992914f,
1.3534708f,
0.86607087f,
0.37607086f,
0.04027084f,
0.40087086f,
0.59507084f,
0.9416709f,
0.53127086f,
-0.01712915f,
1.4610709f,
-0.17152917f,
-0.13992918f,
0.6242708f,
-0.42192918f,
0.38387084f,
-0.15752912f,
0.3311833f,
0.00618333f,
0.17538333f,
0.10418332f,
0.8365834f,
0.27098334f,
1.2421833f,
-0.1114167f,
1.0153834f,
0.9523833f,
0.8317833f,
0.9633833f,
0.6501833f,
0.04258335f,
0.9999833f,
-0.40181667f,
0.11418331f,
0.47938335f,
1.1057833f,
-0.29761666f,
1.0779834f,
0.5243833f,
-0.32181668f,
1.1833833f,
0.73157084f,
0.4317708f,
0.7283708f,
1.2297708f,
0.4307708f,
0.85377085f,
0.05977082f,
-0.09282917f,
0.33957082f,
1.0751709f,
0.2119708f,
0.51897085f,
-0.25302917f,
1.1723708f,
-0.12562919f,
1.1993709f,
0.5257708f,
0.40517086f,
0.53197086f,
0.8441708f,
0.02617085f,
-0.0208292f,
0.8711709f,
0.04137081f,
0.74936247f,
0.6085625f,
0.8997625f,
-0.08743751f,
0.18576252f,
-0.17563748f,
0.5991625f,
-0.0038375f,
0.07576251f,
0.42536253f,
-0.22823751f,
0.36296248f,
0.81456256f,
-0.16183749f,
0.5161625f,
-0.21183747f,
0.7429625f,
0.6217625f,
0.17656249f,
0.02616251f,
-0.17923748f,
1.4659625f,
0.40016252f,
0.28356248f,
0.4195791f,
0.8745791f,
0.36637908f,
0.50597906f,
-0.17942089f,
0.16917908f,
1.0235791f,
1.3699791f,
-0.11382091f,
-0.0918209f,
0.7757791f,
0.09017909f,
1.3807791f,
-0.15202093f,
1.3875791f,
-0.1712209f,
1.3989791f,
0.43777913f,
0.7855791f,
0.1423791f,
1.4711791f,
0.6455791f,
0.6211791f,
-0.48062086f,
0.10189578f,
0.5628958f,
0.68909574f,
0.96649575f,
-0.09370419f,
1.3466958f,
1.4584957f,
1.3544958f,
-0.3829042f,
0.11269578f,
-0.47890422f,
1.0436958f,
0.6128957f,
0.27209583f,
0.2714958f,
0.21889582f,
0.08789578f,
1.1296958f,
0.4596958f,
0.39309582f,
0.8344958f,
0.71149576f,
-0.4799042f,
0.4880958f
}); });
nd4j::ops::adjust_contrast op; nd4j::ops::adjust_contrast op;
@ -587,268 +396,79 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) {
.7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f,
0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f});
auto e = NDArrayFactory::create<double>('c', {8, 8, 3, 1}, { auto e = NDArrayFactory::create<double>('c', {8, 8, 3, 1}, {
1.0218375 , 1.0218375, 1.0666375 , 0.9130375 ,
1.0666375 , -0.07396251, 0.91843754, -0.17496246,
0.9130375 , 0.47543746, 1.2492375 , 0.55643755,
1.3110375 , -0.36456245, 1.0518374 ,
-0.07396251, 0.7824375 , 0.57523745, -0.21656245,
0.91843754, 0.0816375 , -0.2261625 , 0.40323752,
-0.17496246, 1.4520376 , 0.6868375 , 0.81723756,
-0.17576247, 0.81423753, -0.08656245,
0.47543746,
1.2492375 , -0.36249164, 0.45590833, 1.1925083 ,
0.55643755, 0.00650835, 1.4861084 , 1.2079083 ,
0.05270836, 0.37350836, 0.94130826,
1.3110375 , 1.0715083 , 0.6103083 , 0.9825083 ,
-0.36456245, 0.07370833, -0.4518917 , -0.39889166,
1.0518374 , -0.3354917 , 1.2213084 , 1.0345083 ,
-0.3132917 , 0.78470826, 0.23390833,
0.7824375 , 0.6943083 , 0.68170834, -0.09989169,
0.57523745,
-0.21656245, 0.8352709 , 1.3798709 , 0.15507084,
0.26607084, -0.10792917, 1.2302709 ,
0.0816375 , 0.6448709 , -0.29992914, 1.3534708 ,
-0.2261625 , 0.86607087, 0.37607086, 0.04027084,
0.40323752, 0.40087086, 0.59507084, 0.9416709 ,
0.53127086, -0.01712915, 1.4610709 ,
1.4520376 , -0.17152917, -0.13992918, 0.6242708 ,
0.6868375 , -0.42192918, 0.38387084, -0.15752912,
0.81723756,
-0.17576247, 0.3311833 , 0.00618333, 0.17538333,
0.81423753, 0.10418332, 0.8365834 , 0.27098334,
-0.08656245, 1.2421833 , -0.1114167 , 1.0153834 ,
0.9523833 , 0.8317833 , 0.9633833 ,
0.6501833 , 0.04258335, 0.9999833 ,
-0.36249164, -0.40181667, 0.11418331, 0.47938335,
0.45590833, 1.1057833 , -0.29761666, 1.0779834 ,
1.1925083 , 0.5243833 , -0.32181668, 1.1833833 ,
0.00650835, 0.73157084, 0.4317708 , 0.7283708 ,
1.4861084 , 1.2297708 , 0.4307708 , 0.85377085,
1.2079083 , 0.05977082, -0.09282917, 0.33957082,
1.0751709 , 0.2119708 , 0.51897085,
0.05270836, -0.25302917, 1.1723708 , -0.12562919,
0.37350836, 1.1993709 , 0.5257708 , 0.40517086,
0.94130826, 0.53197086, 0.8441708 , 0.02617085,
-0.0208292 , 0.8711709 , 0.04137081,
1.0715083 ,
0.6103083 , 0.74936247, 0.6085625 , 0.8997625 ,
0.9825083 , -0.08743751, 0.18576252, -0.17563748,
0.5991625 , -0.0038375 , 0.07576251,
0.07370833, 0.42536253, -0.22823751, 0.36296248,
-0.4518917 , 0.81456256, -0.16183749, 0.5161625 ,
-0.39889166, -0.21183747, 0.7429625 , 0.6217625 ,
0.17656249, 0.02616251, -0.17923748,
-0.3354917 , 1.4659625 , 0.40016252, 0.28356248,
1.2213084 ,
1.0345083 , 0.4195791 , 0.8745791 , 0.36637908,
0.50597906, -0.17942089, 0.16917908,
-0.3132917 , 1.0235791 , 1.3699791 , -0.11382091,
0.78470826, -0.0918209 , 0.7757791 , 0.09017909,
0.23390833, 1.3807791 , -0.15202093, 1.3875791 ,
-0.1712209 , 1.3989791 , 0.43777913,
0.6943083 , 0.7855791 , 0.1423791 , 1.4711791 ,
0.68170834, 0.6455791 , 0.6211791 , -0.48062086,
-0.09989169,
0.10189578, 0.5628958 , 0.68909574,
0.8352709 , 0.96649575, -0.09370419, 1.3466958 ,
1.3798709 , 1.4584957 , 1.3544958 , -0.3829042 ,
0.15507084, 0.11269578, -0.47890422, 1.0436958 ,
0.6128957 , 0.27209583, 0.2714958 ,
0.26607084, 0.21889582, 0.08789578, 1.1296958 ,
-0.10792917, 0.4596958 , 0.39309582, 0.8344958 ,
1.2302709 , 0.71149576, -0.4799042, 0.4880958
0.6448709 ,
-0.29992914,
1.3534708 ,
0.86607087,
0.37607086,
0.04027084,
0.40087086,
0.59507084,
0.9416709 ,
0.53127086,
-0.01712915,
1.4610709 ,
-0.17152917,
-0.13992918,
0.6242708 ,
-0.42192918,
0.38387084,
-0.15752912,
0.3311833 ,
0.00618333,
0.17538333,
0.10418332,
0.8365834 ,
0.27098334,
1.2421833 ,
-0.1114167 ,
1.0153834 ,
0.9523833 ,
0.8317833 ,
0.9633833 ,
0.6501833 ,
0.04258335,
0.9999833 ,
-0.40181667,
0.11418331,
0.47938335,
1.1057833 ,
-0.29761666,
1.0779834 ,
0.5243833 ,
-0.32181668,
1.1833833 ,
0.73157084,
0.4317708 ,
0.7283708 ,
1.2297708 ,
0.4307708 ,
0.85377085,
0.05977082,
-0.09282917,
0.33957082,
1.0751709 ,
0.2119708 ,
0.51897085,
-0.25302917,
1.1723708 ,
-0.12562919,
1.1993709 ,
0.5257708 ,
0.40517086,
0.53197086,
0.8441708 ,
0.02617085,
-0.0208292 ,
0.8711709 ,
0.04137081,
0.74936247,
0.6085625 ,
0.8997625 ,
-0.08743751,
0.18576252,
-0.17563748,
0.5991625 ,
-0.0038375 ,
0.07576251,
0.42536253,
-0.22823751,
0.36296248,
0.81456256,
-0.16183749,
0.5161625 ,
-0.21183747,
0.7429625 ,
0.6217625 ,
0.17656249,
0.02616251,
-0.17923748,
1.4659625 ,
0.40016252,
0.28356248,
0.4195791 ,
0.8745791 ,
0.36637908,
0.50597906,
-0.17942089,
0.16917908,
1.0235791 ,
1.3699791 ,
-0.11382091,
-0.0918209 ,
0.7757791 ,
0.09017909,
1.3807791 ,
-0.15202093,
1.3875791 ,
-0.1712209 ,
1.3989791 ,
0.43777913,
0.7855791 ,
0.1423791 ,
1.4711791 ,
0.6455791 ,
0.6211791 ,
-0.48062086,
0.10189578,
0.5628958 ,
0.68909574,
0.96649575,
-0.09370419,
1.3466958 ,
1.4584957 ,
1.3544958 ,
-0.3829042 ,
0.11269578,
-0.47890422,
1.0436958 ,
0.6128957 ,
0.27209583,
0.2714958 ,
0.21889582,
0.08789578,
1.1296958 ,
0.4596958 ,
0.39309582,
0.8344958 ,
0.71149576,
-0.4799042,
0.4880958
}); });
// x.linspace(1.); // x.linspace(1.);
nd4j::ops::adjust_contrast_v2 op; nd4j::ops::adjust_contrast_v2 op;