diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java index 7d93b1175..2e608db19 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java @@ -32,7 +32,7 @@ public class INDArrayHelper { * @return The source INDArray with the correct shape */ public static INDArray forceCorrectShape(INDArray source) { - return source.shape()[0] == 1 + return source.shape()[0] == 1 && source.shape().length > 1 ? source : Nd4j.expandDims(source, 0); } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java index 9bfceadad..e1c5c64ed 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java @@ -35,4 +35,18 @@ public class INDArrayHelperTest { assertEquals(3, output.shape()[1]); } + @Test + public void when_inputHasOneDimension_expect_outputWithTwoDimensions() { + // Arrange + INDArray input = Nd4j.create(new double[] { 1.0 }); + + // Act + INDArray output = INDArrayHelper.forceCorrectShape(input); + + // Assert + assertEquals(2, output.shape().length); + assertEquals(1, output.shape()[0]); + assertEquals(1, output.shape()[1]); + } + }