From 67d8199165cb96d08bdcc94ef41f4860b5825cdd Mon Sep 17 00:00:00 2001 From: shugeo Date: Fri, 20 Dec 2019 16:56:28 +0200 Subject: [PATCH] [WIP] Shugeo lup (#126) * Added infrastructure for implementation op lu for both cuda and cpu platforms. * Added implementation of helpers with lu op. Signed-off-by: shugeo * Refactored LU decomposition to use vector of permutations instead. * Refactored helpers for lu op. * Fixed crash with determinant op. * Refactored cpu LU op heleper. * Added implementation for lu op. Signed-off-by: shugeo * Fixed issue with argmax on column. * Added multithreaded behaviour for lu op helper. * Fixed multithreaded cpu implementation helpers for lu op. * Added cuda implementation for lu op helper. * Finished lu helper implementation for cuda platform. * Eliminated waste prints and comments. * Fixed race condition and multithreading issues. * Fixed memory leak with shape construction. * Corrected test for lu op to avoid near zero elements on the main diagonal." Signed-off-by: shugeo * Improved test for adjust_constast op. Signed-off-by: shugeo * Fixed issues with cuda implementation of resize_bicubic helpers. Signed-off-by: shugeo --- .../ops/declarable/generic/parity_ops/lup.cpp | 59 ++ .../ops/declarable/headers/parity_ops.h | 18 + .../ops/declarable/helpers/cpu/lup.cpp | 149 +++- .../declarable/helpers/cuda/image_resize.cu | 48 +- .../ops/declarable/helpers/cuda/lup.cu | 172 ++++- libnd4j/include/ops/declarable/helpers/lup.h | 5 +- .../layers_tests/DeclarableOpsTests11.cpp | 4 +- .../layers_tests/DeclarableOpsTests12.cpp | 253 +++++++ .../layers_tests/DeclarableOpsTests15.cpp | 668 ++++-------------- 9 files changed, 797 insertions(+), 579 deletions(-) create mode 100644 libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp new file mode 100644 index 000000000..83b4a42d9 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by GS at 12/10/2019 +// + +#include +#if NOT_EXCLUDED(OP_matrix_inverse) + +#include +#include +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(lu, 1, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + auto p = OUTPUT_VARIABLE(1); + + REQUIRE_TRUE(input->rankOf() >=2, 0, "matrix_inverse: The rank of input array should not less than 2, but %i is given", input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "matrix_inverse: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); + + helpers::lu(block.launchContext(), input, z, p); + return Status::OK(); + } + + DECLARE_SHAPE_FN(lu) { + auto in = inputShape->at(0); + auto shapeVector = ShapeUtils::shapeAsVector(in); + auto luShape = ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace()); + auto luP = ShapeBuilders::createShapeInfo(nd4j::DataType::INT32, shape::order(in), shapeVector.size() - 1, + shapeVector.data(), block.workspace()); + return SHAPELIST(CONSTANT(luShape), CONSTANT(luP)); + } + + DECLARE_TYPES(lu) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {nd4j::DataType::INT32, nd4j::DataType::INT64}) + ->setSameMode(false); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index e56ba9d6e..cbaae52f7 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1027,6 +1027,24 @@ namespace nd4j { DECLARE_OP(matrix_inverse, 1, 1, true); #endif + /** + * lu op. - make LUP decomposition of given batch of 2D square matricies + * + * input params: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M matricies in it + * 1 - int (32 or 64) batched vector of permutations with length M - shape (x * y * z * ::: * M) + * + * int argument: + * 0 - data type of output permutaion vector (int32 or int64), optional, default INT32 + */ + + #if NOT_EXCLUDED(OP_matrix_inverse) + DECLARE_CUSTOM_OP(lu, 1, 2, false, 0, 0); + #endif + /** * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] * diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 76817078b..d706eaff3 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -22,6 +22,8 @@ #include #include #include +#include +#include namespace nd4j { namespace ops { @@ -32,15 +34,30 @@ namespace helpers { if (theFirst != theSecond) for (int i = 0; i < matrix->columns(); i++) { - T e0 = matrix->e(theFirst, i); - T e1 = matrix->e(theSecond, i); - - matrix->p(theFirst, i, e1); - matrix->p(theSecond, i, e0); + math::nd4j_swap(matrix->t(theFirst, i), matrix->t(theSecond, i)); } } BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES); + template + static void swapRows(T* matrixBuf, Nd4jLong* matrixShape, Nd4jLong theFirst, Nd4jLong theSecond) { + if (theFirst != theSecond) { + auto n = shape::sizeAt(matrixShape, -1); + + auto loop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + Nd4jLong theFirstPos[] = {theFirst, i}; + Nd4jLong theSecondPos[] = {theSecond, i}; + auto theFirstIndex = shape::getOffset(matrixShape, theFirstPos, 0); + auto theSecondIndex = shape::getOffset(matrixShape, theSecondPos, 0); + math::nd4j_swap(matrixBuf[theFirstIndex], matrixBuf[theSecondIndex]); + } + }; + + samediff::Threads::parallel_tad(loop, 0, n, 1); + } + } + void swapRows(NDArray* matrix, int theFirst, int theSecond) { BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES); } @@ -106,7 +123,7 @@ namespace helpers { } - template + template static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) { const int rowNum = input->rows(); @@ -132,7 +149,7 @@ namespace helpers { } } - if( pivotValue > T(0.00001)) { + if( pivotValue > DataTypeUtils::min()) { swapRows(&compoundMatrix, pivot, i); swapRows(&permutationMatrix, pivot, i); if (pivot != i) @@ -155,14 +172,113 @@ namespace helpers { if (swapCount % 2) determinant = -determinant; if (compound != nullptr) compound->assign(compoundMatrix); - if (permutation != nullptr) - permutation->assign(permutationMatrix); + if (permutation != nullptr) { + auto permutaionVector = NDArrayFactory::create('c', {rowNum}, DataTypeUtils::fromT(), input->getContext()); + for (auto i = 0; i < rowNum; i++) { + for (auto j = 0; j < columnNum; j++) { + if (permutationMatrix.t(i, j) != 0) { + permutaionVector.template t(i) = j; + } + } + } + if (permutationMatrix.isSameShape(permutation)) + permutation->assign(permutationMatrix); + else if (permutation->isSameShape(permutaionVector)) { + permutation->assign(permutaionVector); + } + } return determinant; } - BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); + BUILD_DOUBLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); + /* + * lu decomposition with naive algorithm with partial pivoting + * */ + template + static I argmaxCol(I column, T* compoundBuffer, Nd4jLong* compoundShape) { + auto rowNum = shape::sizeAt(compoundShape, 0); + Nd4jLong xInitial[] = {column, column}; + auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); + auto maxValue = T(0); //nd4j::math::nd4j_abs(compoundBuffer[xInitialIndex]); + auto result = -1; + //auto loop = PRAGMA_THREADS_FOR { + auto start = column, stop = rowNum, increment = 1; + for (auto rowCounter = start; rowCounter < stop; rowCounter += increment) { + Nd4jLong xPos[] = {rowCounter, column}; + auto xIndex = shape::getOffset(compoundShape, xPos, 0); + if (nd4j::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { + maxValue = nd4j::math::nd4j_max(maxValue, nd4j::math::nd4j_abs(compoundBuffer[xIndex])); + result = rowCounter; + } + } + //}; + //samediff::Threads::parallel_for(loop, column, rowNum, 1); + return result; + } + template + void processColumns(int currentRow, int rowNum, T* compoundBuf, Nd4jLong* compoundShape) { + Nd4jLong xDiag[] = {currentRow, currentRow}; + auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); + auto loop = PRAGMA_THREADS_FOR { + for (int j = start; j < stop; j += increment) { + Nd4jLong xRow[] = {j, currentRow}; + auto rowIndex = shape::getOffset(compoundShape, xRow, 0); + compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t(i, i); + for (int k = currentRow + 1; k < rowNum; k++) { + Nd4jLong yRow[] = {j, k}; + Nd4jLong yCol[] = {currentRow, k}; + auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); + auto colIndex = shape::getOffset(compoundShape, yCol, 0); + compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; + } + } + }; + samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1); + } + template + static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) { + + //const int rowNum = compound->rows(); +// const int columnNum = output->columns(); + permutation->linspace(0); + auto permutationBuf = permutation->bufferAsT(); //dataBuffer()->primaryAsT(); + auto compoundBuf = compound->bufferAsT(); + auto compoundShape = compound->shapeInfo(); + auto permutationShape = permutation->shapeInfo(); + for (auto i = 0; i < rowNum - 1; i++) { + auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape); + if (pivotIndex < 0) { + throw std::runtime_error("helpers::luNN_: input matrix is singular."); + } + math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); + swapRows(compoundBuf, compoundShape, i, pivotIndex); + + processColumns(i, rowNum, compoundBuf, compoundShape); + } + } + + template + static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { + auto n = input->sizeAt(-1); + + output->assign(input); // fill up output tensor with zeros + std::unique_ptr outputs(output->allTensorsAlongDimension({-2, -1})); + std::unique_ptr permutations(permutationVectors->allTensorsAlongDimension({-1})); + auto loop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + luNN_(context, outputs->at(i), permutations->at(i), n); + } + }; + samediff::Threads::parallel_for(loop, 0, outputs->size(), 1); + } + + void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES); + } + +// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); template static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) { @@ -175,7 +291,7 @@ namespace helpers { for (int e = 0; e < output->lengthOf(); e++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) matrix.p(row, input->e(k)); - output->p(e, lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr)); + output->p(e, lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr)); } return Status::OK(); @@ -196,7 +312,7 @@ template for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { matrix.p(row, input->e(k)); } - NDArray det = lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); + NDArray det = lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); if (det.e(0) != 0.f) output->p(e, nd4j::math::nd4j_log(nd4j::math::nd4j_abs(det.t(0)))); } @@ -229,7 +345,7 @@ template for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { matrix.p(row++, input->e(k)); } - T det = lup_(context, &matrix, &compound, &permutation).template e(0); + T det = lup_(context, &matrix, &compound, &permutation).template e(0); // FIXME: and how this is going to work on float16? if (nd4j::math::nd4j_abs(det) < T(0.000001)) { @@ -274,7 +390,7 @@ template // check for symmetric for (Nd4jLong r = 0; r < thisMatrix->rows(); r++) for (Nd4jLong c = 0; c < thisMatrix->columns(); c++) - if (nd4j::math::nd4j_abs(thisMatrix->e(r, c) - lastMatrixList->at(i)->e(c,r)) > T(1.e-6f)) return false; + if (nd4j::math::nd4j_abs(thisMatrix->e(r, c) - lastMatrixList->at(i)->e(c,r)) > DataTypeUtils::min()) return false; NDArray output = NDArrayFactory::create(0., context); if (ND4J_STATUS_OK != determinant(context, thisMatrix, &output)) return false; @@ -366,6 +482,11 @@ template BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES); } + int lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_, (context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES); + return Status::OK(); + } + } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index ab3a96801..b8cd35261 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -689,11 +689,17 @@ namespace helpers { } template - static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, float* cachedValue, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) { + static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) { // auto numChannels = pResizerState->channels; + for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) { auto pInput = inputPtr + b * inBatchWidth; + float* cachedValue; for (Nd4jLong y = threadIdx.x; y < pResizerState->outHeight; y += blockDim.x) { + if (threadIdx.x == 0) { + extern __shared__ char sharedChar[]; + cachedValue = reinterpret_cast(sharedChar); + } auto pos = (b * pResizerState->outHeight + y) * pResizerState->outWidth * pResizerState->channels; auto pOutput = &outputPtr[pos]; struct WeightsAndIndices yWai; @@ -846,20 +852,20 @@ namespace helpers { throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot set up memory for resizerState", err); } - float* cachedValue = nullptr; - size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels); - if (cachedSize) { - err = cudaMalloc(reinterpret_cast(&cachedValue), cachedSize); - if (err != 0) { - throw cuda_exception::build( - "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err); - } - err = cudaMemset(cachedValue, 0, cachedSize); - if (err != 0) { - throw cuda_exception::build( - "helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err); - } - } +// float* cachedValue = nullptr; +// size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels); +// if (cachedSize) { +// err = cudaMalloc(reinterpret_cast(&cachedValue), cachedSize); +// if (err != 0) { +// throw cuda_exception::build( +// "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err); +// } +// err = cudaMemset(cachedValue, 0, cachedSize); +// if (err != 0) { +// throw cuda_exception::build( +// "helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err); +// } +// } WeightsAndIndices* xWais; //(resizerState.outWidth); err = cudaMalloc(&xWais, sizeof(WeightsAndIndices) * resizerState.outWidth); @@ -878,7 +884,7 @@ namespace helpers { } const T* pInput = image->getDataBuffer()->specialAsT(); float* pOutput = output->dataBuffer()->specialAsT(); //_data.data(); - bicubicInterpolateWithCachingKernel<<<128, 1, 512, *stream>>>(coeffsTable, cachedValue, pInput, + bicubicInterpolateWithCachingKernel<<<128, 1, 512, *stream>>>(coeffsTable, pInput, resizerStateD, xWais, halfPixelCenters, inBatchWidth, inRowWidth, pOutput); err = cudaStreamSynchronize(*stream); if (err != 0) { @@ -889,11 +895,11 @@ namespace helpers { if (err != 0) { throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for resizerState", err); } - if (cachedSize) - err = cudaFree(cachedValue); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err); - } +// if (cachedSize) +// err = cudaFree(cachedValue); +// if (err != 0) { +// throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err); +// } err = cudaFree(xWais); if (err != 0) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 568b9a9bc..4e5d9e85e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -24,6 +24,7 @@ #include #include #include +//#include #include #include @@ -336,7 +337,7 @@ namespace helpers { // // input - A matrix nxn // compound - C matrix L + U - I, or main diagonal and lower - L matrix, from the 2nd diagonal - U matrix - template + template static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { auto stream = context->getCudaStream(); auto n = input->rows(); @@ -383,7 +384,7 @@ namespace helpers { err); } - if (permutation == nullptr) + if (permutation == nullptr) { status = cusolverDnDgetrf( cusolverH, n, @@ -393,9 +394,15 @@ namespace helpers { d_work, nullptr, d_info); + + if (status != CUSOLVER_STATUS_SUCCESS) { + throw cuda_exception::build("helpers::lup_: LU factorization is failed due ", + status); + } + } else { NDArray permutVector('c', {n}, nd4j::DataType::INT32, context); - int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); + int* permutationBuf = permutVector.dataBuffer()->specialAsT(); status = cusolverDnDgetrf( cusolverH, n, @@ -405,9 +412,21 @@ namespace helpers { d_work, permutationBuf, d_info); - fillUpPermutation << < n, n, 1024, *stream >> > - (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); - permutation->tickWriteDevice(); + if (status != CUSOLVER_STATUS_SUCCESS) { + throw cuda_exception::build("helpers::lup_: LU factorization is failed due ", + status); + } + + if (permutation->rankOf() == 2) { + fillUpPermutation <<< n, n, 1024, *stream >>> + (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); + } + else { + permutVector.tickWriteDevice(); + input->tickWriteDevice(); + compound->assign(input); + permutation->assign(permutVector); + } } err = cudaFree(d_work); if (err) { @@ -448,7 +467,7 @@ namespace helpers { nullptr, d_info); else { - NDArray permutVector('c', {n}, nd4j::DataType::INT32, context); + NDArray permutVector('c', {n}, DataType::INT32, context); int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); status = cusolverDnSgetrf( cusolverH, @@ -459,9 +478,16 @@ namespace helpers { d_work, permutationBuf, d_info); - fillUpPermutation <<< n, n, 128, *stream >> > - (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); - permutation->tickWriteDevice(); + if (permutation->rankOf() == 2) { + fillUpPermutation <<< n, n, 128, *stream >>> + (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); + permutation->tickWriteDevice(); + } + else { + input->tickWriteDevice(); + compound->assign(input); + permutation->assign(permutVector); + } } err = cudaFree(d_work); if (err) { @@ -484,8 +510,116 @@ namespace helpers { } // ------------------------------------------------------------------------------------------------------------------ // - BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE); + BUILD_DOUBLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE, INDEXING_TYPES); + template + static __device__ void swapRows(T* matrix, Nd4jLong* shape, Nd4jLong theFirst, Nd4jLong theSecond, Nd4jLong n) { + if (theFirst != theSecond) { + for (auto i = 0; i < n; i++) { + Nd4jLong theFirstPos[] = {theFirst, i}; + Nd4jLong theSecondPos[] = {theSecond, i}; + auto theFirstIndex = shape::getOffset(shape, theFirstPos, 0); + auto theSecondIndex = shape::getOffset(shape, theSecondPos, 0); + math::nd4j_swap(matrix[theFirstIndex], matrix[theSecondIndex]); + } + } + } + + template + static __device__ void processColumns(Nd4jLong currentRow, Nd4jLong rowNum, T* compoundBuf, Nd4jLong* compoundShape) { + Nd4jLong xDiag[] = {currentRow, currentRow}; + auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); + for (auto j = currentRow + 1; j < rowNum; j++) { + Nd4jLong xRow[] = {j, currentRow}; + auto rowIndex = shape::getOffset(compoundShape, xRow, 0); + compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t(i, i); + for (auto k = currentRow + 1; k < rowNum; k++) { + Nd4jLong yRow[] = {j, k}; + Nd4jLong yCol[] = {currentRow, k}; + auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); + auto colIndex = shape::getOffset(compoundShape, yCol, 0); + compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; + } + } + } + + template + __device__ Nd4jLong argmaxCol(Nd4jLong column, T* compoundBuffer, Nd4jLong* compoundShape) { + auto rowNum = shape::sizeAt(compoundShape, 0); + Nd4jLong xInitial[] = {column, column}; + auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); + auto maxValue = T(0); //nd4j::math::nd4j_abs(compoundBuffer[xInitialIndex]); + auto result = -1LL; + + for (auto rowCounter = column; rowCounter < rowNum; rowCounter++) { + Nd4jLong xPos[] = {rowCounter, column}; + auto xIndex = shape::getOffset(compoundShape, xPos, 0); + if (nd4j::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { + maxValue = nd4j::math::nd4j_max(maxValue, nd4j::math::nd4j_abs(compoundBuffer[xIndex])); + result = rowCounter; + } + } + return result; + } + + template + static __device__ int luNN(T* matrix, Nd4jLong* shape, I* permutation, Nd4jLong* permuShape, Nd4jLong n) { + + for (auto i = 0; i < n - 1; i++) { + auto pivotIndex = argmaxCol(i, matrix, shape); + if (pivotIndex < 0) { + return -1;//throw std::runtime_error("helpers::luNN_: input matrix is singular."); + } + math::nd4j_swap(permutation[shape::getIndexOffset(i, permuShape)], permutation[shape::getIndexOffset(pivotIndex, permuShape)]); + swapRows(matrix, shape, (Nd4jLong)i, pivotIndex, n); + + processColumns(i, n, matrix, shape); + } + return 0; + } + + template + static __global__ void luBatchedKernel(T* outputBuf, Nd4jLong* outputShape, I* permutations, Nd4jLong* permuShape, + Nd4jLong* outputTadShape, Nd4jLong* outputTadOffsets, Nd4jLong* permuTadShape, Nd4jLong* permuTadOffsets, + Nd4jLong batchNum) { + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto b = start; b < batchNum; b += step) { + T* matrix = outputBuf + outputTadOffsets[b]; + I* permutation = permutations + permuTadOffsets[b]; + + if (0 != luNN(matrix, outputTadShape, permutation, permuTadShape, shape::length(permuTadShape))) break; + } + } + + template + static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { + auto n = input->sizeAt(-1); + auto stream = context->getCudaStream(); + auto iota = NDArrayFactory::create('c', {n}); + iota.linspace(0); iota.syncToDevice(); + + output->assign(input); // fill up output tensor with zeros + output->tickWriteDevice(); + permutationVectors->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &iota, permutationVectors, true, nullptr); + permutationVectors->tickWriteDevice(); + + auto tads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1}); + auto permutaionTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-1}); + auto batchNum = tads.numberOfTads(); + luBatchedKernel<<>>(reinterpret_cast(output->platformBuffer()), + output->specialShapeInfo(), reinterpret_cast(permutationVectors->platformBuffer()), + permutationVectors->specialShapeInfo(), tads.specialShapeInfo(), tads.specialOffsets(), + permutaionTads.specialShapeInfo(), permutaionTads.specialOffsets(), batchNum); + } + + void lu(LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutations) { + NDArray::prepareSpecialUse({output, permutations}, {input}); + BUILD_DOUBLE_SELECTOR(input->dataType(), permutations->dataType(), lu_, (context, input, output, permutations), FLOAT_NATIVE, INDEXING_TYPES); + NDArray::registerSpecialUse({output, permutations}, {input}); + } // ------------------------------------------------------------------------------------------------------------------ // template static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { @@ -509,7 +643,7 @@ namespace helpers { fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // else // fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); - lup_(context, &matrix, nullptr, nullptr); + lup_(context, &matrix, nullptr, nullptr); // else // lup_(context, &matrix, nullptr, nullptr); auto offset = shape::getIndexOffset(e, output->shapeInfo()); @@ -557,7 +691,7 @@ namespace helpers { // fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // if (matrix.dataType() == input->dataType()) - lup_(context, &matrix, nullptr, nullptr); + lup_(context, &matrix, nullptr, nullptr); // else // lup_(context, &matrix, nullptr, nullptr); auto offset = shape::getIndexOffset(e, output->shapeInfo()); @@ -638,7 +772,7 @@ namespace helpers { matrix.tickWriteDevice(); //compound.assign(matrix); // if (matrix.dataType() == input->dataType()) - lup_(context, &matrix, nullptr, nullptr); + lup_(context, &matrix, nullptr, nullptr); fillLowerUpperKernel<<>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n); lower.tickWriteDevice(); upper.tickWriteDevice(); @@ -849,7 +983,7 @@ namespace helpers { auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput.getShapeInfo(), {tempOutput.rankOf() - 2, tempOutput.rankOf() - 1}); - logDetKernel <<< 128, 512, 256, *stream >>>(inputBuf, tempOutput.specialShapeInfo(), + logDetKernel <<<128, 512, 256, *stream>>>(inputBuf, tempOutput.specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo()); output->tickWriteDevice(); @@ -861,6 +995,14 @@ namespace helpers { BUILD_SINGLE_SELECTOR(output->dataType(), return logdetFunctor_, (context, input, output), FLOAT_NATIVE); } + /* + * lup - batched input, batched outputs + * */ + int lup(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_,(context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES); + return Status::OK(); + } + // BUILD_SINGLE_TEMPLATE(template int logdetFunctor_, // (nd4j::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE); } diff --git a/libnd4j/include/ops/declarable/helpers/lup.h b/libnd4j/include/ops/declarable/helpers/lup.h index 96ec9bec1..ae10e6136 100644 --- a/libnd4j/include/ops/declarable/helpers/lup.h +++ b/libnd4j/include/ops/declarable/helpers/lup.h @@ -26,9 +26,8 @@ namespace nd4j { namespace ops { namespace helpers { - template - T lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation); - + int lup(nd4j::LaunchContext* context, NDArray* input, NDArray* lu, NDArray* permutation); + void lu(nd4j::LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation); int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output); int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 7b08bfbe4..37bcba233 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -1050,7 +1050,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) { auto testData = NDArrayFactory::create('c', {2,9,9,1}, { - 0.230286f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f, + 0.230286514f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f, 0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f, 0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f, 0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f, @@ -1081,7 +1081,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) { NDArray* result = results->at(0); // result->printBuffer("Resized to 9x9"); -// expected.printBuffer("Expect for 9x9"); +// testData.printBuffer("Expect for 9x9"); ASSERT_TRUE(testData.isSameShape(result)); ASSERT_TRUE(testData.equalsTo(result)); delete results; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 0d205c2db..0710e5506 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -2424,3 +2424,256 @@ TEST_F(DeclarableOpsTests12, Pad_SGO_Test_1) { ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_1) { + + auto in = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.}); + auto exp = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7}); + auto pExp = NDArrayFactory::create('c', {3}, {0, 1, 2}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars"); +// p->printIndexedBuffer("Permutaions"); + + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(pExp.equalsTo(p)); + + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_2) { + auto in = NDArrayFactory::create('c', {3,3}, {1, 0, 0, 2, 3, 0, 4, 5, 6}); + + auto expLU = NDArrayFactory::create('c', {3,3}, {4., 5., 6., 0.25, -1.25, -1.5, 0.5, -0.4, -3.6}); + auto expP = NDArrayFactory::create({2, 0, 1}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars2"); +// p->printIndexedBuffer("Permutaions2"); + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_3) { + auto in = NDArrayFactory::create('c', {3,3}, {1,2,3,4,7,9, 11, 12, 13}); + + auto expLU = NDArrayFactory::create('c', {3,3}, { + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753}); + + auto expP = NDArrayFactory::create({2, 1, 0}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars3"); +// p->printIndexedBuffer("Permutaions3"); + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_4) { + + auto in = NDArrayFactory::create('c', {10,10}, { + 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., + 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., + 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., + 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., + 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., + 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., + 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + + auto expLU = NDArrayFactory::create('c', {10,10}, { + 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, + 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, + 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, + 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, + 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, + 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, + 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695 + }); + + auto expP = NDArrayFactory::create({1, 2, 7, 3, 6, 8, 5, 4, 0, 9}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printBuffer("Triangulars4"); +// expLU.printBuffer("TriangulExp4"); +// p->printBuffer("Permutaions4"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +TEST_F(DeclarableOpsTests12, LU_Test_5) { + + auto in = NDArrayFactory::create('c', {2, 10,10}, { + 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., + 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., + 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., + 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., + 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., + 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., + 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + + 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., + 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., + 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., + 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., + 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., + 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., + 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. + }); + + auto expLU = NDArrayFactory::create('c', {2, 10,10}, { + 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, + 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, + 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, + 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, + 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, + 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, + 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695, + + 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, + 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, + 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, + 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, + 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, + 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, + 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695 + + }); + + auto expP = NDArrayFactory::create('c', {2, 10}, { + 1, 2, 7, 3, 6, 8, 5, 4, 0, 9, + 1, 2, 7, 3, 6, 8, 5, 4, 0, 9 + }); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printBuffer("Triangulars5"); +// expLU.printBuffer("TriangulExp5"); +// p->printBuffer("Permutaions5"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_1_2) { + + auto in = NDArrayFactory::create('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.,1., 2., 3., 0., 2., 3., 0., 0., 7.}); + auto exp = NDArrayFactory::create('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7, 1., 2., 3., 0., 2., 3., 0., 0., 7.}); + + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars (2,3,3)"); +// p->printIndexedBuffer("Permutaions (2,3,3)"); + ASSERT_TRUE(exp.equalsTo(res->at(0))); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_3_2) { + + auto in = NDArrayFactory::create('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,1,2,3,4,7,9, 11, 12, 13}); + + auto expLU = NDArrayFactory::create('c', {2, 3,3}, { + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753, + + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753 + }); + + auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 2, 1, 0}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars3_2"); +// p->printIndexedBuffer("Permutaions3_2"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_3_3) { + + auto in = NDArrayFactory::create('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,13,2,3,4,7,9, 11, 12, 1}); + auto expLU = NDArrayFactory::create('c', {2, 3,3}, { + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753, + + 13., 2., 3., + 0.84615386, 10.307693, -1.5384617, + 0.30769232, 0.619403, 9.029851}); + + auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 0, 2, 1}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars3_3"); +// p->printIndexedBuffer("Permutaions3_3"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index d87acc439..5697a5257 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -293,268 +293,77 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) { .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { - 1.0218375f, - 1.0666375f, - 0.9130375f, - - -0.07396251f, - 0.91843754f, - -0.17496246f, - - 0.47543746f, - 1.2492375f, - 0.55643755f, - - 1.3110375f, - -0.36456245f, - 1.0518374f, - - 0.7824375f, - 0.57523745f, - -0.21656245f, - - 0.0816375f, - -0.2261625f, - 0.40323752f, - - 1.4520376f, - 0.6868375f, - 0.81723756f, - - -0.17576247f, - 0.81423753f, - -0.08656245f, - - - -0.36249164f, - 0.45590833f, - 1.1925083f, - - 0.00650835f, - 1.4861084f, - 1.2079083f, - - 0.05270836f, - 0.37350836f, - 0.94130826f, - - 1.0715083f, - 0.6103083f, - 0.9825083f, - - 0.07370833f, - -0.4518917f, - -0.39889166f, - - -0.3354917f, - 1.2213084f, - 1.0345083f, - - -0.3132917f, - 0.78470826f, - 0.23390833f, - - 0.6943083f, - 0.68170834f, - -0.09989169f, - - - 0.8352709f, - 1.3798709f, - 0.15507084f, - - 0.26607084f, - -0.10792917f, - 1.2302709f, - - 0.6448709f, - -0.29992914f, - 1.3534708f, - - 0.86607087f, - 0.37607086f, - 0.04027084f, - - 0.40087086f, - 0.59507084f, - 0.9416709f, - - 0.53127086f, - -0.01712915f, - 1.4610709f, - - -0.17152917f, - -0.13992918f, - 0.6242708f, - - -0.42192918f, - 0.38387084f, - -0.15752912f, - - - 0.3311833f, - 0.00618333f, - 0.17538333f, - - 0.10418332f, - 0.8365834f, - 0.27098334f, - - 1.2421833f, - -0.1114167f, - 1.0153834f, - - 0.9523833f, - 0.8317833f, - 0.9633833f, - - 0.6501833f, - 0.04258335f, - 0.9999833f, - - -0.40181667f, - 0.11418331f, - 0.47938335f, - - 1.1057833f, - -0.29761666f, - 1.0779834f, - - 0.5243833f, - -0.32181668f, - 1.1833833f, - - - 0.73157084f, - 0.4317708f, - 0.7283708f, - - 1.2297708f, - 0.4307708f, - 0.85377085f, - - 0.05977082f, - -0.09282917f, - 0.33957082f, - - 1.0751709f, - 0.2119708f, - 0.51897085f, - - -0.25302917f, - 1.1723708f, - -0.12562919f, - - 1.1993709f, - 0.5257708f, - 0.40517086f, - - 0.53197086f, - 0.8441708f, - 0.02617085f, - - -0.0208292f, - 0.8711709f, - 0.04137081f, - - - 0.74936247f, - 0.6085625f, - 0.8997625f, - - -0.08743751f, - 0.18576252f, - -0.17563748f, - - 0.5991625f, - -0.0038375f, - 0.07576251f, - - 0.42536253f, - -0.22823751f, - 0.36296248f, - - 0.81456256f, - -0.16183749f, - 0.5161625f, - - -0.21183747f, - 0.7429625f, - 0.6217625f, - - 0.17656249f, - 0.02616251f, - -0.17923748f, - - 1.4659625f, - 0.40016252f, - 0.28356248f, - - - 0.4195791f, - 0.8745791f, - 0.36637908f, - - 0.50597906f, - -0.17942089f, - 0.16917908f, - - 1.0235791f, - 1.3699791f, - -0.11382091f, - - -0.0918209f, - 0.7757791f, - 0.09017909f, - - 1.3807791f, - -0.15202093f, - 1.3875791f, - - -0.1712209f, - 1.3989791f, - 0.43777913f, - - 0.7855791f, - 0.1423791f, - 1.4711791f, - - 0.6455791f, - 0.6211791f, - -0.48062086f, - - - 0.10189578f, - 0.5628958f, - 0.68909574f, - - 0.96649575f, - -0.09370419f, - 1.3466958f, - - 1.4584957f, - 1.3544958f, - -0.3829042f, - - 0.11269578f, - -0.47890422f, - 1.0436958f, - - 0.6128957f, - 0.27209583f, - 0.2714958f, - - 0.21889582f, - 0.08789578f, - 1.1296958f, - - 0.4596958f, - 0.39309582f, - 0.8344958f, - - 0.71149576f, - -0.4799042f, - 0.4880958f + 1.0218375f, 1.0666375f, 0.9130375f, + -0.07396251f, 0.91843754f, -0.17496246f, + 0.47543746f, 1.2492375f, 0.55643755f, + 1.3110375f, -0.36456245f, 1.0518374f, + 0.7824375f, 0.57523745f, -0.21656245f, + 0.0816375f, -0.2261625f, 0.40323752f, + 1.4520376f, 0.6868375f, 0.81723756f, + -0.17576247f, 0.81423753f, -0.08656245f, + + -0.36249164f, 0.45590833f, 1.1925083f, + 0.00650835f, 1.4861084f, 1.2079083f, + 0.05270836f, 0.37350836f, 0.94130826f, + 1.0715083f, 0.6103083f, 0.9825083f, + 0.07370833f, -0.4518917f, -0.39889166f, + -0.3354917f, 1.2213084f, 1.0345083f, + -0.3132917f, 0.78470826f, 0.23390833f, + 0.6943083f, 0.68170834f, -0.09989169f, + + 0.8352709f, 1.3798709f, 0.15507084f, + 0.26607084f, -0.10792917f, 1.2302709f, + 0.6448709f, -0.29992914f, 1.3534708f, + 0.86607087f, 0.37607086f, 0.04027084f, + 0.40087086f, 0.59507084f, 0.9416709f, + 0.53127086f, -0.01712915f, 1.4610709f, + -0.17152917f, -0.13992918f, 0.6242708f, + -0.42192918f, 0.38387084f, -0.15752912f, + + 0.3311833f, 0.00618333f, 0.17538333f, + 0.10418332f, 0.8365834f, 0.27098334f, + 1.2421833f, -0.1114167f, 1.0153834f, + 0.9523833f, 0.8317833f, 0.9633833f, + 0.6501833f, 0.04258335f, 0.9999833f, + -0.40181667f, 0.11418331f, 0.47938335f, + 1.1057833f, -0.29761666f, 1.0779834f, + 0.5243833f, -0.32181668f, 1.1833833f, + + 0.73157084f, 0.4317708f, 0.7283708f, + 1.2297708f, 0.4307708f, 0.85377085f, + 0.05977082f, -0.09282917f, 0.33957082f, + 1.0751709f, 0.2119708f, 0.51897085f, + -0.25302917f, 1.1723708f, -0.12562919f, + 1.1993709f, 0.5257708f, 0.40517086f, + 0.53197086f, 0.8441708f, 0.02617085f, + -0.0208292f, 0.8711709f, 0.04137081f, + + 0.74936247f, 0.6085625f, 0.8997625f, + -0.08743751f, 0.18576252f, -0.17563748f, + 0.5991625f, -0.0038375f, 0.07576251f, + 0.42536253f, -0.22823751f, 0.36296248f, + 0.81456256f, -0.16183749f, 0.5161625f, + -0.21183747f, 0.7429625f, 0.6217625f, + 0.17656249f, 0.02616251f, -0.17923748f, + 1.4659625f, 0.40016252f, 0.28356248f, + + 0.4195791f, 0.8745791f, 0.36637908f, + 0.50597906f, -0.17942089f, 0.16917908f, + 1.0235791f, 1.3699791f, -0.11382091f, + -0.0918209f, 0.7757791f, 0.09017909f, + 1.3807791f, -0.15202093f, 1.3875791f, + -0.1712209f, 1.3989791f, 0.43777913f, + 0.7855791f, 0.1423791f, 1.4711791f, + 0.6455791f, 0.6211791f, -0.48062086f, + + 0.10189578f, 0.5628958f, 0.68909574f, + 0.96649575f, -0.09370419f, 1.3466958f, + 1.4584957f, 1.3544958f, -0.3829042f, + 0.11269578f, -0.47890422f, 1.0436958f, + 0.6128957f, 0.27209583f, 0.2714958f, + 0.21889582f, 0.08789578f, 1.1296958f, + 0.4596958f, 0.39309582f, 0.8344958f, + 0.71149576f, -0.4799042f, 0.4880958f }); nd4j::ops::adjust_contrast op; @@ -587,268 +396,79 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { - 1.0218375 , - 1.0666375 , - 0.9130375 , - - -0.07396251, - 0.91843754, - -0.17496246, - - 0.47543746, - 1.2492375 , - 0.55643755, - - 1.3110375 , - -0.36456245, - 1.0518374 , - - 0.7824375 , - 0.57523745, - -0.21656245, - - 0.0816375 , - -0.2261625 , - 0.40323752, - - 1.4520376 , - 0.6868375 , - 0.81723756, - - -0.17576247, - 0.81423753, - -0.08656245, - - - -0.36249164, - 0.45590833, - 1.1925083 , - - 0.00650835, - 1.4861084 , - 1.2079083 , - - 0.05270836, - 0.37350836, - 0.94130826, - - 1.0715083 , - 0.6103083 , - 0.9825083 , - - 0.07370833, - -0.4518917 , - -0.39889166, - - -0.3354917 , - 1.2213084 , - 1.0345083 , - - -0.3132917 , - 0.78470826, - 0.23390833, - - 0.6943083 , - 0.68170834, - -0.09989169, - - - 0.8352709 , - 1.3798709 , - 0.15507084, - - 0.26607084, - -0.10792917, - 1.2302709 , - - 0.6448709 , - -0.29992914, - 1.3534708 , - - 0.86607087, - 0.37607086, - 0.04027084, - - 0.40087086, - 0.59507084, - 0.9416709 , - - 0.53127086, - -0.01712915, - 1.4610709 , - - -0.17152917, - -0.13992918, - 0.6242708 , - - -0.42192918, - 0.38387084, - -0.15752912, - - - 0.3311833 , - 0.00618333, - 0.17538333, - - 0.10418332, - 0.8365834 , - 0.27098334, - - 1.2421833 , - -0.1114167 , - 1.0153834 , - - 0.9523833 , - 0.8317833 , - 0.9633833 , - - 0.6501833 , - 0.04258335, - 0.9999833 , - - -0.40181667, - 0.11418331, - 0.47938335, - - 1.1057833 , - -0.29761666, - 1.0779834 , - - 0.5243833 , - -0.32181668, - 1.1833833 , - - - 0.73157084, - 0.4317708 , - 0.7283708 , - - 1.2297708 , - 0.4307708 , - 0.85377085, - - 0.05977082, - -0.09282917, - 0.33957082, - - 1.0751709 , - 0.2119708 , - 0.51897085, - - -0.25302917, - 1.1723708 , - -0.12562919, - - 1.1993709 , - 0.5257708 , - 0.40517086, - - 0.53197086, - 0.8441708 , - 0.02617085, - - -0.0208292 , - 0.8711709 , - 0.04137081, - - - 0.74936247, - 0.6085625 , - 0.8997625 , - - -0.08743751, - 0.18576252, - -0.17563748, - - 0.5991625 , - -0.0038375 , - 0.07576251, - - 0.42536253, - -0.22823751, - 0.36296248, - - 0.81456256, - -0.16183749, - 0.5161625 , - - -0.21183747, - 0.7429625 , - 0.6217625 , - - 0.17656249, - 0.02616251, - -0.17923748, - - 1.4659625 , - 0.40016252, - 0.28356248, - - - 0.4195791 , - 0.8745791 , - 0.36637908, - - 0.50597906, - -0.17942089, - 0.16917908, - - 1.0235791 , - 1.3699791 , - -0.11382091, - - -0.0918209 , - 0.7757791 , - 0.09017909, - - 1.3807791 , - -0.15202093, - 1.3875791 , - - -0.1712209 , - 1.3989791 , - 0.43777913, - - 0.7855791 , - 0.1423791 , - 1.4711791 , - - 0.6455791 , - 0.6211791 , - -0.48062086, - - - 0.10189578, - 0.5628958 , - 0.68909574, - - 0.96649575, - -0.09370419, - 1.3466958 , - - 1.4584957 , - 1.3544958 , - -0.3829042 , - - 0.11269578, - -0.47890422, - 1.0436958 , - - 0.6128957 , - 0.27209583, - 0.2714958 , - - 0.21889582, - 0.08789578, - 1.1296958 , - - 0.4596958 , - 0.39309582, - 0.8344958 , - - 0.71149576, - -0.4799042, - 0.4880958 + 1.0218375, 1.0666375 , 0.9130375 , + -0.07396251, 0.91843754, -0.17496246, + 0.47543746, 1.2492375 , 0.55643755, + 1.3110375 , -0.36456245, 1.0518374 , + 0.7824375 , 0.57523745, -0.21656245, + 0.0816375 , -0.2261625 , 0.40323752, + 1.4520376 , 0.6868375 , 0.81723756, + -0.17576247, 0.81423753, -0.08656245, + + -0.36249164, 0.45590833, 1.1925083 , + 0.00650835, 1.4861084 , 1.2079083 , + 0.05270836, 0.37350836, 0.94130826, + 1.0715083 , 0.6103083 , 0.9825083 , + 0.07370833, -0.4518917 , -0.39889166, + -0.3354917 , 1.2213084 , 1.0345083 , + -0.3132917 , 0.78470826, 0.23390833, + 0.6943083 , 0.68170834, -0.09989169, + + 0.8352709 , 1.3798709 , 0.15507084, + 0.26607084, -0.10792917, 1.2302709 , + 0.6448709 , -0.29992914, 1.3534708 , + 0.86607087, 0.37607086, 0.04027084, + 0.40087086, 0.59507084, 0.9416709 , + 0.53127086, -0.01712915, 1.4610709 , + -0.17152917, -0.13992918, 0.6242708 , + -0.42192918, 0.38387084, -0.15752912, + + + 0.3311833 , 0.00618333, 0.17538333, + 0.10418332, 0.8365834 , 0.27098334, + 1.2421833 , -0.1114167 , 1.0153834 , + 0.9523833 , 0.8317833 , 0.9633833 , + 0.6501833 , 0.04258335, 0.9999833 , + -0.40181667, 0.11418331, 0.47938335, + 1.1057833 , -0.29761666, 1.0779834 , + 0.5243833 , -0.32181668, 1.1833833 , + + 0.73157084, 0.4317708 , 0.7283708 , + 1.2297708 , 0.4307708 , 0.85377085, + 0.05977082, -0.09282917, 0.33957082, + 1.0751709 , 0.2119708 , 0.51897085, + -0.25302917, 1.1723708 , -0.12562919, + 1.1993709 , 0.5257708 , 0.40517086, + 0.53197086, 0.8441708 , 0.02617085, + -0.0208292 , 0.8711709 , 0.04137081, + + 0.74936247, 0.6085625 , 0.8997625 , + -0.08743751, 0.18576252, -0.17563748, + 0.5991625 , -0.0038375 , 0.07576251, + 0.42536253, -0.22823751, 0.36296248, + 0.81456256, -0.16183749, 0.5161625 , + -0.21183747, 0.7429625 , 0.6217625 , + 0.17656249, 0.02616251, -0.17923748, + 1.4659625 , 0.40016252, 0.28356248, + + 0.4195791 , 0.8745791 , 0.36637908, + 0.50597906, -0.17942089, 0.16917908, + 1.0235791 , 1.3699791 , -0.11382091, + -0.0918209 , 0.7757791 , 0.09017909, + 1.3807791 , -0.15202093, 1.3875791 , + -0.1712209 , 1.3989791 , 0.43777913, + 0.7855791 , 0.1423791 , 1.4711791 , + 0.6455791 , 0.6211791 , -0.48062086, + + + 0.10189578, 0.5628958 , 0.68909574, + 0.96649575, -0.09370419, 1.3466958 , + 1.4584957 , 1.3544958 , -0.3829042 , + 0.11269578, -0.47890422, 1.0436958 , + 0.6128957 , 0.27209583, 0.2714958 , + 0.21889582, 0.08789578, 1.1296958 , + 0.4596958 , 0.39309582, 0.8344958 , + 0.71149576, -0.4799042, 0.4880958 }); // x.linspace(1.); nd4j::ops::adjust_contrast_v2 op;