Testing slice and concat (#8362)
This commit is contained in:
		
							parent
							
								
									7583ccfa15
								
							
						
					
					
						commit
						52c9918c6f
					
				@ -3453,4 +3453,42 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
			
		||||
        assertEquals(outExp, outs);
 | 
			
		||||
        assertEquals(gExp, g);
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    @Test
 | 
			
		||||
	public void testConcatVariableGrad() {
 | 
			
		||||
		SameDiff sd = SameDiff.create();
 | 
			
		||||
		SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
 | 
			
		||||
		SDVariable a = sd.var("a", DataType.FLOAT, 3, 2);
 | 
			
		||||
		SDVariable b = sd.var("b", DataType.FLOAT, 3, 2);
 | 
			
		||||
		INDArray inputArr = Nd4j.rand(3,4);
 | 
			
		||||
		INDArray labelArr =  Nd4j.rand(3,4);
 | 
			
		||||
		SDVariable c = sd.concat("concat", 1, a, b);
 | 
			
		||||
		SDVariable loss = sd.math().pow(c.sub(label), 2);
 | 
			
		||||
		sd.setLossVariables(loss);
 | 
			
		||||
		sd.associateArrayWithVariable(labelArr, label);
 | 
			
		||||
		sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)), a);
 | 
			
		||||
		sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)), b);
 | 
			
		||||
		Map<String, INDArray> map = sd.calculateGradients(null, "a", "b", "concat");
 | 
			
		||||
		INDArray concatArray = Nd4j.hstack(map.get("a"), map.get("b"));
 | 
			
		||||
		assertEquals(concatArray, map.get("concat"));
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void testSliceVariableGrad() {
 | 
			
		||||
		SameDiff sd = SameDiff.create();
 | 
			
		||||
		SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
 | 
			
		||||
		SDVariable input = sd.var("input", DataType.FLOAT, 3, 4);
 | 
			
		||||
		INDArray inputArr =  Nd4j.rand(3,4);
 | 
			
		||||
		INDArray labelArr =  Nd4j.rand(3,4);
 | 
			
		||||
		SDVariable a = input.get(SDIndex.all(), SDIndex.interval(0, 2));
 | 
			
		||||
		SDVariable b = input.get(SDIndex.all(), SDIndex.interval(2, 4));
 | 
			
		||||
		SDVariable c = sd.concat("concat", 1, a, b);
 | 
			
		||||
		SDVariable loss = sd.math().pow(c.sub(label), 2);
 | 
			
		||||
		sd.setLossVariables(loss);
 | 
			
		||||
		sd.associateArrayWithVariable(labelArr, label);
 | 
			
		||||
		sd.associateArrayWithVariable(inputArr, input);
 | 
			
		||||
		Map<String, INDArray> map = sd.calculateGradients(null,"input", "concat");
 | 
			
		||||
		assertEquals(map.get("input"), map.get("concat"));
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user