From 8ac89aeb190c55983835f0293161c4c04a15209f Mon Sep 17 00:00:00 2001 From: Chris Bamford Date: Wed, 1 Apr 2020 06:28:01 +0100 Subject: [PATCH] RL4J: Force shape fix (#352) * fix edge case where input to network needs to have shape > 1 Signed-off-by: Bam4d * adding test for single dimension Signed-off-by: Bam4d --- .../deeplearning4j/rl4j/helper/INDArrayHelper.java | 2 +- .../rl4j/helper/INDArrayHelperTest.java | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java index 7d93b1175..2e608db19 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java @@ -32,7 +32,7 @@ public class INDArrayHelper { * @return The source INDArray with the correct shape */ public static INDArray forceCorrectShape(INDArray source) { - return source.shape()[0] == 1 + return source.shape()[0] == 1 && source.shape().length > 1 ? source : Nd4j.expandDims(source, 0); } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java index 9bfceadad..e1c5c64ed 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java @@ -35,4 +35,18 @@ public class INDArrayHelperTest { assertEquals(3, output.shape()[1]); } + @Test + public void when_inputHasOneDimension_expect_outputWithTwoDimensions() { + // Arrange + INDArray input = Nd4j.create(new double[] { 1.0 }); + + // Act + INDArray output = INDArrayHelper.forceCorrectShape(input); + + // Assert + assertEquals(2, output.shape().length); + assertEquals(1, output.shape()[0]); + assertEquals(1, output.shape()[1]); + } + }