Sparse matrix refactoring. (#8238)

* remove sparse method from INDArray.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove gemm

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove useage of n4j.sparseFactory

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Nd4j.sparseFactory removed.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* sparseNDArray deleted.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* iremove more sparse calls and constants.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove SparseBlasWrapper.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* delete BaseSparseBlaswrapper.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove 3 sparse factory classes.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* delete SparseCPULevel.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* deletes JcusparseLevel, CUDASparselevel.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* delete nativeCPU sparse classes.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* removes sparse methods from NDArrayFactory.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* more deletes.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* delete (ignored) tests. BaseSparseNDArray.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* deletes ISparseNDArray.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove sparse methods from indArray.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* deletes sparse classes.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-09-18 04:56:29 +09:00 committed by raver119
parent 979ef13c0b
commit 83d958d536
54 changed files with 8 additions and 9004 deletions

View File

@ -48,15 +48,6 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, X, Y); OpProfiler.getInstance().processBlasCall(false, X, Y);
if (X.isSparse() && !Y.isSparse()) {
return Nd4j.getSparseBlasWrapper().level1().dot(n, alpha, X, Y);
} else if (!X.isSparse() && Y.isSparse()) {
return Nd4j.getSparseBlasWrapper().level1().dot(n, alpha, Y, X);
} else if (X.isSparse() && Y.isSparse()) {
// TODO - MKL doesn't contain such routines
return 0;
}
if (X.data().dataType() == DataType.DOUBLE) { if (X.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Y); DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Y);
return ddot(n, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(Y)); return ddot(n, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(Y));
@ -100,14 +91,9 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
@Override @Override
public double nrm2(INDArray arr) { public double nrm2(INDArray arr) {
if (arr.isSparse()) {
return Nd4j.getSparseBlasWrapper().level1().nrm2(arr);
}
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, arr); OpProfiler.getInstance().processBlasCall(false, arr);
if (arr.isSparse()) {
return Nd4j.getSparseBlasWrapper().level1().nrm2(arr);
}
if (arr.data().dataType() == DataType.DOUBLE) { if (arr.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr); DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr);
return dnrm2(arr.length(), arr, BlasBufferUtil.getBlasStride(arr)); return dnrm2(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
@ -127,9 +113,6 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
@Override @Override
public double asum(INDArray arr) { public double asum(INDArray arr) {
if (arr.isSparse()) {
return Nd4j.getSparseBlasWrapper().level1().asum(arr);
}
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, arr); OpProfiler.getInstance().processBlasCall(false, arr);
@ -202,9 +185,7 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
*/ */
@Override @Override
public int iamax(INDArray arr) { public int iamax(INDArray arr) {
if (arr.isSparse()) {
return Nd4j.getSparseBlasWrapper().level1().iamax(arr);
}
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, arr); OpProfiler.getInstance().processBlasCall(false, arr);
@ -225,12 +206,8 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
*/ */
@Override @Override
public int iamin(INDArray arr) { public int iamin(INDArray arr) {
if (arr.isSparse()) {
return Nd4j.getSparseBlasWrapper().level1().iamin(arr);
} else {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
}
/** /**
* swaps a vector with another vector. * swaps a vector with another vector.
@ -243,11 +220,6 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, x, y); OpProfiler.getInstance().processBlasCall(false, x, y);
if (x.isSparse() || y.isSparse()) {
Nd4j.getSparseBlasWrapper().level1().swap(x, y);
return;
}
if (x.data().dataType() == DataType.DOUBLE) { if (x.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y); DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y);
dswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); dswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
@ -269,10 +241,6 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, x, y); OpProfiler.getInstance().processBlasCall(false, x, y);
if (x.isSparse() || y.isSparse()) {
Nd4j.getSparseBlasWrapper().level1().copy(x, y);
return;
}
if (x.data().dataType() == DataType.DOUBLE) { if (x.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y); DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y);
dcopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); dcopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
@ -321,9 +289,7 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, x, y); OpProfiler.getInstance().processBlasCall(false, x, y);
if (x.isSparse() && !y.isSparse()) { if (x.data().dataType() == DataType.DOUBLE) {
Nd4j.getSparseBlasWrapper().level1().axpy(n, alpha, x, y);
} else if (x.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y); DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y);
daxpy(n, alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y)); daxpy(n, alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
} else if (x.data().dataType() == DataType.FLOAT) { } else if (x.data().dataType() == DataType.FLOAT) {
@ -384,9 +350,7 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, X, Y); OpProfiler.getInstance().processBlasCall(false, X, Y);
if (X.isSparse() && !Y.isSparse()) { if (X.data().dataType() == DataType.DOUBLE) {
Nd4j.getSparseBlasWrapper().level1().rot(N, X, Y, c, s);
} else if (X.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Y); DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Y);
drot(N, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(X), c, s); drot(N, X, BlasBufferUtil.getBlasStride(X), Y, BlasBufferUtil.getBlasStride(X), c, s);
} else { } else {
@ -421,9 +385,7 @@ public abstract class BaseLevel1 extends BaseLevel implements Level1 {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, X); OpProfiler.getInstance().processBlasCall(false, X);
if (X.isSparse()) { if (X.data().dataType() == DataType.DOUBLE)
Nd4j.getSparseBlasWrapper().level1().scal(N, alpha, X);
} else if (X.data().dataType() == DataType.DOUBLE)
dscal(N, alpha, X, BlasBufferUtil.getBlasStride(X)); dscal(N, alpha, X, BlasBufferUtil.getBlasStride(X));
else if (X.data().dataType() == DataType.FLOAT) else if (X.data().dataType() == DataType.FLOAT)
sscal(N, (float) alpha, X, BlasBufferUtil.getBlasStride(X)); sscal(N, (float) alpha, X, BlasBufferUtil.getBlasStride(X));

View File

@ -56,11 +56,6 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y); OpProfiler.getInstance().processBlasCall(false, A, X, Y);
if (A.isSparse() && !X.isSparse()) {
Nd4j.getSparseBlasWrapper().level2().gemv(order, transA, alpha, A, X, beta, Y);
return;
}
GemvParameters parameters = new GemvParameters(A, X, Y); GemvParameters parameters = new GemvParameters(A, X, Y);
if (A.data().dataType() == DataType.DOUBLE) { if (A.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(), DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(),

View File

@ -1,70 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.blas.impl;
import org.nd4j.linalg.api.blas.Lapack;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* @author Audrey Loeffel
*/
public class SparseBaseLapack implements Lapack {
@Override
public INDArray getrf(INDArray A) {
return null;
}
@Override
public INDArray getPFactor(int M, INDArray ipiv) {
return null;
}
@Override
public INDArray getLFactor(INDArray A) {
return null;
}
@Override
public INDArray getUFactor(INDArray A) {
return null;
}
@Override
public void getri(int N, INDArray A, int lda, int[] IPIV, INDArray WORK, int lwork, int INFO) {
}
@Override
public void geqrf(INDArray A, INDArray R) {
}
@Override
public void potrf(INDArray A, boolean lower) {
}
@Override
public int syev(char jobz, char uplo, INDArray A, INDArray V) {
return 0;
}
@Override
public void gesvd(INDArray A, INDArray S, INDArray U, INDArray VT) {
}
}

View File

@ -1,24 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.blas.impl;
/**
* @author Audrey Loeffel
*/
public abstract class SparseBaseLevel {
}

View File

@ -1,364 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.blas.impl;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.BaseSparseNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
/**
* @author Audrey Loeffel
*/
public abstract class SparseBaseLevel1 extends SparseBaseLevel implements Level1 {
/**
* computes a vector-vector dot product.
*
* @param n number of accessed element
* @param alpha
* @param X an INDArray
* @param Y an INDArray
* @return the vector-vector dot product of X and Y
*/
@Override
public double dot(long n, double alpha, INDArray X, INDArray Y) {
if (X instanceof BaseSparseNDArray) {
BaseSparseNDArray sparseX = (BaseSparseNDArray) X;
DataBuffer pointers = sparseX.getVectorCoordinates();
switch (X.data().dataType()) {
case DOUBLE:
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Y);
return ddoti(n, X, pointers, Y);
case FLOAT:
DefaultOpExecutioner.validateDataType(DataType.FLOAT, X, Y);
return sdoti(n, X, pointers, Y);
case HALF:
DefaultOpExecutioner.validateDataType(DataType.HALF, X, Y);
return hdoti(n, X, pointers, Y);
default:
}
}
throw new UnsupportedOperationException();
}
@Override
public double dot(long n, DataBuffer dx, int offsetX, int incrX, DataBuffer y, int offsetY, int incrY) {
throw new UnsupportedOperationException();
}
/**
* Computes the Euclidean norm of a vector.
*
* @param arr a vector
* @return the Euclidean norm of the vector
*/
@Override
public double nrm2(INDArray arr) {
switch (arr.data().dataType()) {
case DOUBLE:
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr);
return dnrm2(arr.length(), arr, 1);
case FLOAT:
DefaultOpExecutioner.validateDataType(DataType.FLOAT, arr);
return snrm2(arr.length(), arr, 1);
case HALF:
return hnrm2(arr.length(), arr, 1);
default:
}
throw new UnsupportedOperationException();
}
/**
* Compute the sum of magnitude of the vector elements
*
* @param arr a vector
* @return the sum of magnitude of the vector elements
* */
@Override
public double asum(INDArray arr) {
switch (arr.data().dataType()) {
case DOUBLE:
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr);
return dasum(arr.length(), arr, 1);
case FLOAT:
DefaultOpExecutioner.validateDataType(DataType.FLOAT, arr);
return sasum(arr.length(), arr, 1);
case HALF:
DefaultOpExecutioner.validateDataType(DataType.HALF, arr);
return hasum(arr.length(), arr, 1);
default:
}
throw new UnsupportedOperationException();
}
@Override
public double asum(long n, DataBuffer x, int offsetX, int incrX) {
throw new UnsupportedOperationException();
}
/**
* Find the index of the element with maximum absolute value
*
* @param arr a vector
* @return the index of the element with maximum absolute value
* */
@Override
public int iamax(INDArray arr) {
switch (arr.data().dataType()) {
case DOUBLE:
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr);
return idamax(arr.length(), arr, 1);
case FLOAT:
DefaultOpExecutioner.validateDataType(DataType.FLOAT, arr);
return isamax(arr.length(), arr, 1);
case HALF:
DefaultOpExecutioner.validateDataType(DataType.HALF, arr);
return ihamax(arr.length(), arr, 1);
default:
}
throw new UnsupportedOperationException();
}
@Override
public int iamax(long n, INDArray arr, int stride) {
throw new UnsupportedOperationException();
}
@Override
public int iamax(long n, DataBuffer x, int offsetX, int incrX) {
throw new UnsupportedOperationException();
}
/**
* Find the index of the element with maximum absolute value
*
* @param arr a vector
* @return the index of the element with minimum absolute value
* */
@Override
public int iamin(INDArray arr) {
switch (arr.data().dataType()) {
case DOUBLE:
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, arr);
return idamin(arr.length(), arr, 1);
case FLOAT:
DefaultOpExecutioner.validateDataType(DataType.FLOAT, arr);
return isamin(arr.length(), arr, 1);
case HALF:
DefaultOpExecutioner.validateDataType(DataType.HALF, arr);
return ihamin(arr.length(), arr, 1);
default:
}
throw new UnsupportedOperationException();
}
@Override
public void swap(INDArray x, INDArray y) {
throw new UnsupportedOperationException();
}
@Override
public void copy(INDArray x, INDArray y) {
// FIXME - for Raver119 :)
throw new UnsupportedOperationException();
}
@Override
public void copy(long n, DataBuffer x, int offsetX, int incrX, DataBuffer y, int offsetY, int incrY) {
throw new UnsupportedOperationException();
}
/**
* Adds a scalar multiple of compressed sparse vector to a full-storage vector.
*
* @param n The number of element
* @param alpha
* @param x a sparse vector
* @param y a dense vector
*
* */
@Override
public void axpy(long n, double alpha, INDArray x, INDArray y) {
BaseSparseNDArray sparseX = (BaseSparseNDArray) x;
DataBuffer pointers = sparseX.getVectorCoordinates();
switch (x.data().dataType()) {
case DOUBLE:
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x);
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, y);
daxpyi(n, alpha, x, pointers, y);
break;
case FLOAT:
DefaultOpExecutioner.validateDataType(DataType.FLOAT, x);
DefaultOpExecutioner.validateDataType(DataType.FLOAT, y);
saxpyi(n, alpha, x, pointers, y);
break;
case HALF:
DefaultOpExecutioner.validateDataType(DataType.HALF, x);
DefaultOpExecutioner.validateDataType(DataType.HALF, y);
haxpyi(n, alpha, x, pointers, y);
break;
default:
throw new UnsupportedOperationException();
}
}
@Override
public void axpy(long n, double alpha, DataBuffer x, int offsetX, int incrX, DataBuffer y, int offsetY, int incrY) {
throw new UnsupportedOperationException();
}
@Override
public void rotg(INDArray a, INDArray b, INDArray c, INDArray s) {
throw new UnsupportedOperationException();
}
/**
* Applies Givens rotation to sparse vectors one of which is in compressed form.
*
* @param N The number of elements in vectors X and Y
* @param X a sparse vector
* @param Y a full-storage vector
* @param c a scalar
* @param s a scalar
*
* */
@Override
public void rot(long N, INDArray X, INDArray Y, double c, double s) {
if (X instanceof BaseSparseNDArray) {
BaseSparseNDArray sparseX = (BaseSparseNDArray) X;
switch (X.data().dataType()) {
case DOUBLE:
droti(N, X, sparseX.getVectorCoordinates(), Y, c, s);
break;
case FLOAT:
sroti(N, X, sparseX.getVectorCoordinates(), Y, c, s);
break;
case HALF:
hroti(N, X, sparseX.getVectorCoordinates(), Y, c, s);
break;
default:
throw new UnsupportedOperationException();
}
} else {
throw new UnsupportedOperationException();
}
}
@Override
public void rotmg(INDArray d1, INDArray d2, INDArray b1, double b2, INDArray P) {
throw new UnsupportedOperationException();
}
/**
* Computes the product of a vector by a scalar.
*
* @param N The number of elements of the vector X
* @param alpha a scalar
* @param X a vector
* */
@Override
public void scal(long N, double alpha, INDArray X) {
switch (X.data().dataType()) {
case DOUBLE:
dscal(N, alpha, X, 1);
break;
case FLOAT:
sscal(N, alpha, X, 1);
break;
case HALF:
hscal(N, alpha, X, 1);
break;
default:
throw new UnsupportedOperationException();
}
}
@Override
public boolean supportsDataBufferL1Ops() {
return false;
}
/*
* ===========================================================================
* Prototypes for level 1 BLAS functions (complex are recast as routines)
* ===========================================================================
*/
protected abstract double ddoti(long N, INDArray X, DataBuffer indx, INDArray Y);
protected abstract double sdoti(long N, INDArray X, DataBuffer indx, INDArray Y);
protected abstract double hdoti(long N, INDArray X, DataBuffer indx, INDArray Y);
protected abstract double snrm2(long N, INDArray X, int incx);
protected abstract double dnrm2(long N, INDArray X, int incx);
protected abstract double hnrm2(long N, INDArray X, int incx);
protected abstract double dasum(long N, INDArray X, int incx);
protected abstract double sasum(long N, INDArray X, int incx);
protected abstract double hasum(long N, INDArray X, int incx);
protected abstract int isamax(long N, INDArray X, int incx);
protected abstract int idamax(long N, INDArray X, int incx);
protected abstract int ihamax(long N, INDArray X, int incx);
protected abstract int isamin(long N, INDArray X, int incx);
protected abstract int idamin(long N, INDArray X, int incx);
protected abstract int ihamin(long N, INDArray X, int incx);
protected abstract void daxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y);
protected abstract void saxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y);
protected abstract void haxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y);
protected abstract void droti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s);
protected abstract void sroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s);
protected abstract void hroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s);
protected abstract void dscal(long N, double a, INDArray X, int incx);
protected abstract void sscal(long N, double a, INDArray X, int incx);
protected abstract void hscal(long N, double a, INDArray X, int incx);
}

View File

@ -1,146 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.blas.impl;
import org.nd4j.linalg.api.blas.Level2;
import org.nd4j.linalg.api.blas.params.SparseCOOGemvParameters;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import static org.nd4j.base.Preconditions.checkArgument;
/**
* @author Audrey Loeffel
*/
public abstract class SparseBaseLevel2 extends SparseBaseLevel implements Level2 {
@Override
public void gemv(char order, char transA, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
checkArgument(A.isMatrix(), "A must be a matrix");
checkArgument(X.isVector(), "X must be a vector");
checkArgument(Y.isVector(), "Y must be a vector");
SparseCOOGemvParameters parameters = new SparseCOOGemvParameters(A, X, Y);
switch (A.data().dataType()) {
case DOUBLE:
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, parameters.getA(), parameters.getX(),
parameters.getY());
dcoomv(parameters.getAOrdering(), parameters.getM(), parameters.getVal(), parameters.getRowInd(),
parameters.getColInd(), parameters.getNnz(), parameters.getX(), parameters.getY());
break;
case FLOAT:
DefaultOpExecutioner.validateDataType(DataType.FLOAT, parameters.getA(), parameters.getX(),
parameters.getY());
scoomv(parameters.getAOrdering(), parameters.getM(), parameters.getVal(), parameters.getRowInd(),
parameters.getColInd(), parameters.getNnz(), parameters.getX(), parameters.getY());
break;
default:
throw new UnsupportedOperationException();
}
}
@Override
public void gbmv(char order, char TransA, int KL, int KU, double alpha, INDArray A, INDArray X, double beta,
INDArray Y) {
}
@Override
public void ger(char order, double alpha, INDArray X, INDArray Y, INDArray A) {
}
@Override
public void sbmv(char order, char Uplo, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
}
@Override
public void spmv(char order, char Uplo, double alpha, INDArray Ap, INDArray X, double beta, INDArray Y) {
}
@Override
public void spr(char order, char Uplo, double alpha, INDArray X, INDArray Ap) {
}
@Override
public void spr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
}
@Override
public void symv(char order, char Uplo, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
}
@Override
public void syr(char order, char Uplo, int N, double alpha, INDArray X, INDArray A) {
}
@Override
public void syr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
}
@Override
public void tbmv(char order, char Uplo, char TransA, char Diag, INDArray A, INDArray X) {
}
@Override
public void tbsv(char order, char Uplo, char TransA, char Diag, INDArray A, INDArray X) {
}
@Override
public void tpmv(char order, char Uplo, char TransA, char Diag, INDArray Ap, INDArray X) {
}
@Override
public void tpsv(char order, char Uplo, char TransA, char Diag, INDArray Ap, INDArray X) {
}
@Override
public void trmv(char order, char Uplo, char TransA, char Diag, INDArray A, INDArray X) {
}
@Override
public void trsv(char order, char Uplo, char TransA, char Diag, INDArray A, INDArray X) {
}
// ----
protected abstract void scoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz,
INDArray x, INDArray y);
protected abstract void dcoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz,
INDArray x, INDArray y);
}

