parent
b393d3fdb1
commit
ef1de6a4aa
|
@ -53,7 +53,8 @@ import static org.bytedeco.cuda.global.cusolver.*;
|
|||
* JCublas lapack
|
||||
*
|
||||
* @author Adam Gibson
|
||||
* @author Richard Corbishley
|
||||
* @author Richard Corbishley (signed)
|
||||
*
|
||||
*/
|
||||
@Slf4j
|
||||
public class JcublasLapack extends BaseLapack {
|
||||
|
@ -70,7 +71,6 @@ public class JcublasLapack extends BaseLapack {
|
|||
if (A.ordering() == 'c')
|
||||
a = A.dup('f');
|
||||
|
||||
|
||||
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||
|
||||
|
@ -193,7 +193,7 @@ public class JcublasLapack extends BaseLapack {
|
|||
|
||||
|
||||
//=========================
|
||||
// Q R DECOMP
|
||||
// Q R DECOMP
|
||||
@Override
|
||||
public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
|
||||
INDArray a = A;
|
||||
|
@ -306,8 +306,8 @@ public class JcublasLapack extends BaseLapack {
|
|||
if (r != null && r != R)
|
||||
R.assign(r);
|
||||
|
||||
log.info("A: {}", A);
|
||||
if (R != null) log.info("R: {}", R);
|
||||
log.debug("A: {}", A);
|
||||
if (R != null) log.debug("R: {}", R);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -419,16 +419,18 @@ public class JcublasLapack extends BaseLapack {
|
|||
if (r != null && r != R)
|
||||
R.assign(r);
|
||||
|
||||
log.info("A: {}", A);
|
||||
if (R != null) log.info("R: {}", R);
|
||||
log.debug("A: {}", A);
|
||||
if (R != null) log.debug("R: {}", R);
|
||||
}
|
||||
|
||||
//=========================
|
||||
// CHOLESKY DECOMP
|
||||
@Override
|
||||
public void spotrf(byte uplo, int N, INDArray A, INDArray INFO) {
|
||||
public void spotrf(byte _uplo, int N, INDArray A, INDArray INFO) {
|
||||
INDArray a = A;
|
||||
|
||||
int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||
|
||||
if (A.dataType() != DataType.FLOAT)
|
||||
log.warn("FLOAT potrf called for " + A.dataType());
|
||||
|
||||
|
@ -489,7 +491,7 @@ public class JcublasLapack extends BaseLapack {
|
|||
if (a != A)
|
||||
A.assign(a);
|
||||
|
||||
if (uplo == 'U') {
|
||||
if (uplo == CUBLAS_FILL_MODE_UPPER ) {
|
||||
A.assign(A.transpose());
|
||||
INDArrayIndex ix[] = new INDArrayIndex[2];
|
||||
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
|
||||
|
@ -506,13 +508,15 @@ public class JcublasLapack extends BaseLapack {
|
|||
}
|
||||
}
|
||||
|
||||
log.info("A: {}", A);
|
||||
log.debug("A: {}", A);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dpotrf(byte uplo, int N, INDArray A, INDArray INFO) {
|
||||
public void dpotrf(byte _uplo, int N, INDArray A, INDArray INFO) {
|
||||
INDArray a = A;
|
||||
|
||||
int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||
|
||||
if (A.dataType() != DataType.DOUBLE)
|
||||
log.warn("DOUBLE potrf called for " + A.dataType());
|
||||
|
||||
|
@ -573,7 +577,7 @@ public class JcublasLapack extends BaseLapack {
|
|||
if (a != A)
|
||||
A.assign(a);
|
||||
|
||||
if (uplo == 'U') {
|
||||
if (uplo == CUBLAS_FILL_MODE_UPPER ) {
|
||||
A.assign(A.transpose());
|
||||
INDArrayIndex ix[] = new INDArrayIndex[2];
|
||||
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
|
||||
|
@ -590,7 +594,7 @@ public class JcublasLapack extends BaseLapack {
|
|||
}
|
||||
}
|
||||
|
||||
log.info("A: {}", A);
|
||||
log.debug("A: {}", A);
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue