parent
							
								
									841eeb56c5
								
							
						
					
					
						commit
						ece6a17b11
					
				| @ -26,7 +26,6 @@ | ||||
| namespace nd4j { | ||||
| namespace ops { | ||||
| namespace helpers { | ||||
|     nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext(); | ||||
| 
 | ||||
|     template <typename T> | ||||
|     static void swapRows_(NDArray* matrix, int theFirst, int theSecond) { | ||||
| @ -108,14 +107,14 @@ namespace helpers { | ||||
| 
 | ||||
| 
 | ||||
|     template <typename T> | ||||
|     static NDArray lup_(NDArray* input, NDArray* compound, NDArray* permutation) { | ||||
|     static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) { | ||||
| 
 | ||||
|         const int rowNum = input->rows(); | ||||
|         const int columnNum = input->columns(); | ||||
| 
 | ||||
|         NDArray determinant = NDArrayFactory::create<T>(1.f); | ||||
|         NDArray compoundMatrix = *input; // copy
 | ||||
|         NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides
 | ||||
|         NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
 | ||||
|         permutationMatrix.setIdentity(); | ||||
| 
 | ||||
|         T pivotValue; // = T(0.0);
 | ||||
| @ -161,46 +160,43 @@ namespace helpers { | ||||
|         return determinant; | ||||
|     } | ||||
| 
 | ||||
|     BUILD_SINGLE_TEMPLATE(template NDArray lup_, (NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); | ||||
|     BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     template <typename T> | ||||
|     static int determinant_(NDArray* input, NDArray* output) { | ||||
|     static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) { | ||||
| 
 | ||||
|         Nd4jLong n = input->sizeAt(-1); | ||||
|         Nd4jLong n2 = n * n; | ||||
| 
 | ||||
|         auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
 | ||||
|         auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
 | ||||
| 
 | ||||
|         for (int e = 0; e < output->lengthOf(); e++) { | ||||
|             for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) | ||||
|                 matrix.p(row, input->e<T>(k)); | ||||
|             output->p(e, lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr)); | ||||
|             output->p(e, lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr)); | ||||
|         } | ||||
| 
 | ||||
|         return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|     BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES); | ||||
| 
 | ||||
|     int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { | ||||
|         defaultContext = context; | ||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES); | ||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES); | ||||
|     } | ||||
| 
 | ||||
| template <typename T> | ||||
|     int logAbsDeterminant_(NDArray* input, NDArray* output) { | ||||
|     int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) { | ||||
| 
 | ||||
|         Nd4jLong n = input->sizeAt(-1); | ||||
|         Nd4jLong n2 = n * n; | ||||
| 
 | ||||
|         NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
 | ||||
|         NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
 | ||||
|         for (int e = 0; e < output->lengthOf(); e++) { | ||||
|             for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { | ||||
|                 matrix.p(row, input->e<T>(k)); | ||||
|             } | ||||
| 	    NDArray det = lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr); | ||||
| 	    NDArray det = lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); | ||||
| 	    if (det.e<T>(0) != 0.f) | ||||
|              	output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0)))); | ||||
|         } | ||||
| @ -208,25 +204,23 @@ template <typename T> | ||||
|         return ND4J_STATUS_OK; | ||||
|     } | ||||
| 
 | ||||
|     BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES); | ||||
| 
 | ||||
|     int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { | ||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES); | ||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES); | ||||
|     } | ||||
| 
 | ||||
|     template <typename T> | ||||
|     static int inverse_(NDArray* input, NDArray* output) { | ||||
|     static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) { | ||||
| 
 | ||||
|         auto n = input->sizeAt(-1); | ||||
|         auto n2 = n * n; | ||||
|         auto totalCount = output->lengthOf() / n2; | ||||
| 
 | ||||
|         output->assign(0.f); // fill up output tensor with zeros
 | ||||
|         auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
 | ||||
|         auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
 | ||||
|         auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); | ||||
|         auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); | ||||
|         auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); | ||||
|         auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
 | ||||
|         auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
 | ||||
