Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-09-01 13:17:36 +09:00 committed by raver119
parent b393d3fdb1
commit ef1de6a4aa
1 changed files with 17 additions and 13 deletions

View File

@ -53,7 +53,8 @@ import static org.bytedeco.cuda.global.cusolver.*;
* JCublas lapack * JCublas lapack
* *
* @author Adam Gibson * @author Adam Gibson
* @author Richard Corbishley * @author Richard Corbishley (signed)
*
*/ */
@Slf4j @Slf4j
public class JcublasLapack extends BaseLapack { public class JcublasLapack extends BaseLapack {
@ -70,7 +71,6 @@ public class JcublasLapack extends BaseLapack {
if (A.ordering() == 'c') if (A.ordering() == 'c')
a = A.dup('f'); a = A.dup('f');
if (Nd4j.getExecutioner() instanceof GridExecutioner) if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
@ -306,8 +306,8 @@ public class JcublasLapack extends BaseLapack {
if (r != null && r != R) if (r != null && r != R)
R.assign(r); R.assign(r);
log.info("A: {}", A); log.debug("A: {}", A);
if (R != null) log.info("R: {}", R); if (R != null) log.debug("R: {}", R);
} }
@Override @Override
@ -419,16 +419,18 @@ public class JcublasLapack extends BaseLapack {
if (r != null && r != R) if (r != null && r != R)
R.assign(r); R.assign(r);
log.info("A: {}", A); log.debug("A: {}", A);
if (R != null) log.info("R: {}", R); if (R != null) log.debug("R: {}", R);
} }
//========================= //=========================
// CHOLESKY DECOMP // CHOLESKY DECOMP
@Override @Override
public void spotrf(byte uplo, int N, INDArray A, INDArray INFO) { public void spotrf(byte _uplo, int N, INDArray A, INDArray INFO) {
INDArray a = A; INDArray a = A;
int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
if (A.dataType() != DataType.FLOAT) if (A.dataType() != DataType.FLOAT)
log.warn("FLOAT potrf called for " + A.dataType()); log.warn("FLOAT potrf called for " + A.dataType());
@ -489,7 +491,7 @@ public class JcublasLapack extends BaseLapack {
if (a != A) if (a != A)
A.assign(a); A.assign(a);
if (uplo == 'U') { if (uplo == CUBLAS_FILL_MODE_UPPER ) {
A.assign(A.transpose()); A.assign(A.transpose());
INDArrayIndex ix[] = new INDArrayIndex[2]; INDArrayIndex ix[] = new INDArrayIndex[2];
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) { for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
@ -506,13 +508,15 @@ public class JcublasLapack extends BaseLapack {
} }
} }
log.info("A: {}", A); log.debug("A: {}", A);
} }
@Override @Override
public void dpotrf(byte uplo, int N, INDArray A, INDArray INFO) { public void dpotrf(byte _uplo, int N, INDArray A, INDArray INFO) {
INDArray a = A; INDArray a = A;
int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
if (A.dataType() != DataType.DOUBLE) if (A.dataType() != DataType.DOUBLE)
log.warn("DOUBLE potrf called for " + A.dataType()); log.warn("DOUBLE potrf called for " + A.dataType());
@ -573,7 +577,7 @@ public class JcublasLapack extends BaseLapack {
if (a != A) if (a != A)
A.assign(a); A.assign(a);
if (uplo == 'U') { if (uplo == CUBLAS_FILL_MODE_UPPER ) {
A.assign(A.transpose()); A.assign(A.transpose());
INDArrayIndex ix[] = new INDArrayIndex[2]; INDArrayIndex ix[] = new INDArrayIndex[2];
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) { for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
@ -590,7 +594,7 @@ public class JcublasLapack extends BaseLapack {
} }
} }
log.info("A: {}", A); log.debug("A: {}", A);
} }