Small ND4J/SameDiff fixes (#248)
* #8218 Fix Nd4j.hstack rank 1 case Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8209 SameDiff: don't allow empty arrays (with 0s in shape) for variables Signed-off-by: AlexDBlack <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									c99f980513
								
							
						
					
					
						commit
						b582e69e3b
					
				| @ -3367,7 +3367,9 @@ public class SameDiff extends SDBaseOps { | ||||
|      */ | ||||
|     public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, | ||||
|                           org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { | ||||
| 
 | ||||
|         for(long l : shape){ | ||||
|             Preconditions.checkArgument(l != 0, "Cannot create variable with a shape that contains zeros (empty array shape) - got shape %s", shape); | ||||
|         } | ||||
| 
 | ||||
|         if (name == null || name.length() < 1) | ||||
|             name = getNewVarName(); | ||||
| @ -3582,7 +3584,7 @@ public class SameDiff extends SDBaseOps { | ||||
|         Preconditions.checkState(arr.dataType().isFPType(), "Cannot create variable with non-floating point type:" + | ||||
|                 " provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\n" + | ||||
|                 "For non floating point types, these should be created as placeholders or constants instead.", arr.dataType()); | ||||
| 
 | ||||
|         Preconditions.checkArgument(!arr.isEmpty(), "Empty arrays cannot be used when creating variables. Array shape: %ndShape", arr); | ||||
| 
 | ||||
|         if (name == null || name.length() < 1) | ||||
|             name = getNewVarName(); | ||||
|  | ||||
| @ -17,6 +17,7 @@ | ||||
| package org.nd4j.linalg.factory; | ||||
| 
 | ||||
| 
 | ||||
| import lombok.NonNull; | ||||
| import lombok.val; | ||||
| import org.nd4j.base.Preconditions; | ||||
| import org.nd4j.linalg.api.blas.*; | ||||
| @ -959,8 +960,18 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { | ||||
|      * | ||||
|      * @param arrs | ||||
|      */ | ||||
|     public INDArray hstack(INDArray... arrs) { | ||||
|         return Nd4j.concat(1, arrs); | ||||
|     public INDArray hstack(@NonNull INDArray... arrs) { | ||||
|         int firstRank = arrs[0].rank(); | ||||
|         Preconditions.checkState(firstRank > 0 && firstRank <= 2, "Only rank 1 and 2 arrays may be horizontally stacked; first input has rank %ndRank shape %nhShape", arrs[0], arrs[0]); | ||||
|         for( int i=1; i<arrs.length; i++ ){ | ||||
|             Preconditions.checkState(firstRank == arrs[i].rank(), "Array ranks must be equal for horizontal stacking, arrs[0].rank=%s, arrs[%s].rank=%s", | ||||
|                     arrs[0].rank(), i, arrs[i].rank()); | ||||
|         } | ||||
|         if(firstRank == 1){ | ||||
|             return Nd4j.concat(0, arrs); | ||||
|         } else { | ||||
|             return Nd4j.concat(1, arrs); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -972,7 +983,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { | ||||
|     @Override | ||||
|     public INDArray vstack(final INDArray... arrs) { | ||||
|         return Nd4j.concat(0, arrs); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -3563,4 +3563,26 @@ public class SameDiffTests extends BaseNd4jTest { | ||||
|         assertEquals(casted.dataType(), DataType.FLOAT); | ||||
|         assertTrue(casted.getShapeDescriptor().isEmpty()); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testEmptyShapeVar(){ | ||||
|         SameDiff sd = SameDiff.create(); | ||||
| 
 | ||||
|         try { | ||||
|             sd.var(DataType.FLOAT, 1, 0, 2); | ||||
|             fail("Expected exception"); | ||||
|         } catch (IllegalArgumentException e){ | ||||
|             String m = e.getMessage(); | ||||
|             assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0")); | ||||
|         } | ||||
| 
 | ||||
|         try { | ||||
|             sd.var(Nd4j.create(1, 0, 2)); | ||||
|             fail("Expected exception"); | ||||
|         } catch (IllegalArgumentException e){ | ||||
|             String m = e.getMessage().toLowerCase(); | ||||
|             assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0")); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -8122,6 +8122,19 @@ public class Nd4jTestsC extends BaseNd4jTest { | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testVStackHStack1d() { | ||||
|         INDArray rowVector1 = Nd4j.create(new double[]{1,2,3}); | ||||
|         INDArray rowVector2 = Nd4j.create(new double[]{4,5,6}); | ||||
| 
 | ||||
|         INDArray vStack = Nd4j.vstack(rowVector1, rowVector2);      //Vertical stack:   [3]+[3] to [2,3] | ||||
|         INDArray hStack = Nd4j.hstack(rowVector1, rowVector2);      //Horizontal stack: [3]+[3] to [6] | ||||
| 
 | ||||
|         assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6).reshape('c', 2, 3), vStack); | ||||
|         assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6), hStack); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public char ordering() { | ||||
|         return 'c'; | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user