Keras import: Add support for sparse cross entropy loss function (#73)

* #6377 Keras sparse cross entropy loss import support

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix small bug in reshape preprocessor

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-23 00:16:13 +11:00 committed by GitHub
parent 6f514e9431
commit 962575dd62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 33 deletions

View File

@ -75,7 +75,6 @@ public class KerasReshape extends KerasLayer {
List<Integer> targetShapeList = (List<Integer>) innerConfig.get(targetShape); List<Integer> targetShapeList = (List<Integer>) innerConfig.get(targetShape);
this.targetShape = listToLongArray(targetShapeList); this.targetShape = listToLongArray(targetShapeList);
} }
} }
/** /**

View File

@ -94,7 +94,6 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
if (!this.hasMiniBatchDimension) { if (!this.hasMiniBatchDimension) {
targetShape = prependMiniBatchSize(targetShape, miniBatchSize); targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
inputShape = prependMiniBatchSize(inputShape, miniBatchSize); inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
this.hasMiniBatchDimension = true;
this.miniBatchSize = miniBatchSize; this.miniBatchSize = miniBatchSize;
} }
if (this.miniBatchSize != miniBatchSize) { if (this.miniBatchSize != miniBatchSize) {

View File

@ -54,7 +54,7 @@ public class KerasLossUtils {
} else if (kerasLoss.equals(conf.getKERAS_LOSS_HINGE())) { } else if (kerasLoss.equals(conf.getKERAS_LOSS_HINGE())) {
dl4jLoss = LossFunctions.LossFunction.HINGE; dl4jLoss = LossFunctions.LossFunction.HINGE;
} else if (kerasLoss.equals(conf.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY())) { } else if (kerasLoss.equals(conf.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY())) {
throw new UnsupportedKerasConfigurationException("Loss function " + kerasLoss + " not supported yet."); dl4jLoss = LossFunctions.LossFunction.SPARSE_MCXENT;
} else if (kerasLoss.equals(conf.getKERAS_LOSS_BINARY_CROSSENTROPY())) { } else if (kerasLoss.equals(conf.getKERAS_LOSS_BINARY_CROSSENTROPY())) {
dl4jLoss = LossFunctions.LossFunction.XENT; dl4jLoss = LossFunctions.LossFunction.XENT;
} else if (kerasLoss.equals(conf.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) { } else if (kerasLoss.equals(conf.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) {

View File

@ -49,6 +49,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.File; import java.io.File;
@ -88,7 +89,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Test(expected = IllegalStateException.class) @Test(expected = IllegalStateException.class)
public void fileNotFoundEndToEnd() throws Exception { public void fileNotFoundEndToEnd() throws Exception {
String modelPath = "modelimport/keras/examples/foo/bar.h5"; String modelPath = "modelimport/keras/examples/foo/bar.h5";
importEndModelTest(modelPath, null, true, true, false); importEndModelTest(modelPath, null, true, true, false, false);
} }
/** /**
@ -98,28 +99,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importMnistMlpTfKeras1() throws Exception { public void importMnistMlpTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importMnistMlpThKeras1() throws Exception { public void importMnistMlpThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, true, false); importEndModelTest(modelPath, inputsOutputPath, false, true, false, false);
} }
@Test @Test
public void importMnistMlpTfKeras2() throws Exception { public void importMnistMlpTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importMnistMlpReshapeTfKeras1() throws Exception { public void importMnistMlpReshapeTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, true); importEndModelTest(modelPath, inputsOutputPath, true, true, true, false);
} }
/** /**
@ -129,21 +130,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importMnistCnnTfKeras1() throws Exception { public void importMnistCnnTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, false, false); importEndModelTest(modelPath, inputsOutputPath, true, false, false, false);
} }
@Test @Test
public void importMnistCnnThKeras1() throws Exception { public void importMnistCnnThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, true, true); importEndModelTest(modelPath, inputsOutputPath, false, true, true, false);
} }
@Test @Test
public void importMnistCnnTfKeras2() throws Exception { public void importMnistCnnTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, true); importEndModelTest(modelPath, inputsOutputPath, true, true, true, false);
} }
/** /**
@ -153,28 +154,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importImdbLstmTfKeras1() throws Exception { public void importImdbLstmTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importImdbLstmThKeras1() throws Exception { public void importImdbLstmThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importImdbLstmTfKeras2() throws Exception { public void importImdbLstmTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importImdbLstmThKeras2() throws Exception { public void importImdbLstmThKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, true, false); importEndModelTest(modelPath, inputsOutputPath, false, true, false, false);
} }
/** /**
@ -185,21 +186,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importImdbFasttextTfKeras1() throws Exception { public void importImdbFasttextTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, false, false); importEndModelTest(modelPath, inputsOutputPath, false, false, false, false);
} }
@Test @Test
public void importImdbFasttextThKeras1() throws Exception { public void importImdbFasttextThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, false, false); importEndModelTest(modelPath, inputsOutputPath, false, false, false, false);
} }
@Test @Test
public void importImdbFasttextTfKeras2() throws Exception { public void importImdbFasttextTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, false, false); importEndModelTest(modelPath, inputsOutputPath, true, false, false, false);
} }
/** /**
@ -209,21 +210,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importSimpleLstmTfKeras1() throws Exception { public void importSimpleLstmTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importSimpleLstmThKeras1() throws Exception { public void importSimpleLstmThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importSimpleLstmTfKeras2() throws Exception { public void importSimpleLstmTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, false, false); importEndModelTest(modelPath, inputsOutputPath, true, false, false, false);
} }
@ -235,7 +236,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
String modelPath = "modelimport/keras/examples/simple_flatten_lstm/simple_flatten_lstm_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/simple_flatten_lstm/simple_flatten_lstm_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/" + String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/" +
"simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5"; "simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
/** /**
@ -246,7 +247,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
/** /**
@ -257,7 +258,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
String modelPath = "modelimport/keras/examples/simple_rnn/simple_rnn_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/simple_rnn/simple_rnn_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_rnn/" + String inputsOutputPath = "modelimport/keras/examples/simple_rnn/" +
"simple_rnn_tf_keras_2_inputs_and_outputs.h5"; "simple_rnn_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
/** /**
@ -267,7 +268,18 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importCnnNoBiasTfKeras2() throws Exception { public void importCnnNoBiasTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, false, true); importEndModelTest(modelPath, inputsOutputPath, true, true, true, false);
}
@Test
public void importSparseXent() throws Exception {
String modelPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_inputs_and_outputs.h5";
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, true, true);
Layer outLayer = net.getOutputLayer();
assertTrue(outLayer instanceof org.deeplearning4j.nn.layers.LossLayer);
LossLayer llConf = (LossLayer) outLayer.getConfig();
assertEquals(new LossSparseMCXENT(), llConf.getLossFn());
} }
/** /**
@ -626,19 +638,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
} }
} }
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
private void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions) throws Exception { boolean checkGradients, boolean enforceTrainingConfig) throws Exception {
importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, false);
}
public void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
boolean checkGradients) throws Exception {
MultiLayerNetwork model; MultiLayerNetwork model;
try(InputStream is = Resources.asStream(modelPath)) { try(InputStream is = Resources.asStream(modelPath)) {
File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
KerasSequentialModel kerasModel = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) KerasSequentialModel kerasModel = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
.enforceTrainingConfig(false).buildSequential(); .enforceTrainingConfig(enforceTrainingConfig).buildSequential();
model = kerasModel.getMultiLayerNetwork(); model = kerasModel.getMultiLayerNetwork();
} }
@ -699,6 +706,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
checkGradients(model, input, testLabels); checkGradients(model, input, testLabels);
} }
} }
return model;
} }
private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception { private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception {