lup context fix (#164)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-24 16:57:48 +03:00 committed by GitHub
parent 841eeb56c5
commit ece6a17b11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 85 deletions

View File

@ -26,7 +26,6 @@
namespace nd4j {
namespace ops {
namespace helpers {
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
template <typename T>
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
@ -108,14 +107,14 @@ namespace helpers {
template <typename T>
static NDArray lup_(NDArray* input, NDArray* compound, NDArray* permutation) {
static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) {
const int rowNum = input->rows();
const int columnNum = input->columns();
NDArray determinant = NDArrayFactory::create<T>(1.f);
NDArray compoundMatrix = *input; // copy
NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
permutationMatrix.setIdentity();
T pivotValue; // = T(0.0);
@ -161,46 +160,43 @@ namespace helpers {
return determinant;
}
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
template <typename T>
static int determinant_(NDArray* input, NDArray* output) {
static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) {
Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
matrix.p(row, input->e<T>(k));
output->p(e, lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr));
output->p(e, lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr));
}
return Status::OK();
}
BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
}
template <typename T>
int logAbsDeterminant_(NDArray* input, NDArray* output) {
int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) {
Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n;
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
matrix.p(row, input->e<T>(k));
}
NDArray det = lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr);
NDArray det = lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr);
if (det.e<T>(0) != 0.f)
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
}
@ -208,25 +204,23 @@ template <typename T>
return ND4J_STATUS_OK;
}
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
}
template <typename T>
static int inverse_(NDArray* input, NDArray* output) {
static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) {
auto n = input->sizeAt(-1);
auto n2 = n * n;
auto totalCount = output->lengthOf() / n2;
output->assign(0.f); // fill up output tensor with zeros
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
for (int e = 0; e < totalCount; e++) {
if (e)
@ -235,7 +229,7 @@ template <typename T>
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
matrix.p(row++, input->e<T>(k));
}
T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0);
T det = lup_<T>(context, &matrix, &compound, &permutation).template e<T>(0);
// FIXME: and how this is going to work on float16?
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
@ -268,8 +262,7 @@ template <typename T>
}
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
}
template <typename T>
@ -296,14 +289,13 @@ template <typename T>
return true;
}
BUILD_SINGLE_TEMPLATE(template bool checkCholeskyInput_, (nd4j::LaunchContext * context, NDArray const* input), FLOAT_TYPES);
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES);
}
template <typename T>
int cholesky_(NDArray* input, NDArray* output, bool inplace) {
int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) {
auto n = input->sizeAt(-1);
auto n2 = n * n;
@ -311,8 +303,8 @@ template <typename T>
if (!inplace)
output->assign(0.f); // fill up output tensor with zeros only inplace=false
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace());
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext));
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace());
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context));
for (int e = 0; e < totalCount; e++) {
@ -346,14 +338,13 @@ template <typename T>
}
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
}
template <typename T>
int logdetFunctor_(NDArray* input, NDArray* output) {
int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) {
std::unique_ptr<NDArray> tempOutput(input->dup());
int res = cholesky_<T>(input, tempOutput.get(), false);
int res = cholesky_<T>(context, input, tempOutput.get(), false);
if (res != ND4J_STATUS_OK)
return res;
auto n = input->sizeAt(-1);
@ -372,7 +363,7 @@ template <typename T>
}
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES);
}
}

View File

@ -196,36 +196,33 @@ namespace helpers {
}
template<typename T>
static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) {
static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
int n = inputMatrix->rows();
invertedMatrix->setIdentity();
if (inputMatrix->isIdentityMatrix()) return;
auto stream = LaunchContext::defaultContext()->getCudaStream();
auto stream = context->getCudaStream();
// invert main diagonal
upvertKernel<T> << < 1, n, 512, *stream >> >
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
upvertKernel<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invert the second diagonal
invertKernelLow<T> << < 1, n, 512, *stream >> >
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertKernelLow<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertLowKernel<T><<< n, n, 512, *stream >> >
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertLowKernel<T><<<n, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
}
void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix});
}
template<typename T>
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) {
int n = inputMatrix->rows();
invertedMatrix->setIdentity();
auto stream = LaunchContext::defaultContext()->getCudaStream();
auto stream = context->getCudaStream();
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
return;
}
@ -235,13 +232,12 @@ namespace helpers {
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertedMatrix->tickWriteDevice();
invertedMatrix->printIndexedBuffer("Step1 UP inversion");
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
}
void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
}
@ -525,23 +521,19 @@ namespace helpers {
input->tickWriteDevice();
}
BUILD_SINGLE_TEMPLATE(template void lup_,
(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation),
FLOAT_NATIVE);
BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE);
template<typename T>
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n;
std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
{input->rankOf() - 2, input->rankOf() - 1});
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
// DataType dtype = input->dataType();
// if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
LaunchContext::defaultContext()); //, block.getWorkspace());
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input});
@ -550,8 +542,7 @@ namespace helpers {
for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2;
// if (matrix.dataType() == input->dataType())
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
@ -584,15 +575,13 @@ namespace helpers {
Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n;
std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
{input->rankOf() - 2, input->rankOf() - 1});
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
DataType dtype = input->dataType();
if (dtype != DataType::DOUBLE)
dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype,
LaunchContext::defaultContext()); //, block.getWorkspace());
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input});
@ -601,8 +590,7 @@ namespace helpers {
for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2;
// if (matrix.dataType() == input->dataType())
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
@ -614,8 +602,7 @@ namespace helpers {
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
// if (matrix.dataType() == input->dataType())
determinantLogKernel<T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
(inputBuf, outputBuf, n);
determinantLogKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, outputBuf, n);
// else
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
}
@ -694,11 +681,11 @@ namespace helpers {
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
// if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32;
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext());
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context);
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
{input->rankOf() - 2,
input->rankOf() - 1});
@ -708,20 +695,17 @@ namespace helpers {
auto stream = context->getCudaStream();
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
fillMatrix<T, T> << < 1, n2, 1024, *stream >> >
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(),
i * n2, n);
fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
matrix.tickWriteDevice();
compound.assign(matrix);
lup_<T>(context, &compound, nullptr, nullptr);
fillLowerUpperKernel<T> << < n, n, 1024, *stream >> >
(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
matrix.assign(0);
invertUpperMatrix(&upper, &matrix); // U^{-1}
invertUpperMatrix(context, &upper, &matrix); // U^{-1}
matrix.tickWriteDevice();
// matrix.printIndexedBuffer("Upper Inverted");
compound.assign(0);
invertLowerMatrix(&lower, &compound); // L{-1}
invertLowerMatrix(context, &lower, &compound); // L{-1}
compound.tickWriteDevice();
// compound.printIndexedBuffer("Lower Inverted");
// matrix.tickWriteDevice();
@ -729,9 +713,7 @@ namespace helpers {
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
upper.tickWriteDevice();
// upper.printIndexedBuffer("Full inverted");
returnMatrix<T> << < 1, n2, 1024, *stream >> >
(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(),
i * n2, n);
returnMatrix<T> <<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n);
}
return Status::OK();
}
@ -865,8 +847,7 @@ namespace helpers {
cholesky__<float>(context, input, output, inplace);
else {
std::unique_ptr<NDArray> tempOutput(
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32,
LaunchContext::defaultContext()));
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context));
tempOutput->assign(input);
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
output->assign(tempOutput.get());