Empty array casting fix (#457)
* Empty array casting fix Signed-off-by: Alex Black <blacka101@gmail.com> * Tests Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
bf017b458c
commit
5e55e92002
|
@ -5521,7 +5521,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
public INDArray castTo(DataType dataType) {
|
||||
if(dataType == dataType()) //No-op if correct datatype
|
||||
return this;
|
||||
if(isEmpty()){
|
||||
if(isEmpty() && rank() == 0){
|
||||
return Nd4j.empty(dataType);
|
||||
}
|
||||
val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering());
|
||||
|
|
|
@ -8414,6 +8414,26 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testShape0Casts(){
|
||||
for(DataType dt : DataType.values()){
|
||||
if(!dt.isNumerical())
|
||||
continue;
|
||||
|
||||
INDArray a1 = Nd4j.create(dt, 1,0,2);
|
||||
|
||||
for(DataType dt2 : DataType.values()){
|
||||
if(!dt2.isNumerical())
|
||||
continue;
|
||||
INDArray a2 = a1.castTo(dt2);
|
||||
|
||||
assertArrayEquals(a1.shape(), a2.shape());
|
||||
assertEquals(dt2, a2.dataType());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue