diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 1821a30a0..955677ca8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -3367,7 +3367,9 @@ public class SameDiff extends SDBaseOps { */ public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { - + for(long l : shape){ + Preconditions.checkArgument(l != 0, "Cannot create variable with a shape that contains zeros (empty array shape) - got shape %s", shape); + } if (name == null || name.length() < 1) name = getNewVarName(); @@ -3582,7 +3584,7 @@ public class SameDiff extends SDBaseOps { Preconditions.checkState(arr.dataType().isFPType(), "Cannot create variable with non-floating point type:" + " provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\n" + "For non floating point types, these should be created as placeholders or constants instead.", arr.dataType()); - + Preconditions.checkArgument(!arr.isEmpty(), "Empty arrays cannot be used when creating variables. Array shape: %ndShape", arr); if (name == null || name.length() < 1) name = getNewVarName(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java index 1edf0d651..a664d9ee5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.factory; +import lombok.NonNull; import lombok.val; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.blas.*; @@ -959,8 +960,18 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { * * @param arrs */ - public INDArray hstack(INDArray... arrs) { - return Nd4j.concat(1, arrs); + public INDArray hstack(@NonNull INDArray... arrs) { + int firstRank = arrs[0].rank(); + Preconditions.checkState(firstRank > 0 && firstRank <= 2, "Only rank 1 and 2 arrays may be horizontally stacked; first input has rank %ndRank shape %nhShape", arrs[0], arrs[0]); + for( int i=1; i