parent
841eeb56c5
commit
ece6a17b11
|
@ -26,7 +26,6 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
|
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
|
||||||
|
@ -108,14 +107,14 @@ namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static NDArray lup_(NDArray* input, NDArray* compound, NDArray* permutation) {
|
static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) {
|
||||||
|
|
||||||
const int rowNum = input->rows();
|
const int rowNum = input->rows();
|
||||||
const int columnNum = input->columns();
|
const int columnNum = input->columns();
|
||||||
|
|
||||||
NDArray determinant = NDArrayFactory::create<T>(1.f);
|
NDArray determinant = NDArrayFactory::create<T>(1.f);
|
||||||
NDArray compoundMatrix = *input; // copy
|
NDArray compoundMatrix = *input; // copy
|
||||||
NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides
|
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
|
||||||
permutationMatrix.setIdentity();
|
permutationMatrix.setIdentity();
|
||||||
|
|
||||||
T pivotValue; // = T(0.0);
|
T pivotValue; // = T(0.0);
|
||||||
|
@ -161,46 +160,43 @@ namespace helpers {
|
||||||
return determinant;
|
return determinant;
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int determinant_(NDArray* input, NDArray* output) {
|
static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||||
|
|
||||||
Nd4jLong n = input->sizeAt(-1);
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
Nd4jLong n2 = n * n;
|
Nd4jLong n2 = n * n;
|
||||||
|
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
|
||||||
|
|
||||||
for (int e = 0; e < output->lengthOf(); e++) {
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
|
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
|
||||||
matrix.p(row, input->e<T>(k));
|
matrix.p(row, input->e<T>(k));
|
||||||
output->p(e, lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr));
|
output->p(e, lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
|
||||||
|
|
||||||
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
defaultContext = context;
|
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int logAbsDeterminant_(NDArray* input, NDArray* output) {
|
int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||||
|
|
||||||
Nd4jLong n = input->sizeAt(-1);
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
Nd4jLong n2 = n * n;
|
Nd4jLong n2 = n * n;
|
||||||
|
|
||||||
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
|
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
|
||||||
for (int e = 0; e < output->lengthOf(); e++) {
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
|
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
|
||||||
matrix.p(row, input->e<T>(k));
|
matrix.p(row, input->e<T>(k));
|
||||||
}
|
}
|
||||||
NDArray det = lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr);
|
NDArray det = lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr);
|
||||||
if (det.e<T>(0) != 0.f)
|
if (det.e<T>(0) != 0.f)
|
||||||
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
|
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
|
||||||
}
|
}
|
||||||
|
@ -208,25 +204,23 @@ template <typename T>
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
|
||||||
|
|
||||||
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int inverse_(NDArray* input, NDArray* output) {
|
static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||||
|
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
auto n2 = n * n;
|
auto n2 = n * n;
|
||||||
auto totalCount = output->lengthOf() / n2;
|
auto totalCount = output->lengthOf() / n2;
|
||||||
|
|
||||||
output->assign(0.f); // fill up output tensor with zeros
|
output->assign(0.f); // fill up output tensor with zeros
|
||||||
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
|
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||||
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
|
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||||
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||||
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||||
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||||
|
|
||||||
for (int e = 0; e < totalCount; e++) {
|
for (int e = 0; e < totalCount; e++) {
|
||||||
if (e)
|
if (e)
|
||||||
|
@ -235,7 +229,7 @@ template <typename T>
|
||||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||||
matrix.p(row++, input->e<T>(k));
|
matrix.p(row++, input->e<T>(k));
|
||||||
}
|
}
|
||||||
T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0);
|
T det = lup_<T>(context, &matrix, &compound, &permutation).template e<T>(0);
|
||||||
|
|
||||||
// FIXME: and how this is going to work on float16?
|
// FIXME: and how this is going to work on float16?
|
||||||
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
|
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
|
||||||
|
@ -268,8 +262,7 @@ template <typename T>
|
||||||
}
|
}
|
||||||
|
|
||||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
defaultContext = context;
|
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -296,14 +289,13 @@ template <typename T>
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template bool checkCholeskyInput_, (nd4j::LaunchContext * context, NDArray const* input), FLOAT_TYPES);
|
|
||||||
|
|
||||||
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
|
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int cholesky_(NDArray* input, NDArray* output, bool inplace) {
|
int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) {
|
||||||
|
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
auto n2 = n * n;
|
auto n2 = n * n;
|
||||||
|
@ -311,8 +303,8 @@ template <typename T>
|
||||||
if (!inplace)
|
if (!inplace)
|
||||||
output->assign(0.f); // fill up output tensor with zeros only inplace=false
|
output->assign(0.f); // fill up output tensor with zeros only inplace=false
|
||||||
|
|
||||||
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace());
|
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace());
|
||||||
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext));
|
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context));
|
||||||
|
|
||||||
for (int e = 0; e < totalCount; e++) {
|
for (int e = 0; e < totalCount; e++) {
|
||||||
|
|
||||||
|
@ -346,14 +338,13 @@ template <typename T>
|
||||||
}
|
}
|
||||||
|
|
||||||
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
|
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
|
||||||
defaultContext = context;
|
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int logdetFunctor_(NDArray* input, NDArray* output) {
|
int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||||
std::unique_ptr<NDArray> tempOutput(input->dup());
|
std::unique_ptr<NDArray> tempOutput(input->dup());
|
||||||
int res = cholesky_<T>(input, tempOutput.get(), false);
|
int res = cholesky_<T>(context, input, tempOutput.get(), false);
|
||||||
if (res != ND4J_STATUS_OK)
|
if (res != ND4J_STATUS_OK)
|
||||||
return res;
|
return res;
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
|
@ -372,7 +363,7 @@ template <typename T>
|
||||||
}
|
}
|
||||||
|
|
||||||
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -196,36 +196,33 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||||
int n = inputMatrix->rows();
|
int n = inputMatrix->rows();
|
||||||
invertedMatrix->setIdentity();
|
invertedMatrix->setIdentity();
|
||||||
|
|
||||||
if (inputMatrix->isIdentityMatrix()) return;
|
if (inputMatrix->isIdentityMatrix()) return;
|
||||||
|
|
||||||
auto stream = LaunchContext::defaultContext()->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
// invert main diagonal
|
// invert main diagonal
|
||||||
upvertKernel<T> << < 1, n, 512, *stream >> >
|
upvertKernel<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
|
||||||
// invert the second diagonal
|
// invert the second diagonal
|
||||||
invertKernelLow<T> << < 1, n, 512, *stream >> >
|
invertKernelLow<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
|
||||||
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
invertLowKernel<T><<< n, n, 512, *stream >> >
|
invertLowKernel<T><<<n, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||||
NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix});
|
NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix});
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||||
int n = inputMatrix->rows();
|
int n = inputMatrix->rows();
|
||||||
invertedMatrix->setIdentity();
|
invertedMatrix->setIdentity();
|
||||||
auto stream = LaunchContext::defaultContext()->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
|
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -235,13 +232,12 @@ namespace helpers {
|
||||||
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
invertedMatrix->tickWriteDevice();
|
invertedMatrix->tickWriteDevice();
|
||||||
invertedMatrix->printIndexedBuffer("Step1 UP inversion");
|
invertedMatrix->printIndexedBuffer("Step1 UP inversion");
|
||||||
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),
|
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||||
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -525,23 +521,19 @@ namespace helpers {
|
||||||
input->tickWriteDevice();
|
input->tickWriteDevice();
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void lup_,
|
BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE);
|
||||||
(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation),
|
|
||||||
FLOAT_NATIVE);
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
Nd4jLong n = input->sizeAt(-1);
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
Nd4jLong n2 = n * n;
|
Nd4jLong n2 = n * n;
|
||||||
std::vector<int> dims();
|
std::vector<int> dims();
|
||||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||||
{input->rankOf() - 2, input->rankOf() - 1});
|
|
||||||
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
||||||
// DataType dtype = input->dataType();
|
// DataType dtype = input->dataType();
|
||||||
// if (dtype != DataType::DOUBLE)
|
// if (dtype != DataType::DOUBLE)
|
||||||
// dtype = DataType::FLOAT32;
|
// dtype = DataType::FLOAT32;
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||||
LaunchContext::defaultContext()); //, block.getWorkspace());
|
|
||||||
auto det = NDArrayFactory::create<T>(1);
|
auto det = NDArrayFactory::create<T>(1);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
@ -550,8 +542,7 @@ namespace helpers {
|
||||||
for (int e = 0; e < output->lengthOf(); e++) {
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
Nd4jLong pos = e * n2;
|
Nd4jLong pos = e * n2;
|
||||||
// if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
|
||||||
// else
|
// else
|
||||||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
|
|
||||||
|
@ -584,15 +575,13 @@ namespace helpers {
|
||||||
Nd4jLong n = input->sizeAt(-1);
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
Nd4jLong n2 = n * n;
|
Nd4jLong n2 = n * n;
|
||||||
std::vector<int> dims();
|
std::vector<int> dims();
|
||||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||||
{input->rankOf() - 2, input->rankOf() - 1});
|
|
||||||
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
||||||
DataType dtype = input->dataType();
|
DataType dtype = input->dataType();
|
||||||
if (dtype != DataType::DOUBLE)
|
if (dtype != DataType::DOUBLE)
|
||||||
dtype = DataType::FLOAT32;
|
dtype = DataType::FLOAT32;
|
||||||
|
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype,
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
|
||||||
LaunchContext::defaultContext()); //, block.getWorkspace());
|
|
||||||
auto det = NDArrayFactory::create<T>(1);
|
auto det = NDArrayFactory::create<T>(1);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
@ -601,8 +590,7 @@ namespace helpers {
|
||||||
for (int e = 0; e < output->lengthOf(); e++) {
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
Nd4jLong pos = e * n2;
|
Nd4jLong pos = e * n2;
|
||||||
// if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
|
||||||
// else
|
// else
|
||||||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
|
|
||||||
|
@ -614,8 +602,7 @@ namespace helpers {
|
||||||
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
|
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
|
||||||
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
|
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
|
||||||
// if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
determinantLogKernel<T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
determinantLogKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, outputBuf, n);
|
||||||
(inputBuf, outputBuf, n);
|
|
||||||
// else
|
// else
|
||||||
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||||
}
|
}
|
||||||
|
@ -694,11 +681,11 @@ namespace helpers {
|
||||||
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
|
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
|
||||||
// if (dtype != DataType::DOUBLE)
|
// if (dtype != DataType::DOUBLE)
|
||||||
// dtype = DataType::FLOAT32;
|
// dtype = DataType::FLOAT32;
|
||||||
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||||
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||||
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||||
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||||
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
|
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
||||||
{input->rankOf() - 2,
|
{input->rankOf() - 2,
|
||||||
input->rankOf() - 1});
|
input->rankOf() - 1});
|
||||||
|
@ -708,20 +695,17 @@ namespace helpers {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
|
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
|
||||||
fillMatrix<T, T> << < 1, n2, 1024, *stream >> >
|
fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
|
||||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(),
|
|
||||||
i * n2, n);
|
|
||||||
matrix.tickWriteDevice();
|
matrix.tickWriteDevice();
|
||||||
compound.assign(matrix);
|
compound.assign(matrix);
|
||||||
lup_<T>(context, &compound, nullptr, nullptr);
|
lup_<T>(context, &compound, nullptr, nullptr);
|
||||||
fillLowerUpperKernel<T> << < n, n, 1024, *stream >> >
|
fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
|
||||||
(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
|
|
||||||
matrix.assign(0);
|
matrix.assign(0);
|
||||||
invertUpperMatrix(&upper, &matrix); // U^{-1}
|
invertUpperMatrix(context, &upper, &matrix); // U^{-1}
|
||||||
matrix.tickWriteDevice();
|
matrix.tickWriteDevice();
|
||||||
// matrix.printIndexedBuffer("Upper Inverted");
|
// matrix.printIndexedBuffer("Upper Inverted");
|
||||||
compound.assign(0);
|
compound.assign(0);
|
||||||
invertLowerMatrix(&lower, &compound); // L{-1}
|
invertLowerMatrix(context, &lower, &compound); // L{-1}
|
||||||
compound.tickWriteDevice();
|
compound.tickWriteDevice();
|
||||||
// compound.printIndexedBuffer("Lower Inverted");
|
// compound.printIndexedBuffer("Lower Inverted");
|
||||||
// matrix.tickWriteDevice();
|
// matrix.tickWriteDevice();
|
||||||
|
@ -729,9 +713,7 @@ namespace helpers {
|
||||||
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
|
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
|
||||||
upper.tickWriteDevice();
|
upper.tickWriteDevice();
|
||||||
// upper.printIndexedBuffer("Full inverted");
|
// upper.printIndexedBuffer("Full inverted");
|
||||||
returnMatrix<T> << < 1, n2, 1024, *stream >> >
|
returnMatrix<T> <<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n);
|
||||||
(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(),
|
|
||||||
i * n2, n);
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -865,8 +847,7 @@ namespace helpers {
|
||||||
cholesky__<float>(context, input, output, inplace);
|
cholesky__<float>(context, input, output, inplace);
|
||||||
else {
|
else {
|
||||||
std::unique_ptr<NDArray> tempOutput(
|
std::unique_ptr<NDArray> tempOutput(
|
||||||
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32,
|
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context));
|
||||||
LaunchContext::defaultContext()));
|
|
||||||
tempOutput->assign(input);
|
tempOutput->assign(input);
|
||||||
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
||||||
output->assign(tempOutput.get());
|
output->assign(tempOutput.get());
|
||||||
|
|
Loading…
Reference in New Issue