Small ND4J/SameDiff fixes (#248)
* #8218 Fix Nd4j.hstack rank 1 case Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8209 SameDiff: don't allow empty arrays (with 0s in shape) for variables Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
c99f980513
commit
b582e69e3b
|
@ -3367,7 +3367,9 @@ public class SameDiff extends SDBaseOps {
|
||||||
*/
|
*/
|
||||||
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
|
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
|
||||||
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
|
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
|
||||||
|
for(long l : shape){
|
||||||
|
Preconditions.checkArgument(l != 0, "Cannot create variable with a shape that contains zeros (empty array shape) - got shape %s", shape);
|
||||||
|
}
|
||||||
|
|
||||||
if (name == null || name.length() < 1)
|
if (name == null || name.length() < 1)
|
||||||
name = getNewVarName();
|
name = getNewVarName();
|
||||||
|
@ -3582,7 +3584,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
Preconditions.checkState(arr.dataType().isFPType(), "Cannot create variable with non-floating point type:" +
|
Preconditions.checkState(arr.dataType().isFPType(), "Cannot create variable with non-floating point type:" +
|
||||||
" provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\n" +
|
" provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\n" +
|
||||||
"For non floating point types, these should be created as placeholders or constants instead.", arr.dataType());
|
"For non floating point types, these should be created as placeholders or constants instead.", arr.dataType());
|
||||||
|
Preconditions.checkArgument(!arr.isEmpty(), "Empty arrays cannot be used when creating variables. Array shape: %ndShape", arr);
|
||||||
|
|
||||||
if (name == null || name.length() < 1)
|
if (name == null || name.length() < 1)
|
||||||
name = getNewVarName();
|
name = getNewVarName();
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.factory;
|
package org.nd4j.linalg.factory;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.blas.*;
|
import org.nd4j.linalg.api.blas.*;
|
||||||
|
@ -959,8 +960,18 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
||||||
*
|
*
|
||||||
* @param arrs
|
* @param arrs
|
||||||
*/
|
*/
|
||||||
public INDArray hstack(INDArray... arrs) {
|
public INDArray hstack(@NonNull INDArray... arrs) {
|
||||||
return Nd4j.concat(1, arrs);
|
int firstRank = arrs[0].rank();
|
||||||
|
Preconditions.checkState(firstRank > 0 && firstRank <= 2, "Only rank 1 and 2 arrays may be horizontally stacked; first input has rank %ndRank shape %nhShape", arrs[0], arrs[0]);
|
||||||
|
for( int i=1; i<arrs.length; i++ ){
|
||||||
|
Preconditions.checkState(firstRank == arrs[i].rank(), "Array ranks must be equal for horizontal stacking, arrs[0].rank=%s, arrs[%s].rank=%s",
|
||||||
|
arrs[0].rank(), i, arrs[i].rank());
|
||||||
|
}
|
||||||
|
if(firstRank == 1){
|
||||||
|
return Nd4j.concat(0, arrs);
|
||||||
|
} else {
|
||||||
|
return Nd4j.concat(1, arrs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -972,7 +983,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
||||||
@Override
|
@Override
|
||||||
public INDArray vstack(final INDArray... arrs) {
|
public INDArray vstack(final INDArray... arrs) {
|
||||||
return Nd4j.concat(0, arrs);
|
return Nd4j.concat(0, arrs);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3563,4 +3563,26 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
assertEquals(casted.dataType(), DataType.FLOAT);
|
assertEquals(casted.dataType(), DataType.FLOAT);
|
||||||
assertTrue(casted.getShapeDescriptor().isEmpty());
|
assertTrue(casted.getShapeDescriptor().isEmpty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEmptyShapeVar(){
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
try {
|
||||||
|
sd.var(DataType.FLOAT, 1, 0, 2);
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
String m = e.getMessage();
|
||||||
|
assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
sd.var(Nd4j.create(1, 0, 2));
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
String m = e.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0"));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8122,6 +8122,19 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVStackHStack1d() {
|
||||||
|
INDArray rowVector1 = Nd4j.create(new double[]{1,2,3});
|
||||||
|
INDArray rowVector2 = Nd4j.create(new double[]{4,5,6});
|
||||||
|
|
||||||
|
INDArray vStack = Nd4j.vstack(rowVector1, rowVector2); //Vertical stack: [3]+[3] to [2,3]
|
||||||
|
INDArray hStack = Nd4j.hstack(rowVector1, rowVector2); //Horizontal stack: [3]+[3] to [6]
|
||||||
|
|
||||||
|
assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6).reshape('c', 2, 3), vStack);
|
||||||
|
assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6), hStack);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
Loading…
Reference in New Issue