270 lines
11 KiB
Java
Raw Normal View History

2021-02-01 14:31:20 +09:00
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
2021-02-01 17:47:29 +09:00
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
2021-02-01 14:31:20 +09:00
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
2019-06-06 15:21:15 +03:00
package org.deeplearning4j.util;
import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.BaseDL4JTest;
Refactor packages to fix split package issues (#411) * Refactor nd4j-common: org.nd4j.* -> org.nd4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * Fix CUDA (missed nd4j-common package refactoring changes) Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-kryo: org.nd4j -> org.nd4j.kryo Signed-off-by: Alex Black <blacka101@gmail.com> * Fix nd4j-common for deeplearning4j-cuda Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-grppc-client: org.nd4j.graph -> org.nd4j.remote.grpc Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-common: org.deeplearning4.* -> org.deeplearning4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-core: org.deeplearning4j.* -> org.deeplearning.core.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-cuda: org.deeplearning4j.nn.layers.* -> org.deeplearning4j.cuda.* Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-nlp-*: org.deeplearning4.text.* -> org.deeplearning4j.nlp.(language).* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-ui-model: org.deeplearning4j.ui -> org.deeplearning4j.ui.model Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-spark-inference-{server/model/client}: org.datavec.spark.transform -> org.datavec.spark.inference.{server/model/client} Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-jdbc: org.datavec.api -> org.datavec.jdbc Signed-off-by: Alex Black <blacka101@gmail.com> * Delete org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter in favor of (essentially identical) org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter Signed-off-by: Alex Black <blacka101@gmail.com> * ND4S fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-common-tests: org.nd4j.* -> org.nd4j.common.tests Signed-off-by: Alex Black <blacka101@gmail.com> * Trigger CI Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * #8878 Ignore CUDA tests on modules with 'nd4j-native under cuda' issue Signed-off-by: Alex Black <blacka101@gmail.com> * Fix bad imports in tests Signed-off-by: Alex Black <blacka101@gmail.com> * Add ignore on test (already failing) due to #8882 Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Additional import fixes Signed-off-by: Alex Black <blacka101@gmail.com>
2020-04-29 11:19:26 +10:00
import org.deeplearning4j.core.util.ModelGuesser;
2019-06-06 15:21:15 +03:00
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.Timeout;
2019-06-06 15:21:15 +03:00
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j;
Refactor packages to fix split package issues (#411) * Refactor nd4j-common: org.nd4j.* -> org.nd4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * Fix CUDA (missed nd4j-common package refactoring changes) Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-kryo: org.nd4j -> org.nd4j.kryo Signed-off-by: Alex Black <blacka101@gmail.com> * Fix nd4j-common for deeplearning4j-cuda Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-grppc-client: org.nd4j.graph -> org.nd4j.remote.grpc Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-common: org.deeplearning4.* -> org.deeplearning4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-core: org.deeplearning4j.* -> org.deeplearning.core.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-cuda: org.deeplearning4j.nn.layers.* -> org.deeplearning4j.cuda.* Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-nlp-*: org.deeplearning4.text.* -> org.deeplearning4j.nlp.(language).* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-ui-model: org.deeplearning4j.ui -> org.deeplearning4j.ui.model Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-spark-inference-{server/model/client}: org.datavec.spark.transform -> org.datavec.spark.inference.{server/model/client} Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-jdbc: org.datavec.api -> org.datavec.jdbc Signed-off-by: Alex Black <blacka101@gmail.com> * Delete org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter in favor of (essentially identical) org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter Signed-off-by: Alex Black <blacka101@gmail.com> * ND4S fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-common-tests: org.nd4j.* -> org.nd4j.common.tests Signed-off-by: Alex Black <blacka101@gmail.com> * Trigger CI Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * #8878 Ignore CUDA tests on modules with 'nd4j-native under cuda' issue Signed-off-by: Alex Black <blacka101@gmail.com> * Fix bad imports in tests Signed-off-by: Alex Black <blacka101@gmail.com> * Add ignore on test (already failing) due to #8882 Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Additional import fixes Signed-off-by: Alex Black <blacka101@gmail.com>
2020-04-29 11:19:26 +10:00
import org.nd4j.common.io.ClassPathResource;
2019-06-06 15:21:15 +03:00
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
Refactor packages to fix split package issues (#411) * Refactor nd4j-common: org.nd4j.* -> org.nd4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * Fix CUDA (missed nd4j-common package refactoring changes) Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-kryo: org.nd4j -> org.nd4j.kryo Signed-off-by: Alex Black <blacka101@gmail.com> * Fix nd4j-common for deeplearning4j-cuda Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-grppc-client: org.nd4j.graph -> org.nd4j.remote.grpc Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-common: org.deeplearning4.* -> org.deeplearning4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-core: org.deeplearning4j.* -> org.deeplearning.core.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-cuda: org.deeplearning4j.nn.layers.* -> org.deeplearning4j.cuda.* Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-nlp-*: org.deeplearning4.text.* -> org.deeplearning4j.nlp.(language).* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-ui-model: org.deeplearning4j.ui -> org.deeplearning4j.ui.model Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-spark-inference-{server/model/client}: org.datavec.spark.transform -> org.datavec.spark.inference.{server/model/client} Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-jdbc: org.datavec.api -> org.datavec.jdbc Signed-off-by: Alex Black <blacka101@gmail.com> * Delete org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter in favor of (essentially identical) org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter Signed-off-by: Alex Black <blacka101@gmail.com> * ND4S fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-common-tests: org.nd4j.* -> org.nd4j.common.tests Signed-off-by: Alex Black <blacka101@gmail.com> * Trigger CI Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * #8878 Ignore CUDA tests on modules with 'nd4j-native under cuda' issue Signed-off-by: Alex Black <blacka101@gmail.com> * Fix bad imports in tests Signed-off-by: Alex Black <blacka101@gmail.com> * Add ignore on test (already failing) due to #8882 Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Additional import fixes Signed-off-by: Alex Black <blacka101@gmail.com>
2020-04-29 11:19:26 +10:00
import org.nd4j.common.resources.Resources;
2019-06-06 15:21:15 +03:00
import java.io.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeNotNull;
public class ModelGuesserTest extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Rule
public Timeout timeout = Timeout.seconds(300);
2019-06-06 15:21:15 +03:00
@Test
public void testModelGuessFile() throws Exception {
File f = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5");
assertTrue(f.exists());
Model guess1 = ModelGuesser.loadModelGuess(f.getAbsolutePath());
assumeNotNull(guess1);
f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5");
assertTrue(f.exists());
Model guess2 = ModelGuesser.loadModelGuess(f.getAbsolutePath());
assumeNotNull(guess2);
}
@Test
public void testModelGuessInputStream() throws Exception {
File f = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5");
assertTrue(f.exists());
try (InputStream inputStream = new FileInputStream(f)) {
Model guess1 = ModelGuesser.loadModelGuess(inputStream);
assumeNotNull(guess1);
}
f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5");
assertTrue(f.exists());
try (InputStream inputStream = new FileInputStream(f)) {
Model guess1 = ModelGuesser.loadModelGuess(inputStream);
assumeNotNull(guess1);
}
}
@Test
public void testLoadNormalizersFile() throws Exception {
MultiLayerNetwork net = getNetwork();
File tempFile = testDir.newFile("testLoadNormalizersFile.bin");
ModelSerializer.writeModel(net, tempFile, true);
NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
ModelSerializer.addNormalizerToModel(tempFile, normalizer);
Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath());
assertEquals(model, net);
assertEquals(normalizer, normalizer1);
}
@Test
public void testNormalizerInPlace() throws Exception {
MultiLayerNetwork net = getNetwork();
File tempFile = testDir.newFile("testNormalizerInPlace.bin");
NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
ModelSerializer.writeModel(net, tempFile, true,normalizer);
Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath());
assertEquals(model, net);
assertEquals(normalizer, normalizer1);
}
@Test
public void testLoadNormalizersInputStream() throws Exception {
MultiLayerNetwork net = getNetwork();
File tempFile = testDir.newFile("testLoadNormalizersInputStream.bin");
ModelSerializer.writeModel(net, tempFile, true);
NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
ModelSerializer.addNormalizerToModel(tempFile, normalizer);
Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
try (InputStream inputStream = new FileInputStream(tempFile)) {
Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(inputStream);
assertEquals(model, net);
assertEquals(normalizer, normalizer1);
}
}
@Test
public void testModelGuesserDl4jModelFile() throws Exception {
MultiLayerNetwork net = getNetwork();
File tempFile = testDir.newFile("testModelGuesserDl4jModelFile.bin");
ModelSerializer.writeModel(net, tempFile, true);
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson());
assertEquals(net.params(), network.params());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
}
@Test
public void testModelGuesserDl4jModelInputStream() throws Exception {
MultiLayerNetwork net = getNetwork();
File tempFile = testDir.newFile("testModelGuesserDl4jModelInputStream.bin");
ModelSerializer.writeModel(net, tempFile, true);
try (InputStream inputStream = new FileInputStream(tempFile)) {
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream);
assumeNotNull(network);
assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson());
assertEquals(net.params(), network.params());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
}
}
@Test
public void testModelGuessConfigFile() throws Exception {
ClassPathResource resource = new ClassPathResource("modelimport/keras/configs/cnn_tf_config.json",
ModelGuesserTest.class.getClassLoader());
File f = getTempFile(resource);
String configFilename = f.getAbsolutePath();
Object conf = ModelGuesser.loadConfigGuess(configFilename);
assertTrue(conf instanceof MultiLayerConfiguration);
ClassPathResource sequenceResource = new ClassPathResource("/keras/simple/mlp_fapi_multiloss_config.json");
File f2 = getTempFile(sequenceResource);
Object sequenceConf = ModelGuesser.loadConfigGuess(f2.getAbsolutePath());
assertTrue(sequenceConf instanceof ComputationGraphConfiguration);
ClassPathResource resourceDl4j = new ClassPathResource("model.json");
File fDl4j = getTempFile(resourceDl4j);
String configFilenameDl4j = fDl4j.getAbsolutePath();
Object confDl4j = ModelGuesser.loadConfigGuess(configFilenameDl4j);
assertTrue(confDl4j instanceof ComputationGraphConfiguration);
}
@Test
public void testModelGuessConfigInputStream() throws Exception {
ClassPathResource resource = new ClassPathResource("modelimport/keras/configs/cnn_tf_config.json",
ModelGuesserTest.class.getClassLoader());
File f = getTempFile(resource);
try (InputStream inputStream = new FileInputStream(f)) {
Object conf = ModelGuesser.loadConfigGuess(inputStream);
assertTrue(conf instanceof MultiLayerConfiguration);
}
ClassPathResource sequenceResource = new ClassPathResource("/keras/simple/mlp_fapi_multiloss_config.json");
File f2 = getTempFile(sequenceResource);
try (InputStream inputStream = new FileInputStream(f2)) {
Object sequenceConf = ModelGuesser.loadConfigGuess(inputStream);
assertTrue(sequenceConf instanceof ComputationGraphConfiguration);
}
ClassPathResource resourceDl4j = new ClassPathResource("model.json");
File fDl4j = getTempFile(resourceDl4j);
try (InputStream inputStream = new FileInputStream(fDl4j)) {
Object confDl4j = ModelGuesser.loadConfigGuess(inputStream);
assertTrue(confDl4j instanceof ComputationGraphConfiguration);
}
}
private File getTempFile(ClassPathResource classPathResource) throws Exception {
InputStream is = classPathResource.getInputStream();
File f = testDir.newFile();
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f));
IOUtils.copy(is, bos);
bos.flush();
bos.close();
return f;
}
private MultiLayerNetwork getNetwork() {
int nIn = 5;
int nOut = 6;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01)
.updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list()
.layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build())
.layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
return net;
}
}