From 962575dd62531cd9f48dbf38fe6b24c2bfe88a7a Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 23 Nov 2019 00:16:13 +1100 Subject: [PATCH] Keras import: Add support for sparse cross entropy loss function (#73) * #6377 Keras sparse cross entropy loss import support Signed-off-by: AlexDBlack * Fix small bug in reshape preprocessor Signed-off-by: AlexDBlack --- .../keras/layers/core/KerasReshape.java | 1 - .../preprocessors/ReshapePreprocessor.java | 1 - .../keras/utils/KerasLossUtils.java | 2 +- .../keras/e2e/KerasModelEndToEndTest.java | 69 +++++++++++-------- 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java index 6a5e1ff2a..4035e9298 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java @@ -75,7 +75,6 @@ public class KerasReshape extends KerasLayer { List targetShapeList = (List) innerConfig.get(targetShape); this.targetShape = listToLongArray(targetShapeList); } - } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java index dbd5ccd8c..e9aef5b90 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java @@ -94,7 +94,6 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { if (!this.hasMiniBatchDimension) { targetShape = prependMiniBatchSize(targetShape, miniBatchSize); inputShape = prependMiniBatchSize(inputShape, miniBatchSize); - this.hasMiniBatchDimension = true; this.miniBatchSize = miniBatchSize; } if (this.miniBatchSize != miniBatchSize) { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java index 1149401c6..35cf34170 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java @@ -54,7 +54,7 @@ public class KerasLossUtils { } else if (kerasLoss.equals(conf.getKERAS_LOSS_HINGE())) { dl4jLoss = LossFunctions.LossFunction.HINGE; } else if (kerasLoss.equals(conf.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY())) { - throw new UnsupportedKerasConfigurationException("Loss function " + kerasLoss + " not supported yet."); + dl4jLoss = LossFunctions.LossFunction.SPARSE_MCXENT; } else if (kerasLoss.equals(conf.getKERAS_LOSS_BINARY_CROSSENTROPY())) { dl4jLoss = LossFunctions.LossFunction.XENT; } else if (kerasLoss.equals(conf.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index b33ff8d1f..0565cc091 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -49,6 +49,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT; import org.nd4j.resources.Resources; import java.io.File; @@ -88,7 +89,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Test(expected = IllegalStateException.class) public void fileNotFoundEndToEnd() throws Exception { String modelPath = "modelimport/keras/examples/foo/bar.h5"; - importEndModelTest(modelPath, null, true, true, false); + importEndModelTest(modelPath, null, true, true, false, false); } /** @@ -98,28 +99,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importMnistMlpTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importMnistMlpThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, false); + importEndModelTest(modelPath, inputsOutputPath, false, true, false, false); } @Test public void importMnistMlpTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importMnistMlpReshapeTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, true); + importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); } /** @@ -129,21 +130,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importMnistCnnTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); } @Test public void importMnistCnnThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, true); + importEndModelTest(modelPath, inputsOutputPath, false, true, true, false); } @Test public void importMnistCnnTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, true); + importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); } /** @@ -153,28 +154,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importImdbLstmTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importImdbLstmThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importImdbLstmTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importImdbLstmThKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, false); + importEndModelTest(modelPath, inputsOutputPath, false, true, false, false); } /** @@ -185,21 +186,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importImdbFasttextTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, false, false); + importEndModelTest(modelPath, inputsOutputPath, false, false, false, false); } @Test public void importImdbFasttextThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, false, false); + importEndModelTest(modelPath, inputsOutputPath, false, false, false, false); } @Test public void importImdbFasttextTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); } /** @@ -209,21 +210,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importSimpleLstmTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importSimpleLstmThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } @Test public void importSimpleLstmTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); } @@ -235,7 +236,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { String modelPath = "modelimport/keras/examples/simple_flatten_lstm/simple_flatten_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/" + "simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } /** @@ -246,7 +247,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } /** @@ -257,7 +258,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { String modelPath = "modelimport/keras/examples/simple_rnn/simple_rnn_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_rnn/" + "simple_rnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); } /** @@ -267,7 +268,18 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importCnnNoBiasTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, false, true); + importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); + } + + @Test + public void importSparseXent() throws Exception { + String modelPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_model.h5"; + String inputsOutputPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_inputs_and_outputs.h5"; + MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, true, true); + Layer outLayer = net.getOutputLayer(); + assertTrue(outLayer instanceof org.deeplearning4j.nn.layers.LossLayer); + LossLayer llConf = (LossLayer) outLayer.getConfig(); + assertEquals(new LossSparseMCXENT(), llConf.getLossFn()); } /** @@ -626,19 +638,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } } - - private void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions) throws Exception { - importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, false); - } - - public void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, - boolean checkGradients) throws Exception { + public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, + boolean checkGradients, boolean enforceTrainingConfig) throws Exception { MultiLayerNetwork model; try(InputStream is = Resources.asStream(modelPath)) { File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); KerasSequentialModel kerasModel = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) - .enforceTrainingConfig(false).buildSequential(); + .enforceTrainingConfig(enforceTrainingConfig).buildSequential(); model = kerasModel.getMultiLayerNetwork(); } @@ -699,6 +706,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { checkGradients(model, input, testLabels); } } + + return model; } private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception {