Empty array casting fix (#457)
* Empty array casting fix Signed-off-by: Alex Black <blacka101@gmail.com> * Tests Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
bf017b458c
commit
5e55e92002
|
@ -5521,7 +5521,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
public INDArray castTo(DataType dataType) {
|
public INDArray castTo(DataType dataType) {
|
||||||
if(dataType == dataType()) //No-op if correct datatype
|
if(dataType == dataType()) //No-op if correct datatype
|
||||||
return this;
|
return this;
|
||||||
if(isEmpty()){
|
if(isEmpty() && rank() == 0){
|
||||||
return Nd4j.empty(dataType);
|
return Nd4j.empty(dataType);
|
||||||
}
|
}
|
||||||
val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering());
|
val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering());
|
||||||
|
|
|
@ -8414,6 +8414,26 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testShape0Casts(){
|
||||||
|
for(DataType dt : DataType.values()){
|
||||||
|
if(!dt.isNumerical())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
INDArray a1 = Nd4j.create(dt, 1,0,2);
|
||||||
|
|
||||||
|
for(DataType dt2 : DataType.values()){
|
||||||
|
if(!dt2.isNumerical())
|
||||||
|
continue;
|
||||||
|
INDArray a2 = a1.castTo(dt2);
|
||||||
|
|
||||||
|
assertArrayEquals(a1.shape(), a2.shape());
|
||||||
|
assertEquals(dt2, a2.dataType());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
Loading…
Reference in New Issue