View File

@ -1,64 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.blas.impl;
import org.nd4j.linalg.api.blas.Level3;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* @author Audrey Loeffel
*/
public class SparseBaseLevel3 extends SparseBaseLevel implements Level3 {
@Override
public void gemm(char Order, char TransA, char TransB, double alpha, INDArray A, INDArray B, double beta,
INDArray C) {
}
@Override
public void gemm(INDArray A, INDArray B, INDArray C, boolean transposeA, boolean transposeB, double alpha,
double beta) {
}
@Override
public void symm(char Order, char Side, char Uplo, double alpha, INDArray A, INDArray B, double beta, INDArray C) {
}
@Override
public void syrk(char Order, char Uplo, char Trans, double alpha, INDArray A, double beta, INDArray C) {
}
@Override
public void syr2k(char Order, char Uplo, char Trans, double alpha, INDArray A, INDArray B, double beta,
INDArray C) {
}
@Override
public void trmm(char Order, char Side, char Uplo, char TransA, char Diag, double alpha, INDArray A, INDArray B,
INDArray C) {
}
@Override
public void trsm(char Order, char Side, char Uplo, char TransA, char Diag, double alpha, INDArray A, INDArray B) {
}
}

View File

@ -1,66 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.blas.params;
import lombok.Data;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCOO;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.SparseFormat;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author Audrey Loeffel
*/
@Data
public class SparseCOOGemvParameters {
private int m, nnz;
DataBuffer val, rowInd, colInd;
private INDArray a, x, y;
private char aOrdering = 'N';
public SparseCOOGemvParameters(INDArray a, INDArray x, INDArray y) {
this.a = a;
this.x = x;
this.y = y;
if (a.isMatrix() && a.getFormat() == SparseFormat.COO) {
BaseSparseNDArrayCOO coo = (BaseSparseNDArrayCOO) a;
val = coo.getIncludedValues();
nnz = coo.nnz();
// FIXME: int cast
m = (int) coo.rows();
setIndexes(coo, false);
}
}
private void setIndexes(BaseSparseNDArrayCOO coo, boolean oneBased) {
int incr = oneBased ? 1 : 0;
int[] idx = coo.getIncludedIndices().asInt();
int[] rows = new int[nnz];
int[] cols = new int[nnz];
for (int i = 0; i < nnz; i++) {
rows[i] = idx[i * 2] + incr;
cols[i] = idx[(i * 2) + 1] + incr;
}
rowInd = Nd4j.createBuffer(rows);
colInd = Nd4j.createBuffer(cols);
}
}

View File

@ -4309,10 +4309,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
if (n == this) if (n == this)
return true; return true;
if (n.isSparse()) {
return n.equals(this);
}
if (this.rank() != n.rank()) if (this.rank() != n.rank())
return false; return false;
@ -5442,56 +5438,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return array; return array;
} }
/*
* ------- Sparse methods -------
*/
@Override
public DataBuffer getVectorCoordinates() {
throw new UnsupportedOperationException("Not a sparse ndarray");
}
@Override
public INDArray toDense() {
return this;
}
@Override
public int nnz() {
throw new UnsupportedOperationException("Not a sparse ndarray");
}
@Override
public SparseFormat getFormat() {
return SparseFormat.NONE;
}
@Override
public DataBuffer sparseInfoDataBuffer() {
throw new UnsupportedOperationException("Not a sparse ndarray");
}
@Override
public int[] flags() {
throw new UnsupportedOperationException("Not a sparse ndarray");
}
@Override
public int[] hiddenDimensions() {
throw new UnsupportedOperationException("Not a sparse ndarray");
}
@Override
public int[] sparseOffsets() {
throw new UnsupportedOperationException("Not a sparse ndarray");
}
@Override
public int underlyingRank() {
throw new UnsupportedOperationException("Not a sparse ndarray");
}
protected static DataTypeEx convertType(DataType type) { protected static DataTypeEx convertType(DataType type) {
if (type == DataType.HALF) { if (type == DataType.HALF) {
return DataTypeEx.FLOAT16; return DataTypeEx.FLOAT16;

View File

@ -1,30 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ndarray;
import org.nd4j.linalg.api.buffer.DataBuffer;
/**
* @author Audrey Loeffel
*/
abstract public class BaseSparseInfoProvider implements SparseInfoProvider {
@Override
public DataBuffer createSparseInformation(int[] flags, long[] sparseOffsets, int[] hiddenDimensions,
int underlyingRank) {
return createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
}
}

View File

@ -1,273 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ndarray;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.*;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.LongUtils;
import java.util.ArrayList;
import java.util.List;
import static org.nd4j.base.Preconditions.checkArgument;
/**
* @author Audrey Loeffel
*/
public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
protected static final SparseFormat format = SparseFormat.CSR;
protected transient volatile DataBuffer values;
protected transient volatile DataBuffer columnsPointers;
protected transient volatile DataBuffer pointerB;
protected transient volatile DataBuffer pointerE;
/**
*
*
* The length of the values and columns arrays is equal to the number of non-zero elements in A.
* The length of the pointerB and pointerE arrays is equal to the number of rows in A.
* @param data a double array that contains the non-zero element of the sparse matrix A
* @param columnsPointers Element i of the integer array columns is the number of the column in A that contains the i-th value
* in the values array.
* @param pointerB Element j of this integer array gives the index of the element in the values array that is first
* non-zero element in a row j of A. Note that this index is equal to pointerB(j) - pointerB(1)+1 .
* @param pointerE An integer array that contains row indices, such that pointerE(j)-pointerB(1) is the index of the
* element in the values array that is last non-zero element in a row j of A.
* @param shape Shape of the matrix A
*/
public BaseSparseNDArrayCSR(double[] data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
checkArgument(data.length == columnsPointers.length);
checkArgument(pointerB.length == pointerE.length);
// TODO
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, Nd4j.dataType()));
init(shape);
int valuesSpace = (int) (data.length * THRESHOLD_MEMORY_ALLOCATION);
this.values = Nd4j.getDataBufferFactory().createDouble(valuesSpace);
this.values.setData(data);
this.columnsPointers = Nd4j.getDataBufferFactory().createInt(valuesSpace);
this.columnsPointers.setData(columnsPointers);
this.length = columnsPointers.length;
long pointersSpace = rows;
this.pointerB = Nd4j.getDataBufferFactory().createInt(pointersSpace);
this.pointerB.setData(pointerB);
this.pointerE = Nd4j.getDataBufferFactory().createInt(pointersSpace);
this.pointerE.setData(pointerE);
}
public BaseSparseNDArrayCSR(float[] data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
this(Nd4j.createBuffer(data), columnsPointers, pointerB, pointerE, shape);
}
public BaseSparseNDArrayCSR(DataBuffer data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
checkArgument(pointerB.length == pointerE.length);
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, Nd4j.dataType()));
init(shape);
this.values = data;
this.columnsPointers = Nd4j.getDataBufferFactory().createInt(data.length());
this.columnsPointers.setData(columnsPointers);
this.length = columnsPointers.length;
// The size of these pointers are constant
long pointersSpace = rows;
this.pointerB = Nd4j.getDataBufferFactory().createInt(pointersSpace);
this.pointerB.setData(pointerB);
this.pointerE = Nd4j.getDataBufferFactory().createInt(pointersSpace);
this.pointerE.setData(pointerE);
}
public INDArray putScalar(int row, int col, double value) {
checkArgument(row < rows && 0 <= rows);
checkArgument(col < columns && 0 <= columns);
int idx = pointerB.getInt(row);
int idxNextRow = pointerE.getInt(row);
while (columnsPointers.getInt(idx) < col && columnsPointers.getInt(idx) < idxNextRow) {
idx++;
}
if (columnsPointers.getInt(idx) == col) {
values.put(idx, value);
} else {
//Add a new entry in both buffers at a given position
values = addAtPosition(values, length, idx, value);
columnsPointers = addAtPosition(columnsPointers, length, idx, col);
length++;
// shift the indices of the next rows
pointerE.put(row, pointerE.getInt(row) + 1);
for (int i = row + 1; i < rows; i++) {
pointerB.put(i, pointerB.getInt(i) + 1);
pointerE.put(i, pointerE.getInt(i) + 1);
}
}
return this;
}
@Override
public INDArray get(INDArrayIndex... indexes) {
//check for row/column vector and point index being 0
if (indexes.length == 1 && indexes[0] instanceof NDArrayIndexAll || (indexes.length == 2 && (isRowVector()
&& indexes[0] instanceof PointIndex && indexes[0].offset() == 0
&& indexes[1] instanceof NDArrayIndexAll
|| isColumnVector() && indexes[1] instanceof PointIndex && indexes[0].offset() == 0
&& indexes[0] instanceof NDArrayIndexAll)))
return this;
indexes = NDArrayIndex.resolve(javaShapeInformation, indexes);
throw new UnsupportedOperationException("Not implemeted");
}
/**
* Return the minor pointers. (columns for CSR, rows for CSC,...)
* */
public DataBuffer getVectorCoordinates() {
return Nd4j.getDataBufferFactory().create(columnsPointers, 0, length());
}
public double[] getDoubleValues() {
return values.getDoublesAt(0, (int) length);
}
public double[] getColumns() {
return columnsPointers.getDoublesAt(0, (int) length);
}
public int[] getPointerBArray() {
return pointerB.asInt();
}
public int[] getPointerEArray() {
return pointerE.asInt();
}
public SparseFormat getFormat() {
return format;
}
public DataBuffer getPointerB() {
return Nd4j.getDataBufferFactory().create(pointerB, 0, rows());
}
public DataBuffer getPointerE() {
return Nd4j.getDataBufferFactory().create(pointerE, 0, rows());
}
private DataBuffer addAtPosition(DataBuffer buf, long dataSize, int pos, double value) {
DataBuffer buffer = (buf.length() == dataSize) ? reallocate(buf) : buf;
double[] tail = buffer.getDoublesAt(pos, (int) dataSize - pos);
buffer.put(pos, value);
for (int i = 0; i < tail.length; i++) {
buffer.put(i + pos + 1, tail[i]);
}
return buffer;
}
@Override
public DataBuffer data() {
return Nd4j.getDataBufferFactory().create(values, 0, length());
}
@Override
public INDArray toDense() {
// Dummy way - going to use the conversion routines in level2 (?)
INDArray result = Nd4j.zeros(shape());
int[] pointersB = pointerB.asInt();
int[] pointersE = pointerE.asInt();
for (int row = 0; row < rows(); row++) {
for (int idx = pointersB[row]; idx < pointersE[row]; idx++) {
result.put(row, columnsPointers.getInt(idx), values.getNumber(idx));
}
}
return result;
}
@Override
public DataBuffer shapeInfoDataBuffer() {
return shapeInformation;
}
@Override
public boolean equals(Object o) {
//TODO use op
// fixme
if (o == null || !(o instanceof INDArray)) {
return false;
}
INDArray n = (INDArray) o;
if (n.isSparse()) {
BaseSparseNDArray s = (BaseSparseNDArray) n;
switch (s.getFormat()) {
case CSR:
BaseSparseNDArrayCSR csrArray = (BaseSparseNDArrayCSR) s;
if (csrArray.rows() == rows() && csrArray.columns() == columns()
&& csrArray.getVectorCoordinates().equals(getVectorCoordinates())
&& csrArray.data().equals(data()) && csrArray.getPointerB().equals(getPointerB())
&& csrArray.getPointerE().equals(getPointerE())) {
return true;
}
break;
default:
INDArray dense = toDense();
INDArray oDense = s.toDense();
return dense.equals(oDense);
}
} else {
INDArray dense = toDense();
return dense.equals(o);
}
return false;
}
@Override
public boolean isView() {
return false; //todo
}
@Override
public boolean wasClosed() {
return false;
}
@Override
public int underlyingRank() {
return rank;
}
@Override
public INDArray putiColumnVector(INDArray columnVector) {
return null;
}
@Override
public INDArray putiRowVector(INDArray rowVector) {
return null;
}
}

View File

@ -49,13 +49,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
DataBuffer shapeInfoDataBuffer(); DataBuffer shapeInfoDataBuffer();
// TODO: Unused untested method.
/**
* Sparse info
* @return Sparse info.
*/
DataBuffer sparseInfoDataBuffer();
/** /**
* Shape info * Shape info
* @return Shape info * @return Shape info
@ -2692,47 +2685,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray percentile(Number percentile, int... dimension); INDArray percentile(Number percentile, int... dimension);
/*
* ------------ Sparse methods ------------
*/
/**
* Return a array of non-major pointers
* i.e. return the column indexes in case of row-major ndarray
* @return a DataBuffer of indexes
*/
DataBuffer getVectorCoordinates();
/**
* Return a dense representation of the sparse ndarray
* */
INDArray toDense();
/**
* Return the number of non-null element
* @return nnz
*/
int nnz();
/**
* Return the sparse format (i.e COO, CSR, ...)
* @return format
* @see SparseFormat
* */
SparseFormat getFormat();
//TODO: Undocumented but often used method.
int[] flags();
//TODO: Undocumented but often used method.
int[] hiddenDimensions();
//TODO: Undocumented but often used method.
int[] sparseOffsets();
//TODO: Undocumented but often used method.
int underlyingRank();
/** /**
* Add an {@link INDArray} * Add an {@link INDArray}
* to flatbuffers builder * to flatbuffers builder

View File

@ -1,38 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ndarray;
import org.nd4j.linalg.api.buffer.DataBuffer;
/**
* @author Audrey Loeffel
*/
public interface ISparseNDArray extends INDArray {
/*
* TODO
* Will contain methods such as toDense, toCSRFormat,...
*
* */
DataBuffer getVectorCoordinates();
INDArray toDense();
int nnz();
SparseFormat getFormat();
}

View File

@ -1,31 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ndarray;
/**
* @author Audrey Loeffel
* <ul>
* <li>CSR: Compressed Sparse Row</li>
* <li>CSC: Commpressed Sparse Column</li>
* <li>COO: Coordinate Matrix Storage</li>
* <li>None: No sparse format</li>
* </ul>
* @see @see <a href="https://software.intel.com/en-us/node/471374">Sparse Matrix Storage Formats</a>
*/
public enum SparseFormat {
CSR, COO, NONE
}

View File

@ -1,29 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ndarray;
import org.nd4j.linalg.api.buffer.DataBuffer;
/**
* @author Audrey Loeffel
*/
public interface SparseInfoProvider {
DataBuffer createSparseInformation(int[] flags, long[] sparseOffsets, int[] hiddenDimensions, int underlyingRank);
void purgeCache();
}

View File

@ -1,79 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.shape;
import java.util.Arrays;
/**
* @author Audrey Loeffel
*/
public class SparseDescriptor {
int[] flags;
long[] sparseOffsets;
int[] hiddenDimension;
int underlyingRank;
public SparseDescriptor(int[] flags, long[] sparseOffsets, int[] hiddenDimension, int underlyingRank) {
this.flags = Arrays.copyOf(flags, flags.length);
this.sparseOffsets = Arrays.copyOf(sparseOffsets, sparseOffsets.length);
this.hiddenDimension = Arrays.copyOf(hiddenDimension, hiddenDimension.length);
this.underlyingRank = underlyingRank;
}
@Override
public int hashCode() {
int result = underlyingRank;
result = 31 * result + Arrays.hashCode(flags);
result = 31 * result + Arrays.hashCode(sparseOffsets);
result = 31 * result + Arrays.hashCode(hiddenDimension);
return result;
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
SparseDescriptor that = (SparseDescriptor) o;
if (!Arrays.equals(flags, that.flags))
return false;
if (!Arrays.equals(sparseOffsets, that.sparseOffsets))
return false;
if (!Arrays.equals(hiddenDimension, that.hiddenDimension))
return false;
return underlyingRank == that.underlyingRank;
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append(flags.length).append(",").append(Arrays.toString(flags)).append(",").append(sparseOffsets.length)
.append(",").append(Arrays.toString(sparseOffsets)).append(",").append(hiddenDimension.length)
.append(",").append(Arrays.toString(hiddenDimension)).append(",").append(underlyingRank);
String result = builder.toString().replaceAll("\\]", "").replaceAll("\\[", "");
result = "[" + result + "]";
return result;
}
}

View File

@ -217,34 +217,6 @@ public abstract class BaseBlasWrapper implements BlasWrapper {
return a; return a;
} }
@Override
public INDArray gemm(double alpha, INDArray a, INDArray b, double beta, INDArray c) {
LinAlgExceptions.assertMatrix(a, b, c);
if (a.data().dataType() == DataType.FLOAT) {
return gemm((float) alpha, a, b, (float) beta, c);
}
level3().gemm(BlasBufferUtil.getCharForTranspose(a), BlasBufferUtil.getCharForTranspose(b),
BlasBufferUtil.getCharForTranspose(c), alpha, a, b, beta, c);
return c;
}
@Override
public INDArray gemm(float alpha, INDArray a, INDArray b, float beta, INDArray c) {
LinAlgExceptions.assertMatrix(a, b, c);
if (a.data().dataType() == DataType.DOUBLE) {
return gemm((double) alpha, a, b, (double) beta, c);
}
level3().gemm(BlasBufferUtil.getCharForTranspose(a), BlasBufferUtil.getCharForTranspose(b),
BlasBufferUtil.getCharForTranspose(c), alpha, a, b, beta, c);
return c;
}
@Override @Override
public INDArray gesv(INDArray a, int[] ipiv, INDArray b) { public INDArray gesv(INDArray a, int[] ipiv, INDArray b) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -1,243 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.factory;
import org.nd4j.linalg.api.blas.Lapack;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.blas.Level2;
import org.nd4j.linalg.api.blas.Level3;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* @author Audrey Loeffel
*/
public abstract class BaseSparseBlasWrapper implements BlasWrapper {
@Override
public Lapack lapack() {
return Nd4j.sparseFactory().lapack();
}
@Override
public Level1 level1() {
return Nd4j.sparseFactory().level1();
}
@Override
public Level2 level2() {
return Nd4j.sparseFactory().level2();
}
@Override
public Level3 level3() {
return Nd4j.sparseFactory().level3();
}
// ================== TODO ====================
@Override
public INDArray swap(INDArray x, INDArray y) {
return null;
}
@Override
public INDArray scal(double alpha, INDArray x) {
return null;
}
@Override
public INDArray scal(float alpha, INDArray x) {
return null;
}
@Override
public INDArray copy(INDArray x, INDArray y) {
return null;
}
@Override
public INDArray axpy(double da, INDArray dx, INDArray dy) {
return null;
}
@Override
public INDArray axpy(float da, INDArray dx, INDArray dy) {
return null;
}
@Override
public INDArray axpy(Number da, INDArray dx, INDArray dy) {
return null;
}
@Override
public double dot(INDArray x, INDArray y) {
return 0;
}
@Override
public double nrm2(INDArray x) {
return 0;
}
@Override
public double asum(INDArray x) {
return 0;
}
@Override
public int iamax(INDArray x) {
return 0;
}
@Override
public INDArray gemv(Number alpha, INDArray a, INDArray x, double beta, INDArray y) {
return null;
}
@Override
public INDArray gemv(double alpha, INDArray a, INDArray x, double beta, INDArray y) {
return null;
}
@Override
public INDArray gemv(float alpha, INDArray a, INDArray x, float beta, INDArray y) {
return null;
}
@Override
public INDArray ger(Number alpha, INDArray x, INDArray y, INDArray a) {
return null;
}
@Override
public INDArray ger(double alpha, INDArray x, INDArray y, INDArray a) {
return null;
}
@Override
public INDArray ger(float alpha, INDArray x, INDArray y, INDArray a) {
return null;
}
@Override
public INDArray gemm(double alpha, INDArray a, INDArray b, double beta, INDArray c) {
return null;
}
@Override
public INDArray gemm(float alpha, INDArray a, INDArray b, float beta, INDArray c) {
return null;
}
@Override
public INDArray gesv(INDArray a, int[] ipiv, INDArray b) {
return null;
}
@Override
public void checkInfo(String name, int info) {
}
@Override
public INDArray sysv(char uplo, INDArray a, int[] ipiv, INDArray b) {
return null;
}
@Override
public int syev(char jobz, char uplo, INDArray a, INDArray w) {
return 0;
}
@Override
public int syevx(char jobz, char range, char uplo, INDArray a, double vl, double vu, int il, int iu, double abstol,
INDArray w, INDArray z) {
return 0;
}
@Override
public int syevx(char jobz, char range, char uplo, INDArray a, float vl, float vu, int il, int iu, float abstol,
INDArray w, INDArray z) {
return 0;
}
@Override
public int syevd(char jobz, char uplo, INDArray A, INDArray w) {
return 0;
}
@Override
public int syevr(char jobz, char range, char uplo, INDArray a, double vl, double vu, int il, int iu, double abstol,
INDArray w, INDArray z, int[] isuppz) {
return 0;
}
@Override
public int syevr(char jobz, char range, char uplo, INDArray a, float vl, float vu, int il, int iu, float abstol,
INDArray w, INDArray z, int[] isuppz) {
return 0;
}
@Override
public int syevr(char jobz, char range, char uplo, INDArray a, float vl, float vu, int il, int iu, Number abstol,
INDArray w, INDArray z, int[] isuppz) {
return 0;
}
@Override
public void posv(char uplo, INDArray A, INDArray B) {
}
@Override
public int geev(char jobvl, char jobvr, INDArray A, INDArray WR, INDArray WI, INDArray VL, INDArray VR) {
return 0;
}
@Override
public int sygvd(int itype, char jobz, char uplo, INDArray A, INDArray B, INDArray W) {
return 0;
}
@Override
public void gelsd(INDArray A, INDArray B) {
}
@Override
public void geqrf(INDArray A, INDArray tau) {
}
@Override
public void ormqr(char side, char trans, INDArray A, INDArray tau, INDArray C) {
}
@Override
public void saxpy(double alpha, INDArray x, INDArray y) {
}
@Override
public void saxpy(float alpha, INDArray x, INDArray y) {
}
}

View File

@ -1,25 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.factory;
/**
* @author Audrey Loeffel
*/
public abstract class BaseSparseNDArrayFactory extends BaseNDArrayFactory {
// TODO override needed methods
}

View File

@ -154,21 +154,6 @@ public interface BlasWrapper {
*/ */
INDArray ger(float alpha, INDArray x, INDArray y, INDArray a); INDArray ger(float alpha, INDArray x, INDArray y, INDArray a);
/**
* ************************************************************************
* BLAS Level 3
*/
@Deprecated
INDArray gemm(double alpha, INDArray a, INDArray b, double beta, INDArray c);
/**
* Compute c <- a*b + beta * c (general matrix matrix
* multiplication)
*/
@Deprecated
INDArray gemm(float alpha, INDArray a, INDArray b, float beta, INDArray c);
/** /**
* ************************************************************************ * ************************************************************************
* LAPACK * LAPACK

View File

@ -1413,28 +1413,4 @@ public interface NDArrayFactory {
// =========== String methods ============ // =========== String methods ============
INDArray create(Collection<String> strings, long[] shape, char order); INDArray create(Collection<String> strings, long[] shape, char order);
// =========== Sparse methods ===========
INDArray createSparseCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape);
INDArray createSparseCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape);
INDArray createSparseCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape);
INDArray createSparseCOO(double[] values, long[][] indices, long[] shape);
INDArray createSparseCOO(float[] values, long[][] indices, long[] shape);
INDArray createSparseCOO(double[] values, int[][] indices, long[] shape);
INDArray createSparseCOO(float[] values, int[][] indices, long[] shape);
INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] shape);
INDArray createSparseCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape);
INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags,
int[] hiddenDimensions, int underlyingRank, long[] shape);
} }

View File

@ -119,16 +119,13 @@ public class Nd4j {
@Deprecated @Deprecated
public final static String DTYPE = ND4JSystemProperties.DTYPE; public final static String DTYPE = ND4JSystemProperties.DTYPE;
private final static String BLAS_OPS = "blas.ops"; private final static String BLAS_OPS = "blas.ops";
private final static String SPARSE_BLAS_OPS = "sparseblas.ops";
public final static String NATIVE_OPS = "native.ops"; public final static String NATIVE_OPS = "native.ops";
private final static String ORDER_KEY = "ndarray.order"; private final static String ORDER_KEY = "ndarray.order";
private final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class"; private final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class";
private final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class";
private final static String OP_EXECUTIONER = "opexec"; private final static String OP_EXECUTIONER = "opexec";
public final static String DISTRIBUTION = "dist"; public final static String DISTRIBUTION = "dist";
private final static String SHAPEINFO_PROVIDER = "shapeinfoprovider"; private final static String SHAPEINFO_PROVIDER = "shapeinfoprovider";
private final static String SPARSEINFO_PROVIDER = "sparseinfoprovider";
private final static String CONSTANT_PROVIDER = "constantsprovider"; private final static String CONSTANT_PROVIDER = "constantsprovider";
private final static String AFFINITY_MANAGER = "affinitymanager"; private final static String AFFINITY_MANAGER = "affinitymanager";
//disable toString() on compressed arrays for debugging. Should be off by default. //disable toString() on compressed arrays for debugging. Should be off by default.
@ -156,14 +153,11 @@ public class Nd4j {
private static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE; private static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE;
private static BlasWrapper BLAS_WRAPPER_INSTANCE; private static BlasWrapper BLAS_WRAPPER_INSTANCE;
private static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE;
protected static NDArrayFactory INSTANCE; protected static NDArrayFactory INSTANCE;
private static NDArrayFactory SPARSE_INSTANCE;
private static ConvolutionInstance CONVOLUTION_INSTANCE; private static ConvolutionInstance CONVOLUTION_INSTANCE;
private static OpExecutioner OP_EXECUTIONER_INSTANCE; private static OpExecutioner OP_EXECUTIONER_INSTANCE;
private static DistributionFactory DISTRIBUTION_FACTORY; private static DistributionFactory DISTRIBUTION_FACTORY;
private static ShapeInfoProvider shapeInfoProvider; private static ShapeInfoProvider shapeInfoProvider;
private static SparseInfoProvider sparseInfoProvider;
private static ConstantHandler constantHandler; private static ConstantHandler constantHandler;
private static AffinityManager affinityManager; private static AffinityManager affinityManager;
private static MemoryManager memoryManager; private static MemoryManager memoryManager;
@ -800,14 +794,6 @@ public class Nd4j {
return INSTANCE; return INSTANCE;
} }
/**
* The factory used for creating sparse arrays.
* @return the factory used for creating sparse arrays.
*/
public static NDArrayFactory sparseFactory() {
return SPARSE_INSTANCE;
}
/** /**
* See {@link org.nd4j.linalg.api.ndarray.INDArray#cumsum(int)} with Integer.MAX_VALUE for full array reduction. * See {@link org.nd4j.linalg.api.ndarray.INDArray#cumsum(int)} with Integer.MAX_VALUE for full array reduction.
* *
@ -1700,14 +1686,6 @@ public class Nd4j {
return BLAS_WRAPPER_INSTANCE; return BLAS_WRAPPER_INSTANCE;
} }
/**
* Retreive the sparse BLAS wrapper.
* @return the sparse BLAS wrapper.
*/
public static BlasWrapper getSparseBlasWrapper() {
return SPARSE_BLAS_WRAPPER_INSTANCE;
}
/** /**
* Sort an ndarray along a particular dimension.<br> * Sort an ndarray along a particular dimension.<br>
* Note that the input array is modified in-place. * Note that the input array is modified in-place.
@ -5159,8 +5137,6 @@ public class Nd4j {
affinityManager = affinityManagerClazz.newInstance(); affinityManager = affinityManagerClazz.newInstance();
Class<? extends NDArrayFactory> ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName( Class<? extends NDArrayFactory> ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName(
pp.toString(NDARRAY_FACTORY_CLASS)); pp.toString(NDARRAY_FACTORY_CLASS));
Class<? extends NDArrayFactory> sparseNDArrayClazz = (Class<? extends NDArrayFactory>) Class.forName(
pp.toString(SPARSE_NDARRAY_FACTORY_CLASS));
Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName())); .forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName()); String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName());
@ -5168,8 +5144,6 @@ public class Nd4j {
.forName(pp.toString(DATA_BUFFER_OPS, defaultName)); .forName(pp.toString(DATA_BUFFER_OPS, defaultName));
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
.forName(pp.toString(SHAPEINFO_PROVIDER)); .forName(pp.toString(SHAPEINFO_PROVIDER));
Class<? extends BaseSparseInfoProvider> sparseInfoProviderClazz = (Class<? extends BaseSparseInfoProvider>) Class.forName(
pp.toString(SPARSEINFO_PROVIDER));
Class<? extends BasicConstantHandler> constantProviderClazz = (Class<? extends BasicConstantHandler>) Class Class<? extends BasicConstantHandler> constantProviderClazz = (Class<? extends BasicConstantHandler>) Class
.forName(pp.toString(CONSTANT_PROVIDER)); .forName(pp.toString(CONSTANT_PROVIDER));
@ -5187,8 +5161,6 @@ public class Nd4j {
Class<? extends BlasWrapper> blasWrapperClazz = (Class<? extends BlasWrapper>) Class Class<? extends BlasWrapper> blasWrapperClazz = (Class<? extends BlasWrapper>) Class
.forName(pp.toString(BLAS_OPS)); .forName(pp.toString(BLAS_OPS));
Class<? extends BlasWrapper> sparseBlasWrapperClazz = (Class<? extends BlasWrapper>) Class
.forName(pp.toString(SPARSE_BLAS_OPS));
String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName()); String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName());
Class<? extends DistributionFactory> distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName); Class<? extends DistributionFactory> distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName);
@ -5196,7 +5168,6 @@ public class Nd4j {
memoryManager = memoryManagerClazz.newInstance(); memoryManager = memoryManagerClazz.newInstance();
constantHandler = constantProviderClazz.newInstance(); constantHandler = constantProviderClazz.newInstance();
shapeInfoProvider = shapeInfoProviderClazz.newInstance(); shapeInfoProvider = shapeInfoProviderClazz.newInstance();
sparseInfoProvider = sparseInfoProviderClazz.newInstance();
workspaceManager = workspaceManagerClazz.newInstance(); workspaceManager = workspaceManagerClazz.newInstance();
Class<? extends OpExecutioner> opExecutionerClazz = (Class<? extends OpExecutioner>) Class Class<? extends OpExecutioner> opExecutionerClazz = (Class<? extends OpExecutioner>) Class
@ -5205,10 +5176,8 @@ public class Nd4j {
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance(); OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();
Constructor c2 = ndArrayFactoryClazz.getConstructor(DataType.class, char.class); Constructor c2 = ndArrayFactoryClazz.getConstructor(DataType.class, char.class);
INSTANCE = (NDArrayFactory) c2.newInstance(dtype, ORDER); INSTANCE = (NDArrayFactory) c2.newInstance(dtype, ORDER);
SPARSE_INSTANCE = sparseNDArrayClazz.newInstance();
CONVOLUTION_INSTANCE = convolutionInstanceClazz.newInstance(); CONVOLUTION_INSTANCE = convolutionInstanceClazz.newInstance();
BLAS_WRAPPER_INSTANCE = blasWrapperClazz.newInstance(); BLAS_WRAPPER_INSTANCE = blasWrapperClazz.newInstance();
SPARSE_BLAS_WRAPPER_INSTANCE = sparseBlasWrapperClazz.newInstance();
DATA_BUFFER_FACTORY_INSTANCE = dataBufferFactoryClazz.newInstance(); DATA_BUFFER_FACTORY_INSTANCE = dataBufferFactoryClazz.newInstance();
DISTRIBUTION_FACTORY = distributionFactoryClazz.newInstance(); DISTRIBUTION_FACTORY = distributionFactoryClazz.newInstance();
@ -5306,14 +5275,6 @@ public class Nd4j {
return shapeInfoProvider; return shapeInfoProvider;
} }
/**
*
* @return Sparse shape info provider
*/
public static SparseInfoProvider getSparseInfoProvider() {
return sparseInfoProvider;
}
/** /**
* *
* @return constant handler * @return constant handler

View File

@ -1,41 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.nativeblas;
import org.nd4j.linalg.api.blas.Blas;
public abstract class SparseNd4jBlas implements Blas {
public SparseNd4jBlas() {
}
/**
* Returns the BLAS library vendor
*
* @return the BLAS library vendor
*/
@Override
public Vendor getBlasVendor() {
int vendor = getBlasVendorId();
boolean isUnknowVendor = ((vendor > Vendor.values().length - 1) || (vendor <= 0));
if (isUnknowVendor) {
return Vendor.UNKNOWN;
}
return Vendor.values()[vendor];
}
}

View File

@ -1,62 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.jcublas;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.BaseSparseInfoProvider;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.SparseDescriptor;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author Audrey Loeffel
*/
public class DirectSparseInfoProvider extends BaseSparseInfoProvider {
private Map<SparseDescriptor, DataBuffer> sparseCache = new ConcurrentHashMap<>();
private AtomicInteger counter = new AtomicInteger(0);
private static final int MAX_ENTRIES = 100;
@Override
public DataBuffer createSparseInformation(int[] flags, long[] sparseOffsets, int[] hiddenDimensions, int underlyingRank) {
SparseDescriptor descriptor = new SparseDescriptor(flags, sparseOffsets, hiddenDimensions, underlyingRank);
if(!sparseCache.containsKey(descriptor)){
if(counter.get() < MAX_ENTRIES){
if(!sparseCache.containsKey(descriptor)){
counter.incrementAndGet();
DataBuffer buffer = Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
sparseCache.put(descriptor, buffer);
return buffer;
}
} else {
return Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
}
}
return sparseCache.get(descriptor);
}
@Override
public void purgeCache() {
sparseCache = new ConcurrentHashMap<>();
}
}

View File

@ -1612,55 +1612,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@Override
public INDArray createSparseCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(double[] values, long[][] indices, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(float[] values, long[][] indices, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(double[] values, int[][] indices, long[] shape) {
return new JCusparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray createSparseCOO(float[] values, int[][] indices, long[] shape) {
return new JCusparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape) {
throw new UnsupportedOperationException();
}
@Override @Override
public INDArray sortCooIndices(INDArray x) { public INDArray sortCooIndices(INDArray x) {

View File

@ -1,53 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.jcublas;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCOO;
/**
* @author Audrey Loeffel
*/
public class JCusparseNDArrayCOO extends BaseSparseNDArrayCOO {
public JCusparseNDArrayCOO(DataBuffer values, DataBuffer indices, long[] shape) {
super(values, indices, shape);
}
public JCusparseNDArrayCOO(double[] values, long[][] indices, long[] shape) {
super(values, indices, shape);
}
public JCusparseNDArrayCOO(float[] values, long[][] indices, long[] shape) {
super(values, indices, shape);
}
public JCusparseNDArrayCOO(double[] values, int[][] indices, long[] shape) {
super(values, indices, shape);
}
public JCusparseNDArrayCOO(float[] values, int[][] indices, long[] shape) {
super(values, indices, shape);
}
public JCusparseNDArrayCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
super(values, indices, sparseInformation, shape);
}
public JCusparseNDArrayCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape) {
super(values, indices, sparseOffsets, flags, hiddenDimensions, underlyingRank, shape);
}
}

View File

@ -1,570 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.jcublas;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.DataTypeEx;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.ISparseNDArray;
import org.nd4j.linalg.factory.BaseSparseNDArrayFactory;
import org.nd4j.linalg.jcublas.blas.*;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.io.File;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Random;
/**
* @author Audrey Loeffel
*/
@Slf4j
public class JCusparseNDArrayFactory extends BaseSparseNDArrayFactory{
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
public JCusparseNDArrayFactory(){}
@Override
public INDArray create(float[] data, int[] shape, int[] stride, long offset) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, long offset) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType) {
return null;
}
@Override
public INDArray create(double[] data, int[] shape, int[] stride, long offset) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, long[] stride, long offset) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(long[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(int[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(short[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(byte[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(boolean[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(List<INDArray> list, int[] shape) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(long[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(int[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(short[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(byte[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(boolean[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(List<INDArray> list, long[] shape) {
return null;
}
@Override
public INDArray create(long rows, long columns, long[] stride, long offset) {
return null;
}
@Override
public INDArray empty(DataType type) {
throw new IllegalStateException();
}
@Override
public void createBlas() {
blas = new SparseCudaBlas();
}
@Override
public void createLevel1() {
level1 = new JcusparseLevel1();
}
@Override
public void createLevel2() {
level2 = new JcusparseLevel2();
}
@Override
public void createLevel3() {
level3 = new JcusparseLevel3();
}
@Override
public void createLapack() {
lapack = new JcusparseLapack();
}
@Override
public INDArray create(int[] shape, DataBuffer buffer) {
return null;
}
@Override
public INDArray toFlattened(char order, Collection<INDArray> matrices) {
return null;
}
@Override
public INDArray create(double[][] data) {
return null;
}
@Override
public INDArray create(double[][] data, char ordering) {
return null;
}
@Override
public INDArray specialConcat(int dimension, INDArray... toConcat) {
return null;
}
@Override
public INDArray pullRows(INDArray source, int sourceDimension, long[] indexes) {
return null;
}
@Override
public INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, int[] indexes) {
return null;
}
@Override
public void shuffle(INDArray array, Random rnd, int... dimension) {
}
@Override
public void shuffle(Collection<INDArray> array, Random rnd, int... dimension) {
}
@Override
public void shuffle(List<INDArray> array, Random rnd, List<int[]> dimensions) {
}
@Override
public INDArray average(INDArray target, INDArray[] arrays) {
return null;
}
@Override
public INDArray average(INDArray[] arrays) {
return null;
}
@Override
public INDArray average(Collection<INDArray> arrays) {
return null;
}
@Override
public INDArray accumulate(INDArray target, INDArray... arrays) {
return null;
}
@Override
public INDArray average(INDArray target, Collection<INDArray> arrays) {
return null;
}
@Override
public INDArray create(DataBuffer data) {
return null;
}
@Override
public INDArray create(DataBuffer data, long rows, long columns, int[] stride, long offset) {
return null;
}
@Override
public INDArray create(DataBuffer data, int[] shape) {
return null;
}
@Override
public INDArray create(DataBuffer data, long[] shape) {
return null;
}
@Override
public INDArray create(DataBuffer data, int[] shape, int[] stride, long offset) {
return null;
}
@Override
public INDArray create(DataBuffer data, long[] shape, long[] stride, long offset) {
return null;
}
@Override
public INDArray create(float[][] floats) {
return null;
}
@Override
public INDArray create(float[][] data, char ordering) {
return null;
}
@Override
public INDArray create(float[] data, int[] shape, int[] stride, long offset, char ordering) {
return null;
}
@Override
public INDArray create(DataBuffer buffer, int[] shape, long offset) {
return null;
}
@Override
public INDArray create(int[] shape, char ordering) {
return null;
}
@Override
public INDArray create(long[] shape, char ordering) {
return null;
}
@Override
public INDArray create(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(DataType dataType, long[] shape, long[] strides, char ordering, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray createUninitialized(int[] shape, char ordering) {
return null;
}
@Override
public INDArray createUninitialized(long[] shape, char ordering) {
return null;
}
@Override
public INDArray createUninitialized(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape) {
return null;
}
@Override
public INDArray create(DataBuffer data, int[] newShape, int[] newStride, long offset, char ordering) {
return null;
}
@Override
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, char ordering) {
return null;
}
@Override
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, long ews, char ordering) {
return null;
}
@Override
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, char ordering, DataType dataType) {
return null;
}
@Override
public INDArray create(float[] data, int[] shape, long offset, Character order) {
return null;
}
@Override
public INDArray create(float[] data, long rows, long columns, int[] stride, long offset, char ordering) {
return null;
}
@Override
public INDArray create(double[] data, int[] shape, char ordering) {
return null;
}
@Override
public INDArray create(List<INDArray> list, int[] shape, char ordering) {
return null;
}
@Override
public INDArray create(List<INDArray> list, long[] shape, char ordering) {
return null;
}
@Override
public INDArray create(double[] data, int[] shape, long offset) {
return null;
}
@Override
public INDArray create(double[] data, int[] shape, int[] stride, long offset, char ordering) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, char order, long offset) {
return null;
}
@Override
public INDArray convertDataEx(DataTypeEx typeSrc, INDArray source, DataTypeEx typeDst) {
return null;
}
@Override
public DataBuffer convertDataEx(DataTypeEx typeSrc, DataBuffer source, DataTypeEx typeDst) {
return null;
}
@Override
public void convertDataEx(DataTypeEx typeSrc, DataBuffer source, DataTypeEx typeDst, DataBuffer target) {
}
@Override
public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, Pointer target, long length) {
}
@Override
public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, DataBuffer buffer) {
}
@Override
public INDArray createFromNpyPointer(Pointer pointer) {
return null;
}
@Override
public INDArray createFromNpyHeaderPointer(Pointer pointer) {
return null;
}
@Override
public INDArray createFromNpyFile(File file) {
return null;
}
@Override
public Map<String, INDArray> createFromNpzFile(File file) throws Exception {
return null;
}
@Override
public Pointer convertToNumpy(INDArray array) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, long offset, char ordering) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, long[] stride, long offset, char ordering) {
return null;
}
@Override
public INDArray[] tear(INDArray tensor, int... dimensions) {
return new INDArray[0];
}
@Override
public INDArray sort(INDArray x, boolean descending) {
return null;
}
@Override
public INDArray sort(INDArray x, boolean descending, int... dimensions) {
return null;
}
@Override
public INDArray sortCooIndices(INDArray x) {
//TODO
throw new UnsupportedOperationException();
}
@Override
public INDArray create(float[] data, long[] shape, long offset, Character order) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, long offset, Character order) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, char ordering) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, char ordering) {
return null;
}
@Override
public INDArray create(Collection<String> strings, long[] shape, char order) {
return null;
}
@Override
public ISparseNDArray createSparseCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
return new JcusparseNDArrayCSR(data, columns, pointerB, pointerE, shape);
}
@Override
public INDArray createSparseCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(double[] values, long[][] indices, long[] shape) {
return new JCusparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray createSparseCOO(float[] values, long[][] indices, long[] shape) {
return new JCusparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray createSparseCOO(double[] values, int[][] indices, long[] shape) {
return new JCusparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray createSparseCOO(float[] values, int[][] indices, long[] shape) {
return new JCusparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] shape) {
return new JCusparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
return new JCusparseNDArrayCOO(values, indices, sparseInformation, shape);
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape) {
return new JCusparseNDArrayCOO(values, indices, sparseOffsets, flags, hiddenDimensions, underlyingRank, shape);
}
}

View File

@ -1,176 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.jcublas;
import com.google.flatbuffers.FlatBufferBuilder;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCSR;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
/**
* @author Audrey Loeffel
*/
public class JcusparseNDArrayCSR extends BaseSparseNDArrayCSR {
/**
* The length of the values and columns arrays is equal to the number of non-zero elements in A.
* The length of the pointerB and pointerE arrays is equal to the number of rows in A.
*
* @param data a double array that contains the non-zero element of the sparse matrix A
* @param columnsPointers Element i of the integer array columns is the number of the column in A that contains the i-th value
* in the values array.
* @param pointerB Element j of this integer array gives the index of the element in the values array that is first
* non-zero element in a row j of A. Note that this index is equal to pointerB(j) - pointerB(1)+1 .
* @param pointerE An integer array that contains row indices, such that pointerE(j)-pointerB(1) is the index of the
* element in the values array that is last non-zero element in a row j of A.
* @param shape Shape of the matrix A
*/
public JcusparseNDArrayCSR(double[] data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
super(data, columnsPointers, pointerB, pointerE, shape);
}
public JcusparseNDArrayCSR(float[] data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
super(data, columnsPointers, pointerB, pointerE, shape);
}
public JcusparseNDArrayCSR(DataBuffer data, int[] columnsPointers, int[] pointerB, int[] pointerE, long[] shape) {
super(data, columnsPointers, pointerB, pointerE, shape);
}
@Override
public String getString(long index) {
return null;
}
@Override
public INDArray repeat(int dimension, long... repeats) {
return null;
}
@Override
public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) {
return null;
}
@Override
public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
return null;
}
@Override
public INDArray mmuli(INDArray other, MMulTranspose transpose) {
return null;
}
@Override
public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
return null;
}
@Override
public long getLong(long index) {
return 0;
}
@Override
public INDArray reshape(char order, int... newShape) {
return null;
}
@Override
public INDArray reshape(char order, boolean enforceView, long... newShape) {
return null;
}
@Override
public INDArray reshape(int[] shape) {
return null;
}
@Override
public LongShapeDescriptor shapeDescriptor() {
return null;
}
@Override
public int toFlatArray(FlatBufferBuilder builder) {
throw new UnsupportedOperationException();
}
@Override
public boolean isEmpty() {
throw new UnsupportedOperationException();
}
@Override
public boolean isR() {
return false;
}
@Override
public boolean isZ() {
return false;
}
@Override
public boolean isB() {
return false;
}
@Override
public boolean isS() {
return false;
}
@Override
public INDArray castTo(DataType dataType) {
return null;
}
@Override
public boolean all() {
return false;
}
@Override
public boolean any() {
return false;
}
@Override
public boolean none() {
return false;
}
@Override
public boolean closeable() {
return false;
}
@Override
public void close() {
}
@Override
public INDArray assign(boolean value) {
return assign(value ? 1 : 0);
}
}

View File

@ -1,25 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.jcublas;
import org.nd4j.linalg.factory.BaseSparseBlasWrapper;
/**
* @author Audrey Loeffel
*/
public class SparseBlasWrapper extends BaseSparseBlasWrapper {
}

View File

@ -1,25 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.jcublas.blas;
import org.nd4j.linalg.api.blas.impl.SparseBaseLapack;
/**
* @author Audrey Loeffel
*/
public class JcusparseLapack extends SparseBaseLapack {
}

View File

@ -1,147 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.jcublas.blas;
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel1;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* @author Audrey Loeffel
*/
public class JcusparseLevel1 extends SparseBaseLevel1 {
@Override
protected double ddoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
return 0;
}
@Override
protected double sdoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
return 0;
}
@Override
protected double hdoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
return 0;
}
@Override
protected double snrm2(long N, INDArray X, int incx) {
return 0;
}
@Override
protected double dnrm2(long N, INDArray X, int incx) {
return 0;
}
@Override
protected double hnrm2(long N, INDArray X, int incx) {
return 0;
}
@Override
protected double dasum(long N, INDArray X, int incx) {
return 0;
}
@Override
protected double sasum(long N, INDArray X, int incx) {
return 0;
}
@Override
protected double hasum(long N, INDArray X, int incx) {
return 0;
}
@Override
protected int isamax(long N, INDArray X, int incx) {
return 0;
}
@Override
protected int idamax(long N, INDArray X, int incx) {
return 0;
}
@Override
protected int ihamax(long N, INDArray X, int incx) {
return 0;
}
@Override
protected int isamin(long N, INDArray X, int incx) {
return 0;
}
@Override
protected int idamin(long N, INDArray X, int incx) {
return 0;
}
@Override
protected int ihamin(long N, INDArray X, int incx) {
return 0;
}
@Override
protected void daxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y) {
}
@Override
protected void saxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y) {
}
@Override
protected void haxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y) {
}
@Override
protected void droti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
}
@Override
protected void sroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
}
@Override
protected void hroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
}
@Override
protected void dscal(long N, double a, INDArray X, int incx) {
}
@Override
protected void sscal(long N, double a, INDArray X, int incx) {
}
@Override
protected void hscal(long N, double a, INDArray X, int incx) {
}
}

View File

@ -1,38 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.jcublas.blas;
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel2;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* @author Audrey Loeffel
*/
public class JcusparseLevel2 extends SparseBaseLevel2 {
@Override
protected void scoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz, INDArray x, INDArray y) {
}
@Override
protected void dcoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz, INDArray x, INDArray y) {
}
}

View File

@ -1,25 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.jcublas.blas;
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel3;
/**
* @author Audrey Loeffel
*/
public class JcusparseLevel3 extends SparseBaseLevel3 {
}

View File

@ -1,41 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.jcublas.blas;
import org.nd4j.linalg.api.blas.Blas;
import org.nd4j.nativeblas.SparseNd4jBlas;
/**
* @author Audrey Loeffel
*/
public class SparseCudaBlas extends SparseNd4jBlas {
@Override
public void setMaxThreads(int num){
}
@Override
public int getMaxThreads() {
return 0;
}
@Override
public int getBlasVendorId() {
return 0;
}
}

View File

@ -37,6 +37,3 @@ fft = org.nd4j.linalg.jcublas.fft.JcudaFft
opexec.mode= native opexec.mode= native
random=org.nd4j.linalg.jcublas.rng.CudaNativeRandom random=org.nd4j.linalg.jcublas.rng.CudaNativeRandom
sparseinfoprovider = org.nd4j.linalg.jcublas.DirectSparseInfoProvider
sparseblas.ops = org.nd4j.linalg.jcublas.SparseBlasWrapper
sparsendarrayfactory.class = org.nd4j.linalg.jcublas.JCusparseNDArrayFactory

View File

@ -1054,59 +1054,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
convertDataEx(typeSrc, source.addressPointer(), typeDst, target.addressPointer(), target.length()); convertDataEx(typeSrc, source.addressPointer(), typeDst, target.addressPointer(), target.length());
} }
@Override
public INDArray createSparseCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(double[] values, int[][] indices, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(float[] values, int[][] indices, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(double[] values, long[][] indices, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(float[] values, long[][] indices, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape) {
throw new UnsupportedOperationException();
}
@Override @Override
public INDArray sort(INDArray x, boolean descending) { public INDArray sort(INDArray x, boolean descending) {
if (x.isScalar()) if (x.isScalar())

View File

@ -1,606 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.*;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.DataTypeEx;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCOO;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.SparseFormat;
import org.nd4j.linalg.cpu.nativecpu.blas.*;
import org.nd4j.linalg.factory.BaseSparseNDArrayFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.io.File;
import java.util.*;
/**
* @author Audrey Loeffel
*/
// TODO : Implement the methods
@Slf4j
public class CpuSparseNDArrayFactory extends BaseSparseNDArrayFactory {
public CpuSparseNDArrayFactory(){}
@Override
public INDArray createSparseCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape){
return new SparseNDArrayCSR(data, columns, pointerB, pointerE, shape);
}
@Override
public INDArray createSparseCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape){
return new SparseNDArrayCSR(data, columns, pointerB, pointerE, shape);
}
@Override
public INDArray createSparseCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape){
return new SparseNDArrayCSR(data, columns, pointerB, pointerE, shape);
}
@Override
public INDArray createSparseCOO(double[] values, int[][] indices, long[] shape){
return new SparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray createSparseCOO(float[] values, int[][] indices, long[] shape){
return new SparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray create(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(DataType dataType, long[] shape, long[] strides, char ordering, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray createUninitialized(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray createSparseCOO(double[] values, long[][] indices, long[] shape) {
return new SparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray createSparseCOO(float[] values, long[][] indices, long[] shape) {
return new SparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] shape){
return new SparseNDArrayCOO(values, indices, shape);
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape) {
return new SparseNDArrayCOO(values, indices, sparseOffsets, flags, hiddenDimensions, underlyingRank, shape);
}
@Override
public INDArray createSparseCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
return new SparseNDArrayCOO(values, indices, sparseInformation, shape);
}
// TODO ->
@Override
public INDArray pullRows(INDArray source, int sourceDimension, long[] indexes) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, long offset) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, long[] stride, long offset) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(long[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(int[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(short[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(byte[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(boolean[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(long[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(int[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(short[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(byte[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(boolean[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(DataBuffer data, long[] shape) {
return null;
}
@Override
public INDArray create(DataBuffer data, long[] shape, long[] stride, long offset) {
return null;
}
@Override
public INDArray create(List<INDArray> list, long[] shape) {
return null;
}
@Override
public INDArray create(long rows, long columns, long[] stride, long offset) {
return null;
}
@Override
public INDArray create(long[] shape, char ordering) {
return null;
}
@Override
public INDArray createUninitialized(long[] shape, char ordering) {
return null;
}
@Override
public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){
return null;
}
@Override
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, char ordering) {
return null;
}
@Override
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, long ews, char ordering) {
return null;
}
@Override
public INDArray create(DataBuffer data, long[] newShape, long[] newStride, long offset, char ordering, DataType dataType) {
return null;
}
@Override
public INDArray create(List<INDArray> list, long[] shape, char ordering) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, char order, long offset) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, char order, DataType dataType) {
throw new UnsupportedOperationException();
}
@Override
public INDArray specialConcat(int dimension, INDArray... toConcat) {
return null;
}
@Override
public INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, int[] indexes) {
return null;
}
@Override
public INDArray create(float[] data, int[] shape, int[] stride, long offset) {
return null;
}
@Override
public INDArray create(double[] data, int[] shape, int[] stride, long offset) {
return null;
}
@Override
public INDArray create(List<INDArray> list, int[] shape) {
return null;
}
static{
Nd4j.getBlasWrapper();
}
// contructors ?
@Override
public void createBlas(){ blas = new SparseCpuBlas();}
@Override
public void createLevel1() {
level1 = new SparseCpuLevel1();
}
@Override
public void createLevel2() {
level2 = new SparseCpuLevel2();
}
@Override
public void createLevel3() { level3 = new SparseCpuLevel3(); }
@Override
public void createLapack() {
lapack = new SparseCpuLapack();
}
@Override
public INDArray create(int[] shape, DataBuffer buffer) {
return null;
}
@Override
public INDArray toFlattened(char order, Collection<INDArray> matrices) {
return null;
}
@Override
public INDArray create(double[][] data) {
return null;
}
@Override
public INDArray create(double[][] data, char ordering) {
return null;
}
@Override
public void shuffle(INDArray array, Random rnd, int... dimension) {
}
@Override
public void shuffle(Collection<INDArray> array, Random rnd, int... dimension) {
}
@Override
public void shuffle(List<INDArray> array, Random rnd, List<int[]> dimensions) {
}
@Override
public INDArray average(INDArray target, INDArray[] arrays) {
return null;
}
@Override
public INDArray average(INDArray[] arrays) {
return null;
}
@Override
public INDArray average(Collection<INDArray> arrays) {
return null;
}
@Override
public INDArray accumulate(INDArray target, INDArray... arrays) {
return null;
}
@Override
public INDArray average(INDArray target, Collection<INDArray> arrays) {
return null;
}
@Override
public INDArray create(DataBuffer data) {
return null;
}
@Override
public INDArray create(DataBuffer data, long rows, long columns, int[] stride, long offset) {
return null;
}
@Override
public INDArray create(DataBuffer data, int[] shape) {
return null;
}
@Override
public INDArray create(DataBuffer data, int[] shape, int[] stride, long offset) {
return null;
}
@Override
public INDArray create(float[][] floats) {
return null;
}
@Override
public INDArray create(float[][] data, char ordering) {
return null;
}
@Override
public INDArray create(float[] data, int[] shape, int[] stride, long offset, char ordering) {
return null;
}
@Override
public INDArray create(DataBuffer buffer, int[] shape, long offset) {
return null;
}
@Override
public INDArray create(int[] shape, char ordering) {
return null;
}
@Override
public INDArray createUninitialized(int[] shape, char ordering) {
return null;
}
@Override
public INDArray create(DataBuffer data, int[] newShape, int[] newStride, long offset, char ordering) {
return null;
}
@Override
public INDArray create(float[] data, int[] shape, long offset, Character order) {
return null;
}
@Override
public INDArray create(float[] data, long rows, long columns, int[] stride, long offset, char ordering) {
return null;
}
@Override
public INDArray create(double[] data, int[] shape, char ordering) {
return null;
}
@Override
public INDArray create(List<INDArray> list, int[] shape, char ordering) {
return null;
}
@Override
public INDArray create(double[] data, int[] shape, long offset) {
return null;
}
@Override
public INDArray create(double[] data, int[] shape, int[] stride, long offset, char ordering) {
return null;
}
@Override
public INDArray convertDataEx(DataTypeEx typeSrc, INDArray source, DataTypeEx typeDst) {
return null;
}
@Override
public DataBuffer convertDataEx(DataTypeEx typeSrc, DataBuffer source, DataTypeEx typeDst) {
return null;
}
@Override
public void convertDataEx(DataTypeEx typeSrc, DataBuffer source, DataTypeEx typeDst, DataBuffer target) {
}
@Override
public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, Pointer target, long length) {
}
@Override
public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, DataBuffer buffer) {
}
@Override
public INDArray createFromNpyPointer(Pointer pointer) {
return null;
}
@Override
public INDArray createFromNpyHeaderPointer(Pointer pointer) {
return null;
}
@Override
public INDArray createFromNpyFile(File file) {
return null;
}
@Override
public Map<String, INDArray> createFromNpzFile(File file){return null; }
@Override
public Pointer convertToNumpy(INDArray array) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, long[] stride, long offset, char ordering) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, long[] stride, long offset, char ordering) {
return null;
}
@Override
public INDArray[] tear(INDArray tensor, int... dimensions) {
return new INDArray[0];
}
@Override
public INDArray sort(INDArray x, boolean descending) {
if (x.isScalar())
return x;
NativeOpsHolder.getInstance().getDeviceNativeOps().sort(null,
x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(),
null, null,
descending);
return x;
}
@Override
public INDArray sort(INDArray x, boolean descending, int... dimension) {
if (x.isScalar())
return x;
Arrays.sort(dimension);
Pair<DataBuffer, DataBuffer> tadBuffers = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, dimension);
NativeOpsHolder.getInstance().getDeviceNativeOps().sortTad(null,
x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(),
null, null,
new IntPointer(dimension),
dimension.length,
(LongPointer) tadBuffers.getFirst().addressPointer(),
new LongPointerWrapper(tadBuffers.getSecond().addressPointer()),
descending);
return x;
}
@Override
public INDArray sortCooIndices(INDArray x) {
if(x.getFormat() != SparseFormat.COO){
throw new UnsupportedOperationException("Not a COO ndarray");
}
BaseSparseNDArrayCOO array = (BaseSparseNDArrayCOO) x;
DataBuffer val = array.getValues();
DataBuffer idx = array.getIndices();
long length = val.length();
int rank = array.underlyingRank();
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(), val.addressPointer(), length, rank);
return array;
}
@Override
public INDArray create(float[] data, long[] shape, long offset, Character order) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, long offset, Character order) {
return null;
}
@Override
public INDArray create(float[] data, long[] shape, char ordering) {
return null;
}
@Override
public INDArray create(double[] data, long[] shape, char ordering) {
return null;
}
@Override
public INDArray empty(DataType type) {
throw new UnsupportedOperationException();
}
@Override
public INDArray create(Collection<String> strings, long[] shape, char order) {
throw new UnsupportedOperationException();
}
}

View File

@ -1,62 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.BaseSparseInfoProvider;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.SparseDescriptor;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author Audrey Loeffel
*/
public class DirectSparseInfoProvider extends BaseSparseInfoProvider {
private Map<SparseDescriptor, DataBuffer> sparseCache = new ConcurrentHashMap<>();
private AtomicInteger counter = new AtomicInteger(0);
private static final int MAX_ENTRIES = 100;
@Override
public DataBuffer createSparseInformation(int[] flags, long[] sparseOffsets, int[] hiddenDimensions, int underlyingRank) {
SparseDescriptor descriptor = new SparseDescriptor(flags, sparseOffsets, hiddenDimensions, underlyingRank);
if(!sparseCache.containsKey(descriptor)){
if(counter.get() < MAX_ENTRIES){
if(!sparseCache.containsKey(descriptor)){
counter.incrementAndGet();
DataBuffer buffer = Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
sparseCache.put(descriptor, buffer);
return buffer;
}
} else {
return Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
}
}
return sparseCache.get(descriptor);
}
@Override
public void purgeCache() {
sparseCache = new ConcurrentHashMap<>();
}
}

View File

@ -1,25 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu;
import org.nd4j.linalg.factory.BaseSparseBlasWrapper;
/**
* @author Audrey Loeffel
*/
public class SparseBlasWrapper extends BaseSparseBlasWrapper {
}

View File

@ -1,54 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCOO;
/**
* @author Audrey Loeffel
*/
public class SparseNDArrayCOO extends BaseSparseNDArrayCOO {
public SparseNDArrayCOO(double[] values, int[][] indices, long[] shape){
super(values, indices, shape);
}
public SparseNDArrayCOO(float[] values, int[][] indices, long[] shape) {
super(values, indices, shape);
}
public SparseNDArrayCOO(double[] values, long[][] indices, long[] shape){
super(values, indices, shape);
}
public SparseNDArrayCOO(float[] values, long[][] indices, long[] shape) {
super(values, indices, shape);
}
public SparseNDArrayCOO(DataBuffer values, DataBuffer indices, long[] shape){
super(values, indices, shape);
}
public SparseNDArrayCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags, int[] hiddenDimensions, int underlyingRank, long[] shape){
super(values, indices, sparseOffsets, flags, hiddenDimensions, underlyingRank, shape);
}
public SparseNDArrayCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
super(values, indices, sparseInformation, shape);
}
}

View File

@ -1,181 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu;
import com.google.flatbuffers.FlatBufferBuilder;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.*;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
/**
* @author Audrey Loeffel
*/
@Slf4j
public class SparseNDArrayCSR extends BaseSparseNDArrayCSR {
/**
*
*
* The length of the values and columns arrays is equal to the number of non-zero elements in A.
* The length of the pointerB and pointerE arrays is equal to the number of rows in A.
* @param data a double array that contains the non-zero element of the sparse matrix A
* @param columns Element i of the integer array columns is the number of the column in A that contains the i-th value
* in the values array.
* @param pointerB Element j of this integer array gives the index of the element in the values array that is first
* non-zero element in a row j of A. Note that this index is equal to pointerB(j) - pointerB(1)+1 .
* @param pointerE An integer array that contains row indices, such that pointerE(j)-pointerB(1) is the index of the
* element in the values array that is last non-zero element in a row j of A.
* @param shape Shape of the matrix A
*/
public SparseNDArrayCSR(double[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
super(data, columns, pointerB, pointerE, shape);
}
public SparseNDArrayCSR(float[] data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
super(data, columns, pointerB, pointerE, shape);
}
public SparseNDArrayCSR(DataBuffer data, int[] columns, int[] pointerB, int[] pointerE, long[] shape) {
super(data, columns, pointerB, pointerE, shape);
}
@Override
public String getString(long index) {
return null;
}
@Override
public INDArray assign(boolean value) {
return assign(value ? 1 : 0);
}
@Override
public INDArray repeat(int dimension, long... repeats) {
return null;
}
@Override
public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) {
return null;
}
@Override
public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
return null;
}
@Override
public INDArray mmuli(INDArray other, MMulTranspose transpose) {
return null;
}
@Override
public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
return null;
}
@Override
public long getLong(long index) {
return 0;
}
@Override
public INDArray reshape(char order, int... newShape) {
return null;
}
@Override
public INDArray reshape(char order, boolean enforceView, long... newShape) {
return null;
}
@Override
public INDArray reshape(int[] shape) {
return null;
}
@Override
public LongShapeDescriptor shapeDescriptor() {
return null;
}
@Override
public int toFlatArray(FlatBufferBuilder builder) {
throw new UnsupportedOperationException();
}
@Override
public boolean isEmpty() {
return false;
}
@Override
public boolean isR() {
return false;
}
@Override
public boolean isZ() {
return false;
}
@Override
public boolean isB() {
return false;
}
@Override
public INDArray castTo(DataType dataType) {
return null;
}
@Override
public boolean all() {
return false;
}
@Override
public boolean any() {
throw new UnsupportedOperationException();
}
@Override
public boolean none() {
throw new UnsupportedOperationException();
}
@Override
public boolean closeable() {
return false;
}
@Override
public void close() {
}
@Override
public boolean isS() {
return false;
}
}

View File

@ -1,141 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu.blas;
import org.nd4j.nativeblas.SparseNd4jBlas;
import static org.bytedeco.mkl.global.mkl_rt.*;
/**
* @author Audrey Loeffel
*/
public class SparseCpuBlas extends SparseNd4jBlas {
/**
* Converts a character
* to its proper enum
* for row (c) or column (f) ordering
* default is row major
*/
static int convertOrder(int from) {
switch (from) {
case 'c':
case 'C':
return CblasRowMajor;
case 'f':
case 'F':
return CblasColMajor;
default:
return CblasColMajor;
}
}
/**
* Converts a character to its proper enum
* t -> transpose
* n -> no transpose
* c -> conj
*/
static int convertTranspose(int from) {
switch (from) {
case 't':
case 'T':
return CblasTrans;
case 'n':
case 'N':
return CblasNoTrans;
case 'c':
case 'C':
return CblasConjTrans;
default:
return CblasNoTrans;
}
}
/**
* Upper or lower
* U/u -> upper
* L/l -> lower
*
* Default is upper
*/
static int convertUplo(int from) {
switch (from) {
case 'u':
case 'U':
return CblasUpper;
case 'l':
case 'L':
return CblasLower;
default:
return CblasUpper;
}
}
/**
* For diagonals:
* u/U -> unit
* n/N -> non unit
*
* Default: unit
*/
static int convertDiag(int from) {
switch (from) {
case 'u':
case 'U':
return CblasUnit;
case 'n':
case 'N':
return CblasNonUnit;
default:
return CblasUnit;
}
}
/**
* Side of a matrix, left or right
* l /L -> left
* r/R -> right
* default: left
*/
static int convertSide(int from) {
switch (from) {
case 'l':
case 'L':
return CblasLeft;
case 'r':
case 'R':
return CblasRight;
default:
return CblasLeft;
}
}
@Override
public void setMaxThreads(int num){
MKL_Set_Num_Threads(num);
}
@Override
public int getMaxThreads() {
return MKL_Get_Max_Threads();
}
@Override
public int getBlasVendorId() {
return Vendor.MKL.ordinal();
}
}

View File

@ -1,25 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu.blas;
import org.nd4j.linalg.api.blas.impl.SparseBaseLapack;
/**
* @author Audrey Loeffel
*/
public class SparseCpuLapack extends SparseBaseLapack {
}

View File

@ -1,289 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu.blas;
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel1;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.SparseNd4jBlas;
import org.bytedeco.javacpp.*;
import static org.bytedeco.mkl.global.mkl_rt.*;
/**
* @author Audrey Loeffel
*/
public class SparseCpuLevel1 extends SparseBaseLevel1 {
// FIXME: int cast !!!
private SparseNd4jBlas sparseNd4jBlas = (SparseNd4jBlas) Nd4j.sparseFactory().blas();
/**
* Computes the dot product of a compressed sparse double vector by a full-storage real vector.
* @param N The number of elements in x and indx
* @param X an sparse INDArray. Size at least N
* @param indx an Databuffer that Specifies the indices for the elements of x. Size at least N
* @param Y a dense INDArray. Size at least max(indx[i])
* */
@Override
protected double ddoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
return cblas_ddoti((int) N, (DoublePointer) X.data().addressPointer(),(IntPointer) indx.addressPointer(),
(DoublePointer) Y.data().addressPointer());
}
/**
* Computes the dot product of a compressed sparse float vector by a full-storage real vector.
* @param N The number of elements in x and indx
* @param X an sparse INDArray. Size at least N
* @param indx an Databuffer that specifies the indices for the elements of x. Size at least N
* @param Y a dense INDArray. Size at least max(indx[i])
* */
@Override
protected double sdoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
return cblas_sdoti((int) N, (FloatPointer) X.data().addressPointer(),(IntPointer) indx.addressPointer(),
(FloatPointer) Y.data().addressPointer());
}
@Override
protected double hdoti(long N, INDArray X, DataBuffer indx, INDArray Y) {
throw new UnsupportedOperationException();
}
/**
* Computes the Euclidean norm of a float vector
* @param N The number of elements in vector X
* @param X an INDArray
* @param incx the increment of X
* */
@Override
protected double snrm2(long N, INDArray X, int incx){
return cblas_snrm2((int) N, (FloatPointer) X.data().addressPointer(), incx);
}
/**
* Computes the Euclidean norm of a double vector
* @param N The number of elements in vector X
* @param X an INDArray
* @param incx the increment of X
* */
@Override
protected double dnrm2(long N, INDArray X, int incx){
return cblas_dnrm2((int) N, (DoublePointer) X.data().addressPointer(), incx);
}
@Override
protected double hnrm2(long N, INDArray X, int incx){
throw new UnsupportedOperationException();
}
/**
* Compute the sum of magnitude of the double vector elements
*
* @param N The number of elements in vector X
* @param X a double vector
* @param incrx The increment of X
* @return the sum of magnitude of the vector elements
* */
@Override
protected double dasum(long N, INDArray X, int incrx){
return cblas_dasum((int) N, (DoublePointer) X.data().addressPointer(), incrx);
}
/**
* Compute the sum of magnitude of the float vector elements
*
* @param N The number of elements in vector X
* @param X a float vector
* @param incrx The increment of X
* @return the sum of magnitude of the vector elements
* */
@Override
protected double sasum(long N, INDArray X, int incrx){
return cblas_sasum((int) N, (FloatPointer) X.data().addressPointer(), incrx);
}
@Override
protected double hasum(long N, INDArray X, int incrx){
throw new UnsupportedOperationException();
}
/**
* Find the index of the element with maximum absolute value
*
* @param N The number of elements in vector X
* @param X a vector
* @param incX The increment of X
* @return the index of the element with maximum absolute value
* */
@Override
protected int isamax(long N, INDArray X, int incX) {
return (int) cblas_isamax((int) N, (FloatPointer) X.data().addressPointer(), incX);
}
/**
* Find the index of the element with maximum absolute value
*
* @param N The number of elements in vector X
* @param X a vector
* @param incX The increment of X
* @return the index of the element with maximum absolute value
* */
@Override
protected int idamax(long N, INDArray X, int incX) {
return (int) cblas_idamax((int) N, (DoublePointer) X.data().addressPointer(), incX);
}
@Override
protected int ihamax(long N, INDArray X, int incX) {
throw new UnsupportedOperationException();
}
/**
* Find the index of the element with minimum absolute value
*
* @param N The number of elements in vector X
* @param X a vector
* @param incX The increment of X
* @return the index of the element with minimum absolute value
* */
@Override
protected int isamin(long N, INDArray X, int incX) {
return (int) cblas_isamin((int) N, (FloatPointer) X.data().addressPointer(), incX);
}
/**
* Find the index of the element with minimum absolute value
*
* @param N The number of elements in vector X
* @param X a vector
* @param incX The increment of X
* @return the index of the element with minimum absolute value
* */
@Override
protected int idamin(long N, INDArray X, int incX) {
return (int) cblas_idamin((int) N, (DoublePointer) X.data().addressPointer(), incX);
}
@Override
protected int ihamin(long N, INDArray X, int incX) {
throw new UnsupportedOperationException();
}
/**
* Adds a scalar multiple of double compressed sparse vector to a full-storage vector.
*
* @param N The number of elements in vector X
* @param alpha
* @param X a sparse vector
* @param pointers A DataBuffer that specifies the indices for the elements of x.
* @param Y a dense vector
*
* */
@Override
protected void daxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y){
cblas_daxpyi((int) N, alpha, (DoublePointer) X.data().addressPointer(), (IntPointer) pointers.addressPointer(),
(DoublePointer) Y.data().addressPointer());
}
/**
* Adds a scalar multiple of float compressed sparse vector to a full-storage vector.
*
* @param N The number of elements in vector X
* @param alpha
* @param X a sparse vector
* @param pointers A DataBuffer that specifies the indices for the elements of x.
* @param Y a dense vector
*
* */
@Override
protected void saxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y) {
cblas_saxpyi((int) N, (float) alpha, (FloatPointer) X.data().addressPointer(), (IntPointer) pointers.addressPointer(),
(FloatPointer) Y.data().addressPointer());
}
@Override
protected void haxpyi(long N, double alpha, INDArray X, DataBuffer pointers, INDArray Y){
throw new UnsupportedOperationException();
}
/**
* Applies Givens rotation to sparse vectors one of which is in compressed form.
*
* @param N The number of elements in vectors X and Y
* @param X a double sparse vector
* @param indexes The indexes of the sparse vector
* @param Y a double full-storage vector
* @param c a scalar
* @param s a scalar
* */
@Override
protected void droti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
cblas_droti((int) N, (DoublePointer) X.data().addressPointer(), (IntPointer) indexes.addressPointer(),
(DoublePointer) Y.data().addressPointer(), c, s);
}
/**
* Applies Givens rotation to sparse vectors one of which is in compressed form.
*
* @param N The number of elements in vectors X and Y
* @param X a float sparse vector
* @param indexes The indexes of the sparse vector
* @param Y a float full-storage vector
* @param c a scalar
* @param s a scalar
* */
@Override
protected void sroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
cblas_sroti((int) N, (FloatPointer) X.data().addressPointer(), (IntPointer) indexes.addressPointer().capacity(X.columns()),
(FloatPointer) Y.data().addressPointer(), (float) c, (float) s);
}
@Override
protected void hroti(long N, INDArray X, DataBuffer indexes, INDArray Y, double c, double s) {
throw new UnsupportedOperationException();
}
/**
* Computes the product of a double vector by a scalar.
*
* @param N The number of elements of the vector X
* @param a a scalar
* @param X a vector
* @param incx the increment of the vector X
* */
@Override
protected void dscal(long N, double a, INDArray X, int incx) {
cblas_dscal((int) N, a, (DoublePointer) X.data().addressPointer(), incx);
}
/**
* Computes the product of a float vector by a scalar.
*
* @param N The number of elements of the vector X
* @param a a scalar
* @param X a vector
* @param incx the increment of the vector X
* */
@Override
protected void sscal(long N, double a, INDArray X, int incx) {
cblas_sscal((int) N, (float) a, (FloatPointer) X.data().addressPointer(), incx);
}
@Override
protected void hscal(long N, double a, INDArray X, int incx) {
throw new UnsupportedOperationException();
}
}

View File

@ -1,59 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu.blas;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel2;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.SparseNd4jBlas;
import static org.bytedeco.mkl.global.mkl_rt.*;
/**
* @author Audrey Loeffel
*/
public class SparseCpuLevel2 extends SparseBaseLevel2 {
private SparseNd4jBlas sparseNd4jBlas = (SparseNd4jBlas) Nd4j.sparseFactory().blas();
// Mapping with Sparse Blas calls
public void scoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz, INDArray x, INDArray y){
mkl_cspblas_scoogemv(
Character.toString(transA),
(IntPointer) Nd4j.createBuffer(new int[]{M}).addressPointer(),
(FloatPointer) values.addressPointer(),
(IntPointer) rowInd.addressPointer(),
(IntPointer) colInd.addressPointer(),
(IntPointer) Nd4j.createBuffer(new int[]{nnz}).addressPointer(),
(FloatPointer) x.data().addressPointer(),
(FloatPointer)y.data().addressPointer());
}
public void dcoomv(char transA, int M, DataBuffer values, DataBuffer rowInd, DataBuffer colInd, int nnz, INDArray x, INDArray y){
mkl_cspblas_dcoogemv(
Character.toString(transA),
(IntPointer) Nd4j.createBuffer(new int[]{M}).addressPointer(),
(DoublePointer) values.addressPointer(),
(IntPointer) rowInd.addressPointer(),
(IntPointer) colInd.addressPointer(),
(IntPointer) Nd4j.createBuffer(nnz).addressPointer(),
(DoublePointer) x.data().addressPointer(),
(DoublePointer)y.data().addressPointer());
}
}

View File

@ -1,31 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu.blas;
import org.nd4j.linalg.api.blas.impl.SparseBaseLevel3;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.SparseNd4jBlas;
import static org.bytedeco.mkl.global.mkl_rt.*;
/**
* @author Audrey Loeffel
*/
public class SparseCpuLevel3 extends SparseBaseLevel3 {
private SparseNd4jBlas sparseNd4jBlas = (SparseNd4jBlas) Nd4j.sparseFactory().blas();
// TODO Mappings with Sparse Blas methods
}

View File

@ -17,17 +17,15 @@
real.class.double = org.nd4j.linalg.cpu.NDArray real.class.double = org.nd4j.linalg.cpu.NDArray
complex.class.double = org.nd4j.linalg.cpu.nativecpu.complex.ComplexNDArray complex.class.double = org.nd4j.linalg.cpu.nativecpu.complex.ComplexNDArray
shapeinfoprovider = org.nd4j.linalg.cpu.nativecpu.DirectShapeInfoProvider shapeinfoprovider = org.nd4j.linalg.cpu.nativecpu.DirectShapeInfoProvider
sparseinfoprovider = org.nd4j.linalg.cpu.nativecpu.DirectSparseInfoProvider
constantsprovider = org.nd4j.linalg.cpu.nativecpu.cache.ConstantBuffersCache constantsprovider = org.nd4j.linalg.cpu.nativecpu.cache.ConstantBuffersCache
affinitymanager = org.nd4j.linalg.cpu.nativecpu.CpuAffinityManager affinitymanager = org.nd4j.linalg.cpu.nativecpu.CpuAffinityManager
memorymanager = org.nd4j.linalg.cpu.nativecpu.CpuMemoryManager memorymanager = org.nd4j.linalg.cpu.nativecpu.CpuMemoryManager
dtype = float dtype = float
complex.double.class = org.nd4j.linalg.cpu.nativecpu.complex.ComplexDouble complex.double.class = org.nd4j.linalg.cpu.nativecpu.complex.ComplexDouble
blas.ops = org.nd4j.linalg.cpu.nativecpu.BlasWrapper blas.ops = org.nd4j.linalg.cpu.nativecpu.BlasWrapper
sparseblas.ops = org.nd4j.linalg.cpu.nativecpu.SparseBlasWrapper
native.ops= org.nd4j.nativeblas.Nd4jCpu native.ops= org.nd4j.nativeblas.Nd4jCpu
ndarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory ndarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory
sparsendarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuSparseNDArrayFactory
ndarray.order = c ndarray.order = c
resourcemanager_state = false resourcemanager_state = false
databufferfactory = org.nd4j.linalg.api.buffer.factory.DefaultDataBufferFactory databufferfactory = org.nd4j.linalg.api.buffer.factory.DefaultDataBufferFactory

View File

