Small ND4J/SameDiff fixes (#248)

* #8218 Fix Nd4j.hstack rank 1 case

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8209 SameDiff: don't allow empty arrays (with 0s in shape) for variables

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-09 22:54:07 +10:00 committed by GitHub
parent c99f980513
commit b582e69e3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 5 deletions

View File

@ -3367,7 +3367,9 @@ public class SameDiff extends SDBaseOps {
*/
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
for(long l : shape){
Preconditions.checkArgument(l != 0, "Cannot create variable with a shape that contains zeros (empty array shape) - got shape %s", shape);
}
if (name == null || name.length() < 1)
name = getNewVarName();
@ -3582,7 +3584,7 @@ public class SameDiff extends SDBaseOps {
Preconditions.checkState(arr.dataType().isFPType(), "Cannot create variable with non-floating point type:" +
" provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\n" +
"For non floating point types, these should be created as placeholders or constants instead.", arr.dataType());
Preconditions.checkArgument(!arr.isEmpty(), "Empty arrays cannot be used when creating variables. Array shape: %ndShape", arr);
if (name == null || name.length() < 1)
name = getNewVarName();

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.factory;
import lombok.NonNull;
import lombok.val;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.blas.*;
@ -959,8 +960,18 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
*
* @param arrs
*/
public INDArray hstack(INDArray... arrs) {
return Nd4j.concat(1, arrs);
public INDArray hstack(@NonNull INDArray... arrs) {
int firstRank = arrs[0].rank();
Preconditions.checkState(firstRank > 0 && firstRank <= 2, "Only rank 1 and 2 arrays may be horizontally stacked; first input has rank %ndRank shape %nhShape", arrs[0], arrs[0]);
for( int i=1; i<arrs.length; i++ ){
Preconditions.checkState(firstRank == arrs[i].rank(), "Array ranks must be equal for horizontal stacking, arrs[0].rank=%s, arrs[%s].rank=%s",
arrs[0].rank(), i, arrs[i].rank());
}
if(firstRank == 1){
return Nd4j.concat(0, arrs);
} else {
return Nd4j.concat(1, arrs);
}
}
/**
@ -972,7 +983,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
@Override
public INDArray vstack(final INDArray... arrs) {
return Nd4j.concat(0, arrs);
}

View File

@ -3563,4 +3563,26 @@ public class SameDiffTests extends BaseNd4jTest {
assertEquals(casted.dataType(), DataType.FLOAT);
assertTrue(casted.getShapeDescriptor().isEmpty());
}
@Test
public void testEmptyShapeVar(){
SameDiff sd = SameDiff.create();
try {
sd.var(DataType.FLOAT, 1, 0, 2);
fail("Expected exception");
} catch (IllegalArgumentException e){
String m = e.getMessage();
assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0"));
}
try {
sd.var(Nd4j.create(1, 0, 2));
fail("Expected exception");
} catch (IllegalArgumentException e){
String m = e.getMessage().toLowerCase();
assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0"));
}
}
}

View File

@ -8122,6 +8122,19 @@ public class Nd4jTestsC extends BaseNd4jTest {
return ret;
}
@Test
public void testVStackHStack1d() {
INDArray rowVector1 = Nd4j.create(new double[]{1,2,3});
INDArray rowVector2 = Nd4j.create(new double[]{4,5,6});
INDArray vStack = Nd4j.vstack(rowVector1, rowVector2); //Vertical stack: [3]+[3] to [2,3]
INDArray hStack = Nd4j.hstack(rowVector1, rowVector2); //Horizontal stack: [3]+[3] to [6]
assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6).reshape('c', 2, 3), vStack);
assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6), hStack);
}
@Override
public char ordering() {
return 'c';