Keras import: Add support for sparse cross entropy loss function (#73)
* #6377 Keras sparse cross entropy loss import support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix small bug in reshape preprocessor Signed-off-by: AlexDBlack <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									6f514e9431
								
							
						
					
					
						commit
						962575dd62
					
				| @ -75,7 +75,6 @@ public class KerasReshape extends KerasLayer { | ||||
|             List<Integer> targetShapeList = (List<Integer>) innerConfig.get(targetShape); | ||||
|             this.targetShape = listToLongArray(targetShapeList); | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|  | ||||
| @ -94,7 +94,6 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { | ||||
|         if (!this.hasMiniBatchDimension) { | ||||
|             targetShape = prependMiniBatchSize(targetShape, miniBatchSize); | ||||
|             inputShape = prependMiniBatchSize(inputShape, miniBatchSize); | ||||
|             this.hasMiniBatchDimension = true; | ||||
|             this.miniBatchSize = miniBatchSize; | ||||
|         } | ||||
|         if (this.miniBatchSize != miniBatchSize) { | ||||
|  | ||||
| @ -54,7 +54,7 @@ public class KerasLossUtils { | ||||
|         } else if (kerasLoss.equals(conf.getKERAS_LOSS_HINGE())) { | ||||
|             dl4jLoss = LossFunctions.LossFunction.HINGE; | ||||
|         } else if (kerasLoss.equals(conf.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY())) { | ||||
|             throw new UnsupportedKerasConfigurationException("Loss function " + kerasLoss + " not supported yet."); | ||||
|             dl4jLoss = LossFunctions.LossFunction.SPARSE_MCXENT; | ||||
|         } else if (kerasLoss.equals(conf.getKERAS_LOSS_BINARY_CROSSENTROPY())) { | ||||
|             dl4jLoss = LossFunctions.LossFunction.XENT; | ||||
|         } else if (kerasLoss.equals(conf.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) { | ||||
|  | ||||
| @ -49,6 +49,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.learning.config.NoOp; | ||||
| import org.nd4j.linalg.lossfunctions.LossFunctions; | ||||
| import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT; | ||||
| import org.nd4j.resources.Resources; | ||||
| 
 | ||||
| import java.io.File; | ||||
| @ -88,7 +89,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|     @Test(expected = IllegalStateException.class) | ||||
|     public void fileNotFoundEndToEnd() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/foo/bar.h5"; | ||||
|         importEndModelTest(modelPath, null, true, true, false); | ||||
|         importEndModelTest(modelPath, null, true, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -98,28 +99,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|     public void importMnistMlpTfKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importMnistMlpThKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importMnistMlpTfKeras2() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importMnistMlpReshapeTfKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, true); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -129,21 +130,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|     public void importMnistCnnTfKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, false, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importMnistCnnThKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, true, true); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, true, true, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importMnistCnnTfKeras2() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, true); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -153,28 +154,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|     public void importImdbLstmTfKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importImdbLstmThKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importImdbLstmTfKeras2() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importImdbLstmThKeras2() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -185,21 +186,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|     public void importImdbFasttextTfKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, false, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, false, false, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importImdbFasttextThKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, false, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, false, false, false, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importImdbFasttextTfKeras2() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, false, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -209,21 +210,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|     public void importSimpleLstmTfKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importSimpleLstmThKeras1() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importSimpleLstmTfKeras2() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, false, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @ -235,7 +236,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|         String modelPath = "modelimport/keras/examples/simple_flatten_lstm/simple_flatten_lstm_tf_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/" + | ||||
|                 "simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -246,7 +247,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|         String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + | ||||
|                 "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -257,7 +258,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|         String modelPath = "modelimport/keras/examples/simple_rnn/simple_rnn_tf_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/simple_rnn/" + | ||||
|                 "simple_rnn_tf_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -267,7 +268,18 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|     public void importCnnNoBiasTfKeras2() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_inputs_and_outputs.h5"; | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, false, true); | ||||
|         importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void importSparseXent() throws Exception { | ||||
|         String modelPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_model.h5"; | ||||
|         String inputsOutputPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_inputs_and_outputs.h5"; | ||||
|         MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, true, true); | ||||
|         Layer outLayer = net.getOutputLayer(); | ||||
|         assertTrue(outLayer instanceof org.deeplearning4j.nn.layers.LossLayer); | ||||
|         LossLayer llConf = (LossLayer) outLayer.getConfig(); | ||||
|         assertEquals(new LossSparseMCXENT(), llConf.getLossFn()); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -626,19 +638,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     private void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions) throws Exception { | ||||
|         importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, false); | ||||
|     } | ||||
| 
 | ||||
|     public void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, | ||||
|                                     boolean checkGradients) throws Exception { | ||||
|     public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, | ||||
|                                     boolean checkGradients, boolean enforceTrainingConfig) throws Exception { | ||||
|         MultiLayerNetwork model; | ||||
|         try(InputStream is = Resources.asStream(modelPath)) { | ||||
|             File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); | ||||
|             Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); | ||||
|             KerasSequentialModel kerasModel = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) | ||||
|                     .enforceTrainingConfig(false).buildSequential(); | ||||
|                     .enforceTrainingConfig(enforceTrainingConfig).buildSequential(); | ||||
| 
 | ||||
|             model = kerasModel.getMultiLayerNetwork(); | ||||
|         } | ||||
| @ -699,6 +706,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
|                 checkGradients(model, input, testLabels); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return model; | ||||
|     } | ||||
| 
 | ||||
|     private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user