parent
841eeb56c5
commit
ece6a17b11
|
@ -26,7 +26,6 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
|
||||
|
||||
template <typename T>
|
||||
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
|
||||
|
@ -108,14 +107,14 @@ namespace helpers {
|
|||
|
||||
|
||||
template <typename T>
|
||||
static NDArray lup_(NDArray* input, NDArray* compound, NDArray* permutation) {
|
||||
static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) {
|
||||
|
||||
const int rowNum = input->rows();
|
||||
const int columnNum = input->columns();
|
||||
|
||||
NDArray determinant = NDArrayFactory::create<T>(1.f);
|
||||
NDArray compoundMatrix = *input; // copy
|
||||
NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides
|
||||
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
|
||||
permutationMatrix.setIdentity();
|
||||
|
||||
T pivotValue; // = T(0.0);
|
||||
|
@ -161,46 +160,43 @@ namespace helpers {
|
|||
return determinant;
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
static int determinant_(NDArray* input, NDArray* output) {
|
||||
static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||
|
||||
Nd4jLong n = input->sizeAt(-1);
|
||||
Nd4jLong n2 = n * n;
|
||||
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
|
||||
|
||||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
|
||||
matrix.p(row, input->e<T>(k));
|
||||
output->p(e, lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr));
|
||||
output->p(e, lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||
|
||||
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
defaultContext = context;
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int logAbsDeterminant_(NDArray* input, NDArray* output) {
|
||||
int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||
|
||||
Nd4jLong n = input->sizeAt(-1);
|
||||
Nd4jLong n2 = n * n;
|
||||
|
||||
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
|
||||
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
|
||||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
|
||||
matrix.p(row, input->e<T>(k));
|
||||
}
|
||||
NDArray det = lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr);
|
||||
NDArray det = lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr);
|
||||
if (det.e<T>(0) != 0.f)
|
||||
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
|
||||
}
|
||||
|
@ -208,25 +204,23 @@ template <typename T>
|
|||
return ND4J_STATUS_OK;
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||
|
||||
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int inverse_(NDArray* input, NDArray* output) {
|
||||
static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||
|
||||
auto n = input->sizeAt(-1);
|
||||
auto n2 = n * n;
|
||||
auto totalCount = output->lengthOf() / n2;
|
||||
|
||||
output->assign(0.f); // fill up output tensor with zeros
|
||||
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
|
||||
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
|
||||
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
||||
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
||||
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
||||
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||
|
||||
for (int e = 0; e < totalCount; e++) {
|
||||
if (e)
|
||||
|
@ -235,7 +229,7 @@ template <typename T>
|
|||
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||
matrix.p(row++, input->e<T>(k));
|
||||
}
|
||||
T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0);
|
||||
T det = lup_<T>(context, &matrix, &compound, &permutation).template e<T>(0);
|
||||
|
||||
// FIXME: and how this is going to work on float16?
|
||||
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
|
||||
|
@ -268,8 +262,7 @@ template <typename T>
|
|||
}
|
||||
|
||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
defaultContext = context;
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -296,14 +289,13 @@ template <typename T>
|
|||
|
||||
return true;
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template bool checkCholeskyInput_, (nd4j::LaunchContext * context, NDArray const* input), FLOAT_TYPES);
|
||||
|
||||
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int cholesky_(NDArray* input, NDArray* output, bool inplace) {
|
||||
int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) {
|
||||
|
||||
auto n = input->sizeAt(-1);
|
||||
auto n2 = n * n;
|
||||
|
@ -311,8 +303,8 @@ template <typename T>
|
|||
if (!inplace)
|
||||
output->assign(0.f); // fill up output tensor with zeros only inplace=false
|
||||
|
||||
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace());
|
||||
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext));
|
||||
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace());
|
||||
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context));
|
||||
|
||||
for (int e = 0; e < totalCount; e++) {
|
||||
|
||||
|
@ -346,14 +338,13 @@ template <typename T>
|
|||
}
|
||||
|
||||
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
|
||||
defaultContext = context;
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int logdetFunctor_(NDArray* input, NDArray* output) {
|
||||
int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||
std::unique_ptr<NDArray> tempOutput(input->dup());
|
||||
int res = cholesky_<T>(input, tempOutput.get(), false);
|
||||
int res = cholesky_<T>(context, input, tempOutput.get(), false);
|
||||
if (res != ND4J_STATUS_OK)
|
||||
return res;
|
||||
auto n = input->sizeAt(-1);
|
||||
|
@ -372,7 +363,7 @@ template <typename T>
|
|||
}
|
||||
|
||||
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (input, output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -196,36 +196,33 @@ namespace helpers {
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
int n = inputMatrix->rows();
|
||||
invertedMatrix->setIdentity();
|
||||
|
||||
if (inputMatrix->isIdentityMatrix()) return;
|
||||
|
||||
auto stream = LaunchContext::defaultContext()->getCudaStream();
|
||||
auto stream = context->getCudaStream();
|
||||
|
||||
// invert main diagonal
|
||||
upvertKernel<T> << < 1, n, 512, *stream >> >
|
||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
upvertKernel<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
// invert the second diagonal
|
||||
invertKernelLow<T> << < 1, n, 512, *stream >> >
|
||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertKernelLow<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertLowKernel<T><<< n, n, 512, *stream >> >
|
||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertLowKernel<T><<<n, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
}
|
||||
|
||||
void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||
NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||
static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||
int n = inputMatrix->rows();
|
||||
invertedMatrix->setIdentity();
|
||||
auto stream = LaunchContext::defaultContext()->getCudaStream();
|
||||
auto stream = context->getCudaStream();
|
||||
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
|
||||
return;
|
||||
}
|
||||
|
@ -235,13 +232,12 @@ namespace helpers {
|
|||
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertedMatrix->tickWriteDevice();
|
||||
invertedMatrix->printIndexedBuffer("Step1 UP inversion");
|
||||
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),
|
||||
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
}
|
||||
|
||||
void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||
}
|
||||
|
||||
|
@ -525,23 +521,19 @@ namespace helpers {
|
|||
input->tickWriteDevice();
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void lup_,
|
||||
(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation),
|
||||
FLOAT_NATIVE);
|
||||
BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE);
|
||||
|
||||
template<typename T>
|
||||
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
Nd4jLong n = input->sizeAt(-1);
|
||||
Nd4jLong n2 = n * n;
|
||||
std::vector<int> dims();
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
||||
{input->rankOf() - 2, input->rankOf() - 1});
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
||||
// DataType dtype = input->dataType();
|
||||
// if (dtype != DataType::DOUBLE)
|
||||
// dtype = DataType::FLOAT32;
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
|
||||
LaunchContext::defaultContext()); //, block.getWorkspace());
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||
auto det = NDArrayFactory::create<T>(1);
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
|
@ -550,8 +542,7 @@ namespace helpers {
|
|||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
Nd4jLong pos = e * n2;
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
// else
|
||||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
|
||||
|
@ -584,15 +575,13 @@ namespace helpers {
|
|||
Nd4jLong n = input->sizeAt(-1);
|
||||
Nd4jLong n2 = n * n;
|
||||
std::vector<int> dims();
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
||||
{input->rankOf() - 2, input->rankOf() - 1});
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
||||
DataType dtype = input->dataType();
|
||||
if (dtype != DataType::DOUBLE)
|
||||
dtype = DataType::FLOAT32;
|
||||
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype,
|
||||
LaunchContext::defaultContext()); //, block.getWorkspace());
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
|
||||
auto det = NDArrayFactory::create<T>(1);
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
|
@ -601,8 +590,7 @@ namespace helpers {
|
|||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
Nd4jLong pos = e * n2;
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
// else
|
||||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
|
||||
|
@ -614,8 +602,7 @@ namespace helpers {
|
|||
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
|
||||
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
determinantLogKernel<T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
||||
(inputBuf, outputBuf, n);
|
||||
determinantLogKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, outputBuf, n);
|
||||
// else
|
||||
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||
}
|
||||
|
@ -694,11 +681,11 @@ namespace helpers {
|
|||
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
|
||||
// if (dtype != DataType::DOUBLE)
|
||||
// dtype = DataType::FLOAT32;
|
||||
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
||||
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
||||
{input->rankOf() - 2,
|
||||
input->rankOf() - 1});
|
||||
|
@ -708,20 +695,17 @@ namespace helpers {
|
|||
auto stream = context->getCudaStream();
|
||||
|
||||
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
|
||||
fillMatrix<T, T> << < 1, n2, 1024, *stream >> >
|
||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(),
|
||||
i * n2, n);
|
||||
fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
|
||||
matrix.tickWriteDevice();
|
||||
compound.assign(matrix);
|
||||
lup_<T>(context, &compound, nullptr, nullptr);
|
||||
fillLowerUpperKernel<T> << < n, n, 1024, *stream >> >
|
||||
(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
|
||||
fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
|
||||
matrix.assign(0);
|
||||
invertUpperMatrix(&upper, &matrix); // U^{-1}
|
||||
invertUpperMatrix(context, &upper, &matrix); // U^{-1}
|
||||
matrix.tickWriteDevice();
|
||||
// matrix.printIndexedBuffer("Upper Inverted");
|
||||
compound.assign(0);
|
||||
invertLowerMatrix(&lower, &compound); // L{-1}
|
||||
invertLowerMatrix(context, &lower, &compound); // L{-1}
|
||||
compound.tickWriteDevice();
|
||||
// compound.printIndexedBuffer("Lower Inverted");
|
||||
// matrix.tickWriteDevice();
|
||||
|
@ -729,9 +713,7 @@ namespace helpers {
|
|||
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
|
||||
upper.tickWriteDevice();
|
||||
// upper.printIndexedBuffer("Full inverted");
|
||||
returnMatrix<T> << < 1, n2, 1024, *stream >> >
|
||||
(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(),
|
||||
i * n2, n);
|
||||
returnMatrix<T> <<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -865,8 +847,7 @@ namespace helpers {
|
|||
cholesky__<float>(context, input, output, inplace);
|
||||
else {
|
||||
std::unique_ptr<NDArray> tempOutput(
|
||||
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32,
|
||||
LaunchContext::defaultContext()));
|
||||
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context));
|
||||
tempOutput->assign(input);
|
||||
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
||||
output->assign(tempOutput.get());
|
||||
|
|
Loading…
Reference in New Issue