Small ND4J/SameDiff fixes (#248)
* #8218 Fix Nd4j.hstack rank 1 case Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8209 SameDiff: don't allow empty arrays (with 0s in shape) for variables Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
c99f980513
commit
b582e69e3b
|
@ -3367,7 +3367,9 @@ public class SameDiff extends SDBaseOps {
|
|||
*/
|
||||
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
|
||||
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
|
||||
|
||||
for(long l : shape){
|
||||
Preconditions.checkArgument(l != 0, "Cannot create variable with a shape that contains zeros (empty array shape) - got shape %s", shape);
|
||||
}
|
||||
|
||||
if (name == null || name.length() < 1)
|
||||
name = getNewVarName();
|
||||
|
@ -3582,7 +3584,7 @@ public class SameDiff extends SDBaseOps {
|
|||
Preconditions.checkState(arr.dataType().isFPType(), "Cannot create variable with non-floating point type:" +
|
||||
" provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\n" +
|
||||
"For non floating point types, these should be created as placeholders or constants instead.", arr.dataType());
|
||||
|
||||
Preconditions.checkArgument(!arr.isEmpty(), "Empty arrays cannot be used when creating variables. Array shape: %ndShape", arr);
|
||||
|
||||
if (name == null || name.length() < 1)
|
||||
name = getNewVarName();
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.linalg.factory;
|
||||
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.blas.*;
|
||||
|
@ -959,8 +960,18 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
|||
*
|
||||
* @param arrs
|
||||
*/
|
||||
public INDArray hstack(INDArray... arrs) {
|
||||
return Nd4j.concat(1, arrs);
|
||||
public INDArray hstack(@NonNull INDArray... arrs) {
|
||||
int firstRank = arrs[0].rank();
|
||||
Preconditions.checkState(firstRank > 0 && firstRank <= 2, "Only rank 1 and 2 arrays may be horizontally stacked; first input has rank %ndRank shape %nhShape", arrs[0], arrs[0]);
|
||||
for( int i=1; i<arrs.length; i++ ){
|
||||
Preconditions.checkState(firstRank == arrs[i].rank(), "Array ranks must be equal for horizontal stacking, arrs[0].rank=%s, arrs[%s].rank=%s",
|
||||
arrs[0].rank(), i, arrs[i].rank());
|
||||
}
|
||||
if(firstRank == 1){
|
||||
return Nd4j.concat(0, arrs);
|
||||
} else {
|
||||
return Nd4j.concat(1, arrs);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -972,7 +983,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
|||
@Override
|
||||
public INDArray vstack(final INDArray... arrs) {
|
||||
return Nd4j.concat(0, arrs);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -3563,4 +3563,26 @@ public class SameDiffTests extends BaseNd4jTest {
|
|||
assertEquals(casted.dataType(), DataType.FLOAT);
|
||||
assertTrue(casted.getShapeDescriptor().isEmpty());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEmptyShapeVar(){
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
try {
|
||||
sd.var(DataType.FLOAT, 1, 0, 2);
|
||||
fail("Expected exception");
|
||||
} catch (IllegalArgumentException e){
|
||||
String m = e.getMessage();
|
||||
assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0"));
|
||||
}
|
||||
|
||||
try {
|
||||
sd.var(Nd4j.create(1, 0, 2));
|
||||
fail("Expected exception");
|
||||
} catch (IllegalArgumentException e){
|
||||
String m = e.getMessage().toLowerCase();
|
||||
assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8122,6 +8122,19 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testVStackHStack1d() {
|
||||
INDArray rowVector1 = Nd4j.create(new double[]{1,2,3});
|
||||
INDArray rowVector2 = Nd4j.create(new double[]{4,5,6});
|
||||
|
||||
INDArray vStack = Nd4j.vstack(rowVector1, rowVector2); //Vertical stack: [3]+[3] to [2,3]
|
||||
INDArray hStack = Nd4j.hstack(rowVector1, rowVector2); //Horizontal stack: [3]+[3] to [6]
|
||||
|
||||
assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6).reshape('c', 2, 3), vStack);
|
||||
assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6), hStack);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue