Sparse matrix refactoring. (#8238)
* remove sparse method from INDArray. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove gemm Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove useage of n4j.sparseFactory Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Nd4j.sparseFactory removed. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * sparseNDArray deleted. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * iremove more sparse calls and constants. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove SparseBlasWrapper. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * delete BaseSparseBlaswrapper. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove 3 sparse factory classes. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * delete SparseCPULevel. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * deletes JcusparseLevel, CUDASparselevel. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * delete nativeCPU sparse classes. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * removes sparse methods from NDArrayFactory. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * more deletes. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * delete (ignored) tests. BaseSparseNDArray. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * deletes ISparseNDArray. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove sparse methods from indArray. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * deletes sparse classes. Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
979ef13c0b
commit
83d958d536
|
@ -48,15 +48,6 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||||
OpProfiler.getInstance().processBlasCall(false, X, Y);
|
OpProfiler.getInstance().processBlasCall(false, X, Y);
|
||||||
|
|
||||||
if (X.isSparse() && !Y.isSparse()) {
|
|
||||||
return Nd4j.getSparseBlasWrapper().level1().dot(n, alpha, X, Y);
|
|
||||||
} else if (!X.isSparse() && Y.isSparse()) {
|
|
||||||
return Nd4j.getSparseBlasWrapper().level1().dot(n, alpha, Y, X);
|
|
||||||
} else if (X.isSparse() && Y.isSparse()) {
|
|
||||||
// TODO - MKL doesn't contain such routines
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (X.data().dataType() == DataType.DOUBLE) {
|
if (X.data().dataType() == DataType.DOUBLE) {
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Y);
|
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Y);
|
||||||
return ddot(n, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(Y));
|
return ddot(n, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(Y));
|
||||||
|
@ -100,14 +91,9 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
|
||||||
@Override
|
@Override
|
||||||
public double nrm2(INDArray arr) {
|
public double nrm2(INDArray arr) {
|
||||||
|
|
||||||
if (arr.isSparse()) {
|
|
||||||
return Nd4j.getSparseBlasWrapper().level1().nrm2(arr);
|
|
||||||
}
|
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||||
OpProfiler.getInstance().processBlasCall(false, arr);
|
OpProfiler.getInstance().processBlasCall(false, arr);
|
||||||
if (arr.isSparse()) {
|
|
||||||
return Nd4j.getSparseBlasWrapper().level1().nrm2(arr);
|
|
||||||
}
|
|
||||||
if (arr.data().dataType() == DataType.DOUBLE) {
|
if (arr.data().dataType() == DataType.DOUBLE) {
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr);
|
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr);
|
||||||
return dnrm2(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
|
return dnrm2(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
|
||||||
|
@ -127,9 +113,6 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
|
||||||
@Override
|
@Override
|
||||||
public double asum(INDArray arr) {
|
public double asum(INDArray arr) {
|
||||||
|
|
||||||
if (arr.isSparse()) {
|
|
||||||
return Nd4j.getSparseBlasWrapper().level1().asum(arr);
|
|
||||||
}
|
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||||
OpProfiler.getInstance().processBlasCall(false, arr);
|
OpProfiler.getInstance().processBlasCall(false, arr);
|
||||||
|
|
||||||
|
@ -202,9 +185,7 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public int iamax(INDArray arr) {
|
public int iamax(INDArray arr) {
|
||||||
if (arr.isSparse()) {
|
|
||||||
return Nd4j.getSparseBlasWrapper().level1().iamax(arr);
|
|
||||||
}
|
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||||
OpProfiler.getInstance().processBlasCall(false, arr);
|
OpProfiler.getInstance().processBlasCall(false, arr);
|
||||||
|
|
||||||
|
@ -225,12 +206,8 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public int iamin(INDArray arr) {
|
public int iamin(INDArray arr) {
|
||||||
if (arr.isSparse()) {
|
|
||||||
return Nd4j.getSparseBlasWrapper().level1().iamin(arr);
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* swaps a vector with another vector.
|
* swaps a vector with another vector.
|
||||||
|
@ -243,11 +220,6 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||||
OpProfiler.getInstance().processBlasCall(false, x, y);
|
OpProfiler.getInstance().processBlasCall(false, x, y);
|
||||||
|
|
||||||
if (x.isSparse() || y.isSparse()) {
|
|
||||||
Nd4j.getSparseBlasWrapper().level1().swap(x, y);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (x.data().dataType() == DataType.DOUBLE) {
|
if (x.data().dataType() == DataType.DOUBLE) {
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y);
|
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y);
|
||||||
dswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
|
dswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
|
||||||
|
@ -269,10 +241,6 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||||
OpProfiler.getInstance().processBlasCall(false, x, y);
|
OpProfiler.getInstance().processBlasCall(false, x, y);
|
||||||
|
|
||||||
if (x.isSparse() || y.isSparse()) {
|
|
||||||
Nd4j.getSparseBlasWrapper().level1().copy(x, y);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (x.data().dataType() == DataType.DOUBLE) {
|
if (x.data().dataType() == DataType.DOUBLE) {
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y);
|
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y);
|
||||||
dcopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
|
dcopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
|
||||||
|
@ -321,9 +289,7 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||||
OpProfiler.getInstance().processBlasCall(false, x, y);
|
OpProfiler.getInstance().processBlasCall(false, x, y);
|
||||||
|
|
||||||
if (x.isSparse() && !y.isSparse()) {
|
if (x.data().dataType() == DataType.DOUBLE) {
|
||||||
Nd4j.getSparseBlasWrapper().level1().axpy(n, alpha, x, y);
|
|
||||||
} else if (x.data().dataType() == DataType.DOUBLE) {
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y);
|
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y);
|
||||||
daxpy(n, alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
|
daxpy(n, alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
|
||||||
} else if (x.data().dataType() == DataType.FLOAT) {
|
} else if (x.data().dataType() == DataType.FLOAT) {
|
||||||
|
@ -384,9 +350,7 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||||
OpProfiler.getInstance().processBlasCall(false, X, Y);
|
OpProfiler.getInstance().processBlasCall(false, X, Y);
|
||||||
|
|
||||||
if (X.isSparse() && !Y.isSparse()) {
|
if (X.data().dataType() == DataType.DOUBLE) {
|
||||||
Nd4j.getSparseBlasWrapper().level1().rot(N, X, Y, c, s);
|
|
||||||
} else if (X.data().dataType() == DataType.DOUBLE) {
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Y);
|
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Y);
|
||||||
drot(N, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(X), c, s);
|
drot(N, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(X), c, s);
|
||||||
} else {
|
} else {
|
||||||
|
@ -421,9 +385,7 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||||
OpProfiler.getInstance().processBlasCall(false, X);
|
OpProfiler.getInstance().processBlasCall(false, X);
|
||||||
|
|
||||||
if (X.isSparse()) {
|
if (X.data().dataType() == DataType.DOUBLE)
|
||||||
Nd4j.getSparseBlasWrapper().level1().scal(N, alpha, X);
|
|
||||||
} else if (X.data().dataType() == DataType.DOUBLE)
|
|
||||||
dscal(N, alpha, X, BlasBufferUtil.getBlasStride(X));
|
dscal(N, alpha, X, BlasBufferUtil.getBlasStride(X));
|
||||||
else if (X.data().dataType() == DataType.FLOAT)
|
else if (X.data().dataType() == DataType.FLOAT)
|
||||||
sscal(N, (float) alpha, X, BlasBufferUtil.getBlasStride(X));
|
sscal(N, (float) alpha, X, BlasBufferUtil.getBlasStride(X));
|
||||||
|
|
|
@ -56,11 +56,6 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
|
||||||
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
|
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
|
||||||
|
|
||||||
if (A.isSparse() && !X.isSparse()) {
|
|
||||||
Nd4j.getSparseBlasWrapper().level2().gemv(order, transA, alpha, A, X, beta, Y);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
GemvParameters parameters = new GemvParameters(A, X, Y);
|
GemvParameters parameters = new GemvParameters(A, X, Y);
|
||||||
if (A.data().dataType() == DataType.DOUBLE) {
|
if (A.data().dataType() == DataType.DOUBLE) {
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(),
|
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(),
|
||||||
|
|
|
@ -1,70 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.blas.impl;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.Lapack;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseBaseLapack implements Lapack {
|
|
||||||
@Override
|
|
||||||
public INDArray getrf(INDArray A) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getPFactor(int M, INDArray ipiv) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getLFactor(INDArray A) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getUFactor(INDArray A) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void getri(int N, INDArray A, int lda, int[] IPIV, INDArray WORK, int lwork, int INFO) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void geqrf(INDArray A, INDArray R) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void potrf(INDArray A, boolean lower) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int syev(char jobz, char uplo, INDArray A, INDArray V) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void gesvd(INDArray A, INDArray S, INDArray U, INDArray VT) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,24 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.blas.impl;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public abstract class SparseBaseLevel {
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,364 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.blas.impl;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.Level1;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseNDArray;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public abstract class SparseBaseLevel1 extends SparseBaseLevel implements Level1 {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* computes a vector-vector dot product.
|
|
||||||
*
|
|
||||||
* @param n number of accessed element
|
|
||||||
* @param alpha
|
|
||||||
* @param X an INDArray
|
|
||||||
* @param Y an INDArray
|
|
||||||
* @return the vector-vector dot product of X and Y
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public double dot(long n, double alpha, INDArray X, INDArray Y) {
|
|
||||||
|
|
||||||
if (X instanceof BaseSparseNDArray) {
|
|
||||||
BaseSparseNDArray sparseX = (BaseSparseNDArray) X;
|
|
||||||
DataBuffer pointers = sparseX.getVectorCoordinates();
|
|
||||||
|
|
||||||
switch (X.data().dataType()) {
|
|
||||||
case DOUBLE:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Y);
|
|
||||||
return ddoti(n, X, pointers, Y);
|
|
||||||
case FLOAT:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.FLOAT, X, Y);
|
|
||||||
return sdoti(n, X, pointers, Y);
|
|
||||||
case HALF:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.HALF, X, Y);
|
|
||||||
return hdoti(n, X, pointers, Y);
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double dot(long n, DataBuffer dx, int offsetX, int incrX, DataBuffer y, int offsetY, int incrY) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the Euclidean norm of a vector.
|
|
||||||
*
|
|
||||||
* @param arr a vector
|
|
||||||
* @return the Euclidean norm of the vector
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public double nrm2(INDArray arr) {
|
|
||||||
|
|
||||||
switch (arr.data().dataType()) {
|
|
||||||
case DOUBLE:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr);
|
|
||||||
return dnrm2(arr.length(), arr, 1);
|
|
||||||
case FLOAT:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.FLOAT, arr);
|
|
||||||
return snrm2(arr.length(), arr, 1);
|
|
||||||
case HALF:
|
|
||||||
return hnrm2(arr.length(), arr, 1);
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute the sum of magnitude of the vector elements
|
|
||||||
*
|
|
||||||
* @param arr a vector
|
|
||||||
* @return the sum of magnitude of the vector elements
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
public double asum(INDArray arr) {
|
|
||||||
|
|
||||||
switch (arr.data().dataType()) {
|
|
||||||
case DOUBLE:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr);
|
|
||||||
return dasum(arr.length(), arr, 1);
|
|
||||||
case FLOAT:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.FLOAT, arr);
|
|
||||||
return sasum(arr.length(), arr, 1);
|
|
||||||
case HALF:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.HALF, arr);
|
|
||||||
return hasum(arr.length(), arr, 1);
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double asum(long n, DataBuffer x, int offsetX, int incrX) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Find the index of the element with maximum absolute value
|
|
||||||
*
|
|
||||||
* @param arr a vector
|
|
||||||
* @return the index of the element with maximum absolute value
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
public int iamax(INDArray arr) {
|
|
||||||
switch (arr.data().dataType()) {
|
|
||||||
case DOUBLE:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr);
|
|
||||||
return idamax(arr.length(), arr, 1);
|
|
||||||
case FLOAT:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.FLOAT, arr);
|
|
||||||
return isamax(arr.length(), arr, 1);
|
|
||||||
case HALF:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.HALF, arr);
|
|
||||||
return ihamax(arr.length(), arr, 1);
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int iamax(long n, INDArray arr, int stride) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int iamax(long n, DataBuffer x, int offsetX, int incrX) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Find the index of the element with maximum absolute value
|
|
||||||
*
|
|
||||||
* @param arr a vector
|
|
||||||
* @return the index of the element with minimum absolute value
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
public int iamin(INDArray arr) {
|
|
||||||
switch (arr.data().dataType()) {
|
|
||||||
case DOUBLE:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr);
|
|
||||||
return idamin(arr.length(), arr, 1);
|
|
||||||
case FLOAT:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.FLOAT, arr);
|
|
||||||
return isamin(arr.length(), arr, 1);
|
|
||||||
case HALF:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.HALF, arr);
|
|
||||||
return ihamin(arr.length(), arr, 1);
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void swap(INDArray x, INDArray y) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void copy(INDArray x, INDArray y) {
|
|
||||||
// FIXME - for Raver119 :)
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void copy(long n, DataBuffer x, int offsetX, int incrX, DataBuffer y, int offsetY, int incrY) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Adds a scalar multiple of compressed sparse vector to a full-storage vector.
|
|
||||||
*
|
|
||||||
* @param n The number of element
|
|
||||||
* @param alpha
|
|
||||||
* @param x a sparse vector
|
|
||||||
* @param y a dense vector
|
|
||||||
*
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
public void axpy(long n, double alpha, INDArray x, INDArray y) {
|
|
||||||
BaseSparseNDArray sparseX = (BaseSparseNDArray) x;
|
|
||||||
DataBuffer pointers = sparseX.getVectorCoordinates();
|
|
||||||
switch (x.data().dataType()) {
|
|
||||||
case DOUBLE:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x);
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, y);
|
|
||||||
daxpyi(n, alpha, x, pointers, y);
|
|
||||||
break;
|
|
||||||
case FLOAT:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.FLOAT, x);
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.FLOAT, y);
|
|
||||||
saxpyi(n, alpha, x, pointers, y);
|
|
||||||
break;
|
|
||||||
case HALF:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.HALF, x);
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.HALF, y);
|
|
||||||
haxpyi(n, alpha, x, pointers, y);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void axpy(long n, double alpha, DataBuffer x, int offsetX, int incrX, DataBuffer y, int offsetY, int incrY) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void rotg(INDArray a, INDArray b, INDArray c, INDArray s) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Applies Givens rotation to sparse vectors one of which is in compressed form.
|
|
||||||
*
|
|
||||||
* @param N The number of elements in vectors X and Y
|
|
||||||
* @param X a sparse vector
|
|
||||||
* @param Y a full-storage vector
|
|
||||||
* @param c a scalar
|
|
||||||
* @param s a scalar
|
|
||||||
*
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
public void rot(long N, INDArray X, INDArray Y, double c, double s) {
|
|
||||||
|
|
||||||
|
|
||||||
if (X instanceof BaseSparseNDArray) {
|
|
||||||
BaseSparseNDArray sparseX = (BaseSparseNDArray) X;
|
|
||||||
|
|
||||||
switch (X.data().dataType()) {
|
|
||||||
case DOUBLE:
|
|
||||||
droti(N, X, sparseX.getVectorCoordinates(), Y, c, s);
|
|
||||||
break;
|
|
||||||
case FLOAT:
|
|
||||||
sroti(N, X, sparseX.getVectorCoordinates(), Y, c, s);
|
|
||||||
break;
|
|
||||||
case HALF:
|
|
||||||
hroti(N, X, sparseX.getVectorCoordinates(), Y, c, s);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void rotmg(INDArray d1, INDArray d2, INDArray b1, double b2, INDArray P) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the product of a vector by a scalar.
|
|
||||||
*
|
|
||||||
* @param N The number of elements of the vector X
|
|
||||||
* @param alpha a scalar
|
|
||||||
* @param X a vector
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
public void scal(long N, double alpha, INDArray X) {
|
|
||||||
switch (X.data().dataType()) {
|
|
||||||
case DOUBLE:
|
|
||||||
dscal(N, alpha, X, 1);
|
|
||||||
break;
|
|
||||||
case FLOAT:
|
|
||||||
sscal(N, alpha, X, 1);
|
|
||||||
break;
|
|
||||||
case HALF:
|
|
||||||
hscal(N, alpha, X, 1);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean supportsDataBufferL1Ops() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
* ===========================================================================
|
|
||||||
* Prototypes for level 1 BLAS functions (complex are recast as routines)
|
|
||||||
* ===========================================================================
|
|
||||||
*/
|
|
||||||
|
|
||||||
protected abstract double ddoti(long N, INDArray X, DataBuffer indx, INDArray Y);
|
|
||||||
|
|
||||||
protected abstract double sdoti(long N, INDArray X, DataBuffer indx, INDArray Y);
|
|
||||||
|
|
||||||
protected abstract double hdoti(long N, INDArray X, DataBuffer indx, INDArray Y);
|
|
||||||
|
|
||||||
protected abstract double snrm2(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract double dnrm2(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract double hnrm2(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract double dasum(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract double sasum(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract double hasum(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract int isamax(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract int idamax(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract int ihamax(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract int isamin(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract int idamin(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract int ihamin(long N, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract void daxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y);
|
|
||||||
|
|
||||||
protected abstract void saxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y);
|
|
||||||
|
|
||||||
protected abstract void haxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y);
|
|
||||||
|
|
||||||
protected abstract void droti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s);
|
|
||||||
|
|
||||||
protected abstract void sroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s);
|
|
||||||
|
|
||||||
protected abstract void hroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s);
|
|
||||||
|
|
||||||
protected abstract void dscal(long N, double a, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract void sscal(long N, double a, INDArray X, int incx);
|
|
||||||
|
|
||||||
protected abstract void hscal(long N, double a, INDArray X, int incx);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,146 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.blas.impl;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.Level2;
|
|
||||||
import org.nd4j.linalg.api.blas.params.SparseCOOGemvParameters;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
|
|
||||||
|
|
||||||
import static org.nd4j.base.Preconditions.checkArgument;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public abstract class SparseBaseLevel2 extends SparseBaseLevel implements Level2 {
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void gemv(char order, char transA, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
|
|
||||||
checkArgument(A.isMatrix(), "A must be a matrix");
|
|
||||||
checkArgument(X.isVector(), "X must be a vector");
|
|
||||||
checkArgument(Y.isVector(), "Y must be a vector");
|
|
||||||
|
|
||||||
SparseCOOGemvParameters parameters = new SparseCOOGemvParameters(A, X, Y);
|
|
||||||
|
|
||||||
|
|
||||||
switch (A.data().dataType()) {
|
|
||||||
case DOUBLE:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(),
|
|
||||||
parameters.getY());
|
|
||||||
dcoomv(parameters.getAOrdering(), parameters.getM(), parameters.getVal(), parameters.getRowInd(),
|
|
||||||
parameters.getColInd(), parameters.getNnz(), parameters.getX(), parameters.getY());
|
|
||||||
|
|
||||||
break;
|
|
||||||
case FLOAT:
|
|
||||||
DefaultOpExecutioner.validateDataType(DataType.FLOAT, parameters.getA(), parameters.getX(),
|
|
||||||
parameters.getY());
|
|
||||||
scoomv(parameters.getAOrdering(), parameters.getM(), parameters.getVal(), parameters.getRowInd(),
|
|
||||||
parameters.getColInd(), parameters.getNnz(), parameters.getX(), parameters.getY());
|
|
||||||
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void gbmv(char order, char TransA, int KL, int KU, double alpha, INDArray A, INDArray X, double beta,
|
|
||||||
INDArray Y) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void ger(char order, double alpha, INDArray X, INDArray Y, INDArray A) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sbmv(char order, char Uplo, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void spmv(char order, char Uplo, double alpha, INDArray Ap, INDArray X, double beta, INDArray Y) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void spr(char order, char Uplo, double alpha, INDArray X, INDArray Ap) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void spr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void symv(char order, char Uplo, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void syr(char order, char Uplo, int N, double alpha, INDArray X, INDArray A) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void syr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void tbmv(char order, char Uplo, char TransA, char Diag, INDArray A, INDArray X) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void tbsv(char order, char Uplo, char TransA, char Diag, INDArray A, INDArray X) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void tpmv(char order, char Uplo, char TransA, char Diag, INDArray Ap, INDArray X) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void tpsv(char order, char Uplo, char TransA, char Diag, INDArray Ap, INDArray X) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void trmv(char order, char Uplo, char TransA, char Diag, INDArray A, INDArray X) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void trsv(char order, char Uplo, char TransA, char Diag, INDArray A, INDArray X) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// ----
|
|
||||||
protected abstract void scoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz,
|
|
||||||
INDArray x, INDArray y);
|
|
||||||
|
|
||||||
protected abstract void dcoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz,
|
|
||||||
INDArray x, INDArray y);
|
|
||||||
}
|
|
|
@ -1,64 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.blas.impl;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.Level3;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseBaseLevel3 extends SparseBaseLevel implements Level3 {
|
|
||||||
@Override
|
|
||||||
public void gemm(char Order, char TransA, char TransB, double alpha, INDArray A, INDArray B, double beta,
|
|
||||||
INDArray C) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void gemm(INDArray A, INDArray B, INDArray C, boolean transposeA, boolean transposeB, double alpha,
|
|
||||||
double beta) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void symm(char Order, char Side, char Uplo, double alpha, INDArray A, INDArray B, double beta, INDArray C) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void syrk(char Order, char Uplo, char Trans, double alpha, INDArray A, double beta, INDArray C) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void syr2k(char Order, char Uplo, char Trans, double alpha, INDArray A, INDArray B, double beta,
|
|
||||||
INDArray C) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void trmm(char Order, char Side, char Uplo, char TransA, char Diag, double alpha, INDArray A, INDArray B,
|
|
||||||
INDArray C) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void trsm(char Order, char Side, char Uplo, char TransA, char Diag, double alpha, INDArray A, INDArray B) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,66 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.blas.params;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCOO;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ndarray.SparseFormat;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class SparseCOOGemvParameters {
|
|
||||||
|
|
||||||
private int m, nnz;
|
|
||||||
DataBuffer val, rowInd, colInd;
|
|
||||||
private INDArray a, x, y;
|
|
||||||
private char aOrdering = 'N';
|
|
||||||
|
|
||||||
public SparseCOOGemvParameters(INDArray a, INDArray x, INDArray y) {
|
|
||||||
this.a = a;
|
|
||||||
this.x = x;
|
|
||||||
this.y = y;
|
|
||||||
|
|
||||||
if (a.isMatrix() && a.getFormat() == SparseFormat.COO) {
|
|
||||||
BaseSparseNDArrayCOO coo = (BaseSparseNDArrayCOO) a;
|
|
||||||
val = coo.getIncludedValues();
|
|
||||||
nnz = coo.nnz();
|
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
m = (int) coo.rows();
|
|
||||||
setIndexes(coo, false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void setIndexes(BaseSparseNDArrayCOO coo, boolean oneBased) {
|
|
||||||
int incr = oneBased ? 1 : 0;
|
|
||||||
int[] idx = coo.getIncludedIndices().asInt();
|
|
||||||
int[] rows = new int[nnz];
|
|
||||||
int[] cols = new int[nnz];
|
|
||||||
for (int i = 0; i < nnz; i++) {
|
|
||||||
rows[i] = idx[i * 2] + incr;
|
|
||||||
cols[i] = idx[(i * 2) + 1] + incr;
|
|
||||||
}
|
|
||||||
rowInd = Nd4j.createBuffer(rows);
|
|
||||||
colInd = Nd4j.createBuffer(cols);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -4309,10 +4309,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
if (n == this)
|
if (n == this)
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
if (n.isSparse()) {
|
|
||||||
return n.equals(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.rank() != n.rank())
|
if (this.rank() != n.rank())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
@ -5442,56 +5438,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return array;
|
return array;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
* ------- Sparse methods -------
|
|
||||||
*/
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataBuffer getVectorCoordinates() {
|
|
||||||
throw new UnsupportedOperationException("Not a sparse ndarray");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray toDense() {
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int nnz() {
|
|
||||||
throw new UnsupportedOperationException("Not a sparse ndarray");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SparseFormat getFormat() {
|
|
||||||
return SparseFormat.NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataBuffer sparseInfoDataBuffer() {
|
|
||||||
throw new UnsupportedOperationException("Not a sparse ndarray");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int[] flags() {
|
|
||||||
throw new UnsupportedOperationException("Not a sparse ndarray");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int[] hiddenDimensions() {
|
|
||||||
throw new UnsupportedOperationException("Not a sparse ndarray");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int[] sparseOffsets() {
|
|
||||||
throw new UnsupportedOperationException("Not a sparse ndarray");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int underlyingRank() {
|
|
||||||
throw new UnsupportedOperationException("Not a sparse ndarray");
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static DataTypeEx convertType(DataType type) {
|
protected static DataTypeEx convertType(DataType type) {
|
||||||
if (type == DataType.HALF) {
|
if (type == DataType.HALF) {
|
||||||
return DataTypeEx.FLOAT16;
|
return DataTypeEx.FLOAT16;
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ndarray;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
abstract public class BaseSparseInfoProvider implements SparseInfoProvider {
|
|
||||||
@Override
|
|
||||||
public DataBuffer createSparseInformation(int[] flags, long[] sparseOffsets, int[] hiddenDimensions,
|
|
||||||
int underlyingRank) {
|
|
||||||
return createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
|
|
||||||
}
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,273 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ndarray;
|
|
||||||
|
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.indexing.*;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
|
||||||
import org.nd4j.linalg.util.LongUtils;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.nd4j.base.Preconditions.checkArgument;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
|
|
||||||
protected static final SparseFormat format = SparseFormat.CSR;
|
|
||||||
protected transient volatile DataBuffer values;
|
|
||||||
protected transient volatile DataBuffer columnsPointers;
|
|
||||||
protected transient volatile DataBuffer pointerB;
|
|
||||||
protected transient volatile DataBuffer pointerE;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* The length of the values and columns arrays is equal to the number of non-zero elements in A.
|
|
||||||
* The length of the pointerB and pointerE arrays is equal to the number of rows in A.
|
|
||||||
* @param data a double array that contains the non-zero element of the sparse matrix A
|
|
||||||
* @param columnsPointers Element i of the integer array columns is the number of the column in A that contains the i-th value
|
|
||||||
* in the values array.
|
|
||||||
* @param pointerB Element j of this integer array gives the index of the element in the values array that is first
|
|
||||||
* non-zero element in a row j of A. Note that this index is equal to pointerB(j) - pointerB(1)+1 .
|
|
||||||
* @param pointerE An integer array that contains row indices, such that pointerE(j)-pointerB(1) is the index of the
|
|
||||||
* element in the values array that is last non-zero element in a row j of A.
|
|
||||||
* @param shape Shape of the matrix A
|
|
||||||
*/
|
|
||||||
public BaseSparseNDArrayCSR(double[] data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
checkArgument(data.length == columnsPointers.length);
|
|
||||||
checkArgument(pointerB.length == pointerE.length);
|
|
||||||
// TODO
|
|
||||||
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, Nd4j.dataType()));
|
|
||||||
init(shape);
|
|
||||||
int valuesSpace = (int) (data.length * THRESHOLD_MEMORY_ALLOCATION);
|
|
||||||
this.values = Nd4j.getDataBufferFactory().createDouble(valuesSpace);
|
|
||||||
this.values.setData(data);
|
|
||||||
this.columnsPointers = Nd4j.getDataBufferFactory().createInt(valuesSpace);
|
|
||||||
this.columnsPointers.setData(columnsPointers);
|
|
||||||
this.length = columnsPointers.length;
|
|
||||||
long pointersSpace = rows;
|
|
||||||
this.pointerB = Nd4j.getDataBufferFactory().createInt(pointersSpace);
|
|
||||||
this.pointerB.setData(pointerB);
|
|
||||||
this.pointerE = Nd4j.getDataBufferFactory().createInt(pointersSpace);
|
|
||||||
this.pointerE.setData(pointerE);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseSparseNDArrayCSR(float[] data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
this(Nd4j.createBuffer(data), columnsPointers, pointerB, pointerE, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseSparseNDArrayCSR(DataBuffer data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
checkArgument(pointerB.length == pointerE.length);
|
|
||||||
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, Nd4j.dataType()));
|
|
||||||
init(shape);
|
|
||||||
this.values = data;
|
|
||||||
this.columnsPointers = Nd4j.getDataBufferFactory().createInt(data.length());
|
|
||||||
this.columnsPointers.setData(columnsPointers);
|
|
||||||
this.length = columnsPointers.length;
|
|
||||||
// The size of these pointers are constant
|
|
||||||
long pointersSpace = rows;
|
|
||||||
this.pointerB = Nd4j.getDataBufferFactory().createInt(pointersSpace);
|
|
||||||
this.pointerB.setData(pointerB);
|
|
||||||
this.pointerE = Nd4j.getDataBufferFactory().createInt(pointersSpace);
|
|
||||||
this.pointerE.setData(pointerE);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray putScalar(int row, int col, double value) {
|
|
||||||
|
|
||||||
checkArgument(row < rows && 0 <= rows);
|
|
||||||
checkArgument(col < columns && 0 <= columns);
|
|
||||||
|
|
||||||
int idx = pointerB.getInt(row);
|
|
||||||
int idxNextRow = pointerE.getInt(row);
|
|
||||||
|
|
||||||
while (columnsPointers.getInt(idx) < col && columnsPointers.getInt(idx) < idxNextRow) {
|
|
||||||
idx++;
|
|
||||||
}
|
|
||||||
if (columnsPointers.getInt(idx) == col) {
|
|
||||||
values.put(idx, value);
|
|
||||||
} else {
|
|
||||||
//Add a new entry in both buffers at a given position
|
|
||||||
values = addAtPosition(values, length, idx, value);
|
|
||||||
columnsPointers = addAtPosition(columnsPointers, length, idx, col);
|
|
||||||
length++;
|
|
||||||
|
|
||||||
// shift the indices of the next rows
|
|
||||||
pointerE.put(row, pointerE.getInt(row) + 1);
|
|
||||||
for (int i = row + 1; i < rows; i++) {
|
|
||||||
pointerB.put(i, pointerB.getInt(i) + 1);
|
|
||||||
pointerE.put(i, pointerE.getInt(i) + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray get(INDArrayIndex... indexes) {
|
|
||||||
//check for row/column vector and point index being 0
|
|
||||||
if (indexes.length == 1 && indexes[0] instanceof NDArrayIndexAll || (indexes.length == 2 && (isRowVector()
|
|
||||||
&& indexes[0] instanceof PointIndex && indexes[0].offset() == 0
|
|
||||||
&& indexes[1] instanceof NDArrayIndexAll
|
|
||||||
|| isColumnVector() && indexes[1] instanceof PointIndex && indexes[0].offset() == 0
|
|
||||||
&& indexes[0] instanceof NDArrayIndexAll)))
|
|
||||||
return this;
|
|
||||||
|
|
||||||
indexes = NDArrayIndex.resolve(javaShapeInformation, indexes);
|
|
||||||
|
|
||||||
throw new UnsupportedOperationException("Not implemeted");
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return the minor pointers. (columns for CSR, rows for CSC,...)
|
|
||||||
* */
|
|
||||||
public DataBuffer getVectorCoordinates() {
|
|
||||||
return Nd4j.getDataBufferFactory().create(columnsPointers, 0, length());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[] getDoubleValues() {
|
|
||||||
return values.getDoublesAt(0, (int) length);
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[] getColumns() {
|
|
||||||
return columnsPointers.getDoublesAt(0, (int) length);
|
|
||||||
}
|
|
||||||
|
|
||||||
public int[] getPointerBArray() {
|
|
||||||
return pointerB.asInt();
|
|
||||||
}
|
|
||||||
|
|
||||||
public int[] getPointerEArray() {
|
|
||||||
return pointerE.asInt();
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparseFormat getFormat() {
|
|
||||||
return format;
|
|
||||||
}
|
|
||||||
|
|
||||||
public DataBuffer getPointerB() {
|
|
||||||
return Nd4j.getDataBufferFactory().create(pointerB, 0, rows());
|
|
||||||
}
|
|
||||||
|
|
||||||
public DataBuffer getPointerE() {
|
|
||||||
return Nd4j.getDataBufferFactory().create(pointerE, 0, rows());
|
|
||||||
}
|
|
||||||
|
|
||||||
private DataBuffer addAtPosition(DataBuffer buf, long dataSize, int pos, double value) {
|
|
||||||
|
|
||||||
DataBuffer buffer = (buf.length() == dataSize) ? reallocate(buf) : buf;
|
|
||||||
double[] tail = buffer.getDoublesAt(pos, (int) dataSize - pos);
|
|
||||||
|
|
||||||
buffer.put(pos, value);
|
|
||||||
for (int i = 0; i < tail.length; i++) {
|
|
||||||
buffer.put(i + pos + 1, tail[i]);
|
|
||||||
}
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataBuffer data() {
|
|
||||||
return Nd4j.getDataBufferFactory().create(values, 0, length());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray toDense() {
|
|
||||||
// Dummy way - going to use the conversion routines in level2 (?)
|
|
||||||
INDArray result = Nd4j.zeros(shape());
|
|
||||||
|
|
||||||
int[] pointersB = pointerB.asInt();
|
|
||||||
int[] pointersE = pointerE.asInt();
|
|
||||||
|
|
||||||
for (int row = 0; row < rows(); row++) {
|
|
||||||
for (int idx = pointersB[row]; idx < pointersE[row]; idx++) {
|
|
||||||
result.put(row, columnsPointers.getInt(idx), values.getNumber(idx));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataBuffer shapeInfoDataBuffer() {
|
|
||||||
return shapeInformation;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
//TODO use op
|
|
||||||
// fixme
|
|
||||||
if (o == null || !(o instanceof INDArray)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
INDArray n = (INDArray) o;
|
|
||||||
if (n.isSparse()) {
|
|
||||||
BaseSparseNDArray s = (BaseSparseNDArray) n;
|
|
||||||
switch (s.getFormat()) {
|
|
||||||
case CSR:
|
|
||||||
BaseSparseNDArrayCSR csrArray = (BaseSparseNDArrayCSR) s;
|
|
||||||
if (csrArray.rows() == rows() && csrArray.columns() == columns()
|
|
||||||
&& csrArray.getVectorCoordinates().equals(getVectorCoordinates())
|
|
||||||
&& csrArray.data().equals(data()) && csrArray.getPointerB().equals(getPointerB())
|
|
||||||
&& csrArray.getPointerE().equals(getPointerE())) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
INDArray dense = toDense();
|
|
||||||
INDArray oDense = s.toDense();
|
|
||||||
return dense.equals(oDense);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
INDArray dense = toDense();
|
|
||||||
return dense.equals(o);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isView() {
|
|
||||||
return false; //todo
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean wasClosed() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int underlyingRank() {
|
|
||||||
return rank;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray putiColumnVector(INDArray columnVector) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray putiRowVector(INDArray rowVector) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -49,13 +49,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
DataBuffer shapeInfoDataBuffer();
|
DataBuffer shapeInfoDataBuffer();
|
||||||
|
|
||||||
// TODO: Unused untested method.
|
|
||||||
/**
|
|
||||||
* Sparse info
|
|
||||||
* @return Sparse info.
|
|
||||||
*/
|
|
||||||
DataBuffer sparseInfoDataBuffer();
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Shape info
|
* Shape info
|
||||||
* @return Shape info
|
* @return Shape info
|
||||||
|
@ -2692,47 +2685,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray percentile(Number percentile, int... dimension);
|
INDArray percentile(Number percentile, int... dimension);
|
||||||
|
|
||||||
/*
|
|
||||||
* ------------ Sparse methods ------------
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return a array of non-major pointers
|
|
||||||
* i.e. return the column indexes in case of row-major ndarray
|
|
||||||
* @return a DataBuffer of indexes
|
|
||||||
*/
|
|
||||||
DataBuffer getVectorCoordinates();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return a dense representation of the sparse ndarray
|
|
||||||
* */
|
|
||||||
INDArray toDense();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return the number of non-null element
|
|
||||||
* @return nnz
|
|
||||||
*/
|
|
||||||
int nnz();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return the sparse format (i.e COO, CSR, ...)
|
|
||||||
* @return format
|
|
||||||
* @see SparseFormat
|
|
||||||
* */
|
|
||||||
SparseFormat getFormat();
|
|
||||||
|
|
||||||
//TODO: Undocumented but often used method.
|
|
||||||
int[] flags();
|
|
||||||
|
|
||||||
//TODO: Undocumented but often used method.
|
|
||||||
int[] hiddenDimensions();
|
|
||||||
|
|
||||||
//TODO: Undocumented but often used method.
|
|
||||||
int[] sparseOffsets();
|
|
||||||
|
|
||||||
//TODO: Undocumented but often used method.
|
|
||||||
int underlyingRank();
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add an {@link INDArray}
|
* Add an {@link INDArray}
|
||||||
* to flatbuffers builder
|
* to flatbuffers builder
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ndarray;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public interface ISparseNDArray extends INDArray {
|
|
||||||
/*
|
|
||||||
* TODO
|
|
||||||
* Will contain methods such as toDense, toCSRFormat,...
|
|
||||||
*
|
|
||||||
* */
|
|
||||||
|
|
||||||
DataBuffer getVectorCoordinates();
|
|
||||||
|
|
||||||
INDArray toDense();
|
|
||||||
|
|
||||||
int nnz();
|
|
||||||
|
|
||||||
SparseFormat getFormat();
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ndarray;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
* <ul>
|
|
||||||
* <li>CSR: Compressed Sparse Row</li>
|
|
||||||
* <li>CSC: Commpressed Sparse Column</li>
|
|
||||||
* <li>COO: Coordinate Matrix Storage</li>
|
|
||||||
* <li>None: No sparse format</li>
|
|
||||||
* </ul>
|
|
||||||
* @see @see <a href="https://software.intel.com/en-us/node/471374">Sparse Matrix Storage Formats</a>
|
|
||||||
*/
|
|
||||||
public enum SparseFormat {
|
|
||||||
CSR, COO, NONE
|
|
||||||
}
|
|
|
@ -1,29 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ndarray;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public interface SparseInfoProvider {
|
|
||||||
|
|
||||||
DataBuffer createSparseInformation(int[] flags, long[] sparseOffsets, int[] hiddenDimensions, int underlyingRank);
|
|
||||||
|
|
||||||
void purgeCache();
|
|
||||||
}
|
|
|
@ -1,79 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.shape;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseDescriptor {
|
|
||||||
|
|
||||||
int[] flags;
|
|
||||||
long[] sparseOffsets;
|
|
||||||
int[] hiddenDimension;
|
|
||||||
int underlyingRank;
|
|
||||||
|
|
||||||
public SparseDescriptor(int[] flags, long[] sparseOffsets, int[] hiddenDimension, int underlyingRank) {
|
|
||||||
this.flags = Arrays.copyOf(flags, flags.length);
|
|
||||||
this.sparseOffsets = Arrays.copyOf(sparseOffsets, sparseOffsets.length);
|
|
||||||
this.hiddenDimension = Arrays.copyOf(hiddenDimension, hiddenDimension.length);
|
|
||||||
this.underlyingRank = underlyingRank;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
int result = underlyingRank;
|
|
||||||
result = 31 * result + Arrays.hashCode(flags);
|
|
||||||
result = 31 * result + Arrays.hashCode(sparseOffsets);
|
|
||||||
result = 31 * result + Arrays.hashCode(hiddenDimension);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o)
|
|
||||||
return true;
|
|
||||||
if (o == null || getClass() != o.getClass())
|
|
||||||
return false;
|
|
||||||
|
|
||||||
SparseDescriptor that = (SparseDescriptor) o;
|
|
||||||
|
|
||||||
if (!Arrays.equals(flags, that.flags))
|
|
||||||
return false;
|
|
||||||
if (!Arrays.equals(sparseOffsets, that.sparseOffsets))
|
|
||||||
return false;
|
|
||||||
if (!Arrays.equals(hiddenDimension, that.hiddenDimension))
|
|
||||||
return false;
|
|
||||||
return underlyingRank == that.underlyingRank;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
|
|
||||||
StringBuilder builder = new StringBuilder();
|
|
||||||
|
|
||||||
builder.append(flags.length).append(",").append(Arrays.toString(flags)).append(",").append(sparseOffsets.length)
|
|
||||||
.append(",").append(Arrays.toString(sparseOffsets)).append(",").append(hiddenDimension.length)
|
|
||||||
.append(",").append(Arrays.toString(hiddenDimension)).append(",").append(underlyingRank);
|
|
||||||
|
|
||||||
String result = builder.toString().replaceAll("\\]", "").replaceAll("\\[", "");
|
|
||||||
result = "[" + result + "]";
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -217,34 +217,6 @@ public abstract class BaseBlasWrapper implements BlasWrapper {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray gemm(double alpha, INDArray a, INDArray b, double beta, INDArray c) {
|
|
||||||
LinAlgExceptions.assertMatrix(a, b, c);
|
|
||||||
|
|
||||||
if (a.data().dataType() == DataType.FLOAT) {
|
|
||||||
return gemm((float) alpha, a, b, (float) beta, c);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
level3().gemm(BlasBufferUtil.getCharForTranspose(a), BlasBufferUtil.getCharForTranspose(b),
|
|
||||||
BlasBufferUtil.getCharForTranspose(c), alpha, a, b, beta, c);
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray gemm(float alpha, INDArray a, INDArray b, float beta, INDArray c) {
|
|
||||||
LinAlgExceptions.assertMatrix(a, b, c);
|
|
||||||
|
|
||||||
|
|
||||||
if (a.data().dataType() == DataType.DOUBLE) {
|
|
||||||
return gemm((double) alpha, a, b, (double) beta, c);
|
|
||||||
}
|
|
||||||
|
|
||||||
level3().gemm(BlasBufferUtil.getCharForTranspose(a), BlasBufferUtil.getCharForTranspose(b),
|
|
||||||
BlasBufferUtil.getCharForTranspose(c), alpha, a, b, beta, c);
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray gesv(INDArray a, int[] ipiv, INDArray b) {
|
public INDArray gesv(INDArray a, int[] ipiv, INDArray b) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
|
|
|
@ -1,243 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.Lapack;
|
|
||||||
import org.nd4j.linalg.api.blas.Level1;
|
|
||||||
import org.nd4j.linalg.api.blas.Level2;
|
|
||||||
import org.nd4j.linalg.api.blas.Level3;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public abstract class BaseSparseBlasWrapper implements BlasWrapper {
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Lapack lapack() {
|
|
||||||
return Nd4j.sparseFactory().lapack();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Level1 level1() {
|
|
||||||
return Nd4j.sparseFactory().level1();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Level2 level2() {
|
|
||||||
return Nd4j.sparseFactory().level2();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Level3 level3() {
|
|
||||||
return Nd4j.sparseFactory().level3();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// ================== TODO ====================
|
|
||||||
@Override
|
|
||||||
public INDArray swap(INDArray x, INDArray y) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray scal(double alpha, INDArray x) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray scal(float alpha, INDArray x) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray copy(INDArray x, INDArray y) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray axpy(double da, INDArray dx, INDArray dy) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray axpy(float da, INDArray dx, INDArray dy) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray axpy(Number da, INDArray dx, INDArray dy) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double dot(INDArray x, INDArray y) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double nrm2(INDArray x) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double asum(INDArray x) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int iamax(INDArray x) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray gemv(Number alpha, INDArray a, INDArray x, double beta, INDArray y) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray gemv(double alpha, INDArray a, INDArray x, double beta, INDArray y) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray gemv(float alpha, INDArray a, INDArray x, float beta, INDArray y) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray ger(Number alpha, INDArray x, INDArray y, INDArray a) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray ger(double alpha, INDArray x, INDArray y, INDArray a) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray ger(float alpha, INDArray x, INDArray y, INDArray a) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray gemm(double alpha, INDArray a, INDArray b, double beta, INDArray c) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray gemm(float alpha, INDArray a, INDArray b, float beta, INDArray c) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray gesv(INDArray a, int[] ipiv, INDArray b) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void checkInfo(String name, int info) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray sysv(char uplo, INDArray a, int[] ipiv, INDArray b) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int syev(char jobz, char uplo, INDArray a, INDArray w) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int syevx(char jobz, char range, char uplo, INDArray a, double vl, double vu, int il, int iu, double abstol,
|
|
||||||
INDArray w, INDArray z) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int syevx(char jobz, char range, char uplo, INDArray a, float vl, float vu, int il, int iu, float abstol,
|
|
||||||
INDArray w, INDArray z) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int syevd(char jobz, char uplo, INDArray A, INDArray w) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int syevr(char jobz, char range, char uplo, INDArray a, double vl, double vu, int il, int iu, double abstol,
|
|
||||||
INDArray w, INDArray z, int[] isuppz) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int syevr(char jobz, char range, char uplo, INDArray a, float vl, float vu, int il, int iu, float abstol,
|
|
||||||
INDArray w, INDArray z, int[] isuppz) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int syevr(char jobz, char range, char uplo, INDArray a, float vl, float vu, int il, int iu, Number abstol,
|
|
||||||
INDArray w, INDArray z, int[] isuppz) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void posv(char uplo, INDArray A, INDArray B) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int geev(char jobvl, char jobvr, INDArray A, INDArray WR, INDArray WI, INDArray VL, INDArray VR) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int sygvd(int itype, char jobz, char uplo, INDArray A, INDArray B, INDArray W) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void gelsd(INDArray A, INDArray B) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void geqrf(INDArray A, INDArray tau) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void ormqr(char side, char trans, INDArray A, INDArray tau, INDArray C) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void saxpy(double alpha, INDArray x, INDArray y) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void saxpy(float alpha, INDArray x, INDArray y) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.factory;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public abstract class BaseSparseNDArrayFactory extends BaseNDArrayFactory {
|
|
||||||
// TODO override needed methods
|
|
||||||
}
|
|
|
@ -154,21 +154,6 @@ public interface BlasWrapper {
|
||||||
*/
|
*/
|
||||||
INDArray ger(float alpha, INDArray x, INDArray y, INDArray a);
|
INDArray ger(float alpha, INDArray x, INDArray y, INDArray a);
|
||||||
|
|
||||||
/**
|
|
||||||
* ************************************************************************
|
|
||||||
* BLAS Level 3
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
INDArray gemm(double alpha, INDArray a, INDArray b, double beta, INDArray c);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute c <- a*b + beta * c (general matrix matrix
|
|
||||||
* multiplication)
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
INDArray gemm(float alpha, INDArray a, INDArray b, float beta, INDArray c);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ************************************************************************
|
* ************************************************************************
|
||||||
* LAPACK
|
* LAPACK
|
||||||
|
|
|
@ -1413,28 +1413,4 @@ public interface NDArrayFactory {
|
||||||
// =========== String methods ============
|
// =========== String methods ============
|
||||||
|
|
||||||
INDArray create(Collection<String> strings, long[] shape, char order);
|
INDArray create(Collection<String> strings, long[] shape, char order);
|
||||||
|
|
||||||
// =========== Sparse methods ===========
|
|
||||||
|
|
||||||
INDArray createSparseCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape);
|
|
||||||
|
|
||||||
INDArray createSparseCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape);
|
|
||||||
|
|
||||||
INDArray createSparseCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape);
|
|
||||||
|
|
||||||
INDArray createSparseCOO(double[] values, long[][] indices, long[] shape);
|
|
||||||
|
|
||||||
INDArray createSparseCOO(float[] values, long[][] indices, long[] shape);
|
|
||||||
|
|
||||||
INDArray createSparseCOO(double[] values, int[][] indices, long[] shape);
|
|
||||||
|
|
||||||
INDArray createSparseCOO(float[] values, int[][] indices, long[] shape);
|
|
||||||
|
|
||||||
INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] shape);
|
|
||||||
|
|
||||||
INDArray createSparseCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape);
|
|
||||||
|
|
||||||
INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags,
|
|
||||||
int[] hiddenDimensions, int underlyingRank, long[] shape);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -119,16 +119,13 @@ public class Nd4j {
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public final static String DTYPE = ND4JSystemProperties.DTYPE;
|
public final static String DTYPE = ND4JSystemProperties.DTYPE;
|
||||||
private final static String BLAS_OPS = "blas.ops";
|
private final static String BLAS_OPS = "blas.ops";
|
||||||
private final static String SPARSE_BLAS_OPS = "sparseblas.ops";
|
|
||||||
public final static String NATIVE_OPS = "native.ops";
|
public final static String NATIVE_OPS = "native.ops";
|
||||||
private final static String ORDER_KEY = "ndarray.order";
|
private final static String ORDER_KEY = "ndarray.order";
|
||||||
private final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class";
|
private final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class";
|
||||||
private final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class";
|
|
||||||
private final static String OP_EXECUTIONER = "opexec";
|
private final static String OP_EXECUTIONER = "opexec";
|
||||||
|
|
||||||
public final static String DISTRIBUTION = "dist";
|
public final static String DISTRIBUTION = "dist";
|
||||||
private final static String SHAPEINFO_PROVIDER = "shapeinfoprovider";
|
private final static String SHAPEINFO_PROVIDER = "shapeinfoprovider";
|
||||||
private final static String SPARSEINFO_PROVIDER = "sparseinfoprovider";
|
|
||||||
private final static String CONSTANT_PROVIDER = "constantsprovider";
|
private final static String CONSTANT_PROVIDER = "constantsprovider";
|
||||||
private final static String AFFINITY_MANAGER = "affinitymanager";
|
private final static String AFFINITY_MANAGER = "affinitymanager";
|
||||||
//disable toString() on compressed arrays for debugging. Should be off by default.
|
//disable toString() on compressed arrays for debugging. Should be off by default.
|
||||||
|
@ -156,14 +153,11 @@ public class Nd4j {
|
||||||
|
|
||||||
private static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE;
|
private static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE;
|
||||||
private static BlasWrapper BLAS_WRAPPER_INSTANCE;
|
private static BlasWrapper BLAS_WRAPPER_INSTANCE;
|
||||||
private static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE;
|
|
||||||
protected static NDArrayFactory INSTANCE;
|
protected static NDArrayFactory INSTANCE;
|
||||||
private static NDArrayFactory SPARSE_INSTANCE;
|
|
||||||
private static ConvolutionInstance CONVOLUTION_INSTANCE;
|
private static ConvolutionInstance CONVOLUTION_INSTANCE;
|
||||||
private static OpExecutioner OP_EXECUTIONER_INSTANCE;
|
private static OpExecutioner OP_EXECUTIONER_INSTANCE;
|
||||||
private static DistributionFactory DISTRIBUTION_FACTORY;
|
private static DistributionFactory DISTRIBUTION_FACTORY;
|
||||||
private static ShapeInfoProvider shapeInfoProvider;
|
private static ShapeInfoProvider shapeInfoProvider;
|
||||||
private static SparseInfoProvider sparseInfoProvider;
|
|
||||||
private static ConstantHandler constantHandler;
|
private static ConstantHandler constantHandler;
|
||||||
private static AffinityManager affinityManager;
|
private static AffinityManager affinityManager;
|
||||||
private static MemoryManager memoryManager;
|
private static MemoryManager memoryManager;
|
||||||
|
@ -800,14 +794,6 @@ public class Nd4j {
|
||||||
return INSTANCE;
|
return INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* The factory used for creating sparse arrays.
|
|
||||||
* @return the factory used for creating sparse arrays.
|
|
||||||
*/
|
|
||||||
public static NDArrayFactory sparseFactory() {
|
|
||||||
return SPARSE_INSTANCE;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link org.nd4j.linalg.api.ndarray.INDArray#cumsum(int)} with Integer.MAX_VALUE for full array reduction.
|
* See {@link org.nd4j.linalg.api.ndarray.INDArray#cumsum(int)} with Integer.MAX_VALUE for full array reduction.
|
||||||
*
|
*
|
||||||
|
@ -1700,14 +1686,6 @@ public class Nd4j {
|
||||||
return BLAS_WRAPPER_INSTANCE;
|
return BLAS_WRAPPER_INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Retreive the sparse BLAS wrapper.
|
|
||||||
* @return the sparse BLAS wrapper.
|
|
||||||
*/
|
|
||||||
public static BlasWrapper getSparseBlasWrapper() {
|
|
||||||
return SPARSE_BLAS_WRAPPER_INSTANCE;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sort an ndarray along a particular dimension.<br>
|
* Sort an ndarray along a particular dimension.<br>
|
||||||
* Note that the input array is modified in-place.
|
* Note that the input array is modified in-place.
|
||||||
|
@ -5159,8 +5137,6 @@ public class Nd4j {
|
||||||
affinityManager = affinityManagerClazz.newInstance();
|
affinityManager = affinityManagerClazz.newInstance();
|
||||||
Class<? extends NDArrayFactory> ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
Class<? extends NDArrayFactory> ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
||||||
pp.toString(NDARRAY_FACTORY_CLASS));
|
pp.toString(NDARRAY_FACTORY_CLASS));
|
||||||
Class<? extends NDArrayFactory> sparseNDArrayClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
|
||||||
pp.toString(SPARSE_NDARRAY_FACTORY_CLASS));
|
|
||||||
Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
|
Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
|
||||||
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
|
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
|
||||||
String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName());
|
String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName());
|
||||||
|
@ -5168,8 +5144,6 @@ public class Nd4j {
|
||||||
.forName(pp.toString(DATA_BUFFER_OPS, defaultName));
|
.forName(pp.toString(DATA_BUFFER_OPS, defaultName));
|
||||||
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
|
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
|
||||||
.forName(pp.toString(SHAPEINFO_PROVIDER));
|
.forName(pp.toString(SHAPEINFO_PROVIDER));
|
||||||
Class<? extends BaseSparseInfoProvider> sparseInfoProviderClazz = (Class<? extends BaseSparseInfoProvider>) Class.forName(
|
|
||||||
pp.toString(SPARSEINFO_PROVIDER));
|
|
||||||
|
|
||||||
Class<? extends BasicConstantHandler> constantProviderClazz = (Class<? extends BasicConstantHandler>) Class
|
Class<? extends BasicConstantHandler> constantProviderClazz = (Class<? extends BasicConstantHandler>) Class
|
||||||
.forName(pp.toString(CONSTANT_PROVIDER));
|
.forName(pp.toString(CONSTANT_PROVIDER));
|
||||||
|
@ -5187,8 +5161,6 @@ public class Nd4j {
|
||||||
|
|
||||||
Class<? extends BlasWrapper> blasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
Class<? extends BlasWrapper> blasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
||||||
.forName(pp.toString(BLAS_OPS));
|
.forName(pp.toString(BLAS_OPS));
|
||||||
Class<? extends BlasWrapper> sparseBlasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
|
||||||
.forName(pp.toString(SPARSE_BLAS_OPS));
|
|
||||||
String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName());
|
String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName());
|
||||||
Class<? extends DistributionFactory> distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName);
|
Class<? extends DistributionFactory> distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName);
|
||||||
|
|
||||||
|
@ -5196,7 +5168,6 @@ public class Nd4j {
|
||||||
memoryManager = memoryManagerClazz.newInstance();
|
memoryManager = memoryManagerClazz.newInstance();
|
||||||
constantHandler = constantProviderClazz.newInstance();
|
constantHandler = constantProviderClazz.newInstance();
|
||||||
shapeInfoProvider = shapeInfoProviderClazz.newInstance();
|
shapeInfoProvider = shapeInfoProviderClazz.newInstance();
|
||||||
sparseInfoProvider = sparseInfoProviderClazz.newInstance();
|
|
||||||
workspaceManager = workspaceManagerClazz.newInstance();
|
workspaceManager = workspaceManagerClazz.newInstance();
|
||||||
|
|
||||||
Class<? extends OpExecutioner> opExecutionerClazz = (Class<? extends OpExecutioner>) Class
|
Class<? extends OpExecutioner> opExecutionerClazz = (Class<? extends OpExecutioner>) Class
|
||||||
|
@ -5205,10 +5176,8 @@ public class Nd4j {
|
||||||
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();
|
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();
|
||||||
Constructor c2 = ndArrayFactoryClazz.getConstructor(DataType.class, char.class);
|
Constructor c2 = ndArrayFactoryClazz.getConstructor(DataType.class, char.class);
|
||||||
INSTANCE = (NDArrayFactory) c2.newInstance(dtype, ORDER);
|
INSTANCE = (NDArrayFactory) c2.newInstance(dtype, ORDER);
|
||||||
SPARSE_INSTANCE = sparseNDArrayClazz.newInstance();
|
|
||||||
CONVOLUTION_INSTANCE = convolutionInstanceClazz.newInstance();
|
CONVOLUTION_INSTANCE = convolutionInstanceClazz.newInstance();
|
||||||
BLAS_WRAPPER_INSTANCE = blasWrapperClazz.newInstance();
|
BLAS_WRAPPER_INSTANCE = blasWrapperClazz.newInstance();
|
||||||
SPARSE_BLAS_WRAPPER_INSTANCE = sparseBlasWrapperClazz.newInstance();
|
|
||||||
DATA_BUFFER_FACTORY_INSTANCE = dataBufferFactoryClazz.newInstance();
|
DATA_BUFFER_FACTORY_INSTANCE = dataBufferFactoryClazz.newInstance();
|
||||||
|
|
||||||
DISTRIBUTION_FACTORY = distributionFactoryClazz.newInstance();
|
DISTRIBUTION_FACTORY = distributionFactoryClazz.newInstance();
|
||||||
|
@ -5306,14 +5275,6 @@ public class Nd4j {
|
||||||
return shapeInfoProvider;
|
return shapeInfoProvider;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return Sparse shape info provider
|
|
||||||
*/
|
|
||||||
public static SparseInfoProvider getSparseInfoProvider() {
|
|
||||||
return sparseInfoProvider;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @return constant handler
|
* @return constant handler
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.nativeblas;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.Blas;
|
|
||||||
|
|
||||||
public abstract class SparseNd4jBlas implements Blas {
|
|
||||||
|
|
||||||
public SparseNd4jBlas() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the BLAS library vendor
|
|
||||||
*
|
|
||||||
* @return the BLAS library vendor
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public Vendor getBlasVendor() {
|
|
||||||
int vendor = getBlasVendorId();
|
|
||||||
boolean isUnknowVendor = ((vendor > Vendor.values().length - 1) || (vendor <= 0));
|
|
||||||
if (isUnknowVendor) {
|
|
||||||
return Vendor.UNKNOWN;
|
|
||||||
}
|
|
||||||
return Vendor.values()[vendor];
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,62 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.jcublas;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseInfoProvider;
|
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
|
||||||
import org.nd4j.linalg.api.shape.SparseDescriptor;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class DirectSparseInfoProvider extends BaseSparseInfoProvider {
|
|
||||||
|
|
||||||
private Map<SparseDescriptor, DataBuffer> sparseCache = new ConcurrentHashMap<>();
|
|
||||||
private AtomicInteger counter = new AtomicInteger(0);
|
|
||||||
private static final int MAX_ENTRIES = 100;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataBuffer createSparseInformation(int[] flags, long[] sparseOffsets, int[] hiddenDimensions, int underlyingRank) {
|
|
||||||
|
|
||||||
SparseDescriptor descriptor = new SparseDescriptor(flags, sparseOffsets, hiddenDimensions, underlyingRank);
|
|
||||||
|
|
||||||
if(!sparseCache.containsKey(descriptor)){
|
|
||||||
if(counter.get() < MAX_ENTRIES){
|
|
||||||
if(!sparseCache.containsKey(descriptor)){
|
|
||||||
counter.incrementAndGet();
|
|
||||||
DataBuffer buffer = Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
|
|
||||||
sparseCache.put(descriptor, buffer);
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sparseCache.get(descriptor);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void purgeCache() {
|
|
||||||
sparseCache = new ConcurrentHashMap<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1612,55 +1612,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@Override
|
|
||||||
public INDArray createSparseCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(double[] values, long[][] indices, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(float[] values, long[][] indices, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(double[] values, int[][] indices, long[] shape) {
|
|
||||||
return new JCusparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(float[] values, int[][] indices, long[] shape) {
|
|
||||||
return new JCusparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray sortCooIndices(INDArray x) {
|
public INDArray sortCooIndices(INDArray x) {
|
||||||
|
|
|
@ -1,53 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.jcublas;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCOO;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class JCusparseNDArrayCOO extends BaseSparseNDArrayCOO {
|
|
||||||
public JCusparseNDArrayCOO(DataBuffer values, DataBuffer indices, long[] shape) {
|
|
||||||
super(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public JCusparseNDArrayCOO(double[] values, long[][] indices, long[] shape) {
|
|
||||||
super(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public JCusparseNDArrayCOO(float[] values, long[][] indices, long[] shape) {
|
|
||||||
super(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public JCusparseNDArrayCOO(double[] values, int[][] indices, long[] shape) {
|
|
||||||
super(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public JCusparseNDArrayCOO(float[] values, int[][] indices, long[] shape) {
|
|
||||||
super(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public JCusparseNDArrayCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
|
|
||||||
super(values, indices, sparseInformation, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public JCusparseNDArrayCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape) {
|
|
||||||
super(values, indices, sparseOffsets, flags, hiddenDimensions, underlyingRank, shape);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,570 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.jcublas;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.bytedeco.javacpp.Pointer;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataTypeEx;
|
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ndarray.ISparseNDArray;
|
|
||||||
import org.nd4j.linalg.factory.BaseSparseNDArrayFactory;
|
|
||||||
import org.nd4j.linalg.jcublas.blas.*;
|
|
||||||
import org.nd4j.nativeblas.NativeOps;
|
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class JCusparseNDArrayFactory extends BaseSparseNDArrayFactory{
|
|
||||||
|
|
||||||
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
|
||||||
|
|
||||||
public JCusparseNDArrayFactory(){}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, int[] shape, int[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, int[] shape, int[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, long[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(long[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(int[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(short[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(byte[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(boolean[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(List<INDArray> list, int[] shape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(long[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(int[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(short[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(byte[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(boolean[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(List<INDArray> list, long[] shape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(long rows, long columns, long[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray empty(DataType type) {
|
|
||||||
throw new IllegalStateException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void createBlas() {
|
|
||||||
blas = new SparseCudaBlas();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void createLevel1() {
|
|
||||||
level1 = new JcusparseLevel1();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void createLevel2() {
|
|
||||||
level2 = new JcusparseLevel2();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void createLevel3() {
|
|
||||||
level3 = new JcusparseLevel3();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void createLapack() {
|
|
||||||
lapack = new JcusparseLapack();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(int[] shape, DataBuffer buffer) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray toFlattened(char order, Collection<INDArray> matrices) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[][] data) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[][] data, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray specialConcat(int dimension, INDArray... toConcat) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray pullRows(INDArray source, int sourceDimension, long[] indexes) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, int[] indexes) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void shuffle(INDArray array, Random rnd, int... dimension) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void shuffle(Collection<INDArray> array, Random rnd, int... dimension) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void shuffle(List<INDArray> array, Random rnd, List<int[]> dimensions) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray average(INDArray target, INDArray[] arrays) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray average(INDArray[] arrays) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray average(Collection<INDArray> arrays) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray accumulate(INDArray target, INDArray... arrays) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray average(INDArray target, Collection<INDArray> arrays) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long rows, long columns, int[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, int[] shape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long[] shape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, int[] shape, int[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long[] shape, long[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[][] floats) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[][] data, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, int[] shape, int[] stride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer buffer, int[] shape, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(int[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(long[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataType dataType, long[] shape, long[] strides, char ordering, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createUninitialized(int[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createUninitialized(long[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createUninitialized(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, int[] newShape, int[] newStride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, long ews, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, char ordering, DataType dataType) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, int[] shape, long offset, Character order) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long rows, long columns, int[] stride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, int[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(List<INDArray> list, int[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(List<INDArray> list, long[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, int[] shape, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, int[] shape, int[] stride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, char order, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray convertDataEx(DataTypeEx typeSrc, INDArray source, DataTypeEx typeDst) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataBuffer convertDataEx(DataTypeEx typeSrc, DataBuffer source, DataTypeEx typeDst) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void convertDataEx(DataTypeEx typeSrc, DataBuffer source, DataTypeEx typeDst, DataBuffer target) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, Pointer target, long length) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, DataBuffer buffer) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createFromNpyPointer(Pointer pointer) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createFromNpyHeaderPointer(Pointer pointer) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createFromNpyFile(File file) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, INDArray> createFromNpzFile(File file) throws Exception {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Pointer convertToNumpy(INDArray array) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, long[] stride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray[] tear(INDArray tensor, int... dimensions) {
|
|
||||||
return new INDArray[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray sort(INDArray x, boolean descending) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray sort(INDArray x, boolean descending, int... dimensions) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray sortCooIndices(INDArray x) {
|
|
||||||
//TODO
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long offset, Character order) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, long offset, Character order) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(Collection<String> strings, long[] shape, char order) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ISparseNDArray createSparseCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
return new JcusparseNDArrayCSR(data, columns, pointerB, pointerE, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(double[] values, long[][] indices, long[] shape) {
|
|
||||||
return new JCusparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(float[] values, long[][] indices, long[] shape) {
|
|
||||||
return new JCusparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(double[] values, int[][] indices, long[] shape) {
|
|
||||||
return new JCusparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(float[] values, int[][] indices, long[] shape) {
|
|
||||||
return new JCusparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] shape) {
|
|
||||||
return new JCusparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
|
|
||||||
return new JCusparseNDArrayCOO(values, indices, sparseInformation, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape) {
|
|
||||||
return new JCusparseNDArrayCOO(values, indices, sparseOffsets, flags, hiddenDimensions, underlyingRank, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,176 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.jcublas;
|
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
|
||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCSR;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class JcusparseNDArrayCSR extends BaseSparseNDArrayCSR {
|
|
||||||
/**
|
|
||||||
* The length of the values and columns arrays is equal to the number of non-zero elements in A.
|
|
||||||
* The length of the pointerB and pointerE arrays is equal to the number of rows in A.
|
|
||||||
*
|
|
||||||
* @param data a double array that contains the non-zero element of the sparse matrix A
|
|
||||||
* @param columnsPointers Element i of the integer array columns is the number of the column in A that contains the i-th value
|
|
||||||
* in the values array.
|
|
||||||
* @param pointerB Element j of this integer array gives the index of the element in the values array that is first
|
|
||||||
* non-zero element in a row j of A. Note that this index is equal to pointerB(j) - pointerB(1)+1 .
|
|
||||||
* @param pointerE An integer array that contains row indices, such that pointerE(j)-pointerB(1) is the index of the
|
|
||||||
* element in the values array that is last non-zero element in a row j of A.
|
|
||||||
* @param shape Shape of the matrix A
|
|
||||||
*/
|
|
||||||
public JcusparseNDArrayCSR(double[] data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
super(data, columnsPointers, pointerB, pointerE, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public JcusparseNDArrayCSR(float[] data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
super(data, columnsPointers, pointerB, pointerE, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public JcusparseNDArrayCSR(DataBuffer data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
super(data, columnsPointers, pointerB, pointerE, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getString(long index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray repeat(int dimension, long... repeats) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray mmuli(INDArray other, MMulTranspose transpose) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getLong(long index) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray reshape(char order, int... newShape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray reshape(char order, boolean enforceView, long... newShape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray reshape(int[] shape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public LongShapeDescriptor shapeDescriptor() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int toFlatArray(FlatBufferBuilder builder) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isEmpty() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isR() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isZ() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isB() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isS() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray castTo(DataType dataType) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean all() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean any() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean none() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean closeable() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray assign(boolean value) {
|
|
||||||
return assign(value ? 1 : 0);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.jcublas;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.factory.BaseSparseBlasWrapper;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseBlasWrapper extends BaseSparseBlasWrapper {
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.jcublas.blas;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.impl.SparseBaseLapack;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class JcusparseLapack extends SparseBaseLapack {
|
|
||||||
}
|
|
|
@ -1,147 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.jcublas.blas;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel1;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class JcusparseLevel1 extends SparseBaseLevel1 {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double ddoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double sdoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double hdoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double snrm2(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double dnrm2(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double hnrm2(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double dasum(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double sasum(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double hasum(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected int isamax(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected int idamax(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected int ihamax(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected int isamin(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected int idamin(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected int ihamin(long N, INDArray X, int incx) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void daxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void saxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void haxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void droti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void sroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void hroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void dscal(long N, double a, INDArray X, int incx) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void sscal(long N, double a, INDArray X, int incx) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void hscal(long N, double a, INDArray X, int incx) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,38 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.jcublas.blas;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel2;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class JcusparseLevel2 extends SparseBaseLevel2 {
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void scoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz, INDArray x, INDArray y) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void dcoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz, INDArray x, INDArray y) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.jcublas.blas;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel3;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class JcusparseLevel3 extends SparseBaseLevel3 {
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.jcublas.blas;
|
|
||||||
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.Blas;
|
|
||||||
import org.nd4j.nativeblas.SparseNd4jBlas;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseCudaBlas extends SparseNd4jBlas {
|
|
||||||
@Override
|
|
||||||
public void setMaxThreads(int num){
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getMaxThreads() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getBlasVendorId() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -37,6 +37,3 @@ fft = org.nd4j.linalg.jcublas.fft.JcudaFft
|
||||||
opexec.mode= native
|
opexec.mode= native
|
||||||
random=org.nd4j.linalg.jcublas.rng.CudaNativeRandom
|
random=org.nd4j.linalg.jcublas.rng.CudaNativeRandom
|
||||||
|
|
||||||
sparseinfoprovider = org.nd4j.linalg.jcublas.DirectSparseInfoProvider
|
|
||||||
sparseblas.ops = org.nd4j.linalg.jcublas.SparseBlasWrapper
|
|
||||||
sparsendarrayfactory.class = org.nd4j.linalg.jcublas.JCusparseNDArrayFactory
|
|
||||||
|
|
|
@ -1054,59 +1054,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
convertDataEx(typeSrc, source.addressPointer(), typeDst, target.addressPointer(), target.length());
|
convertDataEx(typeSrc, source.addressPointer(), typeDst, target.addressPointer(), target.length());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(double[] values, int[][] indices, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(float[] values, int[][] indices, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(double[] values, long[][] indices, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(float[] values, long[][] indices, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray sort(INDArray x, boolean descending) {
|
public INDArray sort(INDArray x, boolean descending) {
|
||||||
if (x.isScalar())
|
if (x.isScalar())
|
||||||
|
|
|
@ -1,606 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.cpu.nativecpu;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.bytedeco.javacpp.*;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataTypeEx;
|
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCOO;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ndarray.SparseFormat;
|
|
||||||
import org.nd4j.linalg.cpu.nativecpu.blas.*;
|
|
||||||
import org.nd4j.linalg.factory.BaseSparseNDArrayFactory;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.nativeblas.LongPointerWrapper;
|
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
|
|
||||||
// TODO : Implement the methods
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class CpuSparseNDArrayFactory extends BaseSparseNDArrayFactory {
|
|
||||||
|
|
||||||
public CpuSparseNDArrayFactory(){}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape){
|
|
||||||
return new SparseNDArrayCSR(data, columns, pointerB, pointerE, shape);
|
|
||||||
}
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape){
|
|
||||||
return new SparseNDArrayCSR(data, columns, pointerB, pointerE, shape);
|
|
||||||
}
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape){
|
|
||||||
return new SparseNDArrayCSR(data, columns, pointerB, pointerE, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(double[] values, int[][] indices, long[] shape){
|
|
||||||
return new SparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(float[] values, int[][] indices, long[] shape){
|
|
||||||
return new SparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataType dataType, long[] shape, long[] strides, char ordering, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createUninitialized(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(double[] values, long[][] indices, long[] shape) {
|
|
||||||
return new SparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(float[] values, long[][] indices, long[] shape) {
|
|
||||||
return new SparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] shape){
|
|
||||||
return new SparseNDArrayCOO(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape) {
|
|
||||||
return new SparseNDArrayCOO(values, indices, sparseOffsets, flags, hiddenDimensions, underlyingRank, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
|
|
||||||
return new SparseNDArrayCOO(values, indices, sparseInformation, shape);
|
|
||||||
}
|
|
||||||
// TODO ->
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray pullRows(INDArray source, int sourceDimension, long[] indexes) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, long[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(long[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(int[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(short[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(byte[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(boolean[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(long[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(int[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(short[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(byte[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(boolean[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long[] shape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long[] shape, long[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(List<INDArray> list, long[] shape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(long rows, long columns, long[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(long[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createUninitialized(long[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, long ews, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, char ordering, DataType dataType) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(List<INDArray> list, long[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, char order, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray specialConcat(int dimension, INDArray... toConcat) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, int[] indexes) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, int[] shape, int[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, int[] shape, int[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(List<INDArray> list, int[] shape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
static{
|
|
||||||
Nd4j.getBlasWrapper();
|
|
||||||
}
|
|
||||||
// contructors ?
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void createBlas(){ blas = new SparseCpuBlas();}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void createLevel1() {
|
|
||||||
level1 = new SparseCpuLevel1();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void createLevel2() {
|
|
||||||
level2 = new SparseCpuLevel2();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void createLevel3() { level3 = new SparseCpuLevel3(); }
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void createLapack() {
|
|
||||||
lapack = new SparseCpuLapack();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(int[] shape, DataBuffer buffer) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray toFlattened(char order, Collection<INDArray> matrices) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[][] data) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[][] data, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void shuffle(INDArray array, Random rnd, int... dimension) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void shuffle(Collection<INDArray> array, Random rnd, int... dimension) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void shuffle(List<INDArray> array, Random rnd, List<int[]> dimensions) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray average(INDArray target, INDArray[] arrays) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray average(INDArray[] arrays) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray average(Collection<INDArray> arrays) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray accumulate(INDArray target, INDArray... arrays) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray average(INDArray target, Collection<INDArray> arrays) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, long rows, long columns, int[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, int[] shape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, int[] shape, int[] stride, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[][] floats) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[][] data, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, int[] shape, int[] stride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer buffer, int[] shape, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(int[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createUninitialized(int[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(DataBuffer data, int[] newShape, int[] newStride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, int[] shape, long offset, Character order) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long rows, long columns, int[] stride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, int[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(List<INDArray> list, int[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, int[] shape, long offset) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, int[] shape, int[] stride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray convertDataEx(DataTypeEx typeSrc, INDArray source, DataTypeEx typeDst) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataBuffer convertDataEx(DataTypeEx typeSrc, DataBuffer source, DataTypeEx typeDst) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void convertDataEx(DataTypeEx typeSrc, DataBuffer source, DataTypeEx typeDst, DataBuffer target) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, Pointer target, long length) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, DataBuffer buffer) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createFromNpyPointer(Pointer pointer) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createFromNpyHeaderPointer(Pointer pointer) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray createFromNpyFile(File file) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, INDArray> createFromNpzFile(File file){return null; }
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Pointer convertToNumpy(INDArray array) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long[] stride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, long[] stride, long offset, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray[] tear(INDArray tensor, int... dimensions) {
|
|
||||||
return new INDArray[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray sort(INDArray x, boolean descending) {
|
|
||||||
if (x.isScalar())
|
|
||||||
return x;
|
|
||||||
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().sort(null,
|
|
||||||
x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(),
|
|
||||||
null, null,
|
|
||||||
descending);
|
|
||||||
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray sort(INDArray x, boolean descending, int... dimension) {
|
|
||||||
if (x.isScalar())
|
|
||||||
return x;
|
|
||||||
|
|
||||||
Arrays.sort(dimension);
|
|
||||||
Pair<DataBuffer, DataBuffer> tadBuffers = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, dimension);
|
|
||||||
|
|
||||||
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().sortTad(null,
|
|
||||||
x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(),
|
|
||||||
null, null,
|
|
||||||
new IntPointer(dimension),
|
|
||||||
dimension.length,
|
|
||||||
(LongPointer) tadBuffers.getFirst().addressPointer(),
|
|
||||||
new LongPointerWrapper(tadBuffers.getSecond().addressPointer()),
|
|
||||||
descending);
|
|
||||||
|
|
||||||
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray sortCooIndices(INDArray x) {
|
|
||||||
|
|
||||||
if(x.getFormat() != SparseFormat.COO){
|
|
||||||
throw new UnsupportedOperationException("Not a COO ndarray");
|
|
||||||
}
|
|
||||||
BaseSparseNDArrayCOO array = (BaseSparseNDArrayCOO) x;
|
|
||||||
DataBuffer val = array.getValues();
|
|
||||||
DataBuffer idx = array.getIndices();
|
|
||||||
long length = val.length();
|
|
||||||
int rank = array.underlyingRank();
|
|
||||||
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(), val.addressPointer(), length, rank);
|
|
||||||
|
|
||||||
|
|
||||||
return array;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, long offset, Character order) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, long offset, Character order) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(float[] data, long[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(double[] data, long[] shape, char ordering) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray empty(DataType type) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray create(Collection<String> strings, long[] shape, char order) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,62 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.cpu.nativecpu;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseInfoProvider;
|
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
|
||||||
import org.nd4j.linalg.api.shape.SparseDescriptor;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class DirectSparseInfoProvider extends BaseSparseInfoProvider {
|
|
||||||
|
|
||||||
private Map<SparseDescriptor, DataBuffer> sparseCache = new ConcurrentHashMap<>();
|
|
||||||
private AtomicInteger counter = new AtomicInteger(0);
|
|
||||||
private static final int MAX_ENTRIES = 100;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataBuffer createSparseInformation(int[] flags, long[] sparseOffsets, int[] hiddenDimensions, int underlyingRank) {
|
|
||||||
|
|
||||||
SparseDescriptor descriptor = new SparseDescriptor(flags, sparseOffsets, hiddenDimensions, underlyingRank);
|
|
||||||
|
|
||||||
if(!sparseCache.containsKey(descriptor)){
|
|
||||||
if(counter.get() < MAX_ENTRIES){
|
|
||||||
if(!sparseCache.containsKey(descriptor)){
|
|
||||||
counter.incrementAndGet();
|
|
||||||
DataBuffer buffer = Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
|
|
||||||
sparseCache.put(descriptor, buffer);
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sparseCache.get(descriptor);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void purgeCache() {
|
|
||||||
sparseCache = new ConcurrentHashMap<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.cpu.nativecpu;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.factory.BaseSparseBlasWrapper;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseBlasWrapper extends BaseSparseBlasWrapper {
|
|
||||||
}
|
|
|
@ -1,54 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.cpu.nativecpu;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCOO;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseNDArrayCOO extends BaseSparseNDArrayCOO {
|
|
||||||
public SparseNDArrayCOO(double[] values, int[][] indices, long[] shape){
|
|
||||||
super(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparseNDArrayCOO(float[] values, int[][] indices, long[] shape) {
|
|
||||||
super(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparseNDArrayCOO(double[] values, long[][] indices, long[] shape){
|
|
||||||
super(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparseNDArrayCOO(float[] values, long[][] indices, long[] shape) {
|
|
||||||
super(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparseNDArrayCOO(DataBuffer values, DataBuffer indices, long[] shape){
|
|
||||||
super(values, indices, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public SparseNDArrayCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape){
|
|
||||||
super(values, indices, sparseOffsets, flags, hiddenDimensions, underlyingRank, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparseNDArrayCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
|
|
||||||
super(values, indices, sparseInformation, shape);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,181 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.cpu.nativecpu;
|
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.*;
|
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class SparseNDArrayCSR extends BaseSparseNDArrayCSR {
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* The length of the values and columns arrays is equal to the number of non-zero elements in A.
|
|
||||||
* The length of the pointerB and pointerE arrays is equal to the number of rows in A.
|
|
||||||
* @param data a double array that contains the non-zero element of the sparse matrix A
|
|
||||||
* @param columns Element i of the integer array columns is the number of the column in A that contains the i-th value
|
|
||||||
* in the values array.
|
|
||||||
* @param pointerB Element j of this integer array gives the index of the element in the values array that is first
|
|
||||||
* non-zero element in a row j of A. Note that this index is equal to pointerB(j) - pointerB(1)+1 .
|
|
||||||
* @param pointerE An integer array that contains row indices, such that pointerE(j)-pointerB(1) is the index of the
|
|
||||||
* element in the values array that is last non-zero element in a row j of A.
|
|
||||||
* @param shape Shape of the matrix A
|
|
||||||
*/
|
|
||||||
public SparseNDArrayCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
|
|
||||||
super(data, columns, pointerB, pointerE, shape);
|
|
||||||
}
|
|
||||||
public SparseNDArrayCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
|
|
||||||
super(data, columns, pointerB, pointerE, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparseNDArrayCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
|
|
||||||
|
|
||||||
super(data, columns, pointerB, pointerE, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getString(long index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray assign(boolean value) {
|
|
||||||
return assign(value ? 1 : 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray repeat(int dimension, long... repeats) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray mmuli(INDArray other, MMulTranspose transpose) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getLong(long index) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray reshape(char order, int... newShape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray reshape(char order, boolean enforceView, long... newShape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray reshape(int[] shape) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public LongShapeDescriptor shapeDescriptor() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int toFlatArray(FlatBufferBuilder builder) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isEmpty() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isR() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isZ() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isB() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray castTo(DataType dataType) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean all() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean any() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean none() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean closeable() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isS() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,141 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.cpu.nativecpu.blas;
|
|
||||||
|
|
||||||
import org.nd4j.nativeblas.SparseNd4jBlas;
|
|
||||||
|
|
||||||
import static org.bytedeco.mkl.global.mkl_rt.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseCpuBlas extends SparseNd4jBlas {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts a character
|
|
||||||
* to its proper enum
|
|
||||||
* for row (c) or column (f) ordering
|
|
||||||
* default is row major
|
|
||||||
*/
|
|
||||||
static int convertOrder(int from) {
|
|
||||||
switch (from) {
|
|
||||||
case 'c':
|
|
||||||
case 'C':
|
|
||||||
return CblasRowMajor;
|
|
||||||
case 'f':
|
|
||||||
case 'F':
|
|
||||||
return CblasColMajor;
|
|
||||||
default:
|
|
||||||
return CblasColMajor;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
* Converts a character to its proper enum
|
|
||||||
* t -> transpose
|
|
||||||
* n -> no transpose
|
|
||||||
* c -> conj
|
|
||||||
*/
|
|
||||||
static int convertTranspose(int from) {
|
|
||||||
switch (from) {
|
|
||||||
case 't':
|
|
||||||
case 'T':
|
|
||||||
return CblasTrans;
|
|
||||||
case 'n':
|
|
||||||
case 'N':
|
|
||||||
return CblasNoTrans;
|
|
||||||
case 'c':
|
|
||||||
case 'C':
|
|
||||||
return CblasConjTrans;
|
|
||||||
default:
|
|
||||||
return CblasNoTrans;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Upper or lower
|
|
||||||
* U/u -> upper
|
|
||||||
* L/l -> lower
|
|
||||||
*
|
|
||||||
* Default is upper
|
|
||||||
*/
|
|
||||||
static int convertUplo(int from) {
|
|
||||||
switch (from) {
|
|
||||||
case 'u':
|
|
||||||
case 'U':
|
|
||||||
return CblasUpper;
|
|
||||||
case 'l':
|
|
||||||
case 'L':
|
|
||||||
return CblasLower;
|
|
||||||
default:
|
|
||||||
return CblasUpper;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* For diagonals:
|
|
||||||
* u/U -> unit
|
|
||||||
* n/N -> non unit
|
|
||||||
*
|
|
||||||
* Default: unit
|
|
||||||
*/
|
|
||||||
static int convertDiag(int from) {
|
|
||||||
switch (from) {
|
|
||||||
case 'u':
|
|
||||||
case 'U':
|
|
||||||
return CblasUnit;
|
|
||||||
case 'n':
|
|
||||||
case 'N':
|
|
||||||
return CblasNonUnit;
|
|
||||||
default:
|
|
||||||
return CblasUnit;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Side of a matrix, left or right
|
|
||||||
* l /L -> left
|
|
||||||
* r/R -> right
|
|
||||||
* default: left
|
|
||||||
*/
|
|
||||||
static int convertSide(int from) {
|
|
||||||
switch (from) {
|
|
||||||
case 'l':
|
|
||||||
case 'L':
|
|
||||||
return CblasLeft;
|
|
||||||
case 'r':
|
|
||||||
case 'R':
|
|
||||||
return CblasRight;
|
|
||||||
default:
|
|
||||||
return CblasLeft;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setMaxThreads(int num){
|
|
||||||
MKL_Set_Num_Threads(num);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getMaxThreads() {
|
|
||||||
return MKL_Get_Max_Threads();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getBlasVendorId() {
|
|
||||||
return Vendor.MKL.ordinal();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.cpu.nativecpu.blas;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.impl.SparseBaseLapack;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseCpuLapack extends SparseBaseLapack {
|
|
||||||
}
|
|
|
@ -1,289 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.cpu.nativecpu.blas;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel1;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.nativeblas.SparseNd4jBlas;
|
|
||||||
import org.bytedeco.javacpp.*;
|
|
||||||
|
|
||||||
import static org.bytedeco.mkl.global.mkl_rt.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseCpuLevel1 extends SparseBaseLevel1 {
|
|
||||||
|
|
||||||
// FIXME: int cast !!!
|
|
||||||
|
|
||||||
private SparseNd4jBlas sparseNd4jBlas = (SparseNd4jBlas) Nd4j.sparseFactory().blas();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the dot product of a compressed sparse double vector by a full-storage real vector.
|
|
||||||
* @param N The number of elements in x and indx
|
|
||||||
* @param X an sparse INDArray. Size at least N
|
|
||||||
* @param indx an Databuffer that Specifies the indices for the elements of x. Size at least N
|
|
||||||
* @param Y a dense INDArray. Size at least max(indx[i])
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected double ddoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
|
|
||||||
return cblas_ddoti((int) N, (DoublePointer) X.data().addressPointer(),(IntPointer) indx.addressPointer(),
|
|
||||||
(DoublePointer) Y.data().addressPointer());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the dot product of a compressed sparse float vector by a full-storage real vector.
|
|
||||||
* @param N The number of elements in x and indx
|
|
||||||
* @param X an sparse INDArray. Size at least N
|
|
||||||
* @param indx an Databuffer that specifies the indices for the elements of x. Size at least N
|
|
||||||
* @param Y a dense INDArray. Size at least max(indx[i])
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected double sdoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
|
|
||||||
return cblas_sdoti((int) N, (FloatPointer) X.data().addressPointer(),(IntPointer) indx.addressPointer(),
|
|
||||||
(FloatPointer) Y.data().addressPointer());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double hdoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the Euclidean norm of a float vector
|
|
||||||
* @param N The number of elements in vector X
|
|
||||||
* @param X an INDArray
|
|
||||||
* @param incx the increment of X
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected double snrm2(long N, INDArray X, int incx){
|
|
||||||
return cblas_snrm2((int) N, (FloatPointer) X.data().addressPointer(), incx);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the Euclidean norm of a double vector
|
|
||||||
* @param N The number of elements in vector X
|
|
||||||
* @param X an INDArray
|
|
||||||
* @param incx the increment of X
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected double dnrm2(long N, INDArray X, int incx){
|
|
||||||
return cblas_dnrm2((int) N, (DoublePointer) X.data().addressPointer(), incx);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double hnrm2(long N, INDArray X, int incx){
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute the sum of magnitude of the double vector elements
|
|
||||||
*
|
|
||||||
* @param N The number of elements in vector X
|
|
||||||
* @param X a double vector
|
|
||||||
* @param incrx The increment of X
|
|
||||||
* @return the sum of magnitude of the vector elements
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected double dasum(long N, INDArray X, int incrx){
|
|
||||||
return cblas_dasum((int) N, (DoublePointer) X.data().addressPointer(), incrx);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute the sum of magnitude of the float vector elements
|
|
||||||
*
|
|
||||||
* @param N The number of elements in vector X
|
|
||||||
* @param X a float vector
|
|
||||||
* @param incrx The increment of X
|
|
||||||
* @return the sum of magnitude of the vector elements
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected double sasum(long N, INDArray X, int incrx){
|
|
||||||
return cblas_sasum((int) N, (FloatPointer) X.data().addressPointer(), incrx);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected double hasum(long N, INDArray X, int incrx){
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Find the index of the element with maximum absolute value
|
|
||||||
*
|
|
||||||
* @param N The number of elements in vector X
|
|
||||||
* @param X a vector
|
|
||||||
* @param incX The increment of X
|
|
||||||
* @return the index of the element with maximum absolute value
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected int isamax(long N, INDArray X, int incX) {
|
|
||||||
return (int) cblas_isamax((int) N, (FloatPointer) X.data().addressPointer(), incX);
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
* Find the index of the element with maximum absolute value
|
|
||||||
*
|
|
||||||
* @param N The number of elements in vector X
|
|
||||||
* @param X a vector
|
|
||||||
* @param incX The increment of X
|
|
||||||
* @return the index of the element with maximum absolute value
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected int idamax(long N, INDArray X, int incX) {
|
|
||||||
return (int) cblas_idamax((int) N, (DoublePointer) X.data().addressPointer(), incX);
|
|
||||||
}
|
|
||||||
@Override
|
|
||||||
protected int ihamax(long N, INDArray X, int incX) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Find the index of the element with minimum absolute value
|
|
||||||
*
|
|
||||||
* @param N The number of elements in vector X
|
|
||||||
* @param X a vector
|
|
||||||
* @param incX The increment of X
|
|
||||||
* @return the index of the element with minimum absolute value
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected int isamin(long N, INDArray X, int incX) {
|
|
||||||
return (int) cblas_isamin((int) N, (FloatPointer) X.data().addressPointer(), incX);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Find the index of the element with minimum absolute value
|
|
||||||
*
|
|
||||||
* @param N The number of elements in vector X
|
|
||||||
* @param X a vector
|
|
||||||
* @param incX The increment of X
|
|
||||||
* @return the index of the element with minimum absolute value
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected int idamin(long N, INDArray X, int incX) {
|
|
||||||
return (int) cblas_idamin((int) N, (DoublePointer) X.data().addressPointer(), incX);
|
|
||||||
}
|
|
||||||
@Override
|
|
||||||
protected int ihamin(long N, INDArray X, int incX) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Adds a scalar multiple of double compressed sparse vector to a full-storage vector.
|
|
||||||
*
|
|
||||||
* @param N The number of elements in vector X
|
|
||||||
* @param alpha
|
|
||||||
* @param X a sparse vector
|
|
||||||
* @param pointers A DataBuffer that specifies the indices for the elements of x.
|
|
||||||
* @param Y a dense vector
|
|
||||||
*
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected void daxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y){
|
|
||||||
cblas_daxpyi((int) N, alpha, (DoublePointer) X.data().addressPointer(), (IntPointer) pointers.addressPointer(),
|
|
||||||
(DoublePointer) Y.data().addressPointer());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Adds a scalar multiple of float compressed sparse vector to a full-storage vector.
|
|
||||||
*
|
|
||||||
* @param N The number of elements in vector X
|
|
||||||
* @param alpha
|
|
||||||
* @param X a sparse vector
|
|
||||||
* @param pointers A DataBuffer that specifies the indices for the elements of x.
|
|
||||||
* @param Y a dense vector
|
|
||||||
*
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected void saxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y) {
|
|
||||||
cblas_saxpyi((int) N, (float) alpha, (FloatPointer) X.data().addressPointer(), (IntPointer) pointers.addressPointer(),
|
|
||||||
(FloatPointer) Y.data().addressPointer());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void haxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y){
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Applies Givens rotation to sparse vectors one of which is in compressed form.
|
|
||||||
*
|
|
||||||
* @param N The number of elements in vectors X and Y
|
|
||||||
* @param X a double sparse vector
|
|
||||||
* @param indexes The indexes of the sparse vector
|
|
||||||
* @param Y a double full-storage vector
|
|
||||||
* @param c a scalar
|
|
||||||
* @param s a scalar
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected void droti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
|
|
||||||
cblas_droti((int) N, (DoublePointer) X.data().addressPointer(), (IntPointer) indexes.addressPointer(),
|
|
||||||
(DoublePointer) Y.data().addressPointer(), c, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Applies Givens rotation to sparse vectors one of which is in compressed form.
|
|
||||||
*
|
|
||||||
* @param N The number of elements in vectors X and Y
|
|
||||||
* @param X a float sparse vector
|
|
||||||
* @param indexes The indexes of the sparse vector
|
|
||||||
* @param Y a float full-storage vector
|
|
||||||
* @param c a scalar
|
|
||||||
* @param s a scalar
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected void sroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
|
|
||||||
cblas_sroti((int) N, (FloatPointer) X.data().addressPointer(), (IntPointer) indexes.addressPointer().capacity(X.columns()),
|
|
||||||
(FloatPointer) Y.data().addressPointer(), (float) c, (float) s);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void hroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the product of a double vector by a scalar.
|
|
||||||
*
|
|
||||||
* @param N The number of elements of the vector X
|
|
||||||
* @param a a scalar
|
|
||||||
* @param X a vector
|
|
||||||
* @param incx the increment of the vector X
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected void dscal(long N, double a, INDArray X, int incx) {
|
|
||||||
cblas_dscal((int) N, a, (DoublePointer) X.data().addressPointer(), incx);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the product of a float vector by a scalar.
|
|
||||||
*
|
|
||||||
* @param N The number of elements of the vector X
|
|
||||||
* @param a a scalar
|
|
||||||
* @param X a vector
|
|
||||||
* @param incx the increment of the vector X
|
|
||||||
* */
|
|
||||||
@Override
|
|
||||||
protected void sscal(long N, double a, INDArray X, int incx) {
|
|
||||||
cblas_sscal((int) N, (float) a, (FloatPointer) X.data().addressPointer(), incx);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void hscal(long N, double a, INDArray X, int incx) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,59 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.cpu.nativecpu.blas;
|
|
||||||
|
|
||||||
import org.bytedeco.javacpp.DoublePointer;
|
|
||||||
import org.bytedeco.javacpp.FloatPointer;
|
|
||||||
import org.bytedeco.javacpp.IntPointer;
|
|
||||||
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel2;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.nativeblas.SparseNd4jBlas;
|
|
||||||
|
|
||||||
import static org.bytedeco.mkl.global.mkl_rt.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseCpuLevel2 extends SparseBaseLevel2 {
|
|
||||||
private SparseNd4jBlas sparseNd4jBlas = (SparseNd4jBlas) Nd4j.sparseFactory().blas();
|
|
||||||
// Mapping with Sparse Blas calls
|
|
||||||
|
|
||||||
public void scoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz, INDArray x, INDArray y){
|
|
||||||
mkl_cspblas_scoogemv(
|
|
||||||
Character.toString(transA),
|
|
||||||
(IntPointer) Nd4j.createBuffer(new int[]{M}).addressPointer(),
|
|
||||||
(FloatPointer) values.addressPointer(),
|
|
||||||
(IntPointer) rowInd.addressPointer(),
|
|
||||||
(IntPointer) colInd.addressPointer(),
|
|
||||||
(IntPointer) Nd4j.createBuffer(new int[]{nnz}).addressPointer(),
|
|
||||||
(FloatPointer) x.data().addressPointer(),
|
|
||||||
(FloatPointer)y.data().addressPointer());
|
|
||||||
}
|
|
||||||
public void dcoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz, INDArray x, INDArray y){
|
|
||||||
mkl_cspblas_dcoogemv(
|
|
||||||
Character.toString(transA),
|
|
||||||
(IntPointer) Nd4j.createBuffer(new int[]{M}).addressPointer(),
|
|
||||||
(DoublePointer) values.addressPointer(),
|
|
||||||
(IntPointer) rowInd.addressPointer(),
|
|
||||||
(IntPointer) colInd.addressPointer(),
|
|
||||||
(IntPointer) Nd4j.createBuffer(nnz).addressPointer(),
|
|
||||||
(DoublePointer) x.data().addressPointer(),
|
|
||||||
(DoublePointer)y.data().addressPointer());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.cpu.nativecpu.blas;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel3;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.nativeblas.SparseNd4jBlas;
|
|
||||||
|
|
||||||
import static org.bytedeco.mkl.global.mkl_rt.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
public class SparseCpuLevel3 extends SparseBaseLevel3 {
|
|
||||||
private SparseNd4jBlas sparseNd4jBlas = (SparseNd4jBlas) Nd4j.sparseFactory().blas();
|
|
||||||
// TODO Mappings with Sparse Blas methods
|
|
||||||
}
|
|
|
@ -17,17 +17,15 @@
|
||||||
real.class.double = org.nd4j.linalg.cpu.NDArray
|
real.class.double = org.nd4j.linalg.cpu.NDArray
|
||||||
complex.class.double = org.nd4j.linalg.cpu.nativecpu.complex.ComplexNDArray
|
complex.class.double = org.nd4j.linalg.cpu.nativecpu.complex.ComplexNDArray
|
||||||
shapeinfoprovider = org.nd4j.linalg.cpu.nativecpu.DirectShapeInfoProvider
|
shapeinfoprovider = org.nd4j.linalg.cpu.nativecpu.DirectShapeInfoProvider
|
||||||
sparseinfoprovider = org.nd4j.linalg.cpu.nativecpu.DirectSparseInfoProvider
|
|
||||||
constantsprovider = org.nd4j.linalg.cpu.nativecpu.cache.ConstantBuffersCache
|
constantsprovider = org.nd4j.linalg.cpu.nativecpu.cache.ConstantBuffersCache
|
||||||
affinitymanager = org.nd4j.linalg.cpu.nativecpu.CpuAffinityManager
|
affinitymanager = org.nd4j.linalg.cpu.nativecpu.CpuAffinityManager
|
||||||
memorymanager = org.nd4j.linalg.cpu.nativecpu.CpuMemoryManager
|
memorymanager = org.nd4j.linalg.cpu.nativecpu.CpuMemoryManager
|
||||||
dtype = float
|
dtype = float
|
||||||
complex.double.class = org.nd4j.linalg.cpu.nativecpu.complex.ComplexDouble
|
complex.double.class = org.nd4j.linalg.cpu.nativecpu.complex.ComplexDouble
|
||||||
blas.ops = org.nd4j.linalg.cpu.nativecpu.BlasWrapper
|
blas.ops = org.nd4j.linalg.cpu.nativecpu.BlasWrapper
|
||||||
sparseblas.ops = org.nd4j.linalg.cpu.nativecpu.SparseBlasWrapper
|
|
||||||
native.ops= org.nd4j.nativeblas.Nd4jCpu
|
native.ops= org.nd4j.nativeblas.Nd4jCpu
|
||||||
ndarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory
|
ndarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory
|
||||||
sparsendarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuSparseNDArrayFactory
|
|
||||||
ndarray.order = c
|
ndarray.order = c
|
||||||
resourcemanager_state = false
|
resourcemanager_state = false
|
||||||
databufferfactory = org.nd4j.linalg.api.buffer.factory.DefaultDataBufferFactory
|
databufferfactory = org.nd4j.linalg.api.buffer.factory.DefaultDataBufferFactory
|
||||||
|
|
|
@ -1,659 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCOO;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
import org.nd4j.linalg.indexing.SpecifiedIndex;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
@Ignore // temporary ignored
|
|
||||||
public class SparseNDArrayCOOTest extends BaseNd4jTest {
|
|
||||||
|
|
||||||
public SparseNDArrayCOOTest(Nd4jBackend b){
|
|
||||||
super(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public char ordering(){
|
|
||||||
return 'c';
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double[] data = {10, 1, 2, 3, 4, 5};
|
|
||||||
long[] shape = {2, 2, 2};
|
|
||||||
int[][] indices = new int[][] {new int[] {0, 0, 0, 1, 2, 2}, new int[] {0, 0, 1, 1, 1, 2},
|
|
||||||
new int[] {1, 2, 2, 1, 0, 1}};
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldCreateSparseMatrix() {
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/*INDArray sparse = Nd4j.createSparseCOO(data, indices, shape);
|
|
||||||
assertArrayEquals(shape, sparse.shape());
|
|
||||||
assertEquals(data.length, sparse.nnz());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldPutScalar() {
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/*
|
|
||||||
INDArray sparse = Nd4j.createSparseCOO(new double[] {1, 2}, new int[][] {{0, 0}, {0, 2}}, new long[] {1, 3});
|
|
||||||
sparse.putScalar(1, 3);
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldntPutZero() {
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/*
|
|
||||||
INDArray sparse = Nd4j.createSparseCOO(new double[] {1, 2}, new int[][] {{0, 0}, {0, 2}}, new long[] {1, 3});
|
|
||||||
int oldNNZ = sparse.nnz();
|
|
||||||
sparse.putScalar(1, 0);
|
|
||||||
assertArrayEquals(new int[] {0, 2}, sparse.getVectorCoordinates().asInt());
|
|
||||||
assertTrue(sparse.isRowVector());
|
|
||||||
assertEquals(oldNNZ, sparse.nnz());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldRemoveZero() {
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/*
|
|
||||||
INDArray sparse = Nd4j.createSparseCOO(new double[] {1, 2}, new int[][] {{0, 0}, {0, 2}}, new long[] {1, 3});
|
|
||||||
sparse.putScalar(0, 0);
|
|
||||||
assertArrayEquals(new int[] {2}, sparse.getVectorCoordinates().asInt());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldTakeViewInLeftTopCorner() {
|
|
||||||
// Test with dense ndarray
|
|
||||||
double[] data = {0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0};
|
|
||||||
INDArray array = Nd4j.create(data, new int[] {5, 5}, 0, 'c');
|
|
||||||
INDArray denseView = array.get(NDArrayIndex.interval(0, 2), NDArrayIndex.interval(0, 2));
|
|
||||||
|
|
||||||
// test with sparse :
|
|
||||||
double[] values = {1, 2, 3, 4};
|
|
||||||
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/*
|
|
||||||
INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
|
|
||||||
|
|
||||||
// subarray in the top right corner
|
|
||||||
BaseSparseNDArrayCOO sparseView = (BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.interval(0, 2),
|
|
||||||
NDArrayIndex.interval(0, 2));
|
|
||||||
assertArrayEquals(denseView.shape(), sparseView.shape());
|
|
||||||
double[] currentValues = sparseView.data().asDouble();
|
|
||||||
assertArrayEquals(values, currentValues, 1e-5);
|
|
||||||
assertArrayEquals(ArrayUtil.flatten(indices), sparseView.getUnderlyingIndices().asInt());
|
|
||||||
assertEquals(0, sparseView.nnz());
|
|
||||||
System.out.println(sparseView.sparseInfoDataBuffer());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldTakeViewInLeftBottomCorner() {
|
|
||||||
|
|
||||||
double[] values = {1, 2, 3, 4};
|
|
||||||
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/*
|
|
||||||
INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
|
|
||||||
|
|
||||||
BaseSparseNDArrayCOO sparseView = (BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.interval(2, 5),
|
|
||||||
NDArrayIndex.interval(0, 2));
|
|
||||||
assertEquals(1, sparseView.nnz());
|
|
||||||
assertArrayEquals(new double[] {3}, sparseView.getIncludedValues().asDouble(), 1e-1);
|
|
||||||
assertArrayEquals(new int[] {0, 1}, sparseView.getIncludedIndices().asInt());
|
|
||||||
|
|
||||||
System.out.println(sparseView.sparseInfoDataBuffer());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldTakeViewInRightTopCorner() {
|
|
||||||
|
|
||||||
double[] values = {1, 2, 3, 4};
|
|
||||||
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
|
|
||||||
BaseSparseNDArrayCOO sparseView = (BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.interval(0, 2),
|
|
||||||
NDArrayIndex.interval(2, 5));
|
|
||||||
assertEquals(2, sparseView.nnz());
|
|
||||||
assertArrayEquals(new double[] {1, 2}, sparseView.getIncludedValues().asDouble(), 1e-1);
|
|
||||||
assertArrayEquals(new int[] {0, 1, 1, 0}, sparseView.getIncludedIndices().asInt());
|
|
||||||
|
|
||||||
System.out.println(sparseView.sparseInfoDataBuffer());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldTakeViewInTheMiddle() {
|
|
||||||
double[] values = {1, 2, 3, 4};
|
|
||||||
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
|
|
||||||
BaseSparseNDArrayCOO sparseView = (BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.interval(1, 3),
|
|
||||||
NDArrayIndex.interval(1, 3));
|
|
||||||
assertEquals(2, sparseView.nnz());
|
|
||||||
assertArrayEquals(new double[] {2, 3}, sparseView.getIncludedValues().asDouble(), 1e-1);
|
|
||||||
assertArrayEquals(new int[] {0, 1, 1, 0}, sparseView.getIncludedIndices().asInt());
|
|
||||||
|
|
||||||
System.out.println(sparseView.sparseInfoDataBuffer());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldGetFirstColumn() {
|
|
||||||
double[] values = {1, 2, 3, 4};
|
|
||||||
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
|
|
||||||
BaseSparseNDArrayCOO sparseView =
|
|
||||||
(BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
|
||||||
assertEquals(0, sparseView.nnz());
|
|
||||||
|
|
||||||
System.out.println(sparseView.sparseInfoDataBuffer());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldGetRowInTheMiddle() {
|
|
||||||
double[] values = {1, 2, 3, 4};
|
|
||||||
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
|
|
||||||
BaseSparseNDArrayCOO sparseView =
|
|
||||||
(BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.point(2), NDArrayIndex.all());
|
|
||||||
assertEquals(1, sparseView.nnz());
|
|
||||||
assertArrayEquals(new int[] {0, 1}, sparseView.getIncludedIndices().asInt());
|
|
||||||
assertArrayEquals(new double[] {3}, sparseView.getIncludedValues().asDouble(), 1e-1);
|
|
||||||
|
|
||||||
System.out.println(sparseView.sparseInfoDataBuffer());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldGetScalar() {
|
|
||||||
double[] values = {1, 2, 3, 4};
|
|
||||||
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
|
|
||||||
BaseSparseNDArrayCOO sparseView =
|
|
||||||
(BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.point(2), NDArrayIndex.point(1));
|
|
||||||
assertEquals(1, sparseView.nnz());
|
|
||||||
assertArrayEquals(new int[] {0, 0}, sparseView.getIncludedIndices().asInt());
|
|
||||||
assertArrayEquals(new double[] {3}, sparseView.getIncludedValues().asDouble(), 1e-1);
|
|
||||||
assertTrue(sparseView.isScalar());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldTakeView3dimensionArray() {
|
|
||||||
long[] shape = new long[] {2, 2, 2};
|
|
||||||
double[] values = new double[] {2, 1, 4, 3};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 0}, {1, 0, 1}, {1, 1, 0}, {1, 1, 1}};
|
|
||||||
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
BaseSparseNDArrayCOO view =
|
|
||||||
(BaseSparseNDArrayCOO) array.get(NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all());
|
|
||||||
assertEquals(2, view.nnz());
|
|
||||||
assertArrayEquals(new long[] {2, 2}, view.shape());
|
|
||||||
assertArrayEquals(new int[] {0, 0, 1, 1}, view.getIncludedIndices().asInt());
|
|
||||||
assertArrayEquals(new double[] {2, 1}, view.getIncludedValues().asDouble(), 1e-1);
|
|
||||||
|
|
||||||
System.out.println(view.sparseInfoDataBuffer());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldTakeViewOfView() {
|
|
||||||
long[] shape = new long[] {2, 2, 2};
|
|
||||||
double[] values = new double[] {2, 1, 4, 3};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 0}, {1, 0, 1}, {1, 1, 0}, {1, 1, 1}};
|
|
||||||
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
BaseSparseNDArrayCOO baseView =
|
|
||||||
(BaseSparseNDArrayCOO) array.get(NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all());
|
|
||||||
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) baseView.get(NDArrayIndex.point(1), NDArrayIndex.all());
|
|
||||||
assertEquals(1, view.nnz());
|
|
||||||
assertArrayEquals(new long[] {1, 2}, view.shape());
|
|
||||||
assertArrayEquals(new int[] {0, 1}, view.getIncludedIndices().asInt());
|
|
||||||
assertArrayEquals(new double[] {1}, view.getIncludedValues().asDouble(), 1e-1);
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldTakeViewOfView2() {
|
|
||||||
long[] shape = new long[] {4, 2, 3};
|
|
||||||
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
|
|
||||||
{3, 0, 1}, {3, 1, 0}};
|
|
||||||
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
BaseSparseNDArrayCOO baseView = (BaseSparseNDArrayCOO) array.get(NDArrayIndex.interval(1, 4),
|
|
||||||
NDArrayIndex.point(1), NDArrayIndex.all());
|
|
||||||
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) baseView.get(NDArrayIndex.all(), NDArrayIndex.point(2));
|
|
||||||
assertEquals(2, view.nnz());
|
|
||||||
assertArrayEquals(new long[] {3, 1}, view.shape());
|
|
||||||
assertArrayEquals(new int[] {0, 0, 1, 0}, view.getIncludedIndices().asInt());
|
|
||||||
assertArrayEquals(new double[] {5, 7}, view.getIncludedValues().asDouble(), 1e-1);
|
|
||||||
assertTrue(view.isColumnVector());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldGetWithSpecifiedIndexes() {
|
|
||||||
long[] shape = new long[] {4, 2, 3};
|
|
||||||
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
|
|
||||||
{3, 0, 1}, {3, 1, 0}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/*INDArray array = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
BaseSparseNDArrayCOO newArray = (BaseSparseNDArrayCOO) array.get(new SpecifiedIndex(0, 3),
|
|
||||||
NDArrayIndex.all(), NDArrayIndex.all());
|
|
||||||
assertEquals(4, newArray.nnz());
|
|
||||||
assertArrayEquals(new double[] {1, 2, 8, 9}, newArray.getIncludedValues().asDouble(), 1e-1);
|
|
||||||
assertArrayEquals(new int[] {0, 0, 2, 0, 1, 1, 1, 0, 1, 1, 1, 0}, newArray.getIncludedIndices().asInt());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldGetWithSpecifiedIndexes2() {
|
|
||||||
long[] shape = new long[] {4, 2, 3};
|
|
||||||
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
|
|
||||||
{3, 0, 2}, {3, 1, 0}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
|
|
||||||
BaseSparseNDArrayCOO newArray = (BaseSparseNDArrayCOO) array.get(NDArrayIndex.interval(1, 4),
|
|
||||||
new SpecifiedIndex(0), new SpecifiedIndex(0, 2));
|
|
||||||
assertEquals(2, newArray.nnz());
|
|
||||||
assertArrayEquals(new double[] {3, 8}, newArray.getIncludedValues().asDouble(), 1e-1);
|
|
||||||
assertArrayEquals(new int[] {0, 0, 2, 1}, newArray.getIncludedIndices().asInt());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void specifiedIndexWithDenseArray() {
|
|
||||||
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
|
||||||
System.out.println(arr.toString());
|
|
||||||
INDArray v = arr.get(NDArrayIndex.interval(1, 3), new SpecifiedIndex(0),
|
|
||||||
new SpecifiedIndex(0, 2));
|
|
||||||
|
|
||||||
System.out.println("v ");
|
|
||||||
System.out.println(v.toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void newAxisWithSparseArray() {
|
|
||||||
long[] shape = new long[] {4, 2, 3};
|
|
||||||
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
|
|
||||||
{3, 0, 2}, {3, 1, 0}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
INDArray v = array.get(NDArrayIndex.point(0), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void nestedSparseViewWithNewAxis() {
|
|
||||||
long[] shape = new long[] {4, 2, 3};
|
|
||||||
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
|
|
||||||
{3, 0, 2}, {3, 1, 0}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
|
|
||||||
System.out.println("\nTaking view (all, point(1), all");
|
|
||||||
INDArray v = array.get(NDArrayIndex.all(), NDArrayIndex.point(1));
|
|
||||||
System.out.println(v.toString());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
System.out.println("Fixed dimension " + v.flags());
|
|
||||||
System.out.println("sparse offsets " + v.sparseOffsets());
|
|
||||||
System.out.println("hidden dimensions " + v.hiddenDimensions());
|
|
||||||
System.out.println("number of hidden dimensions " + ((BaseSparseNDArrayCOO) v).getNumHiddenDimension());
|
|
||||||
// shape 4 x 3
|
|
||||||
|
|
||||||
System.out.println("\nTaking view (all new axis");
|
|
||||||
INDArray v1 = v.get(NDArrayIndex.all(), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v1.toString());
|
|
||||||
System.out.println(v1.shapeInfoDataBuffer());
|
|
||||||
System.out.println("Fixed dimension " + v1.flags());
|
|
||||||
System.out.println("sparse offsets " + v1.sparseOffsets());
|
|
||||||
System.out.println("hidden dimensions " + v1.hiddenDimensions());
|
|
||||||
System.out.println("number of hidden dimensions " + ((BaseSparseNDArrayCOO) v1).getNumHiddenDimension());
|
|
||||||
// shape 4 x 1 x 3
|
|
||||||
|
|
||||||
System.out.println("\nTaking view (all new axis");
|
|
||||||
v1 = v.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v1.toString());
|
|
||||||
System.out.println(v1.shapeInfoDataBuffer());
|
|
||||||
System.out.println("Fixed dimension " + v1.flags());
|
|
||||||
System.out.println("sparse offsets " + v1.sparseOffsets());
|
|
||||||
System.out.println("hidden dimensions " + v1.hiddenDimensions());
|
|
||||||
System.out.println("number of hidden dimensions " + ((BaseSparseNDArrayCOO) v1).getNumHiddenDimension());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void nestedViewWithNewAxis() {
|
|
||||||
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
|
||||||
System.out.println(arr.toString());
|
|
||||||
System.out.println(arr.shapeInfoDataBuffer());
|
|
||||||
|
|
||||||
System.out.println("\nTaking view (all, point(1), all");
|
|
||||||
INDArray v = arr.get(NDArrayIndex.all(), NDArrayIndex.point(1));
|
|
||||||
System.out.println(v.toString());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
// shape 4 x 3
|
|
||||||
|
|
||||||
System.out.println("\nTaking view (all new axis");
|
|
||||||
INDArray v1 = v.get(NDArrayIndex.all(), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v1.toString());
|
|
||||||
System.out.println(v1.shapeInfoDataBuffer());
|
|
||||||
// shape 4 x 1 x 3
|
|
||||||
|
|
||||||
System.out.println("\nTaking view (all new axis");
|
|
||||||
v1 = v1.get(NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v1.toString());
|
|
||||||
System.out.println(v1.shapeInfoDataBuffer());
|
|
||||||
// shape 4 x 3
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldTranslateViewIndexesToOriginal() {
|
|
||||||
long[] shape = new long[] {4, 2, 3};
|
|
||||||
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
|
|
||||||
{3, 0, 2}, {3, 1, 0}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/*INDArray original = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) original.get(NDArrayIndex.all(), NDArrayIndex.point(1));
|
|
||||||
int[] originalIdx = view.translateToPhysical(new int[] {0, 0});
|
|
||||||
int[] exceptedIdx = new int[] {0, 1, 0};
|
|
||||||
assertArrayEquals(exceptedIdx, originalIdx);
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldTranslateViewIndexesToOriginal2() {
|
|
||||||
long[] shape = new long[] {4, 2, 3};
|
|
||||||
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
|
|
||||||
{3, 0, 2}, {3, 1, 0}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray original = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) original.get(NDArrayIndex.all(), NDArrayIndex.newAxis(),
|
|
||||||
NDArrayIndex.point(1));
|
|
||||||
assertArrayEquals(new int[] {0, 1, 0}, view.translateToPhysical(new int[] {0, 0, 0}));
|
|
||||||
assertArrayEquals(new int[] {1, 1, 1}, view.translateToPhysical(new int[] {1, 0, 1}));
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldTranslateViewIndexesToOriginal3() {
|
|
||||||
long[] shape = new long[] {4, 2, 3, 3};
|
|
||||||
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 2, 0}, {0, 1, 1, 1}, {1, 0, 0, 0}, {1, 0, 1, 0}, {1, 1, 2, 1},
|
|
||||||
{2, 0, 1, 0}, {2, 1, 2, 0}, {3, 0, 2, 1}, {3, 1, 0, 1}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray original = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) original.get(NDArrayIndex.all(), NDArrayIndex.newAxis(),
|
|
||||||
NDArrayIndex.point(1), NDArrayIndex.point(2));
|
|
||||||
assertArrayEquals(new int[] {0, 1, 2, 0}, view.translateToPhysical(new int[] {0, 0, 0}));
|
|
||||||
assertArrayEquals(new int[] {1, 1, 2, 1}, view.translateToPhysical(new int[] {1, 0, 1}));
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldTranslateViewWithPrependNewAxis() {
|
|
||||||
// TODO FIX get view with a new prepend axis
|
|
||||||
long[] shape = new long[] {4, 2, 3};
|
|
||||||
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
|
|
||||||
{3, 0, 2}, {3, 1, 0}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray original = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
|
|
||||||
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) original.get(NDArrayIndex.newAxis(), NDArrayIndex.all(),
|
|
||||||
NDArrayIndex.point(1));
|
|
||||||
System.out.println(view.getIncludedIndices());
|
|
||||||
System.out.println(view.getIncludedValues());
|
|
||||||
assertArrayEquals(new int[] {0, 1, 0}, view.translateToPhysical(new int[] {0, 0, 0}));
|
|
||||||
assertArrayEquals(new int[] {1, 1, 1}, view.translateToPhysical(new int[] {0, 1, 1}));
|
|
||||||
|
|
||||||
int[] originalIdx = view.translateToPhysical(new int[] {0, 1, 2});
|
|
||||||
int[] exceptedIdx = new int[] {1, 0, 2};
|
|
||||||
assertArrayEquals(exceptedIdx, originalIdx);
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldSortCOOIndices() {
|
|
||||||
long[] shape = new long[] {4, 3, 3};
|
|
||||||
double[] values = new double[] {1};
|
|
||||||
long[][] indices = new long[][] {{0, 0, 0}};
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
/* INDArray original = Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
original.putScalar(2, 2, 2, 3);
|
|
||||||
original.putScalar(1, 1, 1, 2);
|
|
||||||
|
|
||||||
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) original.get(NDArrayIndex.all());
|
|
||||||
int[] expectedIdx = new int[] {0, 0, 0, 1, 1, 1, 2, 2, 2};
|
|
||||||
double[] expectedValues = new double[] {1, 2, 3};
|
|
||||||
assertArrayEquals(expectedIdx, view.getIncludedIndices().asInt());
|
|
||||||
assertArrayEquals(expectedValues, view.getIncludedValues().asDouble(), 1e-5);
|
|
||||||
assertTrue(view == original);
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testWithDense() {
|
|
||||||
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
|
||||||
System.out.println(arr);
|
|
||||||
INDArray view = arr.get(NDArrayIndex.all(), NDArrayIndex.point(1));
|
|
||||||
// System.out.println(view.shapeInfoDataBuffer());
|
|
||||||
view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all());
|
|
||||||
System.out.println("view");
|
|
||||||
System.out.println(view);
|
|
||||||
System.out.println(view.shapeInfoDataBuffer());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void newAxisWithDenseArray() {
|
|
||||||
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
|
||||||
System.out.println(arr.toString());
|
|
||||||
System.out.println(arr.shapeInfoDataBuffer());
|
|
||||||
|
|
||||||
System.out.println("\npoint 0");
|
|
||||||
INDArray v = arr.get(NDArrayIndex.point(0));
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
// => shape 2 x 3
|
|
||||||
|
|
||||||
System.out.println("new axis, all, point 1");
|
|
||||||
v = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
|
|
||||||
//System.out.println(v.toString());
|
|
||||||
|
|
||||||
v = arr.get(NDArrayIndex.interval(1, 4), NDArrayIndex.point(0), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
System.out.println(v.isView());
|
|
||||||
// => shape 1 x 2 x 3
|
|
||||||
|
|
||||||
System.out.println("\npoint 0, newaxis");
|
|
||||||
v = arr.get(NDArrayIndex.point(0), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
System.out.println(v.isView());
|
|
||||||
// => shape 1 x 2 x 3
|
|
||||||
|
|
||||||
System.out.println("\n point 0, newaxis, newaxis");
|
|
||||||
v = arr.get(NDArrayIndex.point(0), NDArrayIndex.newAxis(), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
// => shape 1 x 1 x 2 x 3
|
|
||||||
|
|
||||||
System.out.println("\n new axis, point 0, newaxis");
|
|
||||||
v = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.point(0), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
// => shape 1 x 1 x 2 x 3
|
|
||||||
|
|
||||||
System.out.println("\nget( new axis, point(0), point(0), new axis)");
|
|
||||||
v = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
System.out.println(v.toString());
|
|
||||||
// => shape 1 x 1 x 3 x 1
|
|
||||||
|
|
||||||
System.out.println("\nget( specified(1), specified(0), new axis)");
|
|
||||||
v = arr.get(new SpecifiedIndex(0), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
System.out.println(v.toString());
|
|
||||||
// => crash
|
|
||||||
|
|
||||||
// System.out.println("\nget( new axis, point(0), new axis, point(0))");
|
|
||||||
// v = arr.get( NDArrayIndex.newAxis(), NDArrayIndex.point(0), NDArrayIndex.newAxis(), NDArrayIndex.point(0));
|
|
||||||
// System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
// System.out.println(v.toString());
|
|
||||||
// => crash
|
|
||||||
|
|
||||||
System.out.println("\n interval(0, 2), newaxis");
|
|
||||||
v = arr.get(NDArrayIndex.interval(0, 2), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
// => shape 1 x 2 x 2 x 3 - new axis is added at the first position
|
|
||||||
|
|
||||||
/* System.out.println("\n point 0 , all, new axis");
|
|
||||||
v = arr.get(
|
|
||||||
NDArrayIndex.point(0),
|
|
||||||
NDArrayIndex.all(),
|
|
||||||
NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
*/
|
|
||||||
// => crash
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testDenseNewAxisWithSpecifiedIdx() {
|
|
||||||
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
|
||||||
INDArray v = arr.get(new SpecifiedIndex(0), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
System.out.println(v.toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testDenseNewAxisWithSpecifiedIdx2() {
|
|
||||||
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
|
||||||
INDArray v = arr.get(NDArrayIndex.newAxis(), new SpecifiedIndex(0, 1), NDArrayIndex.all());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
System.out.println(v.toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testDenseNewAxisWithSpecifiedIdx3() {
|
|
||||||
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
|
||||||
INDArray v = arr.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.newAxis());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
System.out.println(v.toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testDenseWithNewAxis() {
|
|
||||||
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
|
||||||
INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
|
|
||||||
System.out.println(view);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testWithPrependNewAxis() {
|
|
||||||
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
|
||||||
System.out.println(arr.toString());
|
|
||||||
System.out.println(arr.shapeInfoDataBuffer());
|
|
||||||
|
|
||||||
System.out.println("new axis, all, point 1");
|
|
||||||
INDArray v = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
|
|
||||||
System.out.println(v.toString());
|
|
||||||
System.out.println(v.shapeInfoDataBuffer());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void binarySearch() {
|
|
||||||
long[] shape = new long[] {4, 2, 3};
|
|
||||||
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
|
|
||||||
{3, 0, 2}, {3, 1, 0}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* BaseSparseNDArrayCOO array = (BaseSparseNDArrayCOO) Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
|
|
||||||
assertEquals(0, array.reverseIndexes(0, 0, 2));
|
|
||||||
assertEquals(7, array.reverseIndexes(3, 0, 2));
|
|
||||||
assertEquals(8, array.reverseIndexes(3, 1, 0));
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void rdmTest(){
|
|
||||||
INDArray i = Nd4j.rand(new int[]{3, 3, 3});
|
|
||||||
INDArray ii = i.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all());
|
|
||||||
System.out.println(ii);
|
|
||||||
System.out.println(ii.shapeInfoDataBuffer());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void tryToFindABugWithHiddenDim(){
|
|
||||||
|
|
||||||
long[] shape = new long[] {1, 4, 2, 3};
|
|
||||||
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
int[][] indices = new int[][] {{0, 0, 0, 2}, {0, 0, 1, 1}, {0, 1, 0, 0}, {0, 1, 0, 1}, {0, 1, 1, 2}, {0, 2, 0, 1}, {0, 2, 1, 2},
|
|
||||||
{0, 3, 0, 2}, {0, 3, 1, 0}};
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* BaseSparseNDArrayCOO array = (BaseSparseNDArrayCOO) Nd4j.createSparseCOO(values, indices, shape);
|
|
||||||
|
|
||||||
BaseSparseNDArrayCOO view1 = (BaseSparseNDArrayCOO) array.get( NDArrayIndex.point(0), NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.point(0));
|
|
||||||
System.out.println(view1.shapeInfoDataBuffer());
|
|
||||||
System.out.println(view1.sparseInfoDataBuffer());
|
|
||||||
|
|
||||||
BaseSparseNDArrayCOO view2 = (BaseSparseNDArrayCOO) view1.get( NDArrayIndex.point(0), NDArrayIndex.newAxis(),NDArrayIndex.newAxis(), NDArrayIndex.point(0));
|
|
||||||
System.out.println(view2.shapeInfoDataBuffer());
|
|
||||||
System.out.println(view2.sparseInfoDataBuffer());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,270 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg;
|
|
||||||
|
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCSR;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
@Ignore // temporary ignored
|
|
||||||
public class SparseNDArrayCSRTest extends BaseNd4jTest {
|
|
||||||
|
|
||||||
public SparseNDArrayCSRTest(Nd4jBackend b){
|
|
||||||
super(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public char ordering(){
|
|
||||||
return 'c';
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* [[1 -1 0 -3 0]
|
|
||||||
* [-2 4 0 0 0 ]
|
|
||||||
* [ 0 0 4 6 4 ] = A
|
|
||||||
* [-4 0 2 7 0 ]
|
|
||||||
* [ 0 8 0 0 -5]]
|
|
||||||
* */
|
|
||||||
|
|
||||||
// CSR representation of the matrix A according to https://software.intel.com/en-us/node/599835
|
|
||||||
private double[] values = {1, -2, -3, -2, 5, 4, 6, 4, -4, 2, 7, 8, -5};
|
|
||||||
private int[] columns = {0, 1, 3, 0, 1, 2, 3, 4, 0, 2, 3, 1, 4};
|
|
||||||
private int[] pointerB = {0, 3, 5, 8, 11};
|
|
||||||
private int[] pointerE = {3, 5, 8, 11, 13};
|
|
||||||
private long[] shape = {5, 5};
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldAddValueAtAGivenPosition() {
|
|
||||||
/*
|
|
||||||
* [[1 -1 0 -3 0]
|
|
||||||
* [-2 4 0 0 0 ]
|
|
||||||
* [ 0 3 4 6 4 ] = A'
|
|
||||||
* [-4 0 2 7 0 ]
|
|
||||||
* [ 0 8 0 0 -5]]
|
|
||||||
* */
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
/*
|
|
||||||
INDArray sparseNDArray = Nd4j.createSparseCSR(values, columns, pointerB, pointerE, shape);
|
|
||||||
if (sparseNDArray instanceof BaseSparseNDArrayCSR) {
|
|
||||||
BaseSparseNDArrayCSR sparseCSRArray = (BaseSparseNDArrayCSR) sparseNDArray;
|
|
||||||
sparseCSRArray.putScalar(2, 1, 3);
|
|
||||||
|
|
||||||
double[] expectedValues = {1, -2, -3, -2, 5, 3, 4, 6, 4, -4, 2, 7, 8, -5};
|
|
||||||
double[] expectedColumns = {0, 1, 3, 0, 1, 1, 2, 3, 4, 0, 2, 3, 1, 4};
|
|
||||||
int[] expectedPointerB = {0, 3, 5, 9, 12};
|
|
||||||
int[] expectedPointerE = {3, 5, 9, 12, 14};
|
|
||||||
long[] expectedShape = {5, 5};
|
|
||||||
|
|
||||||
|
|
||||||
assertArrayEquals(expectedValues, sparseCSRArray.getDoubleValues(), 0);
|
|
||||||
assertArrayEquals(expectedColumns, sparseCSRArray.getColumns(), 0);
|
|
||||||
assertArrayEquals(expectedPointerB, sparseCSRArray.getPointerBArray());
|
|
||||||
assertArrayEquals(expectedPointerE, sparseCSRArray.getPointerEArray());
|
|
||||||
assertArrayEquals(expectedShape, shape);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldReallocate() {
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
/*
|
|
||||||
INDArray sparseNDArray = Nd4j.createSparseCSR(values, columns, pointerB, pointerE, shape);
|
|
||||||
if (sparseNDArray instanceof BaseSparseNDArrayCSR) {
|
|
||||||
BaseSparseNDArrayCSR sparseCSRArray = (BaseSparseNDArrayCSR) sparseNDArray;
|
|
||||||
int initialSize = sparseCSRArray.getDoubleValues().length;
|
|
||||||
|
|
||||||
for (int i = 0; i < shape[0]; i++) {
|
|
||||||
for (int j = 0; j < shape[1]; j++) {
|
|
||||||
sparseCSRArray.putScalar(i, j, i + j);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int finalSize = sparseCSRArray.getDoubleValues().length;
|
|
||||||
assertTrue(finalSize > initialSize);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldReplaceValueAtAGivenPosition() {
|
|
||||||
/*
|
|
||||||
* [[1 -1 0 -3 0]
|
|
||||||
* [-2 4 0 0 0 ]
|
|
||||||
* [ 0 0 10 6 4] = A'
|
|
||||||
* [-4 0 2 7 0 ]
|
|
||||||
* [ 0 8 0 0 -5]]
|
|
||||||
* */
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
/*
|
|
||||||
INDArray sparseNDArray = Nd4j.createSparseCSR(values, columns, pointerB, pointerE, shape);
|
|
||||||
if (sparseNDArray instanceof BaseSparseNDArrayCSR) {
|
|
||||||
BaseSparseNDArrayCSR sparseCSRArray = (BaseSparseNDArrayCSR) sparseNDArray;
|
|
||||||
sparseCSRArray.putScalar(2, 2, 10);
|
|
||||||
|
|
||||||
double[] expectedValues = {1, -2, -3, -2, 5, 10, 6, 4, -4, 2, 7, 8, -5};
|
|
||||||
double[] expectedColumns = {0, 1, 3, 0, 1, 2, 3, 4, 0, 2, 3, 1, 4};
|
|
||||||
int[] expectedPointerB = {0, 3, 5, 8, 11};
|
|
||||||
int[] expectedPointerE = {3, 5, 8, 11, 13};
|
|
||||||
long[] expectedShape = {5, 5};
|
|
||||||
|
|
||||||
assertArrayEquals(expectedValues, sparseCSRArray.getDoubleValues(), 0);
|
|
||||||
assertArrayEquals(expectedColumns, sparseCSRArray.getColumns(), 0);
|
|
||||||
assertArrayEquals(expectedPointerB, sparseCSRArray.getPointerBArray());
|
|
||||||
assertArrayEquals(expectedPointerE, sparseCSRArray.getPointerEArray());
|
|
||||||
assertArrayEquals(expectedShape, shape);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldGetValueAtAGivenPosition() {
|
|
||||||
// Not yet implemented
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldBeEqualToDense() {
|
|
||||||
// Not yet implemented
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldGetAView() {
|
|
||||||
|
|
||||||
double[] values = {1, -1, -3, -2, 5, 4, 6, 4, -4, 2, 7, 8, 5};
|
|
||||||
int[] columns = {0, 1, 3, 0, 1, 2, 3, 4, 0, 2, 3, 1, 4};
|
|
||||||
int[] pointerB = {0, 3, 5, 8, 11};
|
|
||||||
int[] pointerE = {3, 5, 8, 11, 13};
|
|
||||||
|
|
||||||
// Test with dense ndarray
|
|
||||||
double[] data = {1, -1, 0, -3, 0, -2, 5, 0, 0, 0, 0, 0, 4, 6, 4, -4, 0, 2, 7, 0, 0, 8, 0, 0, 5};
|
|
||||||
INDArray array = Nd4j.create(data, new int[] {5, 5}, 0, 'c');
|
|
||||||
INDArray denseView = array.get(NDArrayIndex.interval(1, 3), NDArrayIndex.interval(1, 3));
|
|
||||||
|
|
||||||
// test with sparse :
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
/*
|
|
||||||
INDArray sparseNDArray = Nd4j.createSparseCSR(values, columns, pointerB, pointerE, shape);
|
|
||||||
|
|
||||||
// subarray in the top right corner
|
|
||||||
BaseSparseNDArrayCSR sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.interval(0, 3),
|
|
||||||
NDArrayIndex.interval(3, 5));
|
|
||||||
assertArrayEquals(new int[] {0, 0, 1}, sparseView.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {2, 3, 6}, sparseView.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {3, 3, 8}, sparseView.getPointerEArray());
|
|
||||||
|
|
||||||
// subarray in the middle
|
|
||||||
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.interval(1, 3), NDArrayIndex.interval(1, 3));
|
|
||||||
assertArrayEquals(new int[] {0, 1}, sparseView.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {4, 5}, sparseView.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {5, 6}, sparseView.getPointerEArray());
|
|
||||||
|
|
||||||
// get the first row
|
|
||||||
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
|
||||||
assertArrayEquals(new int[] {0, 0, 0}, sparseView.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {0, 3, 4, 8, 9}, sparseView.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {1, 4, 4, 9, 9}, sparseView.getPointerEArray());
|
|
||||||
|
|
||||||
// get several rows
|
|
||||||
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.interval(0, 2), NDArrayIndex.all());
|
|
||||||
assertArrayEquals(new int[] {0, 1, 3, 0, 1}, sparseView.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {0, 3}, sparseView.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {3, 5}, sparseView.getPointerEArray());
|
|
||||||
|
|
||||||
// get a row in the middle
|
|
||||||
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.point(2), NDArrayIndex.all());
|
|
||||||
assertArrayEquals(new int[] {2, 3, 4}, sparseView.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {5}, sparseView.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {8}, sparseView.getPointerEArray());
|
|
||||||
|
|
||||||
// get the first column
|
|
||||||
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
|
||||||
assertArrayEquals(new int[] {0, 0, 0}, sparseView.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {0, 3, 4, 8, 9}, sparseView.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {1, 4, 4, 9, 9}, sparseView.getPointerEArray());
|
|
||||||
|
|
||||||
// get a column in the middle
|
|
||||||
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.all(), NDArrayIndex.point(2));
|
|
||||||
assertArrayEquals(new int[] {0, 0}, sparseView.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {0, 0, 5, 9, 10}, sparseView.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {0, 0, 6, 10, 10}, sparseView.getPointerEArray());
|
|
||||||
|
|
||||||
// get a part of the column in the middle
|
|
||||||
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.interval(1, 4), NDArrayIndex.point(2));
|
|
||||||
assertArrayEquals(new int[] {0, 0}, sparseView.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {0, 5, 9}, sparseView.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {0, 6, 10}, sparseView.getPointerEArray());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldGetAViewFromView() {
|
|
||||||
double[] values = {1, -1, -3, -2, 5, 4, 6, 4, -4, 2, 7, 8, 5};
|
|
||||||
int[] columns = {0, 1, 3, 0, 1, 2, 3, 4, 0, 2, 3, 1, 4};
|
|
||||||
int[] pointerB = {0, 3, 5, 8, 11};
|
|
||||||
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
//INDArray sparseNDArray = Nd4j.createSparseCSR(values, columns, pointerB, pointerE, shape);
|
|
||||||
/* [0, -3, 0]
|
|
||||||
* sparseView = [0, 0, 0] subview = [[0,0], [4,6]]
|
|
||||||
* [4, 6, 4]
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*BaseSparseNDArrayCSR sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.interval(0, 3),
|
|
||||||
NDArrayIndex.interval(2, 5));
|
|
||||||
BaseSparseNDArrayCSR subview =
|
|
||||||
(BaseSparseNDArrayCSR) sparseView.get(NDArrayIndex.interval(1, 3), NDArrayIndex.interval(0, 2));
|
|
||||||
assertArrayEquals(new int[] {0, 1}, subview.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {0, 5}, subview.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {0, 7}, subview.getPointerEArray());
|
|
||||||
|
|
||||||
// get the first column
|
|
||||||
subview = (BaseSparseNDArrayCSR) sparseView.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
|
||||||
assertArrayEquals(new int[] {0}, subview.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {0, 0, 5}, subview.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {0, 0, 6}, subview.getPointerEArray());
|
|
||||||
|
|
||||||
// get a column in the middle
|
|
||||||
subview = (BaseSparseNDArrayCSR) sparseView.get(NDArrayIndex.all(), NDArrayIndex.point(1));
|
|
||||||
assertArrayEquals(new int[] {0, 0}, subview.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {2, 3, 6}, subview.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {3, 3, 7}, subview.getPointerEArray());
|
|
||||||
|
|
||||||
// get the first row
|
|
||||||
subview = (BaseSparseNDArrayCSR) sparseView.get(NDArrayIndex.point(0), NDArrayIndex.all());
|
|
||||||
assertArrayEquals(new int[] {1}, subview.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {2}, subview.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {3}, subview.getPointerEArray());
|
|
||||||
|
|
||||||
// get a row in the middle
|
|
||||||
subview = (BaseSparseNDArrayCSR) sparseView.get(NDArrayIndex.point(1), NDArrayIndex.all());
|
|
||||||
assertArrayEquals(new int[] {}, subview.getVectorCoordinates().asInt());
|
|
||||||
assertArrayEquals(new int[] {0}, subview.getPointerBArray());
|
|
||||||
assertArrayEquals(new int[] {0}, subview.getPointerEArray());
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,167 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.blas;
|
|
||||||
|
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.runner.RunWith;
|
|
||||||
import org.junit.runners.Parameterized;
|
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseNDArray;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
@Ignore // temporary ignored
|
|
||||||
@RunWith(Parameterized.class)
|
|
||||||
public class SparseCOOLevel1Test extends BaseNd4jTest {
|
|
||||||
|
|
||||||
// vector = [1, 2, 0, 4]
|
|
||||||
private double[] data = {1, 2, 4};
|
|
||||||
private int[][] indexes = new int[][] {{0, 0}, {0, 1}, {0, 3}};
|
|
||||||
private long[] shape = {1, 4};
|
|
||||||
|
|
||||||
public SparseCOOLevel1Test(Nd4jBackend backend) {
|
|
||||||
super(backend);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeDot() {
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/*INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
|
|
||||||
//INDArray vec = Nd4j.create( new double[] {1 ,2, 3, 4});
|
|
||||||
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.FLOAT).reshape(1, 4);
|
|
||||||
INDArray vec = matrix.getRow(0);
|
|
||||||
assertEquals(21, Nd4j.getBlasWrapper().dot(sparseVec, vec), 1e-1);
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeNrm2() {
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
//INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
|
|
||||||
//assertEquals(Math.sqrt(21), Nd4j.getBlasWrapper().nrm2(sparseVec), 1e-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeAsum() {
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
//INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
|
|
||||||
//assertEquals(7, Nd4j.getBlasWrapper().asum(sparseVec), 1e-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeIamax() {
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
// INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
|
|
||||||
// assertEquals(2, Nd4j.getBlasWrapper().iamax(sparseVec), 1e-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeIamin() {
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
// INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
|
|
||||||
// assertEquals(0, Nd4j.getBlasWrapper().level1().iamin(sparseVec), 1e-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeAxpy() {
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/* INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
|
|
||||||
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
|
|
||||||
INDArray expected = Nd4j.create(new double[] {2, 4, 3, 8});
|
|
||||||
Nd4j.getBlasWrapper().level1().axpy(vec.length(), 1, sparseVec, vec);
|
|
||||||
assertEquals(getFailureMessage(), expected, vec);
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeRot() {
|
|
||||||
|
|
||||||
// try with dense vectors to get the expected result
|
|
||||||
|
|
||||||
INDArray temp1 = Nd4j.create(new double[] {1, 2, 0, 4});
|
|
||||||
INDArray temp2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
|
||||||
System.out.println("before: " + temp1.data() + " " + temp2.data());
|
|
||||||
Nd4j.getBlasWrapper().level1().rot(temp1.length(), temp1, temp2, 1, 2);
|
|
||||||
System.out.println("after: " + temp1.data() + " " + temp2.data());
|
|
||||||
|
|
||||||
//before: [1.0,2.0,0.0,4.0] [1.0,2.0,3.0,4.0]
|
|
||||||
// after: [3.0,6.0,6.0,12.0] [-1.0,-2.0,3.0,-4.0]
|
|
||||||
|
|
||||||
// Commented out on removal of Nd4j createSparse methods
|
|
||||||
/*INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
|
|
||||||
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
|
|
||||||
Nd4j.getBlasWrapper().level1().rot(vec.length(), sparseVec, vec, 1, 2);
|
|
||||||
System.out.println(sparseVec.data() + " " + vec.data());
|
|
||||||
//System.out.println("indexes: " + ((BaseSparseNDArray) sparseVec).getVectorCoordinates().toString());
|
|
||||||
|
|
||||||
INDArray expectedSparseVec = Nd4j.createSparseCSR(new double[] {3, 6, 6, 12}, new int[] {0, 1, 2, 3},
|
|
||||||
new int[] {0}, new int[] {4}, new long[] {1, 4});
|
|
||||||
|
|
||||||
INDArray expectedVec = Nd4j.create(new double[] {-1, -2, 3, -4});
|
|
||||||
assertEquals(getFailureMessage(), expectedSparseVec.data(), sparseVec.data());
|
|
||||||
|
|
||||||
assertEquals(getFailureMessage(), expectedVec, vec);
|
|
||||||
*/
|
|
||||||
// TODO FIXME
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeRotWithFullVector() {
|
|
||||||
|
|
||||||
// try with dense vectors to get the expected result
|
|
||||||
/*
|
|
||||||
INDArray temp1 = Nd4j.create( new double[] {1 ,2, 3, 4});
|
|
||||||
INDArray temp2 = Nd4j.create( new double[] {1 ,2, 3, 4});
|
|
||||||
System.out.println("before: " + temp1.data() + " " + temp2.data());
|
|
||||||
Nd4j.getBlasWrapper().level1().rot(temp1.length(), temp1, temp2, 1, 2);
|
|
||||||
System.out.println("after: " + temp1.data() + " " + temp2.data());
|
|
||||||
*/
|
|
||||||
//before: [1.0,2.0,3.0,4.0] [1.0,2.0,3.0,4.0]
|
|
||||||
// after: [3.0,6.0,0.0,12.0] [-1.0,-2.0,-3.0,-4.0]
|
|
||||||
|
|
||||||
int[] cols = {0, 1, 2, 3};
|
|
||||||
double[] values = {1, 2, 3, 4};
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
/* INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
|
|
||||||
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
|
|
||||||
Nd4j.getBlasWrapper().level1().rot(vec.length(), sparseVec, vec, 1, 2);
|
|
||||||
INDArray expectedSparseVec = Nd4j.createSparseCSR(new double[] {3, 6, 9, 12}, new int[] {0, 1, 2, 3},
|
|
||||||
new int[] {0}, new int[] {4}, new long[] {1, 4});
|
|
||||||
INDArray expectedVec = Nd4j.create(new double[] {-1, -2, -3, -4});
|
|
||||||
assertEquals(getFailureMessage(), expectedSparseVec.data(), sparseVec.data());
|
|
||||||
assertEquals(getFailureMessage(), expectedVec, vec);
|
|
||||||
if (expectedSparseVec.isSparse() && sparseVec.isSparse()) {
|
|
||||||
BaseSparseNDArray vec2 = ((BaseSparseNDArray) expectedSparseVec);
|
|
||||||
BaseSparseNDArray vecSparse2 = ((BaseSparseNDArray) sparseVec);
|
|
||||||
assertEquals(getFailureMessage(), vec2.getVectorCoordinates(), vecSparse2);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public char ordering() {
|
|
||||||
return 'c';
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,65 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.blas;
|
|
||||||
|
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.runner.RunWith;
|
|
||||||
import org.junit.runners.Parameterized;
|
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
@Ignore // temporary ignored
|
|
||||||
@RunWith(Parameterized.class)
|
|
||||||
public class SparseCOOLevel2Test extends BaseNd4jTest {
|
|
||||||
|
|
||||||
// matrix = [[1, 2], [0, 0]]
|
|
||||||
private double[] data = {1, 2};
|
|
||||||
private int[][] indexes = new int[][] {{0, 0}, {0, 1}};
|
|
||||||
private long[] shape = {2, 2};
|
|
||||||
|
|
||||||
public SparseCOOLevel2Test(Nd4jBackend backend) {
|
|
||||||
super(backend);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testGemv() {
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
/* INDArray array1 = Nd4j.createSparseCOO(data, indexes, shape);
|
|
||||||
INDArray array2 = Nd4j.linspace(1, 2, 2).reshape(2, 1);
|
|
||||||
|
|
||||||
INDArray array3 = array1.mmul(array2); // should be [5, 0]
|
|
||||||
assertEquals(2, array3.length());
|
|
||||||
assertEquals(5, array3.getFloat(0), 1e-5);
|
|
||||||
assertEquals(0, array3.getFloat(1), 1e-5);
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public char ordering() {
|
|
||||||
return 'c';
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,168 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.blas;
|
|
||||||
|
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.runner.RunWith;
|
|
||||||
import org.junit.runners.Parameterized;
|
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.BaseSparseNDArray;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Audrey Loeffel
|
|
||||||
*/
|
|
||||||
@Ignore // temporary ignored
|
|
||||||
@RunWith(Parameterized.class)
|
|
||||||
public class SparseCSRLevel1Test extends BaseNd4jTest {
|
|
||||||
|
|
||||||
private double[] data = {1, 2, 4};
|
|
||||||
private int[] col = {0, 1, 3};
|
|
||||||
private int[] pointerB = {0};
|
|
||||||
private int[] pointerE = {4};
|
|
||||||
private long[] shape = {1, 4};
|
|
||||||
|
|
||||||
public SparseCSRLevel1Test(Nd4jBackend backend) {
|
|
||||||
super(backend);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeDot() {
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
//INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
|
|
||||||
//INDArray vec = Nd4j.create( new double[] {1 ,2, 3, 4});
|
|
||||||
INDArray matrix = Nd4j.linspace(1, 4, 4).reshape(1, 4);
|
|
||||||
INDArray vec = matrix.getRow(0);
|
|
||||||
//assertEquals(21, Nd4j.getBlasWrapper().dot(sparseVec, vec), 1e-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeNrm2() {
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
//INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
|
|
||||||
//assertEquals(Math.sqrt(21), Nd4j.getBlasWrapper().nrm2(sparseVec), 1e-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeAsum() {
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
//INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
|
|
||||||
//assertEquals(7, Nd4j.getBlasWrapper().asum(sparseVec), 1e-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeIamax() {
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
//INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
|
|
||||||
//assertEquals(2, Nd4j.getBlasWrapper().iamax(sparseVec), 1e-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeIamin() {
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
//INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
|
|
||||||
//assertEquals(0, Nd4j.getBlasWrapper().level1().iamin(sparseVec), 1e-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeAxpy() {
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
/* INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
|
|
||||||
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
|
|
||||||
INDArray expected = Nd4j.create(new double[] {2, 4, 3, 8});
|
|
||||||
Nd4j.getBlasWrapper().level1().axpy(vec.length(), 1, sparseVec, vec);
|
|
||||||
assertEquals(getFailureMessage(), expected, vec);
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeRot() {
|
|
||||||
|
|
||||||
// try with dense vectors to get the expected result
|
|
||||||
|
|
||||||
INDArray temp1 = Nd4j.create(new double[] {1, 2, 0, 4});
|
|
||||||
INDArray temp2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
|
||||||
System.out.println("before: " + temp1.data() + " " + temp2.data());
|
|
||||||
Nd4j.getBlasWrapper().level1().rot(temp1.length(), temp1, temp2, 1, 2);
|
|
||||||
System.out.println("after: " + temp1.data() + " " + temp2.data());
|
|
||||||
|
|
||||||
//before: [1.0,2.0,0.0,4.0] [1.0,2.0,3.0,4.0]
|
|
||||||
// after: [3.0,6.0,6.0,12.0] [-1.0,-2.0,3.0,-4.0]
|
|
||||||
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
/* INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
|
|
||||||
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
|
|
||||||
Nd4j.getBlasWrapper().level1().rot(vec.length(), sparseVec, vec, 1, 2);
|
|
||||||
System.out.println(sparseVec.data() + " " + vec.data());
|
|
||||||
|
|
||||||
//System.out.println("indexes: " + ((BaseSparseNDArray) sparseVec).getVectorCoordinates().toString());
|
|
||||||
INDArray expectedSparseVec = Nd4j.createSparseCSR(new double[] {3, 6, 6, 12}, new int[] {0, 1, 2, 3},
|
|
||||||
new int[] {0}, new int[] {4}, new long[] {1, 4});
|
|
||||||
INDArray expectedVec = Nd4j.create(new double[] {-1, -2, 3, -4});
|
|
||||||
assertEquals(getFailureMessage(), expectedSparseVec.data(), sparseVec.data());
|
|
||||||
|
|
||||||
assertEquals(getFailureMessage(), expectedVec, vec);
|
|
||||||
*/
|
|
||||||
// TODO fix it
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void shouldComputeRotWithFullVector() {
|
|
||||||
|
|
||||||
// try with dense vectors to get the expected result
|
|
||||||
/*
|
|
||||||
INDArray temp1 = Nd4j.create( new double[] {1 ,2, 3, 4});
|
|
||||||
INDArray temp2 = Nd4j.create( new double[] {1 ,2, 3, 4});
|
|
||||||
System.out.println("before: " + temp1.data() + " " + temp2.data());
|
|
||||||
Nd4j.getBlasWrapper().level1().rot(temp1.length(), temp1, temp2, 1, 2);
|
|
||||||
System.out.println("after: " + temp1.data() + " " + temp2.data());
|
|
||||||
*/
|
|
||||||
//before: [1.0,2.0,3.0,4.0] [1.0,2.0,3.0,4.0]
|
|
||||||
// after: [3.0,6.0,0.0,12.0] [-1.0,-2.0,-3.0,-4.0]
|
|
||||||
|
|
||||||
int[] cols = {0, 1, 2, 3};
|
|
||||||
double[] values = {1, 2, 3, 4};
|
|
||||||
// commented out on removal of createSparse methods from Nd4j
|
|
||||||
/* INDArray sparseVec = Nd4j.createSparseCSR(values, cols, pointerB, pointerE, shape);
|
|
||||||
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
|
|
||||||
Nd4j.getBlasWrapper().level1().rot(vec.length(), sparseVec, vec, 1, 2);
|
|
||||||
|
|
||||||
INDArray expectedSparseVec = Nd4j.createSparseCSR(new double[] {3, 6, 9, 12}, new int[] {0, 1, 2, 3},
|
|
||||||
new int[] {0}, new int[] {4}, new long[] {1, 4});
|
|
||||||
INDArray expectedVec = Nd4j.create(new double[] {-1, -2, -3, -4});
|
|
||||||
assertEquals(getFailureMessage(), expectedSparseVec.data(), sparseVec.data());
|
|
||||||
assertEquals(getFailureMessage(), expectedVec, vec);
|
|
||||||
if (expectedSparseVec.isSparse() && sparseVec.isSparse()) {
|
|
||||||
BaseSparseNDArray vec2 = ((BaseSparseNDArray) expectedSparseVec);
|
|
||||||
BaseSparseNDArray vecSparse2 = ((BaseSparseNDArray) sparseVec);
|
|
||||||
assertEquals(getFailureMessage(), vec2.getVectorCoordinates(), vecSparse2);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public char ordering() {
|
|
||||||
return 'c';
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue