Fixes Keras import issue (8373) (#54)
* test * fix * rem prn * add test
This commit is contained in:
		
							parent
							
								
									66b84b38cf
								
							
						
					
					
						commit
						630409cd53
					
				| @ -43,13 +43,14 @@ import static org.nd4j.linalg.util.ArrayUtil.prodLong; | ||||
| @Data | ||||
| @Slf4j | ||||
| @EqualsAndHashCode(callSuper = false) | ||||
| @JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize"}) | ||||
| @JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"}) | ||||
| public class ReshapePreprocessor extends BaseInputPreProcessor { | ||||
| 
 | ||||
|     private long[] inputShape; | ||||
|     private long[] targetShape; | ||||
|     private boolean hasMiniBatchDimension = false; | ||||
|     private int miniBatchSize; | ||||
|     private long[] staticTargetShape; | ||||
| 
 | ||||
|     public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) { | ||||
|         this.inputShape = inputShape; | ||||
| @ -80,6 +81,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { | ||||
|     public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { | ||||
|         // the target shape read from a keras config does not have mini-batch size | ||||
|         // included. We prepend it here dynamically. | ||||
| 
 | ||||
|         long[] targetShape; | ||||
|         if (staticTargetShape != null){ | ||||
|             targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize); | ||||
|             hasMiniBatchDimension = true; | ||||
|             this.miniBatchSize = miniBatchSize; | ||||
|         } | ||||
|         else{ | ||||
|             targetShape = this.targetShape; | ||||
|         } | ||||
|         if (!this.hasMiniBatchDimension) { | ||||
|             targetShape = prependMiniBatchSize(targetShape, miniBatchSize); | ||||
|             inputShape = prependMiniBatchSize(inputShape, miniBatchSize); | ||||
| @ -95,7 +106,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { | ||||
|             if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){ | ||||
|                 input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); | ||||
|             } | ||||
|             return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(this.targetShape)); | ||||
|             return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape)); | ||||
|         } else { | ||||
|             throw new IllegalStateException("Input shape " + Arrays.toString(input.shape()) | ||||
|                     + " and output shape" + Arrays.toString(inputShape) + " do not match"); | ||||
| @ -122,20 +133,27 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { | ||||
|     @Override | ||||
|     public InputType getOutputType(InputType inputType) throws InvalidInputTypeException { | ||||
|         val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0); | ||||
|         InputType ret; | ||||
|         switch (shape.length) { | ||||
|             case 2: | ||||
|                 return InputType.feedForward(shape[1]); | ||||
|                 ret = InputType.feedForward(shape[1]); | ||||
|                 break; | ||||
|             case 3: | ||||
|                 return InputType.recurrent(shape[2], shape[1]); | ||||
|                 ret = InputType.recurrent(shape[2], shape[1]); | ||||
|                 break; | ||||
|             case 4: | ||||
|                 if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){ | ||||
|                     return InputType.convolutional(shape[1], shape[2], shape[3]); | ||||
|                     ret = InputType.convolutional(shape[1], shape[2], shape[3]); | ||||
|                 }else { | ||||
|                     return InputType.convolutional(shape[2], shape[3], shape[1]); | ||||
|                     ret = InputType.convolutional(shape[2], shape[3], shape[1]); | ||||
|                 } | ||||
|                 break; | ||||
|             default: | ||||
|                 throw new UnsupportedOperationException( | ||||
|                         "Cannot infer input type for reshape array " + Arrays.toString(shape)); | ||||
| 
 | ||||
|         } | ||||
|         this.staticTargetShape = ret.getShape(); | ||||
|         return ret; | ||||
|     } | ||||
| } | ||||
| @ -255,6 +255,17 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void ReshapeEmbeddingConcatTest() throws Exception{ | ||||
|         try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) { | ||||
|             ComputationGraphConfiguration config = | ||||
|                     new KerasModel().modelBuilder().modelJsonInputStream(is) | ||||
|                             .enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration(); | ||||
|             ComputationGraph model = new ComputationGraph(config); | ||||
|             model.init(); | ||||
|             model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1)); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     private void runSequentialConfigTest(String path) throws Exception { | ||||
|         try(InputStream is = Resources.asStream(path)) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user