RL4J: Force shape fix (#352)
* fix edge case where input to network needs to have shape > 1 Signed-off-by: Bam4d <chrisbam4d@gmail.com> * adding test for single dimension Signed-off-by: Bam4d <chrisbam4d@gmail.com>master
parent
48102c61d0
commit
8ac89aeb19
|
@ -32,7 +32,7 @@ public class INDArrayHelper {
|
||||||
* @return The source INDArray with the correct shape
|
* @return The source INDArray with the correct shape
|
||||||
*/
|
*/
|
||||||
public static INDArray forceCorrectShape(INDArray source) {
|
public static INDArray forceCorrectShape(INDArray source) {
|
||||||
return source.shape()[0] == 1
|
return source.shape()[0] == 1 && source.shape().length > 1
|
||||||
? source
|
? source
|
||||||
: Nd4j.expandDims(source, 0);
|
: Nd4j.expandDims(source, 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,4 +35,18 @@ public class INDArrayHelperTest {
|
||||||
assertEquals(3, output.shape()[1]);
|
assertEquals(3, output.shape()[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_inputHasOneDimension_expect_outputWithTwoDimensions() {
|
||||||
|
// Arrange
|
||||||
|
INDArray input = Nd4j.create(new double[] { 1.0 });
|
||||||
|
|
||||||
|
// Act
|
||||||
|
INDArray output = INDArrayHelper.forceCorrectShape(input);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(2, output.shape().length);
|
||||||
|
assertEquals(1, output.shape()[0]);
|
||||||
|
assertEquals(1, output.shape()[1]);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue