From ece6a17b1197d8ee57487b57253c92906b1bf2f5 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 24 Aug 2019 16:57:48 +0300 Subject: [PATCH] lup context fix (#164) Signed-off-by: raver119 --- .../ops/declarable/helpers/cpu/lup.cpp | 61 ++++++-------- .../ops/declarable/helpers/cuda/lup.cu | 81 +++++++------------ 2 files changed, 57 insertions(+), 85 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index ee9a78cee..1e3c798e2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -26,7 +26,6 @@ namespace nd4j { namespace ops { namespace helpers { - nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext(); template static void swapRows_(NDArray* matrix, int theFirst, int theSecond) { @@ -108,14 +107,14 @@ namespace helpers { template - static NDArray lup_(NDArray* input, NDArray* compound, NDArray* permutation) { + static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) { const int rowNum = input->rows(); const int columnNum = input->columns(); NDArray determinant = NDArrayFactory::create(1.f); NDArray compoundMatrix = *input; // copy - NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides + NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides permutationMatrix.setIdentity(); T pivotValue; // = T(0.0); @@ -161,46 +160,43 @@ namespace helpers { return determinant; } - BUILD_SINGLE_TEMPLATE(template NDArray lup_, (NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); + BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); template - static int determinant_(NDArray* input, NDArray* output) { + static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) { Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace()); for (int e = 0; e < output->lengthOf(); e++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) matrix.p(row, input->e(k)); - output->p(e, lup_(&matrix, (NDArray*)nullptr, (NDArray*)nullptr)); + output->p(e, lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr)); } return Status::OK(); } - BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES); - int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { - defaultContext = context; - BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES); } template - int logAbsDeterminant_(NDArray* input, NDArray* output) { + int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) { Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; - NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace()); + NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace()); for (int e = 0; e < output->lengthOf(); e++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { matrix.p(row, input->e(k)); } - NDArray det = lup_(&matrix, (NDArray*)nullptr, (NDArray*)nullptr); + NDArray det = lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); if (det.e(0) != 0.f) output->p(e, nd4j::math::nd4j_log(nd4j::math::nd4j_abs(det.t(0)))); } @@ -208,25 +204,23 @@ template return ND4J_STATUS_OK; } - BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES); - int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES); } template - static int inverse_(NDArray* input, NDArray* output) { + static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) { auto n = input->sizeAt(-1); auto n2 = n * n; auto totalCount = output->lengthOf() / n2; output->assign(0.f); // fill up output tensor with zeros - auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); //, block.getWorkspace()); - auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); //, block.getWorkspace()); - auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); - auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); - auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); + auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); + auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); + auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); for (int e = 0; e < totalCount; e++) { if (e) @@ -235,7 +229,7 @@ template for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { matrix.p(row++, input->e(k)); } - T det = lup_(&matrix, &compound, &permutation).template e(0); + T det = lup_(context, &matrix, &compound, &permutation).template e(0); // FIXME: and how this is going to work on float16? if (nd4j::math::nd4j_abs(det) < T(0.000001)) { @@ -268,8 +262,7 @@ template } int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { - defaultContext = context; - BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES); } template @@ -296,14 +289,13 @@ template return true; } - BUILD_SINGLE_TEMPLATE(template bool checkCholeskyInput_, (nd4j::LaunchContext * context, NDArray const* input), FLOAT_TYPES); bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) { BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES); } template - int cholesky_(NDArray* input, NDArray* output, bool inplace) { + int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) { auto n = input->sizeAt(-1); auto n2 = n * n; @@ -311,8 +303,8 @@ template if (!inplace) output->assign(0.f); // fill up output tensor with zeros only inplace=false - std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace()); - std::unique_ptr lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext)); + std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace()); + std::unique_ptr lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context)); for (int e = 0; e < totalCount; e++) { @@ -346,14 +338,13 @@ template } int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { - defaultContext = context; - BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); } template - int logdetFunctor_(NDArray* input, NDArray* output) { + int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) { std::unique_ptr tempOutput(input->dup()); - int res = cholesky_(input, tempOutput.get(), false); + int res = cholesky_(context, input, tempOutput.get(), false); if (res != ND4J_STATUS_OK) return res; auto n = input->sizeAt(-1); @@ -372,7 +363,7 @@ template } int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (input, output), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 97d47079b..f11b56745 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -196,36 +196,33 @@ namespace helpers { } template - static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) { + static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { int n = inputMatrix->rows(); invertedMatrix->setIdentity(); if (inputMatrix->isIdentityMatrix()) return; - auto stream = LaunchContext::defaultContext()->getCudaStream(); + auto stream = context->getCudaStream(); // invert main diagonal - upvertKernel << < 1, n, 512, *stream >> > - (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + upvertKernel<<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); // invert the second diagonal - invertKernelLow << < 1, n, 512, *stream >> > - (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + invertKernelLow<<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); // invertKernelLow<<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); - invertLowKernel<<< n, n, 512, *stream >> > - (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + invertLowKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); } - void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { + void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); - BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); + BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix}); } template - static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { + static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) { int n = inputMatrix->rows(); invertedMatrix->setIdentity(); - auto stream = LaunchContext::defaultContext()->getCudaStream(); + auto stream = context->getCudaStream(); if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I return; } @@ -235,13 +232,12 @@ namespace helpers { inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); invertedMatrix->tickWriteDevice(); invertedMatrix->printIndexedBuffer("Step1 UP inversion"); - invertUpKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), - inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + invertUpKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); } - void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { + void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); - BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); + BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); } @@ -525,23 +521,19 @@ namespace helpers { input->tickWriteDevice(); } - BUILD_SINGLE_TEMPLATE(template void lup_, - (LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), - FLOAT_NATIVE); + BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE); template static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; std::vector dims(); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), - {input->rankOf() - 2, input->rankOf() - 1}); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); // DataType dtype = input->dataType(); // if (dtype != DataType::DOUBLE) // dtype = DataType::FLOAT32; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), - LaunchContext::defaultContext()); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); auto det = NDArrayFactory::create(1); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); @@ -550,8 +542,7 @@ namespace helpers { for (int e = 0; e < output->lengthOf(); e++) { Nd4jLong pos = e * n2; // if (matrix.dataType() == input->dataType()) - fillMatrix << < launchDims.x, launchDims.y, launchDims.z, *stream >> > - (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); + fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // else // fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); @@ -584,15 +575,13 @@ namespace helpers { Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; std::vector dims(); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), - {input->rankOf() - 2, input->rankOf() - 1}); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); DataType dtype = input->dataType(); if (dtype != DataType::DOUBLE) dtype = DataType::FLOAT32; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, - LaunchContext::defaultContext()); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace()); auto det = NDArrayFactory::create(1); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); @@ -601,8 +590,7 @@ namespace helpers { for (int e = 0; e < output->lengthOf(); e++) { Nd4jLong pos = e * n2; // if (matrix.dataType() == input->dataType()) - fillMatrix << < launchDims.x, launchDims.y, launchDims.z, *stream >> > - (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); + fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // else // fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); @@ -614,8 +602,7 @@ namespace helpers { auto inputBuf = reinterpret_cast(matrix.specialBuffer()); auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; // if (matrix.dataType() == input->dataType()) - determinantLogKernel << < launchDims.x, launchDims.y, launchDims.z, *stream >> > - (inputBuf, outputBuf, n); + determinantLogKernel<<>>(inputBuf, outputBuf, n); // else // determinantLogKernel<<>> (inputBuf, outputBuf, n); } @@ -694,11 +681,11 @@ namespace helpers { auto dtype = DataTypeUtils::fromT(); //input->dataType(); // if (dtype != DataType::DOUBLE) // dtype = DataType::FLOAT32; - NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); - NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); - NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); - NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); - NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); + NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); @@ -708,20 +695,17 @@ namespace helpers { auto stream = context->getCudaStream(); for (auto i = 0LL; i < packX.numberOfTads(); i++) { - fillMatrix << < 1, n2, 1024, *stream >> > - (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), - i * n2, n); + fillMatrix<<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); matrix.tickWriteDevice(); compound.assign(matrix); lup_(context, &compound, nullptr, nullptr); - fillLowerUpperKernel << < n, n, 1024, *stream >> > - (lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); + fillLowerUpperKernel<<>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); matrix.assign(0); - invertUpperMatrix(&upper, &matrix); // U^{-1} + invertUpperMatrix(context, &upper, &matrix); // U^{-1} matrix.tickWriteDevice(); // matrix.printIndexedBuffer("Upper Inverted"); compound.assign(0); - invertLowerMatrix(&lower, &compound); // L{-1} + invertLowerMatrix(context, &lower, &compound); // L{-1} compound.tickWriteDevice(); // compound.printIndexedBuffer("Lower Inverted"); // matrix.tickWriteDevice(); @@ -729,9 +713,7 @@ namespace helpers { nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); upper.tickWriteDevice(); // upper.printIndexedBuffer("Full inverted"); - returnMatrix << < 1, n2, 1024, *stream >> > - (output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), - i * n2, n); + returnMatrix <<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n); } return Status::OK(); } @@ -865,8 +847,7 @@ namespace helpers { cholesky__(context, input, output, inplace); else { std::unique_ptr tempOutput( - NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, - LaunchContext::defaultContext())); + NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context)); tempOutput->assign(input); cholesky__(context, tempOutput.get(), tempOutput.get(), true); output->assign(tempOutput.get());