Empty array casting fix (#457)

* Empty array casting fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Tests

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-13 01:37:11 +10:00 committed by GitHub
parent bf017b458c
commit 5e55e92002
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 1 deletions

View File

@ -5521,7 +5521,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
public INDArray castTo(DataType dataType) {
if(dataType == dataType()) //No-op if correct datatype
return this;
if(isEmpty()){
if(isEmpty() && rank() == 0){
return Nd4j.empty(dataType);
}
val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering());

View File

@ -8414,6 +8414,26 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testShape0Casts(){
for(DataType dt : DataType.values()){
if(!dt.isNumerical())
continue;
INDArray a1 = Nd4j.create(dt, 1,0,2);
for(DataType dt2 : DataType.values()){
if(!dt2.isNumerical())
continue;
INDArray a2 = a1.castTo(dt2);
assertArrayEquals(a1.shape(), a2.shape());
assertEquals(dt2, a2.dataType());
}
}
}
@Override
public char ordering() {
return 'c';