diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java index bd88683b6..45ca51a39 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java @@ -167,15 +167,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { return new JCublasNDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false); } - @Override - public INDArray createUninitializedDetached(int[] shape, char ordering) { - MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - INDArray ret = new JCublasNDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false); - Nd4j.getMemoryManager().setCurrentWorkspace(workspace); - return ret; - } - @Override public INDArray create(DataBuffer data, int[] newShape, int[] newStride, long offset, char ordering) { return new JCublasNDArray(data, newShape, newStride, offset, ordering); @@ -1676,12 +1667,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { } @Override - public INDArray createUninitializedDetached(long[] shape, char ordering) { - MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - INDArray ret = new JCublasNDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false); - Nd4j.getMemoryManager().setCurrentWorkspace(workspace); - return ret; + public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape) { + return new JCublasNDArray(Nd4j.createBufferDetached(shape, dataType), shape, Nd4j.getStrides(shape, order), order, dataType); } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCusparseNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCusparseNDArrayFactory.java index 6e5776fbd..e4e85d531 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCusparseNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCusparseNDArrayFactory.java @@ -346,12 +346,7 @@ public class JCusparseNDArrayFactory extends BaseSparseNDArrayFactory{ } @Override - public INDArray createUninitializedDetached(int[] shape, char ordering) { - return null; - } - - @Override - public INDArray createUninitializedDetached(long[] shape, char ordering) { + public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape) { return null; }