parent
b393d3fdb1
commit
ef1de6a4aa
|
@ -53,7 +53,8 @@ import static org.bytedeco.cuda.global.cusolver.*;
|
||||||
* JCublas lapack
|
* JCublas lapack
|
||||||
*
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
* @author Richard Corbishley
|
* @author Richard Corbishley (signed)
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class JcublasLapack extends BaseLapack {
|
public class JcublasLapack extends BaseLapack {
|
||||||
|
@ -70,7 +71,6 @@ public class JcublasLapack extends BaseLapack {
|
||||||
if (A.ordering() == 'c')
|
if (A.ordering() == 'c')
|
||||||
a = A.dup('f');
|
a = A.dup('f');
|
||||||
|
|
||||||
|
|
||||||
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||||
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
|
@ -306,8 +306,8 @@ public class JcublasLapack extends BaseLapack {
|
||||||
if (r != null && r != R)
|
if (r != null && r != R)
|
||||||
R.assign(r);
|
R.assign(r);
|
||||||
|
|
||||||
log.info("A: {}", A);
|
log.debug("A: {}", A);
|
||||||
if (R != null) log.info("R: {}", R);
|
if (R != null) log.debug("R: {}", R);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -419,16 +419,18 @@ public class JcublasLapack extends BaseLapack {
|
||||||
if (r != null && r != R)
|
if (r != null && r != R)
|
||||||
R.assign(r);
|
R.assign(r);
|
||||||
|
|
||||||
log.info("A: {}", A);
|
log.debug("A: {}", A);
|
||||||
if (R != null) log.info("R: {}", R);
|
if (R != null) log.debug("R: {}", R);
|
||||||
}
|
}
|
||||||
|
|
||||||
//=========================
|
//=========================
|
||||||
// CHOLESKY DECOMP
|
// CHOLESKY DECOMP
|
||||||
@Override
|
@Override
|
||||||
public void spotrf(byte uplo, int N, INDArray A, INDArray INFO) {
|
public void spotrf(byte _uplo, int N, INDArray A, INDArray INFO) {
|
||||||
INDArray a = A;
|
INDArray a = A;
|
||||||
|
|
||||||
|
int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||||
|
|
||||||
if (A.dataType() != DataType.FLOAT)
|
if (A.dataType() != DataType.FLOAT)
|
||||||
log.warn("FLOAT potrf called for " + A.dataType());
|
log.warn("FLOAT potrf called for " + A.dataType());
|
||||||
|
|
||||||
|
@ -489,7 +491,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
if (a != A)
|
if (a != A)
|
||||||
A.assign(a);
|
A.assign(a);
|
||||||
|
|
||||||
if (uplo == 'U') {
|
if (uplo == CUBLAS_FILL_MODE_UPPER ) {
|
||||||
A.assign(A.transpose());
|
A.assign(A.transpose());
|
||||||
INDArrayIndex ix[] = new INDArrayIndex[2];
|
INDArrayIndex ix[] = new INDArrayIndex[2];
|
||||||
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
|
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
|
||||||
|
@ -506,13 +508,15 @@ public class JcublasLapack extends BaseLapack {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.info("A: {}", A);
|
log.debug("A: {}", A);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void dpotrf(byte uplo, int N, INDArray A, INDArray INFO) {
|
public void dpotrf(byte _uplo, int N, INDArray A, INDArray INFO) {
|
||||||
INDArray a = A;
|
INDArray a = A;
|
||||||
|
|
||||||
|
int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||||
|
|
||||||
if (A.dataType() != DataType.DOUBLE)
|
if (A.dataType() != DataType.DOUBLE)
|
||||||
log.warn("DOUBLE potrf called for " + A.dataType());
|
log.warn("DOUBLE potrf called for " + A.dataType());
|
||||||
|
|
||||||
|
@ -573,7 +577,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
if (a != A)
|
if (a != A)
|
||||||
A.assign(a);
|
A.assign(a);
|
||||||
|
|
||||||
if (uplo == 'U') {
|
if (uplo == CUBLAS_FILL_MODE_UPPER ) {
|
||||||
A.assign(A.transpose());
|
A.assign(A.transpose());
|
||||||
INDArrayIndex ix[] = new INDArrayIndex[2];
|
INDArrayIndex ix[] = new INDArrayIndex[2];
|
||||||
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
|
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
|
||||||
|
@ -590,7 +594,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.info("A: {}", A);
|
log.debug("A: {}", A);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue