RL4J: Force shape fix (#352)
* fix edge case where input to network needs to have shape > 1 Signed-off-by: Bam4d <chrisbam4d@gmail.com> * adding test for single dimension Signed-off-by: Bam4d <chrisbam4d@gmail.com>master
parent
48102c61d0
commit
8ac89aeb19
|
@ -32,7 +32,7 @@ public class INDArrayHelper {
|
|||
* @return The source INDArray with the correct shape
|
||||
*/
|
||||
public static INDArray forceCorrectShape(INDArray source) {
|
||||
return source.shape()[0] == 1
|
||||
return source.shape()[0] == 1 && source.shape().length > 1
|
||||
? source
|
||||
: Nd4j.expandDims(source, 0);
|
||||
}
|
||||
|
|
|
@ -35,4 +35,18 @@ public class INDArrayHelperTest {
|
|||
assertEquals(3, output.shape()[1]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_inputHasOneDimension_expect_outputWithTwoDimensions() {
|
||||
// Arrange
|
||||
INDArray input = Nd4j.create(new double[] { 1.0 });
|
||||
|
||||
// Act
|
||||
INDArray output = INDArrayHelper.forceCorrectShape(input);
|
||||
|
||||
// Assert
|
||||
assertEquals(2, output.shape().length);
|
||||
assertEquals(1, output.shape()[0]);
|
||||
assertEquals(1, output.shape()[1]);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue