parent
b091e972ef
commit
841eeb56c5
|
@ -31,8 +31,6 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
|
||||
|
||||
// template <typename T>
|
||||
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
||||
// if (theFirst != theSecond) {
|
||||
|
@ -204,7 +202,7 @@ namespace helpers {
|
|||
|
||||
if (inputMatrix->isIdentityMatrix()) return;
|
||||
|
||||
auto stream = defaultContext->getCudaStream();
|
||||
auto stream = LaunchContext::defaultContext()->getCudaStream();
|
||||
|
||||
// invert main diagonal
|
||||
upvertKernel<T> << < 1, n, 512, *stream >> >
|
||||
|
@ -227,7 +225,7 @@ namespace helpers {
|
|||
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||
int n = inputMatrix->rows();
|
||||
invertedMatrix->setIdentity();
|
||||
auto stream = defaultContext->getCudaStream();
|
||||
auto stream = LaunchContext::defaultContext()->getCudaStream();
|
||||
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
|
||||
return;
|
||||
}
|
||||
|
@ -392,7 +390,6 @@ namespace helpers {
|
|||
auto n = input->rows();
|
||||
cusolverDnHandle_t cusolverH = nullptr;
|
||||
cusolverStatus_t status = cusolverDnCreate(&cusolverH);
|
||||
defaultContext = context;
|
||||
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||
throw cuda_exception::build("Cannot create cuSolver handle", status);
|
||||
}
|
||||
|
@ -543,9 +540,8 @@ namespace helpers {
|
|||
// DataType dtype = input->dataType();
|
||||
// if (dtype != DataType::DOUBLE)
|
||||
// dtype = DataType::FLOAT32;
|
||||
defaultContext = context;
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
|
||||
defaultContext); //, block.getWorkspace());
|
||||
LaunchContext::defaultContext()); //, block.getWorkspace());
|
||||
auto det = NDArrayFactory::create<T>(1);
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
|
@ -578,7 +574,6 @@ namespace helpers {
|
|||
}
|
||||
|
||||
int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
|
@ -586,7 +581,6 @@ namespace helpers {
|
|||
|
||||
template<typename T>
|
||||
int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
Nd4jLong n = input->sizeAt(-1);
|
||||
Nd4jLong n2 = n * n;
|
||||
std::vector<int> dims();
|
||||
|
@ -598,7 +592,7 @@ namespace helpers {
|
|||
dtype = DataType::FLOAT32;
|
||||
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype,
|
||||
defaultContext); //, block.getWorkspace());
|
||||
LaunchContext::defaultContext()); //, block.getWorkspace());
|
||||
auto det = NDArrayFactory::create<T>(1);
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
|
@ -633,7 +627,6 @@ namespace helpers {
|
|||
}
|
||||
|
||||
int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
|
@ -696,17 +689,16 @@ namespace helpers {
|
|||
|
||||
template<typename T>
|
||||
static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
auto n = input->sizeAt(-1);
|
||||
auto n2 = n * n;
|
||||
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
|
||||
// if (dtype != DataType::DOUBLE)
|
||||
// dtype = DataType::FLOAT32;
|
||||
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
||||
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
||||
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
||||
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
||||
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
||||
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
||||
{input->rankOf() - 2,
|
||||
input->rankOf() - 1});
|
||||
|
@ -745,7 +737,6 @@ namespace helpers {
|
|||
}
|
||||
|
||||
int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
|
@ -788,7 +779,6 @@ namespace helpers {
|
|||
int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||
if (!inplace)
|
||||
output->assign(input);
|
||||
defaultContext = context;
|
||||
std::unique_ptr<NDArray> tempOutput(output->dup());
|
||||
cusolverDnHandle_t handle = nullptr;
|
||||
auto n = input->sizeAt(-1);
|
||||
|
@ -868,7 +858,6 @@ namespace helpers {
|
|||
|
||||
// template <typename T>
|
||||
int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||
defaultContext = context;
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
if (input->dataType() == DataType::DOUBLE)
|
||||
cholesky__<double>(context, input, output, inplace);
|
||||
|
@ -877,7 +866,7 @@ namespace helpers {
|
|||
else {
|
||||
std::unique_ptr<NDArray> tempOutput(
|
||||
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32,
|
||||
defaultContext));
|
||||
LaunchContext::defaultContext()));
|
||||
tempOutput->assign(input);
|
||||
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
||||
output->assign(tempOutput.get());
|
||||
|
@ -888,7 +877,6 @@ namespace helpers {
|
|||
|
||||
int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
|
||||
defaultContext = context;
|
||||
return cholesky_(context, input, output, inplace);
|
||||
}
|
||||
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
||||
|
@ -927,7 +915,6 @@ namespace helpers {
|
|||
|
||||
template<typename T>
|
||||
int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
|
||||
auto stream = context->getCudaStream();
|
||||
|
@ -957,7 +944,6 @@ namespace helpers {
|
|||
}
|
||||
|
||||
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue