lup context fix (#164)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-24 16:57:48 +03:00 committed by GitHub
parent 841eeb56c5
commit ece6a17b11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 85 deletions

View File

@ -26,7 +26,6 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
template <typename T> template <typename T>
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) { static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
@ -108,14 +107,14 @@ namespace helpers {
template <typename T> template <typename T>
static NDArray lup_(NDArray* input, NDArray* compound, NDArray* permutation) { static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) {
const int rowNum = input->rows(); const int rowNum = input->rows();
const int columnNum = input->columns(); const int columnNum = input->columns();
NDArray determinant = NDArrayFactory::create<T>(1.f); NDArray determinant = NDArrayFactory::create<T>(1.f);
NDArray compoundMatrix = *input; // copy NDArray compoundMatrix = *input; // copy
NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
permutationMatrix.setIdentity(); permutationMatrix.setIdentity();
T pivotValue; // = T(0.0); T pivotValue; // = T(0.0);
@ -161,46 +160,43 @@ namespace helpers {
return determinant; return determinant;
} }
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
template <typename T> template <typename T>
static int determinant_(NDArray* input, NDArray* output) { static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) {
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace()); auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
matrix.p(row, input->e<T>(k)); matrix.p(row, input->e<T>(k));
output->p(e, lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr)); output->p(e, lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr));
} }
return Status::OK(); return Status::OK();
} }
BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context; BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES);
} }
template <typename T> template <typename T>
int logAbsDeterminant_(NDArray* input, NDArray* output) { int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) {
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace()); NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
matrix.p(row, input->e<T>(k)); matrix.p(row, input->e<T>(k));
} }
NDArray det = lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr); NDArray det = lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr);
if (det.e<T>(0) != 0.f) if (det.e<T>(0) != 0.f)
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0)))); output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
} }
@ -208,25 +204,23 @@ template <typename T>
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
} }
template <typename T> template <typename T>
static int inverse_(NDArray* input, NDArray* output) { static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) {
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
auto n2 = n * n; auto n2 = n * n;
auto totalCount = output->lengthOf() / n2; auto totalCount = output->lengthOf() / n2;
output->assign(0.f); // fill up output tensor with zeros output->assign(0.f); // fill up output tensor with zeros
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace()); auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace()); auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
for (int e = 0; e < totalCount; e++) { for (int e = 0; e < totalCount; e++) {
if (e) if (e)
@ -235,7 +229,7 @@ template <typename T>
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
matrix.p(row++, input->e<T>(k)); matrix.p(row++, input->e<T>(k));
} }
T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0); T det = lup_<T>(context, &matrix, &compound, &permutation).template e<T>(0);
// FIXME: and how this is going to work on float16? // FIXME: and how this is going to work on float16?
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) { if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
@ -268,8 +262,7 @@ template <typename T>
} }
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context; BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
} }
template <typename T> template <typename T>
@ -296,14 +289,13 @@ template <typename T>
return true; return true;
} }
BUILD_SINGLE_TEMPLATE(template bool checkCholeskyInput_, (nd4j::LaunchContext * context, NDArray const* input), FLOAT_TYPES);
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) { bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES);
} }
template <typename T> template <typename T>
int cholesky_(NDArray* input, NDArray* output, bool inplace) { int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) {
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
auto n2 = n * n; auto n2 = n * n;
@ -311,8 +303,8 @@ template <typename T>
if (!inplace) if (!inplace)
output->assign(0.f); // fill up output tensor with zeros only inplace=false output->assign(0.f); // fill up output tensor with zeros only inplace=false
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace()); std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace());
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext)); std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context));
for (int e = 0; e < totalCount; e++) { for (int e = 0; e < totalCount; e++) {
@ -346,14 +338,13 @@ template <typename T>
} }
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
defaultContext = context; BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
} }
template <typename T> template <typename T>
int logdetFunctor_(NDArray* input, NDArray* output) { int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) {
std::unique_ptr<NDArray> tempOutput(input->dup()); std::unique_ptr<NDArray> tempOutput(input->dup());
int res = cholesky_<T>(input, tempOutput.get(), false); int res = cholesky_<T>(context, input, tempOutput.get(), false);
if (res != ND4J_STATUS_OK) if (res != ND4J_STATUS_OK)
return res; return res;
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
@ -372,7 +363,7 @@ template <typename T>
} }
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES);
} }
} }

View File

