parent
b091e972ef
commit
841eeb56c5
|
@ -31,8 +31,6 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
|
|
||||||
|
|
||||||
// template <typename T>
|
// template <typename T>
|
||||||
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
||||||
// if (theFirst != theSecond) {
|
// if (theFirst != theSecond) {
|
||||||
|
@ -204,7 +202,7 @@ namespace helpers {
|
||||||
|
|
||||||
if (inputMatrix->isIdentityMatrix()) return;
|
if (inputMatrix->isIdentityMatrix()) return;
|
||||||
|
|
||||||
auto stream = defaultContext->getCudaStream();
|
auto stream = LaunchContext::defaultContext()->getCudaStream();
|
||||||
|
|
||||||
// invert main diagonal
|
// invert main diagonal
|
||||||
upvertKernel<T> << < 1, n, 512, *stream >> >
|
upvertKernel<T> << < 1, n, 512, *stream >> >
|
||||||
|
@ -227,7 +225,7 @@ namespace helpers {
|
||||||
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||||
int n = inputMatrix->rows();
|
int n = inputMatrix->rows();
|
||||||
invertedMatrix->setIdentity();
|
invertedMatrix->setIdentity();
|
||||||
auto stream = defaultContext->getCudaStream();
|
auto stream = LaunchContext::defaultContext()->getCudaStream();
|
||||||
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
|
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -392,7 +390,6 @@ namespace helpers {
|
||||||
auto n = input->rows();
|
auto n = input->rows();
|
||||||
cusolverDnHandle_t cusolverH = nullptr;
|
cusolverDnHandle_t cusolverH = nullptr;
|
||||||
cusolverStatus_t status = cusolverDnCreate(&cusolverH);
|
cusolverStatus_t status = cusolverDnCreate(&cusolverH);
|
||||||
defaultContext = context;
|
|
||||||
if (CUSOLVER_STATUS_SUCCESS != status) {
|
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
throw cuda_exception::build("Cannot create cuSolver handle", status);
|
throw cuda_exception::build("Cannot create cuSolver handle", status);
|
||||||
}
|
}
|
||||||
|
@ -543,9 +540,8 @@ namespace helpers {
|
||||||
// DataType dtype = input->dataType();
|
// DataType dtype = input->dataType();
|
||||||
// if (dtype != DataType::DOUBLE)
|
// if (dtype != DataType::DOUBLE)
|
||||||
// dtype = DataType::FLOAT32;
|
// dtype = DataType::FLOAT32;
|
||||||
defaultContext = context;
|
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
|
||||||
defaultContext); //, block.getWorkspace());
|
LaunchContext::defaultContext()); //, block.getWorkspace());
|
||||||
auto det = NDArrayFactory::create<T>(1);
|
auto det = NDArrayFactory::create<T>(1);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
@ -578,7 +574,6 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
@ -586,7 +581,6 @@ namespace helpers {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) {
|
int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
Nd4jLong n = input->sizeAt(-1);
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
Nd4jLong n2 = n * n;
|
Nd4jLong n2 = n * n;
|
||||||
std::vector<int> dims();
|
std::vector<int> dims();
|
||||||
|
@ -598,7 +592,7 @@ namespace helpers {
|
||||||
dtype = DataType::FLOAT32;
|
dtype = DataType::FLOAT32;
|
||||||
|
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype,
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype,
|
||||||
defaultContext); //, block.getWorkspace());
|
LaunchContext::defaultContext()); //, block.getWorkspace());
|
||||||
auto det = NDArrayFactory::create<T>(1);
|
auto det = NDArrayFactory::create<T>(1);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
@ -633,7 +627,6 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
@ -696,17 +689,16 @@ namespace helpers {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
auto n2 = n * n;
|
auto n2 = n * n;
|
||||||
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
|
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
|
||||||
// if (dtype != DataType::DOUBLE)
|
// if (dtype != DataType::DOUBLE)
|
||||||
// dtype = DataType::FLOAT32;
|
// dtype = DataType::FLOAT32;
|
||||||
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||||
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||||
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||||
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||||
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
||||||
{input->rankOf() - 2,
|
{input->rankOf() - 2,
|
||||||
input->rankOf() - 1});
|
input->rankOf() - 1});
|
||||||
|
@ -745,7 +737,6 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
@ -788,7 +779,6 @@ namespace helpers {
|
||||||
int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||||
if (!inplace)
|
if (!inplace)
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
defaultContext = context;
|
|
||||||
std::unique_ptr<NDArray> tempOutput(output->dup());
|
std::unique_ptr<NDArray> tempOutput(output->dup());
|
||||||
cusolverDnHandle_t handle = nullptr;
|
cusolverDnHandle_t handle = nullptr;
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
|
@ -868,7 +858,6 @@ namespace helpers {
|
||||||
|
|
||||||
// template <typename T>
|
// template <typename T>
|
||||||
int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||||
defaultContext = context;
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
if (input->dataType() == DataType::DOUBLE)
|
if (input->dataType() == DataType::DOUBLE)
|
||||||
cholesky__<double>(context, input, output, inplace);
|
cholesky__<double>(context, input, output, inplace);
|
||||||
|
@ -877,7 +866,7 @@ namespace helpers {
|
||||||
else {
|
else {
|
||||||
std::unique_ptr<NDArray> tempOutput(
|
std::unique_ptr<NDArray> tempOutput(
|
||||||
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32,
|
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32,
|
||||||
defaultContext));
|
LaunchContext::defaultContext()));
|
||||||
tempOutput->assign(input);
|
tempOutput->assign(input);
|
||||||
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
||||||
output->assign(tempOutput.get());
|
output->assign(tempOutput.get());
|
||||||
|
@ -888,7 +877,6 @@ namespace helpers {
|
||||||
|
|
||||||
int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||||
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
|
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
|
||||||
defaultContext = context;
|
|
||||||
return cholesky_(context, input, output, inplace);
|
return cholesky_(context, input, output, inplace);
|
||||||
}
|
}
|
||||||
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
||||||
|
@ -927,7 +915,6 @@ namespace helpers {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
|
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
@ -957,7 +944,6 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue