cuda build fix for issues introduced by recent refactoring

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-02 11:50:54 +03:00 committed by AlexDBlack
parent dfec54242d
commit e565788329
2 changed files with 3 additions and 21 deletions

View File

@ -167,15 +167,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
return new JCublasNDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
}
@Override
public INDArray createUninitializedDetached(int[] shape, char ordering) {
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null);
INDArray ret = new JCublasNDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
return ret;
}
@Override
public INDArray create(DataBuffer data, int[] newShape, int[] newStride, long offset, char ordering) {
return new JCublasNDArray(data, newShape, newStride, offset, ordering);
@ -1676,12 +1667,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
}
@Override
public INDArray createUninitializedDetached(long[] shape, char ordering) {
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null);
INDArray ret = new JCublasNDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
return ret;
public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape) {
return new JCublasNDArray(Nd4j.createBufferDetached(shape, dataType), shape, Nd4j.getStrides(shape, order), order, dataType);
}
@Override

View File

@ -346,12 +346,7 @@ public class JCusparseNDArrayFactory extends BaseSparseNDArrayFactory{
}
@Override
public INDArray createUninitializedDetached(int[] shape, char ordering) {
return null;
}
@Override
public INDArray createUninitializedDetached(long[] shape, char ordering) {
public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape) {
return null;
}