@ -196,36 +196,33 @@ namespace helpers {
} }
template<typename T> template<typename T>
static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) { static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
int n = inputMatrix->rows(); int n = inputMatrix->rows();
invertedMatrix->setIdentity(); invertedMatrix->setIdentity();
if (inputMatrix->isIdentityMatrix()) return; if (inputMatrix->isIdentityMatrix()) return;
auto stream = LaunchContext::defaultContext()->getCudaStream(); auto stream = context->getCudaStream();
// invert main diagonal // invert main diagonal
upvertKernel<T> << < 1, n, 512, *stream >> > upvertKernel<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invert the second diagonal // invert the second diagonal
invertKernelLow<T> << < 1, n, 512, *stream >> > invertKernelLow<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); // invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertLowKernel<T><<< n, n, 512, *stream >> > invertLowKernel<T><<<n, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
} }
void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix}); NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix});
} }
template<typename T> template<typename T>
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) {
int n = inputMatrix->rows(); int n = inputMatrix->rows();
invertedMatrix->setIdentity(); invertedMatrix->setIdentity();
auto stream = LaunchContext::defaultContext()->getCudaStream(); auto stream = context->getCudaStream();
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
return; return;
} }
@ -235,13 +232,12 @@ namespace helpers {
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertedMatrix->tickWriteDevice(); invertedMatrix->tickWriteDevice();
invertedMatrix->printIndexedBuffer("Step1 UP inversion"); invertedMatrix->printIndexedBuffer("Step1 UP inversion");
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
} }
void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
} }
@ -525,23 +521,19 @@ namespace helpers {
input->tickWriteDevice(); input->tickWriteDevice();
} }
BUILD_SINGLE_TEMPLATE(template void lup_, BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE);
(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation),
FLOAT_NATIVE);
template<typename T> template<typename T>
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
std::vector<int> dims(); std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
{input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
// DataType dtype = input->dataType(); // DataType dtype = input->dataType();
// if (dtype != DataType::DOUBLE) // if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32; // dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
LaunchContext::defaultContext()); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
@ -550,8 +542,7 @@ namespace helpers {
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2; Nd4jLong pos = e * n2;
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> > fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else // else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
@ -584,15 +575,13 @@ namespace helpers {
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
std::vector<int> dims(); std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
{input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
DataType dtype = input->dataType(); DataType dtype = input->dataType();
if (dtype != DataType::DOUBLE) if (dtype != DataType::DOUBLE)
dtype = DataType::FLOAT32; dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
LaunchContext::defaultContext()); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
@ -601,8 +590,7 @@ namespace helpers {
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2; Nd4jLong pos = e * n2;
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> > fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else // else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
@ -614,8 +602,7 @@ namespace helpers {
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer()); auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset; auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
determinantLogKernel<T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> > determinantLogKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, outputBuf, n);
(inputBuf, outputBuf, n);
// else // else
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); // determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
} }
@ -694,11 +681,11 @@ namespace helpers {
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType(); auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
// if (dtype != DataType::DOUBLE) // if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32; // dtype = DataType::FLOAT32;
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context);
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
{input->rankOf() - 2, {input->rankOf() - 2,
input->rankOf() - 1}); input->rankOf() - 1});
@ -708,20 +695,17 @@ namespace helpers {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
for (auto i = 0LL; i < packX.numberOfTads(); i++) { for (auto i = 0LL; i < packX.numberOfTads(); i++) {
fillMatrix<T, T> << < 1, n2, 1024, *stream >> > fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(),
i * n2, n);
matrix.tickWriteDevice(); matrix.tickWriteDevice();
compound.assign(matrix); compound.assign(matrix);
lup_<T>(context, &compound, nullptr, nullptr); lup_<T>(context, &compound, nullptr, nullptr);
fillLowerUpperKernel<T> << < n, n, 1024, *stream >> > fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
matrix.assign(0); matrix.assign(0);
invertUpperMatrix(&upper, &matrix); // U^{-1} invertUpperMatrix(context, &upper, &matrix); // U^{-1}
matrix.tickWriteDevice(); matrix.tickWriteDevice();
// matrix.printIndexedBuffer("Upper Inverted"); // matrix.printIndexedBuffer("Upper Inverted");
compound.assign(0); compound.assign(0);
invertLowerMatrix(&lower, &compound); // L{-1} invertLowerMatrix(context, &lower, &compound); // L{-1}
compound.tickWriteDevice(); compound.tickWriteDevice();
// compound.printIndexedBuffer("Lower Inverted"); // compound.printIndexedBuffer("Lower Inverted");
// matrix.tickWriteDevice(); // matrix.tickWriteDevice();
@ -729,9 +713,7 @@ namespace helpers {
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
upper.tickWriteDevice(); upper.tickWriteDevice();
// upper.printIndexedBuffer("Full inverted"); // upper.printIndexedBuffer("Full inverted");
returnMatrix<T> << < 1, n2, 1024, *stream >> > returnMatrix<T> <<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n);
(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(),
i * n2, n);
} }
return Status::OK(); return Status::OK();
} }
@ -865,8 +847,7 @@ namespace helpers {
cholesky__<float>(context, input, output, inplace); cholesky__<float>(context, input, output, inplace);
else { else {
std::unique_ptr<NDArray> tempOutput( std::unique_ptr<NDArray> tempOutput(
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context));
LaunchContext::defaultContext()));
tempOutput->assign(input); tempOutput->assign(input);
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true); cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
output->assign(tempOutput.get()); output->assign(tempOutput.get());