|         auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); | ||||
|         auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); | ||||
|         auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); | ||||
| 
 | ||||
|         for (int e = 0; e < totalCount; e++) { | ||||
|             if (e) | ||||
| @ -235,7 +229,7 @@ template <typename T> | ||||
|             for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { | ||||
|                 matrix.p(row++, input->e<T>(k)); | ||||
|             } | ||||
|             T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0); | ||||
|             T det = lup_<T>(context, &matrix, &compound, &permutation).template e<T>(0); | ||||
| 
 | ||||
|             // FIXME: and how this is going to work on float16?
 | ||||
|             if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) { | ||||
| @ -268,8 +262,7 @@ template <typename T> | ||||
|     } | ||||
| 
 | ||||
|     int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { | ||||
|         defaultContext = context; | ||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES); | ||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES); | ||||
|     } | ||||
| 
 | ||||
|     template <typename T> | ||||
| @ -296,14 +289,13 @@ template <typename T> | ||||
| 
 | ||||
|         return true; | ||||
|     } | ||||
|     BUILD_SINGLE_TEMPLATE(template bool checkCholeskyInput_, (nd4j::LaunchContext * context, NDArray const* input), FLOAT_TYPES); | ||||
| 
 | ||||
|     bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) { | ||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES); | ||||
|     } | ||||
| 
 | ||||
|     template <typename T> | ||||
|     int cholesky_(NDArray* input, NDArray* output, bool inplace) { | ||||
|     int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) { | ||||
| 
 | ||||
|         auto n = input->sizeAt(-1); | ||||
|         auto n2 = n * n; | ||||
| @ -311,8 +303,8 @@ template <typename T> | ||||
|         if (!inplace) | ||||
|              output->assign(0.f); // fill up output tensor with zeros only inplace=false
 | ||||
| 
 | ||||
|         std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace());
 | ||||
|         std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext)); | ||||
|         std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace());
 | ||||
|         std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context)); | ||||
| 
 | ||||
|         for (int e = 0; e < totalCount; e++) { | ||||
| 
 | ||||
| @ -346,14 +338,13 @@ template <typename T> | ||||
|     } | ||||
| 
 | ||||
|     int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { | ||||
|         defaultContext = context; | ||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES); | ||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); | ||||
|     } | ||||
| 
 | ||||
|     template <typename T> | ||||
|     int logdetFunctor_(NDArray* input, NDArray* output) { | ||||
|     int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) { | ||||
|         std::unique_ptr<NDArray> tempOutput(input->dup()); | ||||
|         int res = cholesky_<T>(input, tempOutput.get(), false); | ||||
|         int res = cholesky_<T>(context, input, tempOutput.get(), false); | ||||
|         if (res != ND4J_STATUS_OK) | ||||
|             return res; | ||||
|         auto n = input->sizeAt(-1); | ||||
| @ -372,7 +363,7 @@ template <typename T> | ||||
|     } | ||||
| 
 | ||||
|     int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { | ||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (input, output), FLOAT_TYPES); | ||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
|  | ||||
| @ -196,36 +196,33 @@ namespace helpers { | ||||
|     } | ||||
| 
 | ||||
|     template<typename T> | ||||
|     static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) { | ||||
|     static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { | ||||
|         int n = inputMatrix->rows(); | ||||
|         invertedMatrix->setIdentity(); | ||||
| 
 | ||||
|         if (inputMatrix->isIdentityMatrix()) return; | ||||
| 
 | ||||
|         auto stream = LaunchContext::defaultContext()->getCudaStream(); | ||||
|         auto stream = context->getCudaStream(); | ||||
| 
 | ||||
|         // invert main diagonal | ||||
|         upvertKernel<T> << < 1, n, 512, *stream >> > | ||||
|                                         (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||
|         upvertKernel<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||
|         // invert the second diagonal | ||||
|         invertKernelLow<T> << < 1, n, 512, *stream >> > | ||||
|                                            (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||
|         invertKernelLow<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||
| //        invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||
|         invertLowKernel<T><<< n, n, 512, *stream >> > | ||||
|                                            (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||
|         invertLowKernel<T><<<n, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||
|     } | ||||
| 
 | ||||
|     void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { | ||||
|     void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { | ||||
|         NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); | ||||
|         BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); | ||||
|         BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); | ||||
|         NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix}); | ||||
|     } | ||||
| 
 | ||||
|     template<typename T> | ||||
|     static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { | ||||
|     static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) { | ||||
|         int n = inputMatrix->rows(); | ||||
|         invertedMatrix->setIdentity(); | ||||
|         auto stream = LaunchContext::defaultContext()->getCudaStream(); | ||||
|         auto stream = context->getCudaStream(); | ||||
|         if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I | ||||
|             return; | ||||
|         } | ||||
| @ -235,13 +232,12 @@ namespace helpers { | ||||
|                 inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||
|         invertedMatrix->tickWriteDevice(); | ||||
|         invertedMatrix->printIndexedBuffer("Step1 UP inversion"); | ||||
|         invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), | ||||
|                 inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||
|         invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||
|     } | ||||
| 
 | ||||
|     void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { | ||||
|     void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { | ||||
|         NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); | ||||
|         BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); | ||||
|         BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); | ||||
|         NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); | ||||
|     } | ||||
| 
 | ||||
| @ -525,23 +521,19 @@ namespace helpers { | ||||
|             input->tickWriteDevice(); | ||||
|         } | ||||
| 
 | ||||
|         BUILD_SINGLE_TEMPLATE(template void lup_, | ||||
|                               (LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), | ||||
|                               FLOAT_NATIVE); | ||||
|         BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE); | ||||
| 
 | ||||
|         template<typename T> | ||||
|         static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { | ||||
|             Nd4jLong n = input->sizeAt(-1); | ||||
|             Nd4jLong n2 = n * n; | ||||
|             std::vector<int> dims(); | ||||
|             auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), | ||||
|                                                                             {input->rankOf() - 2, input->rankOf() - 1}); | ||||
|             auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); | ||||
|             //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); | ||||
| //        DataType dtype = input->dataType(); | ||||
| //        if (dtype != DataType::DOUBLE) | ||||
| //            dtype = DataType::FLOAT32; | ||||
|             auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), | ||||
|                                                  LaunchContext::defaultContext()); //, block.getWorkspace()); | ||||
|             auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace()); | ||||
|             auto det = NDArrayFactory::create<T>(1); | ||||
|             auto stream = context->getCudaStream(); | ||||
|             NDArray::prepareSpecialUse({output}, {input}); | ||||
| @ -550,8 +542,7 @@ namespace helpers { | ||||
|             for (int e = 0; e < output->lengthOf(); e++) { | ||||
|                 Nd4jLong pos = e * n2; | ||||
| //            if (matrix.dataType() == input->dataType()) | ||||
|                 fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> > | ||||
|                                                                                 (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | ||||
|                 fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | ||||
| //            else | ||||
| //                fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | ||||
| 
 | ||||
| @ -584,15 +575,13 @@ namespace helpers { | ||||
|             Nd4jLong n = input->sizeAt(-1); | ||||
|             Nd4jLong n2 = n * n; | ||||
|             std::vector<int> dims(); | ||||
|             auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), | ||||
|                                                                             {input->rankOf() - 2, input->rankOf() - 1}); | ||||
|             auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); | ||||
|             //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); | ||||
|             DataType dtype = input->dataType(); | ||||
|             if (dtype != DataType::DOUBLE) | ||||
|                 dtype = DataType::FLOAT32; | ||||
| 
 | ||||
|             auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, | ||||
|                                                  LaunchContext::defaultContext()); //, block.getWorkspace()); | ||||
|             auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace()); | ||||
|             auto det = NDArrayFactory::create<T>(1); | ||||
|             auto stream = context->getCudaStream(); | ||||
|             NDArray::prepareSpecialUse({output}, {input}); | ||||
| @ -601,8 +590,7 @@ namespace helpers { | ||||
|             for (int e = 0; e < output->lengthOf(); e++) { | ||||
|                 Nd4jLong pos = e * n2; | ||||
| //            if (matrix.dataType() == input->dataType()) | ||||
|                 fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> > | ||||
|                                                                                 (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | ||||
|                 fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | ||||
| //            else | ||||
| //                fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | ||||
| 
 | ||||
| @ -614,8 +602,7 @@ namespace helpers { | ||||
|                 auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer()); | ||||
|                 auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset; | ||||
| //            if (matrix.dataType() == input->dataType()) | ||||
|                 determinantLogKernel<T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> > | ||||
|                                                                                        (inputBuf, outputBuf, n); | ||||
|                 determinantLogKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, outputBuf, n); | ||||
| //            else | ||||
| //                determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); | ||||
|             } | ||||
| @ -694,11 +681,11 @@ namespace helpers { | ||||
|             auto dtype = DataTypeUtils::fromT<T>(); //input->dataType(); | ||||
| //            if (dtype != DataType::DOUBLE) | ||||
| //                dtype = DataType::FLOAT32; | ||||
|             NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); | ||||
|             NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); | ||||
|             NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); | ||||
|             NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); | ||||
|             NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, LaunchContext::defaultContext()); | ||||
|             NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context); | ||||
|             NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context); | ||||
|             NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context); | ||||
|             NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context); | ||||
|             NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context); | ||||
|             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), | ||||
|                                                                                   {input->rankOf() - 2, | ||||
|                                                                                    input->rankOf() - 1}); | ||||
| @ -708,20 +695,17 @@ namespace helpers { | ||||
|             auto stream = context->getCudaStream(); | ||||
| 
 | ||||
|             for (auto i = 0LL; i < packX.numberOfTads(); i++) { | ||||
|                 fillMatrix<T, T> << < 1, n2, 1024, *stream >> > | ||||
|                                                    (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), | ||||
|                                                            i * n2, n); | ||||
|                 fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); | ||||
|                 matrix.tickWriteDevice(); | ||||
|                 compound.assign(matrix); | ||||
|                 lup_<T>(context, &compound, nullptr, nullptr); | ||||
|                 fillLowerUpperKernel<T> << < n, n, 1024, *stream >> > | ||||
|                                                          (lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); | ||||
|                 fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); | ||||
|                 matrix.assign(0); | ||||
|                 invertUpperMatrix(&upper, &matrix); // U^{-1} | ||||
|                 invertUpperMatrix(context, &upper, &matrix); // U^{-1} | ||||
|                 matrix.tickWriteDevice(); | ||||
| //                matrix.printIndexedBuffer("Upper Inverted"); | ||||
|                 compound.assign(0); | ||||
|                 invertLowerMatrix(&lower, &compound); // L{-1} | ||||
|                 invertLowerMatrix(context, &lower, &compound); // L{-1} | ||||
|                 compound.tickWriteDevice(); | ||||
| //                compound.printIndexedBuffer("Lower Inverted"); | ||||
| //                matrix.tickWriteDevice(); | ||||
| @ -729,9 +713,7 @@ namespace helpers { | ||||
|                 nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); | ||||
|                 upper.tickWriteDevice(); | ||||
| //                upper.printIndexedBuffer("Full inverted"); | ||||
|                 returnMatrix<T> << < 1, n2, 1024, *stream >> > | ||||
|                                                      (output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), | ||||
|                                                              i * n2, n); | ||||
|                 returnMatrix<T> <<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n); | ||||
|             } | ||||
|             return Status::OK(); | ||||
|         } | ||||
| @ -865,8 +847,7 @@ namespace helpers { | ||||
|                 cholesky__<float>(context, input, output, inplace); | ||||
|             else { | ||||
|                 std::unique_ptr<NDArray> tempOutput( | ||||
|                         NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, | ||||
|                                                 LaunchContext::defaultContext())); | ||||
|                         NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context)); | ||||
|                 tempOutput->assign(input); | ||||
|                 cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true); | ||||
|                 output->assign(tempOutput.get()); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user