Testing slice and concat (#8362)

master
longzhendong 2019-11-07 11:47:37 +08:00 committed by Alex Black
parent 7583ccfa15
commit 52c9918c6f
1 changed files with 38 additions and 0 deletions

View File

@ -3453,4 +3453,42 @@ public class SameDiffTests extends BaseNd4jTest {
assertEquals(outExp, outs);
assertEquals(gExp, g);
}
@Test
public void testConcatVariableGrad() {
SameDiff sd = SameDiff.create();
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
SDVariable a = sd.var("a", DataType.FLOAT, 3, 2);
SDVariable b = sd.var("b", DataType.FLOAT, 3, 2);
INDArray inputArr = Nd4j.rand(3,4);
INDArray labelArr = Nd4j.rand(3,4);
SDVariable c = sd.concat("concat", 1, a, b);
SDVariable loss = sd.math().pow(c.sub(label), 2);
sd.setLossVariables(loss);
sd.associateArrayWithVariable(labelArr, label);
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)), a);
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)), b);
Map<String, INDArray> map = sd.calculateGradients(null, "a", "b", "concat");
INDArray concatArray = Nd4j.hstack(map.get("a"), map.get("b"));
assertEquals(concatArray, map.get("concat"));
}
@Test
public void testSliceVariableGrad() {
SameDiff sd = SameDiff.create();
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
SDVariable input = sd.var("input", DataType.FLOAT, 3, 4);
INDArray inputArr = Nd4j.rand(3,4);
INDArray labelArr = Nd4j.rand(3,4);
SDVariable a = input.get(SDIndex.all(), SDIndex.interval(0, 2));
SDVariable b = input.get(SDIndex.all(), SDIndex.interval(2, 4));
SDVariable c = sd.concat("concat", 1, a, b);
SDVariable loss = sd.math().pow(c.sub(label), 2);
sd.setLossVariables(loss);
sd.associateArrayWithVariable(labelArr, label);
sd.associateArrayWithVariable(inputArr, input);
Map<String, INDArray> map = sd.calculateGradients(null,"input", "concat");
assertEquals(map.get("input"), map.get("concat"));
}
}