diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 69c69388b..7ee7b5bc2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -3453,4 +3453,42 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(outExp, outs); assertEquals(gExp, g); } + + @Test + public void testConcatVariableGrad() { + SameDiff sd = SameDiff.create(); + SDVariable label = sd.var("label", DataType.FLOAT, 3, 4); + SDVariable a = sd.var("a", DataType.FLOAT, 3, 2); + SDVariable b = sd.var("b", DataType.FLOAT, 3, 2); + INDArray inputArr = Nd4j.rand(3,4); + INDArray labelArr = Nd4j.rand(3,4); + SDVariable c = sd.concat("concat", 1, a, b); + SDVariable loss = sd.math().pow(c.sub(label), 2); + sd.setLossVariables(loss); + sd.associateArrayWithVariable(labelArr, label); + sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)), a); + sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)), b); + Map map = sd.calculateGradients(null, "a", "b", "concat"); + INDArray concatArray = Nd4j.hstack(map.get("a"), map.get("b")); + assertEquals(concatArray, map.get("concat")); + + } + + @Test + public void testSliceVariableGrad() { + SameDiff sd = SameDiff.create(); + SDVariable label = sd.var("label", DataType.FLOAT, 3, 4); + SDVariable input = sd.var("input", DataType.FLOAT, 3, 4); + INDArray inputArr = Nd4j.rand(3,4); + INDArray labelArr = Nd4j.rand(3,4); + SDVariable a = input.get(SDIndex.all(), SDIndex.interval(0, 2)); + SDVariable b = input.get(SDIndex.all(), SDIndex.interval(2, 4)); + SDVariable c = sd.concat("concat", 1, a, b); + SDVariable loss = sd.math().pow(c.sub(label), 2); + sd.setLossVariables(loss); + sd.associateArrayWithVariable(labelArr, label); + sd.associateArrayWithVariable(inputArr, input); + Map map = sd.calculateGradients(null,"input", "concat"); + assertEquals(map.get("input"), map.get("concat")); + } }