diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index bf9c73e7c..97d47079b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -31,8 +31,6 @@ namespace nd4j { namespace ops { namespace helpers { - nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext(); - // template // static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) { // if (theFirst != theSecond) { @@ -204,7 +202,7 @@ namespace helpers { if (inputMatrix->isIdentityMatrix()) return; - auto stream = defaultContext->getCudaStream(); + auto stream = LaunchContext::defaultContext()->getCudaStream(); // invert main diagonal upvertKernel << < 1, n, 512, *stream >> > @@ -227,7 +225,7 @@ namespace helpers { static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { int n = inputMatrix->rows(); invertedMatrix->setIdentity(); - auto stream = defaultContext->getCudaStream(); + auto stream = LaunchContext::defaultContext()->getCudaStream(); if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I return; } @@ -392,7 +390,6 @@ namespace helpers { auto n = input->rows(); cusolverDnHandle_t cusolverH = nullptr; cusolverStatus_t status = cusolverDnCreate(&cusolverH); - defaultContext = context; if (CUSOLVER_STATUS_SUCCESS != status) { throw cuda_exception::build("Cannot create cuSolver handle", status); } @@ -543,9 +540,8 @@ namespace helpers { // DataType dtype = input->dataType(); // if (dtype != DataType::DOUBLE) // dtype = DataType::FLOAT32; - defaultContext = context; auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), - defaultContext); //, block.getWorkspace()); + LaunchContext::defaultContext()); //, block.getWorkspace()); auto det = NDArrayFactory::create(1); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); @@ -578,7 +574,6 @@ namespace helpers { } int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; NDArray::prepareSpecialUse({output}, {input}); BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); NDArray::registerSpecialUse({output}, {input}); @@ -586,7 +581,6 @@ namespace helpers { template int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; std::vector dims(); @@ -598,7 +592,7 @@ namespace helpers { dtype = DataType::FLOAT32; auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, - defaultContext); //, block.getWorkspace()); + LaunchContext::defaultContext()); //, block.getWorkspace()); auto det = NDArrayFactory::create(1); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); @@ -633,7 +627,6 @@ namespace helpers { } int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; NDArray::prepareSpecialUse({output}, {input}); BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); NDArray::registerSpecialUse({output}, {input}); @@ -696,17 +689,16 @@ namespace helpers { template static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; auto n = input->sizeAt(-1); auto n2 = n * n; auto dtype = DataTypeUtils::fromT(); //input->dataType(); // if (dtype != DataType::DOUBLE) // dtype = DataType::FLOAT32; - NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); - NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); - NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); - NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); - NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); + NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); + NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); + NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); + NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); + NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); @@ -745,7 +737,6 @@ namespace helpers { } int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; NDArray::prepareSpecialUse({output}, {input}); BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE); NDArray::registerSpecialUse({output}, {input}); @@ -788,7 +779,6 @@ namespace helpers { int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { if (!inplace) output->assign(input); - defaultContext = context; std::unique_ptr tempOutput(output->dup()); cusolverDnHandle_t handle = nullptr; auto n = input->sizeAt(-1); @@ -868,7 +858,6 @@ namespace helpers { // template int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { - defaultContext = context; NDArray::prepareSpecialUse({output}, {input}); if (input->dataType() == DataType::DOUBLE) cholesky__(context, input, output, inplace); @@ -877,7 +866,7 @@ namespace helpers { else { std::unique_ptr tempOutput( NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, - defaultContext)); + LaunchContext::defaultContext())); tempOutput->assign(input); cholesky__(context, tempOutput.get(), tempOutput.get(), true); output->assign(tempOutput.get()); @@ -888,7 +877,6 @@ namespace helpers { int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { // BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); - defaultContext = context; return cholesky_(context, input, output, inplace); } // BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); @@ -927,7 +915,6 @@ namespace helpers { template int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; NDArray::prepareSpecialUse({output}, {input}); auto n2 = input->sizeAt(-1) * input->sizeAt(-2); auto stream = context->getCudaStream(); @@ -957,7 +944,6 @@ namespace helpers { } int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE); }