get rid of context variable

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-24 16:18:38 +03:00
parent b091e972ef
commit 841eeb56c5
1 changed files with 10 additions and 24 deletions

View File

@ -31,8 +31,6 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
// template <typename T> // template <typename T>
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) { // static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
// if (theFirst != theSecond) { // if (theFirst != theSecond) {
@ -204,7 +202,7 @@ namespace helpers {
if (inputMatrix->isIdentityMatrix()) return; if (inputMatrix->isIdentityMatrix()) return;
auto stream = defaultContext->getCudaStream(); auto stream = LaunchContext::defaultContext()->getCudaStream();
// invert main diagonal // invert main diagonal
upvertKernel<T> << < 1, n, 512, *stream >> > upvertKernel<T> << < 1, n, 512, *stream >> >
@ -227,7 +225,7 @@ namespace helpers {
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
int n = inputMatrix->rows(); int n = inputMatrix->rows();
invertedMatrix->setIdentity(); invertedMatrix->setIdentity();
auto stream = defaultContext->getCudaStream(); auto stream = LaunchContext::defaultContext()->getCudaStream();
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
return; return;
} }
@ -392,7 +390,6 @@ namespace helpers {
auto n = input->rows(); auto n = input->rows();
cusolverDnHandle_t cusolverH = nullptr; cusolverDnHandle_t cusolverH = nullptr;
cusolverStatus_t status = cusolverDnCreate(&cusolverH); cusolverStatus_t status = cusolverDnCreate(&cusolverH);
defaultContext = context;
if (CUSOLVER_STATUS_SUCCESS != status) { if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("Cannot create cuSolver handle", status); throw cuda_exception::build("Cannot create cuSolver handle", status);
} }
@ -543,9 +540,8 @@ namespace helpers {
// DataType dtype = input->dataType(); // DataType dtype = input->dataType();
// if (dtype != DataType::DOUBLE) // if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32; // dtype = DataType::FLOAT32;
defaultContext = context;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
defaultContext); //, block.getWorkspace()); LaunchContext::defaultContext()); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
@ -578,7 +574,6 @@ namespace helpers {
} }
int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
@ -586,7 +581,6 @@ namespace helpers {
template<typename T> template<typename T>
int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) { int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
std::vector<int> dims(); std::vector<int> dims();
@ -598,7 +592,7 @@ namespace helpers {
dtype = DataType::FLOAT32; dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype,
defaultContext); //, block.getWorkspace()); LaunchContext::defaultContext()); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
@ -633,7 +627,6 @@ namespace helpers {
} }
int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
@ -696,17 +689,16 @@ namespace helpers {
template<typename T> template<typename T>
static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
auto n2 = n * n; auto n2 = n * n;
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType(); auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
// if (dtype != DataType::DOUBLE) // if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32; // dtype = DataType::FLOAT32;
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
{input->rankOf() - 2, {input->rankOf() - 2,
input->rankOf() - 1}); input->rankOf() - 1});
@ -745,7 +737,6 @@ namespace helpers {
} }
int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
@ -788,7 +779,6 @@ namespace helpers {
int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
if (!inplace) if (!inplace)
output->assign(input); output->assign(input);
defaultContext = context;
std::unique_ptr<NDArray> tempOutput(output->dup()); std::unique_ptr<NDArray> tempOutput(output->dup());
cusolverDnHandle_t handle = nullptr; cusolverDnHandle_t handle = nullptr;
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
@ -868,7 +858,6 @@ namespace helpers {
// template <typename T> // template <typename T>
int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
if (input->dataType() == DataType::DOUBLE) if (input->dataType() == DataType::DOUBLE)
cholesky__<double>(context, input, output, inplace); cholesky__<double>(context, input, output, inplace);
@ -877,7 +866,7 @@ namespace helpers {
else { else {
std::unique_ptr<NDArray> tempOutput( std::unique_ptr<NDArray> tempOutput(
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32,
defaultContext)); LaunchContext::defaultContext()));
tempOutput->assign(input); tempOutput->assign(input);
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true); cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
output->assign(tempOutput.get()); output->assign(tempOutput.get());
@ -888,7 +877,6 @@ namespace helpers {
int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); // BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
defaultContext = context;
return cholesky_(context, input, output, inplace); return cholesky_(context, input, output, inplace);
} }
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); // BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
@ -927,7 +915,6 @@ namespace helpers {
template<typename T> template<typename T>
int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
auto n2 = input->sizeAt(-1) * input->sizeAt(-2); auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
@ -957,7 +944,6 @@ namespace helpers {
} }
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE);
} }