From 81efa5c3b6e8bfebd9a787fdd47ae8445b49756a Mon Sep 17 00:00:00 2001 From: raver119 Date: Sun, 2 Feb 2020 19:17:26 +0300 Subject: [PATCH] [WIP] one small fix (#207) * one small fix Signed-off-by: raver119 * assert added Signed-off-by: raver119 --- .../linalg/jcublas/buffer/BaseCudaDataBuffer.java | 12 ++++++++++++ .../src/main/java/org/nd4j/nativeblas/Nd4jCuda.java | 1 - .../src/main/java/org/nd4j/nativeblas/Nd4jCpu.java | 1 - .../test/java/org/nd4j/linalg/shape/EmptyTests.java | 6 ++++++ 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 07614a0ad..02b857f7f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -1050,21 +1050,33 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override public void setData(int[] data) { + if (data.length == 0) + return; + set(data, data.length, 0, 0); } @Override public void setData(long[] data) { + if (data.length == 0) + return; + set(data, data.length, 0, 0); } @Override public void setData(float[] data) { + if (data.length == 0) + return; + set(data, data.length, 0, 0); } @Override public void setData(double[] data) { + if (data.length == 0) + return; + set(data, data.length, 0, 0); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 8d0029bc3..f85ae9cf1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3804,7 +3804,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype); - /** * This method returns new array with the same shape & data type * @return diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 93fbb71d7..5522141be 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3807,7 +3807,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype); - /** * This method returns new array with the same shape & data type * @return diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java index 3bef69c19..aa81097d1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java @@ -320,6 +320,12 @@ public class EmptyTests extends BaseNd4jTest { Nd4j.exec(op); } + @Test + public void testEmptyConstructor_1() { + val x = Nd4j.create(new double[0]); + assertTrue(x.isEmpty()); + } + @Override public char ordering() { return 'c';