RL4J: Force shape fix (#352)

* fix edge case where input to network needs to have shape > 1

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* adding test for single dimension

Signed-off-by: Bam4d <chrisbam4d@gmail.com>
master
Chris Bamford 2020-04-01 06:28:01 +01:00 committed by GitHub
parent 48102c61d0
commit 8ac89aeb19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 1 deletions

View File

@ -32,7 +32,7 @@ public class INDArrayHelper {
* @return The source INDArray with the correct shape * @return The source INDArray with the correct shape
*/ */
public static INDArray forceCorrectShape(INDArray source) { public static INDArray forceCorrectShape(INDArray source) {
return source.shape()[0] == 1 return source.shape()[0] == 1 && source.shape().length > 1
? source ? source
: Nd4j.expandDims(source, 0); : Nd4j.expandDims(source, 0);
} }

View File

@ -35,4 +35,18 @@ public class INDArrayHelperTest {
assertEquals(3, output.shape()[1]); assertEquals(3, output.shape()[1]);
} }
@Test
public void when_inputHasOneDimension_expect_outputWithTwoDimensions() {
// Arrange
INDArray input = Nd4j.create(new double[] { 1.0 });
// Act
INDArray output = INDArrayHelper.forceCorrectShape(input);
// Assert
assertEquals(2, output.shape().length);
assertEquals(1, output.shape()[0]);
assertEquals(1, output.shape()[1]);
}
} }