From ef1de6a4aad3672920d17bb02709d68d687e11f5 Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Sun, 1 Sep 2019 13:17:36 +0900 Subject: [PATCH] rcorbish #8617 (#8188) Signed-off-by: Robert Altena --- .../linalg/jcublas/blas/JcublasLapack.java | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java index 74a8fc99c..3eade74e9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java @@ -53,7 +53,8 @@ import static org.bytedeco.cuda.global.cusolver.*; * JCublas lapack * * @author Adam Gibson - * @author Richard Corbishley + * @author Richard Corbishley (signed) + * */ @Slf4j public class JcublasLapack extends BaseLapack { @@ -70,7 +71,6 @@ public class JcublasLapack extends BaseLapack { if (A.ordering() == 'c') a = A.dup('f'); - if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); @@ -193,7 +193,7 @@ public class JcublasLapack extends BaseLapack { //========================= -// Q R DECOMP + // Q R DECOMP @Override public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) { INDArray a = A; @@ -306,8 +306,8 @@ public class JcublasLapack extends BaseLapack { if (r != null && r != R) R.assign(r); - log.info("A: {}", A); - if (R != null) log.info("R: {}", R); + log.debug("A: {}", A); + if (R != null) log.debug("R: {}", R); } @Override @@ -419,16 +419,18 @@ public class JcublasLapack extends BaseLapack { if (r != null && r != R) R.assign(r); - log.info("A: {}", A); - if (R != null) log.info("R: {}", R); + log.debug("A: {}", A); + if (R != null) log.debug("R: {}", R); } //========================= // CHOLESKY DECOMP @Override - public void spotrf(byte uplo, int N, INDArray A, INDArray INFO) { + public void spotrf(byte _uplo, int N, INDArray A, INDArray INFO) { INDArray a = A; + int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + if (A.dataType() != DataType.FLOAT) log.warn("FLOAT potrf called for " + A.dataType()); @@ -489,7 +491,7 @@ public class JcublasLapack extends BaseLapack { if (a != A) A.assign(a); - if (uplo == 'U') { + if (uplo == CUBLAS_FILL_MODE_UPPER ) { A.assign(A.transpose()); INDArrayIndex ix[] = new INDArrayIndex[2]; for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) { @@ -506,13 +508,15 @@ public class JcublasLapack extends BaseLapack { } } - log.info("A: {}", A); + log.debug("A: {}", A); } @Override - public void dpotrf(byte uplo, int N, INDArray A, INDArray INFO) { + public void dpotrf(byte _uplo, int N, INDArray A, INDArray INFO) { INDArray a = A; + int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + if (A.dataType() != DataType.DOUBLE) log.warn("DOUBLE potrf called for " + A.dataType()); @@ -573,7 +577,7 @@ public class JcublasLapack extends BaseLapack { if (a != A) A.assign(a); - if (uplo == 'U') { + if (uplo == CUBLAS_FILL_MODE_UPPER ) { A.assign(A.transpose()); INDArrayIndex ix[] = new INDArrayIndex[2]; for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) { @@ -590,7 +594,7 @@ public class JcublasLapack extends BaseLapack { } } - log.info("A: {}", A); + log.debug("A: {}", A); }