diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 835a2f4cb..052251734 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -5521,7 +5521,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { public INDArray castTo(DataType dataType) { if(dataType == dataType()) //No-op if correct datatype return this; - if(isEmpty()){ + if(isEmpty() && rank() == 0){ return Nd4j.empty(dataType); } val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index da8983118..46f47017e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8414,6 +8414,26 @@ public class Nd4jTestsC extends BaseNd4jTest { } } + + @Test + public void testShape0Casts(){ + for(DataType dt : DataType.values()){ + if(!dt.isNumerical()) + continue; + + INDArray a1 = Nd4j.create(dt, 1,0,2); + + for(DataType dt2 : DataType.values()){ + if(!dt2.isNumerical()) + continue; + INDArray a2 = a1.castTo(dt2); + + assertArrayEquals(a1.shape(), a2.shape()); + assertEquals(dt2, a2.dataType()); + } + } + } + @Override public char ordering() { return 'c';