Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-09-01 13:17:36 +09:00 committed by raver119
parent b393d3fdb1
commit ef1de6a4aa
1 changed files with 17 additions and 13 deletions

View File

@ -53,7 +53,8 @@ import static org.bytedeco.cuda.global.cusolver.*;
* JCublas lapack
*
* @author Adam Gibson
* @author Richard Corbishley
* @author Richard Corbishley (signed)
*
*/
@Slf4j
public class JcublasLapack extends BaseLapack {
@ -70,7 +71,6 @@ public class JcublasLapack extends BaseLapack {
if (A.ordering() == 'c')
a = A.dup('f');
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
@ -193,7 +193,7 @@ public class JcublasLapack extends BaseLapack {
//=========================
// Q R DECOMP
// Q R DECOMP
@Override
public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
INDArray a = A;
@ -306,8 +306,8 @@ public class JcublasLapack extends BaseLapack {
if (r != null && r != R)
R.assign(r);
log.info("A: {}", A);
if (R != null) log.info("R: {}", R);
log.debug("A: {}", A);
if (R != null) log.debug("R: {}", R);
}
@Override
@ -419,16 +419,18 @@ public class JcublasLapack extends BaseLapack {
if (r != null && r != R)
R.assign(r);
log.info("A: {}", A);
if (R != null) log.info("R: {}", R);
log.debug("A: {}", A);
if (R != null) log.debug("R: {}", R);
}
//=========================
// CHOLESKY DECOMP
@Override
public void spotrf(byte uplo, int N, INDArray A, INDArray INFO) {
public void spotrf(byte _uplo, int N, INDArray A, INDArray INFO) {
INDArray a = A;
int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
if (A.dataType() != DataType.FLOAT)
log.warn("FLOAT potrf called for " + A.dataType());
@ -489,7 +491,7 @@ public class JcublasLapack extends BaseLapack {
if (a != A)
A.assign(a);
if (uplo == 'U') {
if (uplo == CUBLAS_FILL_MODE_UPPER ) {
A.assign(A.transpose());
INDArrayIndex ix[] = new INDArrayIndex[2];
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
@ -506,13 +508,15 @@ public class JcublasLapack extends BaseLapack {
}
}
log.info("A: {}", A);
log.debug("A: {}", A);
}
@Override
public void dpotrf(byte uplo, int N, INDArray A, INDArray INFO) {
public void dpotrf(byte _uplo, int N, INDArray A, INDArray INFO) {
INDArray a = A;
int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
if (A.dataType() != DataType.DOUBLE)
log.warn("DOUBLE potrf called for " + A.dataType());
@ -573,7 +577,7 @@ public class JcublasLapack extends BaseLapack {
if (a != A)
A.assign(a);
if (uplo == 'U') {
if (uplo == CUBLAS_FILL_MODE_UPPER ) {
A.assign(A.transpose());
INDArrayIndex ix[] = new INDArrayIndex[2];
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
@ -590,7 +594,7 @@ public class JcublasLapack extends BaseLapack {
}
}
log.info("A: {}", A);
log.debug("A: {}", A);
}