@ -1,659 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg;
import lombok.extern.slf4j.Slf4j;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCOO;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.util.ArrayUtil;
import static org.junit.Assert.*;
/**
* @author Audrey Loeffel
*/
@Slf4j
@Ignore // temporary ignored
public class SparseNDArrayCOOTest extends BaseNd4jTest {
public SparseNDArrayCOOTest(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
return 'c';
}
double[] data = {10, 1, 2, 3, 4, 5};
long[] shape = {2, 2, 2};
int[][] indices = new int[][] {new int[] {0, 0, 0, 1, 2, 2}, new int[] {0, 0, 1, 1, 1, 2},
new int[] {1, 2, 2, 1, 0, 1}};
@Test
public void shouldCreateSparseMatrix() {
// Commented out on removal of Nd4j createSparse methods
/*INDArray sparse = Nd4j.createSparseCOO(data, indices, shape);
assertArrayEquals(shape, sparse.shape());
assertEquals(data.length, sparse.nnz());
*/
}
@Test
public void shouldPutScalar() {
// Commented out on removal of Nd4j createSparse methods
/*
INDArray sparse = Nd4j.createSparseCOO(new double[] {1, 2}, new int[][] {{0, 0}, {0, 2}}, new long[] {1, 3});
sparse.putScalar(1, 3);
*/
}
@Test
public void shouldntPutZero() {
// Commented out on removal of Nd4j createSparse methods
/*
INDArray sparse = Nd4j.createSparseCOO(new double[] {1, 2}, new int[][] {{0, 0}, {0, 2}}, new long[] {1, 3});
int oldNNZ = sparse.nnz();
sparse.putScalar(1, 0);
assertArrayEquals(new int[] {0, 2}, sparse.getVectorCoordinates().asInt());
assertTrue(sparse.isRowVector());
assertEquals(oldNNZ, sparse.nnz());
*/
}
@Test
public void shouldRemoveZero() {
// Commented out on removal of Nd4j createSparse methods
/*
INDArray sparse = Nd4j.createSparseCOO(new double[] {1, 2}, new int[][] {{0, 0}, {0, 2}}, new long[] {1, 3});
sparse.putScalar(0, 0);
assertArrayEquals(new int[] {2}, sparse.getVectorCoordinates().asInt());
*/
}
@Test
public void shouldTakeViewInLeftTopCorner() {
// Test with dense ndarray
double[] data = {0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0};
INDArray array = Nd4j.create(data, new int[] {5, 5}, 0, 'c');
INDArray denseView = array.get(NDArrayIndex.interval(0, 2), NDArrayIndex.interval(0, 2));
// test with sparse :
double[] values = {1, 2, 3, 4};
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
// Commented out on removal of Nd4j createSparse methods
/*
INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
// subarray in the top right corner
BaseSparseNDArrayCOO sparseView = (BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.interval(0, 2),
NDArrayIndex.interval(0, 2));
assertArrayEquals(denseView.shape(), sparseView.shape());
double[] currentValues = sparseView.data().asDouble();
assertArrayEquals(values, currentValues, 1e-5);
assertArrayEquals(ArrayUtil.flatten(indices), sparseView.getUnderlyingIndices().asInt());
assertEquals(0, sparseView.nnz());
System.out.println(sparseView.sparseInfoDataBuffer());
*/
}
@Test
public void shouldTakeViewInLeftBottomCorner() {
double[] values = {1, 2, 3, 4};
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
// Commented out on removal of Nd4j createSparse methods
/*
INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
BaseSparseNDArrayCOO sparseView = (BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.interval(2, 5),
NDArrayIndex.interval(0, 2));
assertEquals(1, sparseView.nnz());
assertArrayEquals(new double[] {3}, sparseView.getIncludedValues().asDouble(), 1e-1);
assertArrayEquals(new int[] {0, 1}, sparseView.getIncludedIndices().asInt());
System.out.println(sparseView.sparseInfoDataBuffer());
*/
}
@Test
public void shouldTakeViewInRightTopCorner() {
double[] values = {1, 2, 3, 4};
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
BaseSparseNDArrayCOO sparseView = (BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.interval(0, 2),
NDArrayIndex.interval(2, 5));
assertEquals(2, sparseView.nnz());
assertArrayEquals(new double[] {1, 2}, sparseView.getIncludedValues().asDouble(), 1e-1);
assertArrayEquals(new int[] {0, 1, 1, 0}, sparseView.getIncludedIndices().asInt());
System.out.println(sparseView.sparseInfoDataBuffer());
*/
}
@Test
public void shouldTakeViewInTheMiddle() {
double[] values = {1, 2, 3, 4};
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
BaseSparseNDArrayCOO sparseView = (BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.interval(1, 3),
NDArrayIndex.interval(1, 3));
assertEquals(2, sparseView.nnz());
assertArrayEquals(new double[] {2, 3}, sparseView.getIncludedValues().asDouble(), 1e-1);
assertArrayEquals(new int[] {0, 1, 1, 0}, sparseView.getIncludedIndices().asInt());
System.out.println(sparseView.sparseInfoDataBuffer());
*/
}
@Test
public void shouldGetFirstColumn() {
double[] values = {1, 2, 3, 4};
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
BaseSparseNDArrayCOO sparseView =
(BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.all(), NDArrayIndex.point(0));
assertEquals(0, sparseView.nnz());
System.out.println(sparseView.sparseInfoDataBuffer());
*/
}
@Test
public void shouldGetRowInTheMiddle() {
double[] values = {1, 2, 3, 4};
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
BaseSparseNDArrayCOO sparseView =
(BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.point(2), NDArrayIndex.all());
assertEquals(1, sparseView.nnz());
assertArrayEquals(new int[] {0, 1}, sparseView.getIncludedIndices().asInt());
assertArrayEquals(new double[] {3}, sparseView.getIncludedValues().asDouble(), 1e-1);
System.out.println(sparseView.sparseInfoDataBuffer());
*/
}
@Test
public void shouldGetScalar() {
double[] values = {1, 2, 3, 4};
int[][] indices = {{0, 3}, {1, 2}, {2, 1}, {3, 4}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray sparseNDArray = Nd4j.createSparseCOO(values, indices, new long[] {5, 5});
BaseSparseNDArrayCOO sparseView =
(BaseSparseNDArrayCOO) sparseNDArray.get(NDArrayIndex.point(2), NDArrayIndex.point(1));
assertEquals(1, sparseView.nnz());
assertArrayEquals(new int[] {0, 0}, sparseView.getIncludedIndices().asInt());
assertArrayEquals(new double[] {3}, sparseView.getIncludedValues().asDouble(), 1e-1);
assertTrue(sparseView.isScalar());
*/
}
@Test
public void shouldTakeView3dimensionArray() {
long[] shape = new long[] {2, 2, 2};
double[] values = new double[] {2, 1, 4, 3};
int[][] indices = new int[][] {{0, 0, 0}, {1, 0, 1}, {1, 1, 0}, {1, 1, 1}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
BaseSparseNDArrayCOO view =
(BaseSparseNDArrayCOO) array.get(NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all());
assertEquals(2, view.nnz());
assertArrayEquals(new long[] {2, 2}, view.shape());
assertArrayEquals(new int[] {0, 0, 1, 1}, view.getIncludedIndices().asInt());
assertArrayEquals(new double[] {2, 1}, view.getIncludedValues().asDouble(), 1e-1);
System.out.println(view.sparseInfoDataBuffer());
*/
}
@Test
public void shouldTakeViewOfView() {
long[] shape = new long[] {2, 2, 2};
double[] values = new double[] {2, 1, 4, 3};
int[][] indices = new int[][] {{0, 0, 0}, {1, 0, 1}, {1, 1, 0}, {1, 1, 1}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
BaseSparseNDArrayCOO baseView =
(BaseSparseNDArrayCOO) array.get(NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all());
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) baseView.get(NDArrayIndex.point(1), NDArrayIndex.all());
assertEquals(1, view.nnz());
assertArrayEquals(new long[] {1, 2}, view.shape());
assertArrayEquals(new int[] {0, 1}, view.getIncludedIndices().asInt());
assertArrayEquals(new double[] {1}, view.getIncludedValues().asDouble(), 1e-1);
*/
}
@Test
public void shouldTakeViewOfView2() {
long[] shape = new long[] {4, 2, 3};
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
{3, 0, 1}, {3, 1, 0}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
BaseSparseNDArrayCOO baseView = (BaseSparseNDArrayCOO) array.get(NDArrayIndex.interval(1, 4),
NDArrayIndex.point(1), NDArrayIndex.all());
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) baseView.get(NDArrayIndex.all(), NDArrayIndex.point(2));
assertEquals(2, view.nnz());
assertArrayEquals(new long[] {3, 1}, view.shape());
assertArrayEquals(new int[] {0, 0, 1, 0}, view.getIncludedIndices().asInt());
assertArrayEquals(new double[] {5, 7}, view.getIncludedValues().asDouble(), 1e-1);
assertTrue(view.isColumnVector());
*/
}
@Test
public void shouldGetWithSpecifiedIndexes() {
long[] shape = new long[] {4, 2, 3};
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
{3, 0, 1}, {3, 1, 0}};
// Commented out on removal of Nd4j createSparse methods
/*INDArray array = Nd4j.createSparseCOO(values, indices, shape);
BaseSparseNDArrayCOO newArray = (BaseSparseNDArrayCOO) array.get(new SpecifiedIndex(0, 3),
NDArrayIndex.all(), NDArrayIndex.all());
assertEquals(4, newArray.nnz());
assertArrayEquals(new double[] {1, 2, 8, 9}, newArray.getIncludedValues().asDouble(), 1e-1);
assertArrayEquals(new int[] {0, 0, 2, 0, 1, 1, 1, 0, 1, 1, 1, 0}, newArray.getIncludedIndices().asInt());
*/
}
@Test
public void shouldGetWithSpecifiedIndexes2() {
long[] shape = new long[] {4, 2, 3};
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
{3, 0, 2}, {3, 1, 0}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
BaseSparseNDArrayCOO newArray = (BaseSparseNDArrayCOO) array.get(NDArrayIndex.interval(1, 4),
new SpecifiedIndex(0), new SpecifiedIndex(0, 2));
assertEquals(2, newArray.nnz());
assertArrayEquals(new double[] {3, 8}, newArray.getIncludedValues().asDouble(), 1e-1);
assertArrayEquals(new int[] {0, 0, 2, 1}, newArray.getIncludedIndices().asInt());
*/
}
@Test
public void specifiedIndexWithDenseArray() {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
System.out.println(arr.toString());
INDArray v = arr.get(NDArrayIndex.interval(1, 3), new SpecifiedIndex(0),
new SpecifiedIndex(0, 2));
System.out.println("v ");
System.out.println(v.toString());
}
@Test
public void newAxisWithSparseArray() {
long[] shape = new long[] {4, 2, 3};
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
{3, 0, 2}, {3, 1, 0}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
INDArray v = array.get(NDArrayIndex.point(0), NDArrayIndex.newAxis());
System.out.println(v.shapeInfoDataBuffer());
*/
}
@Test
public void nestedSparseViewWithNewAxis() {
long[] shape = new long[] {4, 2, 3};
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
{3, 0, 2}, {3, 1, 0}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray array = Nd4j.createSparseCOO(values, indices, shape);
System.out.println("\nTaking view (all, point(1), all");
INDArray v = array.get(NDArrayIndex.all(), NDArrayIndex.point(1));
System.out.println(v.toString());
System.out.println(v.shapeInfoDataBuffer());
System.out.println("Fixed dimension " + v.flags());
System.out.println("sparse offsets " + v.sparseOffsets());
System.out.println("hidden dimensions " + v.hiddenDimensions());
System.out.println("number of hidden dimensions " + ((BaseSparseNDArrayCOO) v).getNumHiddenDimension());
// shape 4 x 3
System.out.println("\nTaking view (all new axis");
INDArray v1 = v.get(NDArrayIndex.all(), NDArrayIndex.newAxis());
System.out.println(v1.toString());
System.out.println(v1.shapeInfoDataBuffer());
System.out.println("Fixed dimension " + v1.flags());
System.out.println("sparse offsets " + v1.sparseOffsets());
System.out.println("hidden dimensions " + v1.hiddenDimensions());
System.out.println("number of hidden dimensions " + ((BaseSparseNDArrayCOO) v1).getNumHiddenDimension());
// shape 4 x 1 x 3
System.out.println("\nTaking view (all new axis");
v1 = v.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.newAxis());
System.out.println(v1.toString());
System.out.println(v1.shapeInfoDataBuffer());
System.out.println("Fixed dimension " + v1.flags());
System.out.println("sparse offsets " + v1.sparseOffsets());
System.out.println("hidden dimensions " + v1.hiddenDimensions());
System.out.println("number of hidden dimensions " + ((BaseSparseNDArrayCOO) v1).getNumHiddenDimension());
*/
}
@Test
public void nestedViewWithNewAxis() {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
System.out.println(arr.toString());
System.out.println(arr.shapeInfoDataBuffer());
System.out.println("\nTaking view (all, point(1), all");
INDArray v = arr.get(NDArrayIndex.all(), NDArrayIndex.point(1));
System.out.println(v.toString());
System.out.println(v.shapeInfoDataBuffer());
// shape 4 x 3
System.out.println("\nTaking view (all new axis");
INDArray v1 = v.get(NDArrayIndex.all(), NDArrayIndex.newAxis());
System.out.println(v1.toString());
System.out.println(v1.shapeInfoDataBuffer());
// shape 4 x 1 x 3
System.out.println("\nTaking view (all new axis");
v1 = v1.get(NDArrayIndex.newAxis());
System.out.println(v1.toString());
System.out.println(v1.shapeInfoDataBuffer());
// shape 4 x 3
}
@Test
public void shouldTranslateViewIndexesToOriginal() {
long[] shape = new long[] {4, 2, 3};
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
{3, 0, 2}, {3, 1, 0}};
// Commented out on removal of Nd4j createSparse methods
/*INDArray original = Nd4j.createSparseCOO(values, indices, shape);
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) original.get(NDArrayIndex.all(), NDArrayIndex.point(1));
int[] originalIdx = view.translateToPhysical(new int[] {0, 0});
int[] exceptedIdx = new int[] {0, 1, 0};
assertArrayEquals(exceptedIdx, originalIdx);
*/
}
@Test
public void shouldTranslateViewIndexesToOriginal2() {
long[] shape = new long[] {4, 2, 3};
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
{3, 0, 2}, {3, 1, 0}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray original = Nd4j.createSparseCOO(values, indices, shape);
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) original.get(NDArrayIndex.all(), NDArrayIndex.newAxis(),
NDArrayIndex.point(1));
assertArrayEquals(new int[] {0, 1, 0}, view.translateToPhysical(new int[] {0, 0, 0}));
assertArrayEquals(new int[] {1, 1, 1}, view.translateToPhysical(new int[] {1, 0, 1}));
*/
}
@Test
public void shouldTranslateViewIndexesToOriginal3() {
long[] shape = new long[] {4, 2, 3, 3};
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
int[][] indices = new int[][] {{0, 0, 2, 0}, {0, 1, 1, 1}, {1, 0, 0, 0}, {1, 0, 1, 0}, {1, 1, 2, 1},
{2, 0, 1, 0}, {2, 1, 2, 0}, {3, 0, 2, 1}, {3, 1, 0, 1}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray original = Nd4j.createSparseCOO(values, indices, shape);
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) original.get(NDArrayIndex.all(), NDArrayIndex.newAxis(),
NDArrayIndex.point(1), NDArrayIndex.point(2));
assertArrayEquals(new int[] {0, 1, 2, 0}, view.translateToPhysical(new int[] {0, 0, 0}));
assertArrayEquals(new int[] {1, 1, 2, 1}, view.translateToPhysical(new int[] {1, 0, 1}));
*/
}
@Test
public void shouldTranslateViewWithPrependNewAxis() {
// TODO FIX get view with a new prepend axis
long[] shape = new long[] {4, 2, 3};
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
{3, 0, 2}, {3, 1, 0}};
// Commented out on removal of Nd4j createSparse methods
/* INDArray original = Nd4j.createSparseCOO(values, indices, shape);
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) original.get(NDArrayIndex.newAxis(), NDArrayIndex.all(),
NDArrayIndex.point(1));
System.out.println(view.getIncludedIndices());
System.out.println(view.getIncludedValues());
assertArrayEquals(new int[] {0, 1, 0}, view.translateToPhysical(new int[] {0, 0, 0}));
assertArrayEquals(new int[] {1, 1, 1}, view.translateToPhysical(new int[] {0, 1, 1}));
int[] originalIdx = view.translateToPhysical(new int[] {0, 1, 2});
int[] exceptedIdx = new int[] {1, 0, 2};
assertArrayEquals(exceptedIdx, originalIdx);
*/
}
@Test
public void shouldSortCOOIndices() {
long[] shape = new long[] {4, 3, 3};
double[] values = new double[] {1};
long[][] indices = new long[][] {{0, 0, 0}};
// commented out on removal of createSparse methods from Nd4j
/* INDArray original = Nd4j.createSparseCOO(values, indices, shape);
original.putScalar(2, 2, 2, 3);
original.putScalar(1, 1, 1, 2);
BaseSparseNDArrayCOO view = (BaseSparseNDArrayCOO) original.get(NDArrayIndex.all());
int[] expectedIdx = new int[] {0, 0, 0, 1, 1, 1, 2, 2, 2};
double[] expectedValues = new double[] {1, 2, 3};
assertArrayEquals(expectedIdx, view.getIncludedIndices().asInt());
assertArrayEquals(expectedValues, view.getIncludedValues().asDouble(), 1e-5);
assertTrue(view == original);
*/
}
@Test
public void testWithDense() {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
System.out.println(arr);
INDArray view = arr.get(NDArrayIndex.all(), NDArrayIndex.point(1));
// System.out.println(view.shapeInfoDataBuffer());
view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all());
System.out.println("view");
System.out.println(view);
System.out.println(view.shapeInfoDataBuffer());
}
@Test
public void newAxisWithDenseArray() {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
System.out.println(arr.toString());
System.out.println(arr.shapeInfoDataBuffer());
System.out.println("\npoint 0");
INDArray v = arr.get(NDArrayIndex.point(0));
System.out.println(v.shapeInfoDataBuffer());
// => shape 2 x 3
System.out.println("new axis, all, point 1");
v = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
//System.out.println(v.toString());
v = arr.get(NDArrayIndex.interval(1, 4), NDArrayIndex.point(0), NDArrayIndex.newAxis());
System.out.println(v.shapeInfoDataBuffer());
System.out.println(v.isView());
// => shape 1 x 2 x 3
System.out.println("\npoint 0, newaxis");
v = arr.get(NDArrayIndex.point(0), NDArrayIndex.newAxis());
System.out.println(v.shapeInfoDataBuffer());
System.out.println(v.isView());
// => shape 1 x 2 x 3
System.out.println("\n point 0, newaxis, newaxis");
v = arr.get(NDArrayIndex.point(0), NDArrayIndex.newAxis(), NDArrayIndex.newAxis());
System.out.println(v.shapeInfoDataBuffer());
// => shape 1 x 1 x 2 x 3
System.out.println("\n new axis, point 0, newaxis");
v = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.point(0), NDArrayIndex.newAxis());
System.out.println(v.shapeInfoDataBuffer());
// => shape 1 x 1 x 2 x 3
System.out.println("\nget( new axis, point(0), point(0), new axis)");
v = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.newAxis());
System.out.println(v.shapeInfoDataBuffer());
System.out.println(v.toString());
// => shape 1 x 1 x 3 x 1
System.out.println("\nget( specified(1), specified(0), new axis)");
v = arr.get(new SpecifiedIndex(0), NDArrayIndex.newAxis());
System.out.println(v.shapeInfoDataBuffer());
System.out.println(v.toString());
// => crash
// System.out.println("\nget( new axis, point(0), new axis, point(0))");
// v = arr.get( NDArrayIndex.newAxis(), NDArrayIndex.point(0), NDArrayIndex.newAxis(), NDArrayIndex.point(0));
// System.out.println(v.shapeInfoDataBuffer());
// System.out.println(v.toString());
// => crash
System.out.println("\n interval(0, 2), newaxis");
v = arr.get(NDArrayIndex.interval(0, 2), NDArrayIndex.newAxis());
System.out.println(v.shapeInfoDataBuffer());
// => shape 1 x 2 x 2 x 3 - new axis is added at the first position
/* System.out.println("\n point 0 , all, new axis");
v = arr.get(
NDArrayIndex.point(0),
NDArrayIndex.all(),
NDArrayIndex.newAxis());
System.out.println(v.shapeInfoDataBuffer());
*/
// => crash
}
@Test
public void testDenseNewAxisWithSpecifiedIdx() {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
INDArray v = arr.get(new SpecifiedIndex(0), NDArrayIndex.newAxis());
System.out.println(v.shapeInfoDataBuffer());
System.out.println(v.toString());
}
@Test
public void testDenseNewAxisWithSpecifiedIdx2() {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
INDArray v = arr.get(NDArrayIndex.newAxis(), new SpecifiedIndex(0, 1), NDArrayIndex.all());
System.out.println(v.shapeInfoDataBuffer());
System.out.println(v.toString());
}
@Test
public void testDenseNewAxisWithSpecifiedIdx3() {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
INDArray v = arr.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.newAxis());
System.out.println(v.shapeInfoDataBuffer());
System.out.println(v.toString());
}
@Test
public void testDenseWithNewAxis() {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
System.out.println(view);
}
@Test
public void testWithPrependNewAxis() {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
System.out.println(arr.toString());
System.out.println(arr.shapeInfoDataBuffer());
System.out.println("new axis, all, point 1");
INDArray v = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
System.out.println(v.toString());
System.out.println(v.shapeInfoDataBuffer());
}
@Test
public void binarySearch() {
long[] shape = new long[] {4, 2, 3};
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
int[][] indices = new int[][] {{0, 0, 2}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 2}, {2, 0, 1}, {2, 1, 2},
{3, 0, 2}, {3, 1, 0}};
// Commented out on removal of Nd4j createSparse methods
/* BaseSparseNDArrayCOO array = (BaseSparseNDArrayCOO) Nd4j.createSparseCOO(values, indices, shape);
assertEquals(0, array.reverseIndexes(0, 0, 2));
assertEquals(7, array.reverseIndexes(3, 0, 2));
assertEquals(8, array.reverseIndexes(3, 1, 0));
*/
}
@Test
public void rdmTest(){
INDArray i = Nd4j.rand(new int[]{3, 3, 3});
INDArray ii = i.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all());
System.out.println(ii);
System.out.println(ii.shapeInfoDataBuffer());
}
@Test
public void tryToFindABugWithHiddenDim(){
long[] shape = new long[] {1, 4, 2, 3};
double[] values = new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
int[][] indices = new int[][] {{0, 0, 0, 2}, {0, 0, 1, 1}, {0, 1, 0, 0}, {0, 1, 0, 1}, {0, 1, 1, 2}, {0, 2, 0, 1}, {0, 2, 1, 2},
{0, 3, 0, 2}, {0, 3, 1, 0}};
// Commented out on removal of Nd4j createSparse methods
/* BaseSparseNDArrayCOO array = (BaseSparseNDArrayCOO) Nd4j.createSparseCOO(values, indices, shape);
BaseSparseNDArrayCOO view1 = (BaseSparseNDArrayCOO) array.get( NDArrayIndex.point(0), NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.point(0));
System.out.println(view1.shapeInfoDataBuffer());
System.out.println(view1.sparseInfoDataBuffer());
BaseSparseNDArrayCOO view2 = (BaseSparseNDArrayCOO) view1.get( NDArrayIndex.point(0), NDArrayIndex.newAxis(),NDArrayIndex.newAxis(), NDArrayIndex.point(0));
System.out.println(view2.shapeInfoDataBuffer());
System.out.println(view2.sparseInfoDataBuffer());
*/
}
}

View File

