RL4J: Force shape fix (#352)
* fix edge case where input to network needs to have shape > 1 Signed-off-by: Bam4d <chrisbam4d@gmail.com> * adding test for single dimension Signed-off-by: Bam4d <chrisbam4d@gmail.com>
This commit is contained in:
		
							parent
							
								
									48102c61d0
								
							
						
					
					
						commit
						8ac89aeb19
					
				@ -32,7 +32,7 @@ public class INDArrayHelper {
 | 
			
		||||
     * @return The source INDArray with the correct shape
 | 
			
		||||
     */
 | 
			
		||||
    public static INDArray forceCorrectShape(INDArray source) {
 | 
			
		||||
        return source.shape()[0] == 1
 | 
			
		||||
        return source.shape()[0] == 1 && source.shape().length > 1
 | 
			
		||||
                ? source
 | 
			
		||||
                : Nd4j.expandDims(source, 0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -35,4 +35,18 @@ public class INDArrayHelperTest {
 | 
			
		||||
        assertEquals(3, output.shape()[1]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_inputHasOneDimension_expect_outputWithTwoDimensions() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        INDArray input = Nd4j.create(new double[] { 1.0 });
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray output = INDArrayHelper.forceCorrectShape(input);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(2, output.shape().length);
 | 
			
		||||
        assertEquals(1, output.shape()[0]);
 | 
			
		||||
        assertEquals(1, output.shape()[1]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user