Keras import: Add support for sparse cross entropy loss function (#73)
* #6377 Keras sparse cross entropy loss import support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix small bug in reshape preprocessor Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
6f514e9431
commit
962575dd62
|
@ -75,7 +75,6 @@ public class KerasReshape extends KerasLayer {
|
||||||
List<Integer> targetShapeList = (List<Integer>) innerConfig.get(targetShape);
|
List<Integer> targetShapeList = (List<Integer>) innerConfig.get(targetShape);
|
||||||
this.targetShape = listToLongArray(targetShapeList);
|
this.targetShape = listToLongArray(targetShapeList);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -94,7 +94,6 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
if (!this.hasMiniBatchDimension) {
|
if (!this.hasMiniBatchDimension) {
|
||||||
targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
|
targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
|
||||||
inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
|
inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
|
||||||
this.hasMiniBatchDimension = true;
|
|
||||||
this.miniBatchSize = miniBatchSize;
|
this.miniBatchSize = miniBatchSize;
|
||||||
}
|
}
|
||||||
if (this.miniBatchSize != miniBatchSize) {
|
if (this.miniBatchSize != miniBatchSize) {
|
||||||
|
|
|
@ -54,7 +54,7 @@ public class KerasLossUtils {
|
||||||
} else if (kerasLoss.equals(conf.getKERAS_LOSS_HINGE())) {
|
} else if (kerasLoss.equals(conf.getKERAS_LOSS_HINGE())) {
|
||||||
dl4jLoss = LossFunctions.LossFunction.HINGE;
|
dl4jLoss = LossFunctions.LossFunction.HINGE;
|
||||||
} else if (kerasLoss.equals(conf.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY())) {
|
} else if (kerasLoss.equals(conf.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY())) {
|
||||||
throw new UnsupportedKerasConfigurationException("Loss function " + kerasLoss + " not supported yet.");
|
dl4jLoss = LossFunctions.LossFunction.SPARSE_MCXENT;
|
||||||
} else if (kerasLoss.equals(conf.getKERAS_LOSS_BINARY_CROSSENTROPY())) {
|
} else if (kerasLoss.equals(conf.getKERAS_LOSS_BINARY_CROSSENTROPY())) {
|
||||||
dl4jLoss = LossFunctions.LossFunction.XENT;
|
dl4jLoss = LossFunctions.LossFunction.XENT;
|
||||||
} else if (kerasLoss.equals(conf.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) {
|
} else if (kerasLoss.equals(conf.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) {
|
||||||
|
|
|
@ -49,6 +49,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.NoOp;
|
import org.nd4j.linalg.learning.config.NoOp;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
|
||||||
import org.nd4j.resources.Resources;
|
import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -88,7 +89,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
@Test(expected = IllegalStateException.class)
|
@Test(expected = IllegalStateException.class)
|
||||||
public void fileNotFoundEndToEnd() throws Exception {
|
public void fileNotFoundEndToEnd() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
||||||
importEndModelTest(modelPath, null, true, true, false);
|
importEndModelTest(modelPath, null, true, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -98,28 +99,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
public void importMnistMlpTfKeras1() throws Exception {
|
public void importMnistMlpTfKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importMnistMlpThKeras1() throws Exception {
|
public void importMnistMlpThKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, false, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importMnistMlpTfKeras2() throws Exception {
|
public void importMnistMlpTfKeras2() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importMnistMlpReshapeTfKeras1() throws Exception {
|
public void importMnistMlpReshapeTfKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, true);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, true, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -129,21 +130,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
public void importMnistCnnTfKeras1() throws Exception {
|
public void importMnistCnnTfKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, false, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importMnistCnnThKeras1() throws Exception {
|
public void importMnistCnnThKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, false, true, true);
|
importEndModelTest(modelPath, inputsOutputPath, false, true, true, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importMnistCnnTfKeras2() throws Exception {
|
public void importMnistCnnTfKeras2() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, true);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, true, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -153,28 +154,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
public void importImdbLstmTfKeras1() throws Exception {
|
public void importImdbLstmTfKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importImdbLstmThKeras1() throws Exception {
|
public void importImdbLstmThKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importImdbLstmTfKeras2() throws Exception {
|
public void importImdbLstmTfKeras2() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importImdbLstmThKeras2() throws Exception {
|
public void importImdbLstmThKeras2() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, false, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -185,21 +186,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
public void importImdbFasttextTfKeras1() throws Exception {
|
public void importImdbFasttextTfKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, false, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, false, false, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importImdbFasttextThKeras1() throws Exception {
|
public void importImdbFasttextThKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, false, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, false, false, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importImdbFasttextTfKeras2() throws Exception {
|
public void importImdbFasttextTfKeras2() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, false, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -209,21 +210,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
public void importSimpleLstmTfKeras1() throws Exception {
|
public void importSimpleLstmTfKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importSimpleLstmThKeras1() throws Exception {
|
public void importSimpleLstmThKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importSimpleLstmTfKeras2() throws Exception {
|
public void importSimpleLstmTfKeras2() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, false, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -235,7 +236,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
String modelPath = "modelimport/keras/examples/simple_flatten_lstm/simple_flatten_lstm_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/simple_flatten_lstm/simple_flatten_lstm_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/" +
|
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/" +
|
||||||
"simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5";
|
"simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -246,7 +247,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
|
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
|
||||||
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
|
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -257,7 +258,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
String modelPath = "modelimport/keras/examples/simple_rnn/simple_rnn_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/simple_rnn/simple_rnn_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/simple_rnn/" +
|
String inputsOutputPath = "modelimport/keras/examples/simple_rnn/" +
|
||||||
"simple_rnn_tf_keras_2_inputs_and_outputs.h5";
|
"simple_rnn_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -267,7 +268,18 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
public void importCnnNoBiasTfKeras2() throws Exception {
|
public void importCnnNoBiasTfKeras2() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, false, true);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, true, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void importSparseXent() throws Exception {
|
||||||
|
String modelPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_model.h5";
|
||||||
|
String inputsOutputPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_inputs_and_outputs.h5";
|
||||||
|
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, true, true);
|
||||||
|
Layer outLayer = net.getOutputLayer();
|
||||||
|
assertTrue(outLayer instanceof org.deeplearning4j.nn.layers.LossLayer);
|
||||||
|
LossLayer llConf = (LossLayer) outLayer.getConfig();
|
||||||
|
assertEquals(new LossSparseMCXENT(), llConf.getLossFn());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -626,19 +638,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
|
||||||
private void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions) throws Exception {
|
boolean checkGradients, boolean enforceTrainingConfig) throws Exception {
|
||||||
importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
|
|
||||||
boolean checkGradients) throws Exception {
|
|
||||||
MultiLayerNetwork model;
|
MultiLayerNetwork model;
|
||||||
try(InputStream is = Resources.asStream(modelPath)) {
|
try(InputStream is = Resources.asStream(modelPath)) {
|
||||||
File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
|
File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
|
||||||
Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
|
Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
|
||||||
KerasSequentialModel kerasModel = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
|
KerasSequentialModel kerasModel = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
|
||||||
.enforceTrainingConfig(false).buildSequential();
|
.enforceTrainingConfig(enforceTrainingConfig).buildSequential();
|
||||||
|
|
||||||
model = kerasModel.getMultiLayerNetwork();
|
model = kerasModel.getMultiLayerNetwork();
|
||||||
}
|
}
|
||||||
|
@ -699,6 +706,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
checkGradients(model, input, testLabels);
|
checkGradients(model, input, testLabels);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return model;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception {
|
private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception {
|
||||||
|
|
Loading…
Reference in New Issue