Testing slice and concat (#8362)
parent
7583ccfa15
commit
52c9918c6f
|
@ -3453,4 +3453,42 @@ public class SameDiffTests extends BaseNd4jTest {
|
|||
assertEquals(outExp, outs);
|
||||
assertEquals(gExp, g);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConcatVariableGrad() {
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
|
||||
SDVariable a = sd.var("a", DataType.FLOAT, 3, 2);
|
||||
SDVariable b = sd.var("b", DataType.FLOAT, 3, 2);
|
||||
INDArray inputArr = Nd4j.rand(3,4);
|
||||
INDArray labelArr = Nd4j.rand(3,4);
|
||||
SDVariable c = sd.concat("concat", 1, a, b);
|
||||
SDVariable loss = sd.math().pow(c.sub(label), 2);
|
||||
sd.setLossVariables(loss);
|
||||
sd.associateArrayWithVariable(labelArr, label);
|
||||
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)), a);
|
||||
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)), b);
|
||||
Map<String, INDArray> map = sd.calculateGradients(null, "a", "b", "concat");
|
||||
INDArray concatArray = Nd4j.hstack(map.get("a"), map.get("b"));
|
||||
assertEquals(concatArray, map.get("concat"));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSliceVariableGrad() {
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
|
||||
SDVariable input = sd.var("input", DataType.FLOAT, 3, 4);
|
||||
INDArray inputArr = Nd4j.rand(3,4);
|
||||
INDArray labelArr = Nd4j.rand(3,4);
|
||||
SDVariable a = input.get(SDIndex.all(), SDIndex.interval(0, 2));
|
||||
SDVariable b = input.get(SDIndex.all(), SDIndex.interval(2, 4));
|
||||
SDVariable c = sd.concat("concat", 1, a, b);
|
||||
SDVariable loss = sd.math().pow(c.sub(label), 2);
|
||||
sd.setLossVariables(loss);
|
||||
sd.associateArrayWithVariable(labelArr, label);
|
||||
sd.associateArrayWithVariable(inputArr, input);
|
||||
Map<String, INDArray> map = sd.calculateGradients(null,"input", "concat");
|
||||
assertEquals(map.get("input"), map.get("concat"));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue