[WIP] Shugeo lup (#126)
* Added infrastructure for implementation op lu for both cuda and cpu platforms. * Added implementation of helpers with lu op. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored LU decomposition to use vector of permutations instead. * Refactored helpers for lu op. * Fixed crash with determinant op. * Refactored cpu LU op heleper. * Added implementation for lu op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed issue with argmax on column. * Added multithreaded behaviour for lu op helper. * Fixed multithreaded cpu implementation helpers for lu op. * Added cuda implementation for lu op helper. * Finished lu helper implementation for cuda platform. * Eliminated waste prints and comments. * Fixed race condition and multithreading issues. * Fixed memory leak with shape construction. * Corrected test for lu op to avoid near zero elements on the main diagonal." Signed-off-by: shugeo <sgazeos@gmail.com> * Improved test for adjust_constast op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed issues with cuda implementation of resize_bicubic helpers. Signed-off-by: shugeo <sgazeos@gmail.com>master
parent
6d8a063c9b
commit
67d8199165
|
@ -0,0 +1,59 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by GS <sgazeos@gmail.com> at 12/10/2019
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_matrix_inverse)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/lup.h>
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(lu, 1, 2, false, 0, 0) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
auto p = OUTPUT_VARIABLE(1);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() >=2, 0, "matrix_inverse: The rank of input array should not less than 2, but %i is given", input->rankOf());
|
||||
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "matrix_inverse: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
|
||||
|
||||
helpers::lu(block.launchContext(), input, z, p);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(lu) {
|
||||
auto in = inputShape->at(0);
|
||||
auto shapeVector = ShapeUtils::shapeAsVector(in);
|
||||
auto luShape = ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace());
|
||||
auto luP = ShapeBuilders::createShapeInfo(nd4j::DataType::INT32, shape::order(in), shapeVector.size() - 1,
|
||||
shapeVector.data(), block.workspace());
|
||||
return SHAPELIST(CONSTANT(luShape), CONSTANT(luP));
|
||||
}
|
||||
|
||||
DECLARE_TYPES(lu) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(1, {nd4j::DataType::INT32, nd4j::DataType::INT64})
|
||||
->setSameMode(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -1027,6 +1027,24 @@ namespace nd4j {
|
|||
DECLARE_OP(matrix_inverse, 1, 1, true);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* lu op. - make LUP decomposition of given batch of 2D square matricies
|
||||
*
|
||||
* input params:
|
||||
* 0 - float tensor with dimension (x * y * z * ::: * M * M)
|
||||
*
|
||||
* return value:
|
||||
* 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M matricies in it
|
||||
* 1 - int (32 or 64) batched vector of permutations with length M - shape (x * y * z * ::: * M)
|
||||
*
|
||||
* int argument:
|
||||
* 0 - data type of output permutaion vector (int32 or int64), optional, default INT32
|
||||
*/
|
||||
|
||||
#if NOT_EXCLUDED(OP_matrix_inverse)
|
||||
DECLARE_CUSTOM_OP(lu, 1, 2, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* sequence_mask op. - make mask for given tensor filled by (j > x[i_1, i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j]
|
||||
*
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include <MmulHelper.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <Status.h>
|
||||
#include <execution/Threads.h>
|
||||
#include <execution/Threads.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -32,15 +34,30 @@ namespace helpers {
|
|||
|
||||
if (theFirst != theSecond)
|
||||
for (int i = 0; i < matrix->columns(); i++) {
|
||||
T e0 = matrix->e<T>(theFirst, i);
|
||||
T e1 = matrix->e<T>(theSecond, i);
|
||||
|
||||
matrix->p<T>(theFirst, i, e1);
|
||||
matrix->p<T>(theSecond, i, e0);
|
||||
math::nd4j_swap(matrix->t<T>(theFirst, i), matrix->t<T>(theSecond, i));
|
||||
}
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES);
|
||||
|
||||
template <typename T>
|
||||
static void swapRows(T* matrixBuf, Nd4jLong* matrixShape, Nd4jLong theFirst, Nd4jLong theSecond) {
|
||||
if (theFirst != theSecond) {
|
||||
auto n = shape::sizeAt(matrixShape, -1);
|
||||
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
Nd4jLong theFirstPos[] = {theFirst, i};
|
||||
Nd4jLong theSecondPos[] = {theSecond, i};
|
||||
auto theFirstIndex = shape::getOffset(matrixShape, theFirstPos, 0);
|
||||
auto theSecondIndex = shape::getOffset(matrixShape, theSecondPos, 0);
|
||||
math::nd4j_swap(matrixBuf[theFirstIndex], matrixBuf[theSecondIndex]);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_tad(loop, 0, n, 1);
|
||||
}
|
||||
}
|
||||
|
||||
void swapRows(NDArray* matrix, int theFirst, int theSecond) {
|
||||
BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES);
|
||||
}
|
||||
|
@ -106,7 +123,7 @@ namespace helpers {
|
|||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename I>
|
||||
static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) {
|
||||
|
||||
const int rowNum = input->rows();
|
||||
|
@ -132,7 +149,7 @@ namespace helpers {
|
|||
}
|
||||
}
|
||||
|
||||
if( pivotValue > T(0.00001)) {
|
||||
if( pivotValue > DataTypeUtils::min<T>()) {
|
||||
swapRows(&compoundMatrix, pivot, i);
|
||||
swapRows(&permutationMatrix, pivot, i);
|
||||
if (pivot != i)
|
||||
|
@ -155,14 +172,113 @@ namespace helpers {
|
|||
if (swapCount % 2) determinant = -determinant;
|
||||
if (compound != nullptr)
|
||||
compound->assign(compoundMatrix);
|
||||
if (permutation != nullptr)
|
||||
if (permutation != nullptr) {
|
||||
auto permutaionVector = NDArrayFactory::create('c', {rowNum}, DataTypeUtils::fromT<I>(), input->getContext());
|
||||
for (auto i = 0; i < rowNum; i++) {
|
||||
for (auto j = 0; j < columnNum; j++) {
|
||||
if (permutationMatrix.t<T>(i, j) != 0) {
|
||||
permutaionVector.template t<I>(i) = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (permutationMatrix.isSameShape(permutation))
|
||||
permutation->assign(permutationMatrix);
|
||||
else if (permutation->isSameShape(permutaionVector)) {
|
||||
permutation->assign(permutaionVector);
|
||||
}
|
||||
}
|
||||
return determinant;
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES);
|
||||
/*
|
||||
* lu decomposition with naive algorithm with partial pivoting
|
||||
* */
|
||||
template <typename T, typename I>
|
||||
static I argmaxCol(I column, T* compoundBuffer, Nd4jLong* compoundShape) {
|
||||
auto rowNum = shape::sizeAt(compoundShape, 0);
|
||||
Nd4jLong xInitial[] = {column, column};
|
||||
auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0);
|
||||
auto maxValue = T(0); //nd4j::math::nd4j_abs(compoundBuffer[xInitialIndex]);
|
||||
auto result = -1;
|
||||
//auto loop = PRAGMA_THREADS_FOR {
|
||||
auto start = column, stop = rowNum, increment = 1;
|
||||
for (auto rowCounter = start; rowCounter < stop; rowCounter += increment) {
|
||||
Nd4jLong xPos[] = {rowCounter, column};
|
||||
auto xIndex = shape::getOffset(compoundShape, xPos, 0);
|
||||
if (nd4j::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) {
|
||||
maxValue = nd4j::math::nd4j_max(maxValue, nd4j::math::nd4j_abs(compoundBuffer[xIndex]));
|
||||
result = rowCounter;
|
||||
}
|
||||
}
|
||||
//};
|
||||
//samediff::Threads::parallel_for(loop, column, rowNum, 1);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void processColumns(int currentRow, int rowNum, T* compoundBuf, Nd4jLong* compoundShape) {
|
||||
Nd4jLong xDiag[] = {currentRow, currentRow};
|
||||
auto diagIndex = shape::getOffset(compoundShape, xDiag, 0);
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (int j = start; j < stop; j += increment) {
|
||||
Nd4jLong xRow[] = {j, currentRow};
|
||||
auto rowIndex = shape::getOffset(compoundShape, xRow, 0);
|
||||
compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t<T>(i, i);
|
||||
for (int k = currentRow + 1; k < rowNum; k++) {
|
||||
Nd4jLong yRow[] = {j, k};
|
||||
Nd4jLong yCol[] = {currentRow, k};
|
||||
auto rowIndexY = shape::getOffset(compoundShape, yRow, 0);
|
||||
auto colIndex = shape::getOffset(compoundShape, yCol, 0);
|
||||
compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex];
|
||||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) {
|
||||
|
||||
//const int rowNum = compound->rows();
|
||||
// const int columnNum = output->columns();
|
||||
permutation->linspace(0);
|
||||
auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>();
|
||||
auto compoundBuf = compound->bufferAsT<T>();
|
||||
auto compoundShape = compound->shapeInfo();
|
||||
auto permutationShape = permutation->shapeInfo();
|
||||
for (auto i = 0; i < rowNum - 1; i++) {
|
||||
auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape);
|
||||
if (pivotIndex < 0) {
|
||||
throw std::runtime_error("helpers::luNN_: input matrix is singular.");
|
||||
}
|
||||
math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]);
|
||||
swapRows(compoundBuf, compoundShape, i, pivotIndex);
|
||||
|
||||
processColumns(i, rowNum, compoundBuf, compoundShape);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
|
||||
auto n = input->sizeAt(-1);
|
||||
|
||||
output->assign(input); // fill up output tensor with zeros
|
||||
std::unique_ptr<ResultSet> outputs(output->allTensorsAlongDimension({-2, -1}));
|
||||
std::unique_ptr<ResultSet> permutations(permutationVectors->allTensorsAlongDimension({-1}));
|
||||
auto loop = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
luNN_<T, I>(context, outputs->at(i), permutations->at(i), n);
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(loop, 0, outputs->size(), 1);
|
||||
}
|
||||
|
||||
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES);
|
||||
}
|
||||
|
||||
// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES);
|
||||
|
||||
template <typename T>
|
||||
static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||
|
@ -175,7 +291,7 @@ namespace helpers {
|
|||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
|
||||
matrix.p(row, input->e<T>(k));
|
||||
output->p(e, lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr));
|
||||
output->p(e, lup_<T, int>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -196,7 +312,7 @@ template <typename T>
|
|||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
|
||||
matrix.p(row, input->e<T>(k));
|
||||
}
|
||||
NDArray det = lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr);
|
||||
NDArray det = lup_<T, int>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr);
|
||||
if (det.e<T>(0) != 0.f)
|
||||
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
|
||||
}
|
||||
|
@ -229,7 +345,7 @@ template <typename T>
|
|||
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||
matrix.p(row++, input->e<T>(k));
|
||||
}
|
||||
T det = lup_<T>(context, &matrix, &compound, &permutation).template e<T>(0);
|
||||
T det = lup_<T, int>(context, &matrix, &compound, &permutation).template e<T>(0);
|
||||
|
||||
// FIXME: and how this is going to work on float16?
|
||||
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
|
||||
|
@ -274,7 +390,7 @@ template <typename T>
|
|||
// check for symmetric
|
||||
for (Nd4jLong r = 0; r < thisMatrix->rows(); r++)
|
||||
for (Nd4jLong c = 0; c < thisMatrix->columns(); c++)
|
||||
if (nd4j::math::nd4j_abs(thisMatrix->e<T>(r, c) - lastMatrixList->at(i)->e<T>(c,r)) > T(1.e-6f)) return false;
|
||||
if (nd4j::math::nd4j_abs(thisMatrix->e<T>(r, c) - lastMatrixList->at(i)->e<T>(c,r)) > DataTypeUtils::min<T>()) return false;
|
||||
|
||||
NDArray output = NDArrayFactory::create<T>(0., context);
|
||||
if (ND4J_STATUS_OK != determinant(context, thisMatrix, &output)) return false;
|
||||
|
@ -366,6 +482,11 @@ template <typename T>
|
|||
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
int lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_, (context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -689,11 +689,17 @@ namespace helpers {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, float* cachedValue, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) {
|
||||
static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) {
|
||||
// auto numChannels = pResizerState->channels;
|
||||
|
||||
for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) {
|
||||
auto pInput = inputPtr + b * inBatchWidth;
|
||||
float* cachedValue;
|
||||
for (Nd4jLong y = threadIdx.x; y < pResizerState->outHeight; y += blockDim.x) {
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ char sharedChar[];
|
||||
cachedValue = reinterpret_cast<float*>(sharedChar);
|
||||
}
|
||||
auto pos = (b * pResizerState->outHeight + y) * pResizerState->outWidth * pResizerState->channels;
|
||||
auto pOutput = &outputPtr[pos];
|
||||
struct WeightsAndIndices yWai;
|
||||
|
@ -846,20 +852,20 @@ namespace helpers {
|
|||
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot set up memory for resizerState", err);
|
||||
}
|
||||
|
||||
float* cachedValue = nullptr;
|
||||
size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels);
|
||||
if (cachedSize) {
|
||||
err = cudaMalloc(reinterpret_cast<void**>(&cachedValue), cachedSize);
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build(
|
||||
"helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err);
|
||||
}
|
||||
err = cudaMemset(cachedValue, 0, cachedSize);
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build(
|
||||
"helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err);
|
||||
}
|
||||
}
|
||||
// float* cachedValue = nullptr;
|
||||
// size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels);
|
||||
// if (cachedSize) {
|
||||
// err = cudaMalloc(reinterpret_cast<void**>(&cachedValue), cachedSize);
|
||||
// if (err != 0) {
|
||||
// throw cuda_exception::build(
|
||||
// "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err);
|
||||
// }
|
||||
// err = cudaMemset(cachedValue, 0, cachedSize);
|
||||
// if (err != 0) {
|
||||
// throw cuda_exception::build(
|
||||
// "helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err);
|
||||
// }
|
||||
// }
|
||||
|
||||
WeightsAndIndices* xWais; //(resizerState.outWidth);
|
||||
err = cudaMalloc(&xWais, sizeof(WeightsAndIndices) * resizerState.outWidth);
|
||||
|
@ -878,7 +884,7 @@ namespace helpers {
|
|||
}
|
||||
const T* pInput = image->getDataBuffer()->specialAsT<T>();
|
||||
float* pOutput = output->dataBuffer()->specialAsT<float>(); //_data.data();
|
||||
bicubicInterpolateWithCachingKernel<T><<<128, 1, 512, *stream>>>(coeffsTable, cachedValue, pInput,
|
||||
bicubicInterpolateWithCachingKernel<T><<<128, 1, 512, *stream>>>(coeffsTable, pInput,
|
||||
resizerStateD, xWais, halfPixelCenters, inBatchWidth, inRowWidth, pOutput);
|
||||
err = cudaStreamSynchronize(*stream);
|
||||
if (err != 0) {
|
||||
|
@ -889,11 +895,11 @@ namespace helpers {
|
|||
if (err != 0) {
|
||||
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for resizerState", err);
|
||||
}
|
||||
if (cachedSize)
|
||||
err = cudaFree(cachedValue);
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err);
|
||||
}
|
||||
// if (cachedSize)
|
||||
// err = cudaFree(cachedValue);
|
||||
// if (err != 0) {
|
||||
// throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err);
|
||||
// }
|
||||
|
||||
err = cudaFree(xWais);
|
||||
if (err != 0) {
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <Status.h>
|
||||
#include <ConstantTadHelper.h>
|
||||
#include <ShapeUtils.h>
|
||||
//#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||
|
||||
#include <cusolverDn.h>
|
||||
#include <cuda_exception.h>
|
||||
|
@ -336,7 +337,7 @@ namespace helpers {
|
|||
//
|
||||
// input - A matrix nxn
|
||||
// compound - C matrix L + U - I, or main diagonal and lower - L matrix, from the 2nd diagonal - U matrix
|
||||
template<typename T>
|
||||
template<typename T, typename I>
|
||||
static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) {
|
||||
auto stream = context->getCudaStream();
|
||||
auto n = input->rows();
|
||||
|
@ -383,7 +384,7 @@ namespace helpers {
|
|||
err);
|
||||
}
|
||||
|
||||
if (permutation == nullptr)
|
||||
if (permutation == nullptr) {
|
||||
status = cusolverDnDgetrf(
|
||||
cusolverH,
|
||||
n,
|
||||
|
@ -393,9 +394,15 @@ namespace helpers {
|
|||
d_work,
|
||||
nullptr,
|
||||
d_info);
|
||||
|
||||
if (status != CUSOLVER_STATUS_SUCCESS) {
|
||||
throw cuda_exception::build("helpers::lup_: LU factorization is failed due ",
|
||||
status);
|
||||
}
|
||||
}
|
||||
else {
|
||||
NDArray permutVector('c', {n}, nd4j::DataType::INT32, context);
|
||||
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
|
||||
int* permutationBuf = permutVector.dataBuffer()->specialAsT<int>();
|
||||
status = cusolverDnDgetrf(
|
||||
cusolverH,
|
||||
n,
|
||||
|
@ -405,9 +412,21 @@ namespace helpers {
|
|||
d_work,
|
||||
permutationBuf,
|
||||
d_info);
|
||||
if (status != CUSOLVER_STATUS_SUCCESS) {
|
||||
throw cuda_exception::build("helpers::lup_: LU factorization is failed due ",
|
||||
status);
|
||||
}
|
||||
|
||||
if (permutation->rankOf() == 2) {
|
||||
fillUpPermutation<double> <<< n, n, 1024, *stream >>>
|
||||
(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
|
||||
permutation->tickWriteDevice();
|
||||
}
|
||||
else {
|
||||
permutVector.tickWriteDevice();
|
||||
input->tickWriteDevice();
|
||||
compound->assign(input);
|
||||
permutation->assign(permutVector);
|
||||
}
|
||||
}
|
||||
err = cudaFree(d_work);
|
||||
if (err) {
|
||||
|
@ -448,7 +467,7 @@ namespace helpers {
|
|||
nullptr,
|
||||
d_info);
|
||||
else {
|
||||
NDArray permutVector('c', {n}, nd4j::DataType::INT32, context);
|
||||
NDArray permutVector('c', {n}, DataType::INT32, context);
|
||||
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
|
||||
status = cusolverDnSgetrf(
|
||||
cusolverH,
|
||||
|
@ -459,10 +478,17 @@ namespace helpers {
|
|||
d_work,
|
||||
permutationBuf,
|
||||
d_info);
|
||||
fillUpPermutation<T> <<< n, n, 128, *stream >> >
|
||||
if (permutation->rankOf() == 2) {
|
||||
fillUpPermutation<I> <<< n, n, 128, *stream >>>
|
||||
(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
|
||||
permutation->tickWriteDevice();
|
||||
}
|
||||
else {
|
||||
input->tickWriteDevice();
|
||||
compound->assign(input);
|
||||
permutation->assign(permutVector);
|
||||
}
|
||||
}
|
||||
err = cudaFree(d_work);
|
||||
if (err) {
|
||||
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer",
|
||||
|
@ -484,8 +510,116 @@ namespace helpers {
|
|||
}
|
||||
// ------------------------------------------------------------------------------------------------------------------ //
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE);
|
||||
BUILD_DOUBLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE, INDEXING_TYPES);
|
||||
|
||||
template <typename T>
|
||||
static __device__ void swapRows(T* matrix, Nd4jLong* shape, Nd4jLong theFirst, Nd4jLong theSecond, Nd4jLong n) {
|
||||
if (theFirst != theSecond) {
|
||||
for (auto i = 0; i < n; i++) {
|
||||
Nd4jLong theFirstPos[] = {theFirst, i};
|
||||
Nd4jLong theSecondPos[] = {theSecond, i};
|
||||
auto theFirstIndex = shape::getOffset(shape, theFirstPos, 0);
|
||||
auto theSecondIndex = shape::getOffset(shape, theSecondPos, 0);
|
||||
math::nd4j_swap(matrix[theFirstIndex], matrix[theSecondIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ void processColumns(Nd4jLong currentRow, Nd4jLong rowNum, T* compoundBuf, Nd4jLong* compoundShape) {
|
||||
Nd4jLong xDiag[] = {currentRow, currentRow};
|
||||
auto diagIndex = shape::getOffset(compoundShape, xDiag, 0);
|
||||
for (auto j = currentRow + 1; j < rowNum; j++) {
|
||||
Nd4jLong xRow[] = {j, currentRow};
|
||||
auto rowIndex = shape::getOffset(compoundShape, xRow, 0);
|
||||
compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t<T>(i, i);
|
||||
for (auto k = currentRow + 1; k < rowNum; k++) {
|
||||
Nd4jLong yRow[] = {j, k};
|
||||
Nd4jLong yCol[] = {currentRow, k};
|
||||
auto rowIndexY = shape::getOffset(compoundShape, yRow, 0);
|
||||
auto colIndex = shape::getOffset(compoundShape, yCol, 0);
|
||||
compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ Nd4jLong argmaxCol(Nd4jLong column, T* compoundBuffer, Nd4jLong* compoundShape) {
|
||||
auto rowNum = shape::sizeAt(compoundShape, 0);
|
||||
Nd4jLong xInitial[] = {column, column};
|
||||
auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0);
|
||||
auto maxValue = T(0); //nd4j::math::nd4j_abs(compoundBuffer[xInitialIndex]);
|
||||
auto result = -1LL;
|
||||
|
||||
for (auto rowCounter = column; rowCounter < rowNum; rowCounter++) {
|
||||
Nd4jLong xPos[] = {rowCounter, column};
|
||||
auto xIndex = shape::getOffset(compoundShape, xPos, 0);
|
||||
if (nd4j::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) {
|
||||
maxValue = nd4j::math::nd4j_max(maxValue, nd4j::math::nd4j_abs(compoundBuffer[xIndex]));
|
||||
result = rowCounter;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
static __device__ int luNN(T* matrix, Nd4jLong* shape, I* permutation, Nd4jLong* permuShape, Nd4jLong n) {
|
||||
|
||||
for (auto i = 0; i < n - 1; i++) {
|
||||
auto pivotIndex = argmaxCol(i, matrix, shape);
|
||||
if (pivotIndex < 0) {
|
||||
return -1;//throw std::runtime_error("helpers::luNN_: input matrix is singular.");
|
||||
}
|
||||
math::nd4j_swap(permutation[shape::getIndexOffset(i, permuShape)], permutation[shape::getIndexOffset(pivotIndex, permuShape)]);
|
||||
swapRows(matrix, shape, (Nd4jLong)i, pivotIndex, n);
|
||||
|
||||
processColumns(i, n, matrix, shape);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
static __global__ void luBatchedKernel(T* outputBuf, Nd4jLong* outputShape, I* permutations, Nd4jLong* permuShape,
|
||||
Nd4jLong* outputTadShape, Nd4jLong* outputTadOffsets, Nd4jLong* permuTadShape, Nd4jLong* permuTadOffsets,
|
||||
Nd4jLong batchNum) {
|
||||
|
||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (auto b = start; b < batchNum; b += step) {
|
||||
T* matrix = outputBuf + outputTadOffsets[b];
|
||||
I* permutation = permutations + permuTadOffsets[b];
|
||||
|
||||
if (0 != luNN(matrix, outputTadShape, permutation, permuTadShape, shape::length(permuTadShape))) break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
|
||||
auto n = input->sizeAt(-1);
|
||||
auto stream = context->getCudaStream();
|
||||
auto iota = NDArrayFactory::create<int>('c', {n});
|
||||
iota.linspace(0); iota.syncToDevice();
|
||||
|
||||
output->assign(input); // fill up output tensor with zeros
|
||||
output->tickWriteDevice();
|
||||
permutationVectors->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &iota, permutationVectors, true, nullptr);
|
||||
permutationVectors->tickWriteDevice();
|
||||
|
||||
auto tads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1});
|
||||
auto permutaionTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-1});
|
||||
auto batchNum = tads.numberOfTads();
|
||||
luBatchedKernel<T,I><<<batchNum, 256, 1024, *stream>>>(reinterpret_cast<T*>(output->platformBuffer()),
|
||||
output->specialShapeInfo(), reinterpret_cast<I*>(permutationVectors->platformBuffer()),
|
||||
permutationVectors->specialShapeInfo(), tads.specialShapeInfo(), tads.specialOffsets(),
|
||||
permutaionTads.specialShapeInfo(), permutaionTads.specialOffsets(), batchNum);
|
||||
}
|
||||
|
||||
void lu(LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutations) {
|
||||
NDArray::prepareSpecialUse({output, permutations}, {input});
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), permutations->dataType(), lu_, (context, input, output, permutations), FLOAT_NATIVE, INDEXING_TYPES);
|
||||
NDArray::registerSpecialUse({output, permutations}, {input});
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------------------------ //
|
||||
template<typename T>
|
||||
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
|
@ -509,7 +643,7 @@ namespace helpers {
|
|||
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
// else
|
||||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
lup_<T>(context, &matrix, nullptr, nullptr);
|
||||
lup_<T, int>(context, &matrix, nullptr, nullptr);
|
||||
// else
|
||||
// lup_<float>(context, &matrix, nullptr, nullptr);
|
||||
auto offset = shape::getIndexOffset(e, output->shapeInfo());
|
||||
|
@ -557,7 +691,7 @@ namespace helpers {
|
|||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
lup_<T>(context, &matrix, nullptr, nullptr);
|
||||
lup_<T, int>(context, &matrix, nullptr, nullptr);
|
||||
// else
|
||||
// lup_<float>(context, &matrix, nullptr, nullptr);
|
||||
auto offset = shape::getIndexOffset(e, output->shapeInfo());
|
||||
|
@ -638,7 +772,7 @@ namespace helpers {
|
|||
matrix.tickWriteDevice();
|
||||
//compound.assign(matrix);
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
lup_<T>(context, &matrix, nullptr, nullptr);
|
||||
lup_<T, int>(context, &matrix, nullptr, nullptr);
|
||||
fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n);
|
||||
lower.tickWriteDevice();
|
||||
upper.tickWriteDevice();
|
||||
|
@ -861,6 +995,14 @@ namespace helpers {
|
|||
BUILD_SINGLE_SELECTOR(output->dataType(), return logdetFunctor_, (context, input, output), FLOAT_NATIVE);
|
||||
}
|
||||
|
||||
/*
|
||||
* lup - batched input, batched outputs
|
||||
* */
|
||||
int lup(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_,(context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// BUILD_SINGLE_TEMPLATE(template int logdetFunctor_,
|
||||
// (nd4j::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE);
|
||||
}
|
||||
|
|
|
@ -26,9 +26,8 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
T lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation);
|
||||
|
||||
int lup(nd4j::LaunchContext* context, NDArray* input, NDArray* lu, NDArray* permutation);
|
||||
void lu(nd4j::LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation);
|
||||
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
|
||||
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
|
||||
|
||||
|
|
|
@ -1050,7 +1050,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
|
|||
|
||||
|
||||
auto testData = NDArrayFactory::create<float>('c', {2,9,9,1}, {
|
||||
0.230286f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f,
|
||||
0.230286514f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f,
|
||||
0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f,
|
||||
0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f,
|
||||
0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f,
|
||||
|
@ -1081,7 +1081,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
|
|||
NDArray* result = results->at(0);
|
||||
|
||||
// result->printBuffer("Resized to 9x9");
|
||||
// expected.printBuffer("Expect for 9x9");
|
||||
// testData.printBuffer("Expect for 9x9");
|
||||
ASSERT_TRUE(testData.isSameShape(result));
|
||||
ASSERT_TRUE(testData.equalsTo(result));
|
||||
delete results;
|
||||
|
|
|
@ -2424,3 +2424,256 @@ TEST_F(DeclarableOpsTests12, Pad_SGO_Test_1) {
|
|||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, LU_Test_1) {
|
||||
|
||||
auto in = NDArrayFactory::create<double>('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7});
|
||||
auto pExp = NDArrayFactory::create<int>('c', {3}, {0, 1, 2});
|
||||
nd4j::ops::lu op;
|
||||
|
||||
auto res = op.execute({&in}, {}, {});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
auto p = res->at(1);
|
||||
// z->printIndexedBuffer("Triangulars");
|
||||
// p->printIndexedBuffer("Permutaions");
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
ASSERT_TRUE(pExp.equalsTo(p));
|
||||
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, LU_Test_2) {
|
||||
auto in = NDArrayFactory::create<double>('c', {3,3}, {1, 0, 0, 2, 3, 0, 4, 5, 6});
|
||||
|
||||
auto expLU = NDArrayFactory::create<double>('c', {3,3}, {4., 5., 6., 0.25, -1.25, -1.5, 0.5, -0.4, -3.6});
|
||||
auto expP = NDArrayFactory::create<int>({2, 0, 1});
|
||||
nd4j::ops::lu op;
|
||||
|
||||
auto res = op.execute({&in}, {}, {});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
auto p = res->at(1);
|
||||
// z->printIndexedBuffer("Triangulars2");
|
||||
// p->printIndexedBuffer("Permutaions2");
|
||||
ASSERT_TRUE(expLU.equalsTo(z));
|
||||
ASSERT_TRUE(expP.equalsTo(p));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, LU_Test_3) {
|
||||
auto in = NDArrayFactory::create<double>('c', {3,3}, {1,2,3,4,7,9, 11, 12, 13});
|
||||
|
||||
auto expLU = NDArrayFactory::create<double>('c', {3,3}, {
|
||||
11., 12., 13.,
|
||||
0.36363637, 2.6363635, 4.272727,
|
||||
0.09090909, 0.3448276, 0.34482753});
|
||||
|
||||
auto expP = NDArrayFactory::create<int>({2, 1, 0});
|
||||
nd4j::ops::lu op;
|
||||
|
||||
auto res = op.execute({&in}, {}, {});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
auto p = res->at(1);
|
||||
// z->printIndexedBuffer("Triangulars3");
|
||||
// p->printIndexedBuffer("Permutaions3");
|
||||
ASSERT_TRUE(expLU.equalsTo(z));
|
||||
ASSERT_TRUE(expP.equalsTo(p));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, LU_Test_4) {
|
||||
|
||||
auto in = NDArrayFactory::create<double>('c', {10,10}, {
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 1., 15.,
|
||||
5., 1., 13., 4., 15., 1., 17., 9., 11., 25.,
|
||||
1., 9., 1., 4., 5., 2., 13., 10, 21., 15.,
|
||||
3., 9., 4., 1., 5., 3., 7., 1, 1., 5.,
|
||||
2., 3., 2., 5., 4., 4., 7., 3, 3., 4.,
|
||||
0., 1., 3., 3., 5., 1., 3., 1, 31., 15.,
|
||||
2., 1., 4., 3., 1., 5., 1., 2, 31., 35.,
|
||||
3., 4., 3., 3., 4., 4., 4., 1., 3., 1.,
|
||||
1., 1., 1., 1., 5., 6., 5., 4., 3., 2.,
|
||||
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
|
||||
|
||||
auto expLU = NDArrayFactory::create<double>('c', {10,10}, {
|
||||
5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0,
|
||||
0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0,
|
||||
0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636,
|
||||
0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957,
|
||||
0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323,
|
||||
0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387,
|
||||
0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300,
|
||||
0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119,
|
||||
0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178,
|
||||
0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695
|
||||
});
|
||||
|
||||
auto expP = NDArrayFactory::create<int>({1, 2, 7, 3, 6, 8, 5, 4, 0, 9});
|
||||
nd4j::ops::lu op;
|
||||
|
||||
auto res = op.execute({&in}, {}, {});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
auto p = res->at(1);
|
||||
// z->printBuffer("Triangulars4");
|
||||
// expLU.printBuffer("TriangulExp4");
|
||||
// p->printBuffer("Permutaions4");
|
||||
|
||||
ASSERT_TRUE(expLU.equalsTo(z));
|
||||
ASSERT_TRUE(expP.equalsTo(p));
|
||||
delete res;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests12, LU_Test_5) {
|
||||
|
||||
auto in = NDArrayFactory::create<double>('c', {2, 10,10}, {
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 1., 15.,
|
||||
5., 1., 13., 4., 15., 1., 17., 9., 11., 25.,
|
||||
1., 9., 1., 4., 5., 2., 13., 10, 21., 15.,
|
||||
3., 9., 4., 1., 5., 3., 7., 1, 1., 5.,
|
||||
2., 3., 2., 5., 4., 4., 7., 3, 3., 4.,
|
||||
0., 1., 3., 3., 5., 1., 3., 1, 31., 15.,
|
||||
2., 1., 4., 3., 1., 5., 1., 2, 31., 35.,
|
||||
3., 4., 3., 3., 4., 4., 4., 1., 3., 1.,
|
||||
1., 1., 1., 1., 5., 6., 5., 4., 3., 2.,
|
||||
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 1., 15.,
|
||||
5., 1., 13., 4., 15., 1., 17., 9., 11., 25.,
|
||||
1., 9., 1., 4., 5., 2., 13., 10, 21., 15.,
|
||||
3., 9., 4., 1., 5., 3., 7., 1, 1., 5.,
|
||||
2., 3., 2., 5., 4., 4., 7., 3, 3., 4.,
|
||||
0., 1., 3., 3., 5., 1., 3., 1, 31., 15.,
|
||||
2., 1., 4., 3., 1., 5., 1., 2, 31., 35.,
|
||||
3., 4., 3., 3., 4., 4., 4., 1., 3., 1.,
|
||||
1., 1., 1., 1., 5., 6., 5., 4., 3., 2.,
|
||||
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
|
||||
});
|
||||
|
||||
auto expLU = NDArrayFactory::create<double>('c', {2, 10,10}, {
|
||||
5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0,
|
||||
0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0,
|
||||
0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636,
|
||||
0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957,
|
||||
0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323,
|
||||
0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387,
|
||||
0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300,
|
||||
0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119,
|
||||
0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178,
|
||||
0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695,
|
||||
|
||||
5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0,
|
||||
0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0,
|
||||
0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636,
|
||||
0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957,
|
||||
0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323,
|
||||
0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387,
|
||||
0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300,
|
||||
0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119,
|
||||
0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178,
|
||||
0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695
|
||||
|
||||
});
|
||||
|
||||
auto expP = NDArrayFactory::create<int>('c', {2, 10}, {
|
||||
1, 2, 7, 3, 6, 8, 5, 4, 0, 9,
|
||||
1, 2, 7, 3, 6, 8, 5, 4, 0, 9
|
||||
});
|
||||
nd4j::ops::lu op;
|
||||
|
||||
auto res = op.execute({&in}, {}, {});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
auto p = res->at(1);
|
||||
// z->printBuffer("Triangulars5");
|
||||
// expLU.printBuffer("TriangulExp5");
|
||||
// p->printBuffer("Permutaions5");
|
||||
|
||||
ASSERT_TRUE(expLU.equalsTo(z));
|
||||
ASSERT_TRUE(expP.equalsTo(p));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, LU_Test_1_2) {
|
||||
|
||||
auto in = NDArrayFactory::create<double>('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.,1., 2., 3., 0., 2., 3., 0., 0., 7.});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7, 1., 2., 3., 0., 2., 3., 0., 0., 7.});
|
||||
|
||||
nd4j::ops::lu op;
|
||||
|
||||
auto res = op.execute({&in}, {}, {});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
auto p = res->at(1);
|
||||
// z->printIndexedBuffer("Triangulars (2,3,3)");
|
||||
// p->printIndexedBuffer("Permutaions (2,3,3)");
|
||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, LU_Test_3_2) {
|
||||
|
||||
auto in = NDArrayFactory::create<double>('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,1,2,3,4,7,9, 11, 12, 13});
|
||||
|
||||
auto expLU = NDArrayFactory::create<double>('c', {2, 3,3}, {
|
||||
11., 12., 13.,
|
||||
0.36363637, 2.6363635, 4.272727,
|
||||
0.09090909, 0.3448276, 0.34482753,
|
||||
|
||||
11., 12., 13.,
|
||||
0.36363637, 2.6363635, 4.272727,
|
||||
0.09090909, 0.3448276, 0.34482753
|
||||
});
|
||||
|
||||
auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 2, 1, 0});
|
||||
nd4j::ops::lu op;
|
||||
|
||||
auto res = op.execute({&in}, {}, {});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
auto p = res->at(1);
|
||||
// z->printIndexedBuffer("Triangulars3_2");
|
||||
// p->printIndexedBuffer("Permutaions3_2");
|
||||
|
||||
ASSERT_TRUE(expLU.equalsTo(z));
|
||||
ASSERT_TRUE(expP.equalsTo(p));
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, LU_Test_3_3) {
|
||||
|
||||
auto in = NDArrayFactory::create<double>('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,13,2,3,4,7,9, 11, 12, 1});
|
||||
auto expLU = NDArrayFactory::create<double>('c', {2, 3,3}, {
|
||||
11., 12., 13.,
|
||||
0.36363637, 2.6363635, 4.272727,
|
||||
0.09090909, 0.3448276, 0.34482753,
|
||||
|
||||
13., 2., 3.,
|
||||
0.84615386, 10.307693, -1.5384617,
|
||||
0.30769232, 0.619403, 9.029851});
|
||||
|
||||
auto expP = NDArrayFactory::create<int>('c', {2,3}, {2, 1, 0, 0, 2, 1});
|
||||
nd4j::ops::lu op;
|
||||
|
||||
auto res = op.execute({&in}, {}, {});
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
auto z = res->at(0);
|
||||
auto p = res->at(1);
|
||||
// z->printIndexedBuffer("Triangulars3_3");
|
||||
// p->printIndexedBuffer("Permutaions3_3");
|
||||
|
||||
ASSERT_TRUE(expLU.equalsTo(z));
|
||||
ASSERT_TRUE(expP.equalsTo(p));
|
||||
delete res;
|
||||
}
|
||||
|
|
|
@ -293,268 +293,77 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) {
|
|||
.7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f,
|
||||
0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f});
|
||||
auto e = NDArrayFactory::create<float>('c', {8, 8, 3, 1}, {
|
||||
1.0218375f,
|
||||
1.0666375f,
|
||||
0.9130375f,
|
||||
|
||||
-0.07396251f,
|
||||
0.91843754f,
|
||||
-0.17496246f,
|
||||
|
||||
0.47543746f,
|
||||
1.2492375f,
|
||||
0.55643755f,
|
||||
|
||||
1.3110375f,
|
||||
-0.36456245f,
|
||||
1.0518374f,
|
||||
|
||||
0.7824375f,
|
||||
0.57523745f,
|
||||
-0.21656245f,
|
||||
|
||||
0.0816375f,
|
||||
-0.2261625f,
|
||||
0.40323752f,
|
||||
|
||||
1.4520376f,
|
||||
0.6868375f,
|
||||
0.81723756f,
|
||||
|
||||
-0.17576247f,
|
||||
0.81423753f,
|
||||
-0.08656245f,
|
||||
|
||||
|
||||
-0.36249164f,
|
||||
0.45590833f,
|
||||
1.1925083f,
|
||||
|
||||
0.00650835f,
|
||||
1.4861084f,
|
||||
1.2079083f,
|
||||
|
||||
0.05270836f,
|
||||
0.37350836f,
|
||||
0.94130826f,
|
||||
|
||||
1.0715083f,
|
||||
0.6103083f,
|
||||
0.9825083f,
|
||||
|
||||
0.07370833f,
|
||||
-0.4518917f,
|
||||
-0.39889166f,
|
||||
|
||||
-0.3354917f,
|
||||
1.2213084f,
|
||||
1.0345083f,
|
||||
|
||||
-0.3132917f,
|
||||
0.78470826f,
|
||||
0.23390833f,
|
||||
|
||||
0.6943083f,
|
||||
0.68170834f,
|
||||
-0.09989169f,
|
||||
|
||||
|
||||
0.8352709f,
|
||||
1.3798709f,
|
||||
0.15507084f,
|
||||
|
||||
0.26607084f,
|
||||
-0.10792917f,
|
||||
1.2302709f,
|
||||
|
||||
0.6448709f,
|
||||
-0.29992914f,
|
||||
1.3534708f,
|
||||
|
||||
0.86607087f,
|
||||
0.37607086f,
|
||||
0.04027084f,
|
||||
|
||||
0.40087086f,
|
||||
0.59507084f,
|
||||
0.9416709f,
|
||||
|
||||
0.53127086f,
|
||||
-0.01712915f,
|
||||
1.4610709f,
|
||||
|
||||
-0.17152917f,
|
||||
-0.13992918f,
|
||||
0.6242708f,
|
||||
|
||||
-0.42192918f,
|
||||
0.38387084f,
|
||||
-0.15752912f,
|
||||
|
||||
|
||||
0.3311833f,
|
||||
0.00618333f,
|
||||
0.17538333f,
|
||||
|
||||
0.10418332f,
|
||||
0.8365834f,
|
||||
0.27098334f,
|
||||
|
||||
1.2421833f,
|
||||
-0.1114167f,
|
||||
1.0153834f,
|
||||
|
||||
0.9523833f,
|
||||
0.8317833f,
|
||||
0.9633833f,
|
||||
|
||||
0.6501833f,
|
||||
0.04258335f,
|
||||
0.9999833f,
|
||||
|
||||
-0.40181667f,
|
||||
0.11418331f,
|
||||
0.47938335f,
|
||||
|
||||
1.1057833f,
|
||||
-0.29761666f,
|
||||
1.0779834f,
|
||||
|
||||
0.5243833f,
|
||||
-0.32181668f,
|
||||
1.1833833f,
|
||||
|
||||
|
||||
0.73157084f,
|
||||
0.4317708f,
|
||||
0.7283708f,
|
||||
|
||||
1.2297708f,
|
||||
0.4307708f,
|
||||
0.85377085f,
|
||||
|
||||
0.05977082f,
|
||||
-0.09282917f,
|
||||
0.33957082f,
|
||||
|
||||
1.0751709f,
|
||||
0.2119708f,
|
||||
0.51897085f,
|
||||
|
||||
-0.25302917f,
|
||||
1.1723708f,
|
||||
-0.12562919f,
|
||||
|
||||
1.1993709f,
|
||||
0.5257708f,
|
||||
0.40517086f,
|
||||
|
||||
0.53197086f,
|
||||
0.8441708f,
|
||||
0.02617085f,
|
||||
|
||||
-0.0208292f,
|
||||
0.8711709f,
|
||||
0.04137081f,
|
||||
|
||||
|
||||
0.74936247f,
|
||||
0.6085625f,
|
||||
0.8997625f,
|
||||
|
||||
-0.08743751f,
|
||||
0.18576252f,
|
||||
-0.17563748f,
|
||||
|
||||
0.5991625f,
|
||||
-0.0038375f,
|
||||
0.07576251f,
|
||||
|
||||
0.42536253f,
|
||||
-0.22823751f,
|
||||
0.36296248f,
|
||||
|
||||
0.81456256f,
|
||||
-0.16183749f,
|
||||
0.5161625f,
|
||||
|
||||
-0.21183747f,
|
||||
0.7429625f,
|
||||
0.6217625f,
|
||||
|
||||
0.17656249f,
|
||||
0.02616251f,
|
||||
-0.17923748f,
|
||||
|
||||
1.4659625f,
|
||||
0.40016252f,
|
||||
0.28356248f,
|
||||
|
||||
|
||||
0.4195791f,
|
||||
0.8745791f,
|
||||
0.36637908f,
|
||||
|
||||
0.50597906f,
|
||||
-0.17942089f,
|
||||
0.16917908f,
|
||||
|
||||
1.0235791f,
|
||||
1.3699791f,
|
||||
-0.11382091f,
|
||||
|
||||
-0.0918209f,
|
||||
0.7757791f,
|
||||
0.09017909f,
|
||||
|
||||
1.3807791f,
|
||||
-0.15202093f,
|
||||
1.3875791f,
|
||||
|
||||
-0.1712209f,
|
||||
1.3989791f,
|
||||
0.43777913f,
|
||||
|
||||
0.7855791f,
|
||||
0.1423791f,
|
||||
1.4711791f,
|
||||
|
||||
0.6455791f,
|
||||
0.6211791f,
|
||||
-0.48062086f,
|
||||
|
||||
|
||||
0.10189578f,
|
||||
0.5628958f,
|
||||
0.68909574f,
|
||||
|
||||
0.96649575f,
|
||||
-0.09370419f,
|
||||
1.3466958f,
|
||||
|
||||
1.4584957f,
|
||||
1.3544958f,
|
||||
-0.3829042f,
|
||||
|
||||
0.11269578f,
|
||||
-0.47890422f,
|
||||
1.0436958f,
|
||||
|
||||
0.6128957f,
|
||||
0.27209583f,
|
||||
0.2714958f,
|
||||
|
||||
0.21889582f,
|
||||
0.08789578f,
|
||||
1.1296958f,
|
||||
|
||||
0.4596958f,
|
||||
0.39309582f,
|
||||
0.8344958f,
|
||||
|
||||
0.71149576f,
|
||||
-0.4799042f,
|
||||
0.4880958f
|
||||
1.0218375f, 1.0666375f, 0.9130375f,
|
||||
-0.07396251f, 0.91843754f, -0.17496246f,
|
||||
0.47543746f, 1.2492375f, 0.55643755f,
|
||||
1.3110375f, -0.36456245f, 1.0518374f,
|
||||
0.7824375f, 0.57523745f, -0.21656245f,
|
||||
0.0816375f, -0.2261625f, 0.40323752f,
|
||||
1.4520376f, 0.6868375f, 0.81723756f,
|
||||
-0.17576247f, 0.81423753f, -0.08656245f,
|
||||
|
||||
-0.36249164f, 0.45590833f, 1.1925083f,
|
||||
0.00650835f, 1.4861084f, 1.2079083f,
|
||||
0.05270836f, 0.37350836f, 0.94130826f,
|
||||
1.0715083f, 0.6103083f, 0.9825083f,
|
||||
0.07370833f, -0.4518917f, -0.39889166f,
|
||||
-0.3354917f, 1.2213084f, 1.0345083f,
|
||||
-0.3132917f, 0.78470826f, 0.23390833f,
|
||||
0.6943083f, 0.68170834f, -0.09989169f,
|
||||
|
||||
0.8352709f, 1.3798709f, 0.15507084f,
|
||||
0.26607084f, -0.10792917f, 1.2302709f,
|
||||
0.6448709f, -0.29992914f, 1.3534708f,
|
||||
0.86607087f, 0.37607086f, 0.04027084f,
|
||||
0.40087086f, 0.59507084f, 0.9416709f,
|
||||
0.53127086f, -0.01712915f, 1.4610709f,
|
||||
-0.17152917f, -0.13992918f, 0.6242708f,
|
||||
-0.42192918f, 0.38387084f, -0.15752912f,
|
||||
|
||||
0.3311833f, 0.00618333f, 0.17538333f,
|
||||
0.10418332f, 0.8365834f, 0.27098334f,
|
||||
1.2421833f, -0.1114167f, 1.0153834f,
|
||||
0.9523833f, 0.8317833f, 0.9633833f,
|
||||
0.6501833f, 0.04258335f, 0.9999833f,
|
||||
-0.40181667f, 0.11418331f, 0.47938335f,
|
||||
1.1057833f, -0.29761666f, 1.0779834f,
|
||||
0.5243833f, -0.32181668f, 1.1833833f,
|
||||
|
||||
0.73157084f, 0.4317708f, 0.7283708f,
|
||||
1.2297708f, 0.4307708f, 0.85377085f,
|
||||
0.05977082f, -0.09282917f, 0.33957082f,
|
||||
1.0751709f, 0.2119708f, 0.51897085f,
|
||||
-0.25302917f, 1.1723708f, -0.12562919f,
|
||||
1.1993709f, 0.5257708f, 0.40517086f,
|
||||
0.53197086f, 0.8441708f, 0.02617085f,
|
||||
-0.0208292f, 0.8711709f, 0.04137081f,
|
||||
|
||||
0.74936247f, 0.6085625f, 0.8997625f,
|
||||
-0.08743751f, 0.18576252f, -0.17563748f,
|
||||
0.5991625f, -0.0038375f, 0.07576251f,
|
||||
0.42536253f, -0.22823751f, 0.36296248f,
|
||||
0.81456256f, -0.16183749f, 0.5161625f,
|
||||
-0.21183747f, 0.7429625f, 0.6217625f,
|
||||
0.17656249f, 0.02616251f, -0.17923748f,
|
||||
1.4659625f, 0.40016252f, 0.28356248f,
|
||||
|
||||
0.4195791f, 0.8745791f, 0.36637908f,
|
||||
0.50597906f, -0.17942089f, 0.16917908f,
|
||||
1.0235791f, 1.3699791f, -0.11382091f,
|
||||
-0.0918209f, 0.7757791f, 0.09017909f,
|
||||
1.3807791f, -0.15202093f, 1.3875791f,
|
||||
-0.1712209f, 1.3989791f, 0.43777913f,
|
||||
0.7855791f, 0.1423791f, 1.4711791f,
|
||||
0.6455791f, 0.6211791f, -0.48062086f,
|
||||
|
||||
0.10189578f, 0.5628958f, 0.68909574f,
|
||||
0.96649575f, -0.09370419f, 1.3466958f,
|
||||
1.4584957f, 1.3544958f, -0.3829042f,
|
||||
0.11269578f, -0.47890422f, 1.0436958f,
|
||||
0.6128957f, 0.27209583f, 0.2714958f,
|
||||
0.21889582f, 0.08789578f, 1.1296958f,
|
||||
0.4596958f, 0.39309582f, 0.8344958f,
|
||||
0.71149576f, -0.4799042f, 0.4880958f
|
||||
});
|
||||
|
||||
nd4j::ops::adjust_contrast op;
|
||||
|
@ -587,268 +396,79 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) {
|
|||
.7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f,
|
||||
0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f});
|
||||
auto e = NDArrayFactory::create<double>('c', {8, 8, 3, 1}, {
|
||||
1.0218375 ,
|
||||
1.0666375 ,
|
||||
0.9130375 ,
|
||||
|
||||
-0.07396251,
|
||||
0.91843754,
|
||||
-0.17496246,
|
||||
|
||||
0.47543746,
|
||||
1.2492375 ,
|
||||
0.55643755,
|
||||
|
||||
1.3110375 ,
|
||||
-0.36456245,
|
||||
1.0518374 ,
|
||||
|
||||
0.7824375 ,
|
||||
0.57523745,
|
||||
-0.21656245,
|
||||
|
||||
0.0816375 ,
|
||||
-0.2261625 ,
|
||||
0.40323752,
|
||||
|
||||
1.4520376 ,
|
||||
0.6868375 ,
|
||||
0.81723756,
|
||||
|
||||
-0.17576247,
|
||||
0.81423753,
|
||||
-0.08656245,
|
||||
|
||||
|
||||
-0.36249164,
|
||||
0.45590833,
|
||||
1.1925083 ,
|
||||
|
||||
0.00650835,
|
||||
1.4861084 ,
|
||||
1.2079083 ,
|
||||
|
||||
0.05270836,
|
||||
0.37350836,
|
||||
0.94130826,
|
||||
|
||||
1.0715083 ,
|
||||
0.6103083 ,
|
||||
0.9825083 ,
|
||||
|
||||
0.07370833,
|
||||
-0.4518917 ,
|
||||
-0.39889166,
|
||||
|
||||
-0.3354917 ,
|
||||
1.2213084 ,
|
||||
1.0345083 ,
|
||||
|
||||
-0.3132917 ,
|
||||
0.78470826,
|
||||
0.23390833,
|
||||
|
||||
0.6943083 ,
|
||||
0.68170834,
|
||||
-0.09989169,
|
||||
|
||||
|
||||
0.8352709 ,
|
||||
1.3798709 ,
|
||||
0.15507084,
|
||||
|
||||
0.26607084,
|
||||
-0.10792917,
|
||||
1.2302709 ,
|
||||
|
||||
0.6448709 ,
|
||||
-0.29992914,
|
||||
1.3534708 ,
|
||||
|
||||
0.86607087,
|
||||
0.37607086,
|
||||
0.04027084,
|
||||
|
||||
0.40087086,
|
||||
0.59507084,
|
||||
0.9416709 ,
|
||||
|
||||
0.53127086,
|
||||
-0.01712915,
|
||||
1.4610709 ,
|
||||
|
||||
-0.17152917,
|
||||
-0.13992918,
|
||||
0.6242708 ,
|
||||
|
||||
-0.42192918,
|
||||
0.38387084,
|
||||
-0.15752912,
|
||||
|
||||
|
||||
0.3311833 ,
|
||||
0.00618333,
|
||||
0.17538333,
|
||||
|
||||
0.10418332,
|
||||
0.8365834 ,
|
||||
0.27098334,
|
||||
|
||||
1.2421833 ,
|
||||
-0.1114167 ,
|
||||
1.0153834 ,
|
||||
|
||||
0.9523833 ,
|
||||
0.8317833 ,
|
||||
0.9633833 ,
|
||||
|
||||
0.6501833 ,
|
||||
0.04258335,
|
||||
0.9999833 ,
|
||||
|
||||
-0.40181667,
|
||||
0.11418331,
|
||||
0.47938335,
|
||||
|
||||
1.1057833 ,
|
||||
-0.29761666,
|
||||
1.0779834 ,
|
||||
|
||||
0.5243833 ,
|
||||
-0.32181668,
|
||||
1.1833833 ,
|
||||
|
||||
|
||||
0.73157084,
|
||||
0.4317708 ,
|
||||
0.7283708 ,
|
||||
|
||||
1.2297708 ,
|
||||
0.4307708 ,
|
||||
0.85377085,
|
||||
|
||||
0.05977082,
|
||||
-0.09282917,
|
||||
0.33957082,
|
||||
|
||||
1.0751709 ,
|
||||
0.2119708 ,
|
||||
0.51897085,
|
||||
|
||||
-0.25302917,
|
||||
1.1723708 ,
|
||||
-0.12562919,
|
||||
|
||||
1.1993709 ,
|
||||
0.5257708 ,
|
||||
0.40517086,
|
||||
|
||||
0.53197086,
|
||||
0.8441708 ,
|
||||
0.02617085,
|
||||
|
||||
-0.0208292 ,
|
||||
0.8711709 ,
|
||||
0.04137081,
|
||||
|
||||
|
||||
0.74936247,
|
||||
0.6085625 ,
|
||||
0.8997625 ,
|
||||
|
||||
-0.08743751,
|
||||
0.18576252,
|
||||
-0.17563748,
|
||||
|
||||
0.5991625 ,
|
||||
-0.0038375 ,
|
||||
0.07576251,
|
||||
|
||||
0.42536253,
|
||||
-0.22823751,
|
||||
0.36296248,
|
||||
|
||||
0.81456256,
|
||||
-0.16183749,
|
||||
0.5161625 ,
|
||||
|
||||
-0.21183747,
|
||||
0.7429625 ,
|
||||
0.6217625 ,
|
||||
|
||||
0.17656249,
|
||||
0.02616251,
|
||||
-0.17923748,
|
||||
|
||||
1.4659625 ,
|
||||
0.40016252,
|
||||
0.28356248,
|
||||
|
||||
|
||||
0.4195791 ,
|
||||
0.8745791 ,
|
||||
0.36637908,
|
||||
|
||||
0.50597906,
|
||||
-0.17942089,
|
||||
0.16917908,
|
||||
|
||||
1.0235791 ,
|
||||
1.3699791 ,
|
||||
-0.11382091,
|
||||
|
||||
-0.0918209 ,
|
||||
0.7757791 ,
|
||||
0.09017909,
|
||||
|
||||
1.3807791 ,
|
||||
-0.15202093,
|
||||
1.3875791 ,
|
||||
|
||||
-0.1712209 ,
|
||||
1.3989791 ,
|
||||
0.43777913,
|
||||
|
||||
0.7855791 ,
|
||||
0.1423791 ,
|
||||
1.4711791 ,
|
||||
|
||||
0.6455791 ,
|
||||
0.6211791 ,
|
||||
-0.48062086,
|
||||
|
||||
|
||||
0.10189578,
|
||||
0.5628958 ,
|
||||
0.68909574,
|
||||
|
||||
0.96649575,
|
||||
-0.09370419,
|
||||
1.3466958 ,
|
||||
|
||||
1.4584957 ,
|
||||
1.3544958 ,
|
||||
-0.3829042 ,
|
||||
|
||||
0.11269578,
|
||||
-0.47890422,
|
||||
1.0436958 ,
|
||||
|
||||
0.6128957 ,
|
||||
0.27209583,
|
||||
0.2714958 ,
|
||||
|
||||
0.21889582,
|
||||
0.08789578,
|
||||
1.1296958 ,
|
||||
|
||||
0.4596958 ,
|
||||
0.39309582,
|
||||
0.8344958 ,
|
||||
|
||||
0.71149576,
|
||||
-0.4799042,
|
||||
0.4880958
|
||||
1.0218375, 1.0666375 , 0.9130375 ,
|
||||
-0.07396251, 0.91843754, -0.17496246,
|
||||
0.47543746, 1.2492375 , 0.55643755,
|
||||
1.3110375 , -0.36456245, 1.0518374 ,
|
||||
0.7824375 , 0.57523745, -0.21656245,
|
||||
0.0816375 , -0.2261625 , 0.40323752,
|
||||
1.4520376 , 0.6868375 , 0.81723756,
|
||||
-0.17576247, 0.81423753, -0.08656245,
|
||||
|
||||
-0.36249164, 0.45590833, 1.1925083 ,
|
||||
0.00650835, 1.4861084 , 1.2079083 ,
|
||||
0.05270836, 0.37350836, 0.94130826,
|
||||
1.0715083 , 0.6103083 , 0.9825083 ,
|
||||
0.07370833, -0.4518917 , -0.39889166,
|
||||
-0.3354917 , 1.2213084 , 1.0345083 ,
|
||||
-0.3132917 , 0.78470826, 0.23390833,
|
||||
0.6943083 , 0.68170834, -0.09989169,
|
||||
|
||||
0.8352709 , 1.3798709 , 0.15507084,
|
||||
0.26607084, -0.10792917, 1.2302709 ,
|
||||
0.6448709 , -0.29992914, 1.3534708 ,
|
||||
0.86607087, 0.37607086, 0.04027084,
|
||||
0.40087086, 0.59507084, 0.9416709 ,
|
||||
0.53127086, -0.01712915, 1.4610709 ,
|
||||
-0.17152917, -0.13992918, 0.6242708 ,
|
||||
-0.42192918, 0.38387084, -0.15752912,
|
||||
|
||||
|
||||
0.3311833 , 0.00618333, 0.17538333,
|
||||
0.10418332, 0.8365834 , 0.27098334,
|
||||
1.2421833 , -0.1114167 , 1.0153834 ,
|
||||
0.9523833 , 0.8317833 , 0.9633833 ,
|
||||
0.6501833 , 0.04258335, 0.9999833 ,
|
||||
-0.40181667, 0.11418331, 0.47938335,
|
||||
1.1057833 , -0.29761666, 1.0779834 ,
|
||||
0.5243833 , -0.32181668, 1.1833833 ,
|
||||
|
||||
0.73157084, 0.4317708 , 0.7283708 ,
|
||||
1.2297708 , 0.4307708 , 0.85377085,
|
||||
0.05977082, -0.09282917, 0.33957082,
|
||||
1.0751709 , 0.2119708 , 0.51897085,
|
||||
-0.25302917, 1.1723708 , -0.12562919,
|
||||
1.1993709 , 0.5257708 , 0.40517086,
|
||||
0.53197086, 0.8441708 , 0.02617085,
|
||||
-0.0208292 , 0.8711709 , 0.04137081,
|
||||
|
||||
0.74936247, 0.6085625 , 0.8997625 ,
|
||||
-0.08743751, 0.18576252, -0.17563748,
|
||||
0.5991625 , -0.0038375 , 0.07576251,
|
||||
0.42536253, -0.22823751, 0.36296248,
|
||||
0.81456256, -0.16183749, 0.5161625 ,
|
||||
-0.21183747, 0.7429625 , 0.6217625 ,
|
||||
0.17656249, 0.02616251, -0.17923748,
|
||||
1.4659625 , 0.40016252, 0.28356248,
|
||||
|
||||
0.4195791 , 0.8745791 , 0.36637908,
|
||||
0.50597906, -0.17942089, 0.16917908,
|
||||
1.0235791 , 1.3699791 , -0.11382091,
|
||||
-0.0918209 , 0.7757791 , 0.09017909,
|
||||
1.3807791 , -0.15202093, 1.3875791 ,
|
||||
-0.1712209 , 1.3989791 , 0.43777913,
|
||||
0.7855791 , 0.1423791 , 1.4711791 ,
|
||||
0.6455791 , 0.6211791 , -0.48062086,
|
||||
|
||||
|
||||
0.10189578, 0.5628958 , 0.68909574,
|
||||
0.96649575, -0.09370419, 1.3466958 ,
|
||||
1.4584957 , 1.3544958 , -0.3829042 ,
|
||||
0.11269578, -0.47890422, 1.0436958 ,
|
||||
0.6128957 , 0.27209583, 0.2714958 ,
|
||||
0.21889582, 0.08789578, 1.1296958 ,
|
||||
0.4596958 , 0.39309582, 0.8344958 ,
|
||||
0.71149576, -0.4799042, 0.4880958
|
||||
});
|
||||
// x.linspace(1.);
|
||||
nd4j::ops::adjust_contrast_v2 op;
|
||||
|
|
Loading…
Reference in New Issue