@ -1,270 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCSR;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertTrue;
/**
* @author Audrey Loeffel
*/
@Ignore // temporary ignored
public class SparseNDArrayCSRTest extends BaseNd4jTest {
public SparseNDArrayCSRTest(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
return 'c';
}
/*
* [[1 -1 0 -3 0]
* [-2 4 0 0 0 ]
* [ 0 0 4 6 4 ] = A
* [-4 0 2 7 0 ]
* [ 0 8 0 0 -5]]
* */
// CSR representation of the matrix A according to https://software.intel.com/en-us/node/599835
private double[] values = {1, -2, -3, -2, 5, 4, 6, 4, -4, 2, 7, 8, -5};
private int[] columns = {0, 1, 3, 0, 1, 2, 3, 4, 0, 2, 3, 1, 4};
private int[] pointerB = {0, 3, 5, 8, 11};
private int[] pointerE = {3, 5, 8, 11, 13};
private long[] shape = {5, 5};
@Test
public void shouldAddValueAtAGivenPosition() {
/*
* [[1 -1 0 -3 0]
* [-2 4 0 0 0 ]
* [ 0 3 4 6 4 ] = A'
* [-4 0 2 7 0 ]
* [ 0 8 0 0 -5]]
* */
// commented out on removal of createSparse methods from Nd4j
/*
INDArray sparseNDArray = Nd4j.createSparseCSR(values, columns, pointerB, pointerE, shape);
if (sparseNDArray instanceof BaseSparseNDArrayCSR) {
BaseSparseNDArrayCSR sparseCSRArray = (BaseSparseNDArrayCSR) sparseNDArray;
sparseCSRArray.putScalar(2, 1, 3);
double[] expectedValues = {1, -2, -3, -2, 5, 3, 4, 6, 4, -4, 2, 7, 8, -5};
double[] expectedColumns = {0, 1, 3, 0, 1, 1, 2, 3, 4, 0, 2, 3, 1, 4};
int[] expectedPointerB = {0, 3, 5, 9, 12};
int[] expectedPointerE = {3, 5, 9, 12, 14};
long[] expectedShape = {5, 5};
assertArrayEquals(expectedValues, sparseCSRArray.getDoubleValues(), 0);
assertArrayEquals(expectedColumns, sparseCSRArray.getColumns(), 0);
assertArrayEquals(expectedPointerB, sparseCSRArray.getPointerBArray());
assertArrayEquals(expectedPointerE, sparseCSRArray.getPointerEArray());
assertArrayEquals(expectedShape, shape);
}
*/
}
@Test
public void shouldReallocate() {
// commented out on removal of createSparse methods from Nd4j
/*
INDArray sparseNDArray = Nd4j.createSparseCSR(values, columns, pointerB, pointerE, shape);
if (sparseNDArray instanceof BaseSparseNDArrayCSR) {
BaseSparseNDArrayCSR sparseCSRArray = (BaseSparseNDArrayCSR) sparseNDArray;
int initialSize = sparseCSRArray.getDoubleValues().length;
for (int i = 0; i < shape[0]; i++) {
for (int j = 0; j < shape[1]; j++) {
sparseCSRArray.putScalar(i, j, i + j);
}
}
int finalSize = sparseCSRArray.getDoubleValues().length;
assertTrue(finalSize > initialSize);
}
*/
}
@Test
public void shouldReplaceValueAtAGivenPosition() {
/*
* [[1 -1 0 -3 0]
* [-2 4 0 0 0 ]
* [ 0 0 10 6 4] = A'
* [-4 0 2 7 0 ]
* [ 0 8 0 0 -5]]
* */
// commented out on removal of createSparse methods from Nd4j
/*
INDArray sparseNDArray = Nd4j.createSparseCSR(values, columns, pointerB, pointerE, shape);
if (sparseNDArray instanceof BaseSparseNDArrayCSR) {
BaseSparseNDArrayCSR sparseCSRArray = (BaseSparseNDArrayCSR) sparseNDArray;
sparseCSRArray.putScalar(2, 2, 10);
double[] expectedValues = {1, -2, -3, -2, 5, 10, 6, 4, -4, 2, 7, 8, -5};
double[] expectedColumns = {0, 1, 3, 0, 1, 2, 3, 4, 0, 2, 3, 1, 4};
int[] expectedPointerB = {0, 3, 5, 8, 11};
int[] expectedPointerE = {3, 5, 8, 11, 13};
long[] expectedShape = {5, 5};
assertArrayEquals(expectedValues, sparseCSRArray.getDoubleValues(), 0);
assertArrayEquals(expectedColumns, sparseCSRArray.getColumns(), 0);
assertArrayEquals(expectedPointerB, sparseCSRArray.getPointerBArray());
assertArrayEquals(expectedPointerE, sparseCSRArray.getPointerEArray());
assertArrayEquals(expectedShape, shape);
}
*/
}
@Test
public void shouldGetValueAtAGivenPosition() {
// Not yet implemented
}
@Test
public void shouldBeEqualToDense() {
// Not yet implemented
}
@Test
public void shouldGetAView() {
double[] values = {1, -1, -3, -2, 5, 4, 6, 4, -4, 2, 7, 8, 5};
int[] columns = {0, 1, 3, 0, 1, 2, 3, 4, 0, 2, 3, 1, 4};
int[] pointerB = {0, 3, 5, 8, 11};
int[] pointerE = {3, 5, 8, 11, 13};
// Test with dense ndarray
double[] data = {1, -1, 0, -3, 0, -2, 5, 0, 0, 0, 0, 0, 4, 6, 4, -4, 0, 2, 7, 0, 0, 8, 0, 0, 5};
INDArray array = Nd4j.create(data, new int[] {5, 5}, 0, 'c');
INDArray denseView = array.get(NDArrayIndex.interval(1, 3), NDArrayIndex.interval(1, 3));
// test with sparse :
// commented out on removal of createSparse methods from Nd4j
/*
INDArray sparseNDArray = Nd4j.createSparseCSR(values, columns, pointerB, pointerE, shape);
// subarray in the top right corner
BaseSparseNDArrayCSR sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.interval(0, 3),
NDArrayIndex.interval(3, 5));
assertArrayEquals(new int[] {0, 0, 1}, sparseView.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {2, 3, 6}, sparseView.getPointerBArray());
assertArrayEquals(new int[] {3, 3, 8}, sparseView.getPointerEArray());
// subarray in the middle
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.interval(1, 3), NDArrayIndex.interval(1, 3));
assertArrayEquals(new int[] {0, 1}, sparseView.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {4, 5}, sparseView.getPointerBArray());
assertArrayEquals(new int[] {5, 6}, sparseView.getPointerEArray());
// get the first row
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.all(), NDArrayIndex.point(0));
assertArrayEquals(new int[] {0, 0, 0}, sparseView.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {0, 3, 4, 8, 9}, sparseView.getPointerBArray());
assertArrayEquals(new int[] {1, 4, 4, 9, 9}, sparseView.getPointerEArray());
// get several rows
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.interval(0, 2), NDArrayIndex.all());
assertArrayEquals(new int[] {0, 1, 3, 0, 1}, sparseView.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {0, 3}, sparseView.getPointerBArray());
assertArrayEquals(new int[] {3, 5}, sparseView.getPointerEArray());
// get a row in the middle
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.point(2), NDArrayIndex.all());
assertArrayEquals(new int[] {2, 3, 4}, sparseView.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {5}, sparseView.getPointerBArray());
assertArrayEquals(new int[] {8}, sparseView.getPointerEArray());
// get the first column
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.all(), NDArrayIndex.point(0));
assertArrayEquals(new int[] {0, 0, 0}, sparseView.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {0, 3, 4, 8, 9}, sparseView.getPointerBArray());
assertArrayEquals(new int[] {1, 4, 4, 9, 9}, sparseView.getPointerEArray());
// get a column in the middle
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.all(), NDArrayIndex.point(2));
assertArrayEquals(new int[] {0, 0}, sparseView.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {0, 0, 5, 9, 10}, sparseView.getPointerBArray());
assertArrayEquals(new int[] {0, 0, 6, 10, 10}, sparseView.getPointerEArray());
// get a part of the column in the middle
sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.interval(1, 4), NDArrayIndex.point(2));
assertArrayEquals(new int[] {0, 0}, sparseView.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {0, 5, 9}, sparseView.getPointerBArray());
assertArrayEquals(new int[] {0, 6, 10}, sparseView.getPointerEArray());
*/
}
@Test
public void shouldGetAViewFromView() {
double[] values = {1, -1, -3, -2, 5, 4, 6, 4, -4, 2, 7, 8, 5};
int[] columns = {0, 1, 3, 0, 1, 2, 3, 4, 0, 2, 3, 1, 4};
int[] pointerB = {0, 3, 5, 8, 11};
// commented out on removal of createSparse methods from Nd4j
//INDArray sparseNDArray = Nd4j.createSparseCSR(values, columns, pointerB, pointerE, shape);
/* [0, -3, 0]
* sparseView = [0, 0, 0] subview = [[0,0], [4,6]]
* [4, 6, 4]
*/
/*BaseSparseNDArrayCSR sparseView = (BaseSparseNDArrayCSR) sparseNDArray.get(NDArrayIndex.interval(0, 3),
NDArrayIndex.interval(2, 5));
BaseSparseNDArrayCSR subview =
(BaseSparseNDArrayCSR) sparseView.get(NDArrayIndex.interval(1, 3), NDArrayIndex.interval(0, 2));
assertArrayEquals(new int[] {0, 1}, subview.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {0, 5}, subview.getPointerBArray());
assertArrayEquals(new int[] {0, 7}, subview.getPointerEArray());
// get the first column
subview = (BaseSparseNDArrayCSR) sparseView.get(NDArrayIndex.all(), NDArrayIndex.point(0));
assertArrayEquals(new int[] {0}, subview.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {0, 0, 5}, subview.getPointerBArray());
assertArrayEquals(new int[] {0, 0, 6}, subview.getPointerEArray());
// get a column in the middle
subview = (BaseSparseNDArrayCSR) sparseView.get(NDArrayIndex.all(), NDArrayIndex.point(1));
assertArrayEquals(new int[] {0, 0}, subview.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {2, 3, 6}, subview.getPointerBArray());
assertArrayEquals(new int[] {3, 3, 7}, subview.getPointerEArray());
// get the first row
subview = (BaseSparseNDArrayCSR) sparseView.get(NDArrayIndex.point(0), NDArrayIndex.all());
assertArrayEquals(new int[] {1}, subview.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {2}, subview.getPointerBArray());
assertArrayEquals(new int[] {3}, subview.getPointerEArray());
// get a row in the middle
subview = (BaseSparseNDArrayCSR) sparseView.get(NDArrayIndex.point(1), NDArrayIndex.all());
assertArrayEquals(new int[] {}, subview.getVectorCoordinates().asInt());
assertArrayEquals(new int[] {0}, subview.getPointerBArray());
assertArrayEquals(new int[] {0}, subview.getPointerEArray());
*/
}
}

View File

@ -1,167 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.blas;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.BaseSparseNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals;
/**
* @author Audrey Loeffel
*/
@Ignore // temporary ignored
@RunWith(Parameterized.class)
public class SparseCOOLevel1Test extends BaseNd4jTest {
// vector = [1, 2, 0, 4]
private double[] data = {1, 2, 4};
private int[][] indexes = new int[][] {{0, 0}, {0, 1}, {0, 3}};
private long[] shape = {1, 4};
public SparseCOOLevel1Test(Nd4jBackend backend) {
super(backend);
}
@Test
public void shouldComputeDot() {
// Commented out on removal of Nd4j createSparse methods
/*INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
//INDArray vec = Nd4j.create( new double[] {1 ,2, 3, 4});
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.FLOAT).reshape(1, 4);
INDArray vec = matrix.getRow(0);
assertEquals(21, Nd4j.getBlasWrapper().dot(sparseVec, vec), 1e-1);
*/
}
@Test
public void shouldComputeNrm2() {
// Commented out on removal of Nd4j createSparse methods
//INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
//assertEquals(Math.sqrt(21), Nd4j.getBlasWrapper().nrm2(sparseVec), 1e-1);
}
@Test
public void shouldComputeAsum() {
// Commented out on removal of Nd4j createSparse methods
//INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
//assertEquals(7, Nd4j.getBlasWrapper().asum(sparseVec), 1e-1);
}
@Test
public void shouldComputeIamax() {
// Commented out on removal of Nd4j createSparse methods
// INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
// assertEquals(2, Nd4j.getBlasWrapper().iamax(sparseVec), 1e-1);
}
@Test
public void shouldComputeIamin() {
// Commented out on removal of Nd4j createSparse methods
// INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
// assertEquals(0, Nd4j.getBlasWrapper().level1().iamin(sparseVec), 1e-1);
}
@Test
public void shouldComputeAxpy() {
// Commented out on removal of Nd4j createSparse methods
/* INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray expected = Nd4j.create(new double[] {2, 4, 3, 8});
Nd4j.getBlasWrapper().level1().axpy(vec.length(), 1, sparseVec, vec);
assertEquals(getFailureMessage(), expected, vec);
*/
}
@Test
public void shouldComputeRot() {
// try with dense vectors to get the expected result
INDArray temp1 = Nd4j.create(new double[] {1, 2, 0, 4});
INDArray temp2 = Nd4j.create(new double[] {1, 2, 3, 4});
System.out.println("before: " + temp1.data() + " " + temp2.data());
Nd4j.getBlasWrapper().level1().rot(temp1.length(), temp1, temp2, 1, 2);
System.out.println("after: " + temp1.data() + " " + temp2.data());
//before: [1.0,2.0,0.0,4.0] [1.0,2.0,3.0,4.0]
// after: [3.0,6.0,6.0,12.0] [-1.0,-2.0,3.0,-4.0]
// Commented out on removal of Nd4j createSparse methods
/*INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
Nd4j.getBlasWrapper().level1().rot(vec.length(), sparseVec, vec, 1, 2);
System.out.println(sparseVec.data() + " " + vec.data());
//System.out.println("indexes: " + ((BaseSparseNDArray) sparseVec).getVectorCoordinates().toString());
INDArray expectedSparseVec = Nd4j.createSparseCSR(new double[] {3, 6, 6, 12}, new int[] {0, 1, 2, 3},
new int[] {0}, new int[] {4}, new long[] {1, 4});
INDArray expectedVec = Nd4j.create(new double[] {-1, -2, 3, -4});
assertEquals(getFailureMessage(), expectedSparseVec.data(), sparseVec.data());
assertEquals(getFailureMessage(), expectedVec, vec);
*/
// TODO FIXME
}
@Test
public void shouldComputeRotWithFullVector() {
// try with dense vectors to get the expected result
/*
INDArray temp1 = Nd4j.create( new double[] {1 ,2, 3, 4});
INDArray temp2 = Nd4j.create( new double[] {1 ,2, 3, 4});
System.out.println("before: " + temp1.data() + " " + temp2.data());
Nd4j.getBlasWrapper().level1().rot(temp1.length(), temp1, temp2, 1, 2);
System.out.println("after: " + temp1.data() + " " + temp2.data());
*/
//before: [1.0,2.0,3.0,4.0] [1.0,2.0,3.0,4.0]
// after: [3.0,6.0,0.0,12.0] [-1.0,-2.0,-3.0,-4.0]
int[] cols = {0, 1, 2, 3};
double[] values = {1, 2, 3, 4};
// commented out on removal of createSparse methods from Nd4j
/* INDArray sparseVec = Nd4j.createSparseCOO(data, indexes, shape);
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
Nd4j.getBlasWrapper().level1().rot(vec.length(), sparseVec, vec, 1, 2);
INDArray expectedSparseVec = Nd4j.createSparseCSR(new double[] {3, 6, 9, 12}, new int[] {0, 1, 2, 3},
new int[] {0}, new int[] {4}, new long[] {1, 4});
INDArray expectedVec = Nd4j.create(new double[] {-1, -2, -3, -4});
assertEquals(getFailureMessage(), expectedSparseVec.data(), sparseVec.data());
assertEquals(getFailureMessage(), expectedVec, vec);
if (expectedSparseVec.isSparse() && sparseVec.isSparse()) {
BaseSparseNDArray vec2 = ((BaseSparseNDArray) expectedSparseVec);
BaseSparseNDArray vecSparse2 = ((BaseSparseNDArray) sparseVec);
assertEquals(getFailureMessage(), vec2.getVectorCoordinates(), vecSparse2);
}
*/
}
@Override
public char ordering() {
return 'c';
}
}

View File

@ -1,65 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.blas;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals;
/**
* @author Audrey Loeffel
*/
@Ignore // temporary ignored
@RunWith(Parameterized.class)
public class SparseCOOLevel2Test extends BaseNd4jTest {
// matrix = [[1, 2], [0, 0]]
private double[] data = {1, 2};
private int[][] indexes = new int[][] {{0, 0}, {0, 1}};
private long[] shape = {2, 2};
public SparseCOOLevel2Test(Nd4jBackend backend) {
super(backend);
}
@Test
public void testGemv() {
// commented out on removal of createSparse methods from Nd4j
/* INDArray array1 = Nd4j.createSparseCOO(data, indexes, shape);
INDArray array2 = Nd4j.linspace(1, 2, 2).reshape(2, 1);
INDArray array3 = array1.mmul(array2); // should be [5, 0]
assertEquals(2, array3.length());
assertEquals(5, array3.getFloat(0), 1e-5);
assertEquals(0, array3.getFloat(1), 1e-5);
*/
}
@Override
public char ordering() {
return 'c';
}
}

View File

@ -1,168 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.blas;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.BaseSparseNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals;
/**
* @author Audrey Loeffel
*/
@Ignore // temporary ignored
@RunWith(Parameterized.class)
public class SparseCSRLevel1Test extends BaseNd4jTest {
private double[] data = {1, 2, 4};
private int[] col = {0, 1, 3};
private int[] pointerB = {0};
private int[] pointerE = {4};
private long[] shape = {1, 4};
public SparseCSRLevel1Test(Nd4jBackend backend) {
super(backend);
}
@Test
public void shouldComputeDot() {
// commented out on removal of createSparse methods from Nd4j
//INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
//INDArray vec = Nd4j.create( new double[] {1 ,2, 3, 4});
INDArray matrix = Nd4j.linspace(1, 4, 4).reshape(1, 4);
INDArray vec = matrix.getRow(0);
//assertEquals(21, Nd4j.getBlasWrapper().dot(sparseVec, vec), 1e-1);
}
@Test
public void shouldComputeNrm2() {
// commented out on removal of createSparse methods from Nd4j
//INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
//assertEquals(Math.sqrt(21), Nd4j.getBlasWrapper().nrm2(sparseVec), 1e-1);
}
@Test
public void shouldComputeAsum() {
// commented out on removal of createSparse methods from Nd4j
//INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
//assertEquals(7, Nd4j.getBlasWrapper().asum(sparseVec), 1e-1);
}
@Test
public void shouldComputeIamax() {
// commented out on removal of createSparse methods from Nd4j
//INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
//assertEquals(2, Nd4j.getBlasWrapper().iamax(sparseVec), 1e-1);
}
@Test
public void shouldComputeIamin() {
// commented out on removal of createSparse methods from Nd4j
//INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
//assertEquals(0, Nd4j.getBlasWrapper().level1().iamin(sparseVec), 1e-1);
}
@Test
public void shouldComputeAxpy() {
// commented out on removal of createSparse methods from Nd4j
/* INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray expected = Nd4j.create(new double[] {2, 4, 3, 8});
Nd4j.getBlasWrapper().level1().axpy(vec.length(), 1, sparseVec, vec);
assertEquals(getFailureMessage(), expected, vec);
*/
}
@Test
public void shouldComputeRot() {
// try with dense vectors to get the expected result
INDArray temp1 = Nd4j.create(new double[] {1, 2, 0, 4});
INDArray temp2 = Nd4j.create(new double[] {1, 2, 3, 4});
System.out.println("before: " + temp1.data() + " " + temp2.data());
Nd4j.getBlasWrapper().level1().rot(temp1.length(), temp1, temp2, 1, 2);
System.out.println("after: " + temp1.data() + " " + temp2.data());
//before: [1.0,2.0,0.0,4.0] [1.0,2.0,3.0,4.0]
// after: [3.0,6.0,6.0,12.0] [-1.0,-2.0,3.0,-4.0]
// commented out on removal of createSparse methods from Nd4j
/* INDArray sparseVec = Nd4j.createSparseCSR(data, col, pointerB, pointerE, shape);
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
Nd4j.getBlasWrapper().level1().rot(vec.length(), sparseVec, vec, 1, 2);
System.out.println(sparseVec.data() + " " + vec.data());
//System.out.println("indexes: " + ((BaseSparseNDArray) sparseVec).getVectorCoordinates().toString());
INDArray expectedSparseVec = Nd4j.createSparseCSR(new double[] {3, 6, 6, 12}, new int[] {0, 1, 2, 3},
new int[] {0}, new int[] {4}, new long[] {1, 4});
INDArray expectedVec = Nd4j.create(new double[] {-1, -2, 3, -4});
assertEquals(getFailureMessage(), expectedSparseVec.data(), sparseVec.data());
assertEquals(getFailureMessage(), expectedVec, vec);
*/
// TODO fix it
}
@Test
public void shouldComputeRotWithFullVector() {
// try with dense vectors to get the expected result
/*
INDArray temp1 = Nd4j.create( new double[] {1 ,2, 3, 4});
INDArray temp2 = Nd4j.create( new double[] {1 ,2, 3, 4});
System.out.println("before: " + temp1.data() + " " + temp2.data());
Nd4j.getBlasWrapper().level1().rot(temp1.length(), temp1, temp2, 1, 2);
System.out.println("after: " + temp1.data() + " " + temp2.data());
*/
//before: [1.0,2.0,3.0,4.0] [1.0,2.0,3.0,4.0]
// after: [3.0,6.0,0.0,12.0] [-1.0,-2.0,-3.0,-4.0]
int[] cols = {0, 1, 2, 3};
double[] values = {1, 2, 3, 4};
// commented out on removal of createSparse methods from Nd4j
/* INDArray sparseVec = Nd4j.createSparseCSR(values, cols, pointerB, pointerE, shape);
INDArray vec = Nd4j.create(new double[] {1, 2, 3, 4});
Nd4j.getBlasWrapper().level1().rot(vec.length(), sparseVec, vec, 1, 2);
INDArray expectedSparseVec = Nd4j.createSparseCSR(new double[] {3, 6, 9, 12}, new int[] {0, 1, 2, 3},
new int[] {0}, new int[] {4}, new long[] {1, 4});
INDArray expectedVec = Nd4j.create(new double[] {-1, -2, -3, -4});
assertEquals(getFailureMessage(), expectedSparseVec.data(), sparseVec.data());
assertEquals(getFailureMessage(), expectedVec, vec);
if (expectedSparseVec.isSparse() && sparseVec.isSparse()) {
BaseSparseNDArray vec2 = ((BaseSparseNDArray) expectedSparseVec);
BaseSparseNDArray vecSparse2 = ((BaseSparseNDArray) sparseVec);
assertEquals(getFailureMessage(), vec2.getVectorCoordinates(), vecSparse2);
}
*/
}
@Override
public char ordering() {
return 'c';
}
}