502 lines
20 KiB
Java

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.util;
import lombok.val;
import org.apache.commons.lang3.SerializationUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.primitives.Pair;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
public class ModelSerializerTest extends BaseDL4JTest {
@TempDir
public File tempDir;
@Test
public void testWriteMLNModel() throws Exception {
int nIn = 5;
int nOut = 6;
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01)
.l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list()
.layer(0, DenseLayer.builder().nIn(nIn).nOut(20).build())
.layer(1, DenseLayer.builder().nIn(20).nOut(30).build()).layer(2, OutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
File tempFile = tempDir;
ModelSerializer.writeModel(net, tempFile, true);
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile);
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
assertEquals(net.getModelParams(), network.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
}
@Test
public void testWriteMlnModelInputStream() throws Exception {
int nIn = 5;
int nOut = 6;
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01)
.l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list()
.layer(0, DenseLayer.builder().nIn(nIn).nOut(20).build())
.layer(1, DenseLayer.builder().nIn(20).nOut(30).build()).layer(2, OutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
File tempFile = tempDir;
FileOutputStream fos = new FileOutputStream(tempFile);
ModelSerializer.writeModel(net, fos, true);
// checking adding of DataNormalization to the model file
NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler();
DataSetIterator iter = new IrisDataSetIterator(150, 150);
scaler.fit(iter);
ModelSerializer.addNormalizerToModel(tempFile, scaler);
NormalizerMinMaxScaler restoredScaler = ModelSerializer.restoreNormalizerFromFile(tempFile);
assertNotEquals(null, scaler.getMax());
assertEquals(scaler.getMax(), restoredScaler.getMax());
assertEquals(scaler.getMin(), restoredScaler.getMin());
FileInputStream fis = new FileInputStream(tempFile);
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis);
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
assertEquals(net.getModelParams(), network.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
}
@Test
public void testWriteCGModel() throws Exception {
ComputationGraphConfiguration config = NeuralNetConfiguration.builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1))
.graphBuilder().addInputs("in")
.addLayer("dense", DenseLayer.builder().nIn(4).nOut(2).build(), "in").addLayer("out",
OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3)
.activation(Activation.SOFTMAX).build(),
"dense")
.setOutputs("out").build();
ComputationGraph cg = new ComputationGraph(config);
cg.init();
File tempFile = tempDir;
ModelSerializer.writeModel(cg, tempFile, true);
ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile);
assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
assertEquals(cg.getModelParams(), network.getModelParams());
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
}
@Test
public void testWriteCGModelInputStream() throws Exception {
ComputationGraphConfiguration config = NeuralNetConfiguration.builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1))
.graphBuilder().addInputs("in")
.addLayer("dense", DenseLayer.builder().nIn(4).nOut(2).build(), "in").addLayer("out",
OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3)
.activation(Activation.SOFTMAX).build(),
"dense")
.setOutputs("out").build();
ComputationGraph cg = new ComputationGraph(config);
cg.init();
File tempFile = tempDir;
ModelSerializer.writeModel(cg, tempFile, true);
FileInputStream fis = new FileInputStream(tempFile);
ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);
assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
assertEquals(cg.getModelParams(), network.getModelParams());
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
}
private DataSet trivialDataSet() {
INDArray inputs = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f}, new int[]{1,3});
INDArray labels = Nd4j.create(new float[] {4.0f, 5.0f, 6.0f}, new int[]{1,3});
return new DataSet(inputs, labels);
}
private ComputationGraph simpleComputationGraph() {
ComputationGraphConfiguration config = NeuralNetConfiguration.builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1))
.graphBuilder().addInputs("in")
.addLayer("dense", DenseLayer.builder().nIn(4).nOut(2).build(), "in").addLayer("out",
OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3)
.activation(Activation.SOFTMAX).build(),
"dense")
.setOutputs("out").build();
return new ComputationGraph(config);
}
@Test
public void testSaveRestoreNormalizerFromInputStream() throws Exception {
DataSet dataSet = trivialDataSet();
NormalizerStandardize norm = new NormalizerStandardize();
norm.fit(dataSet);
ComputationGraph cg = simpleComputationGraph();
cg.init();
File tempFile = tempDir;
ModelSerializer.writeModel(cg, tempFile, true);
ModelSerializer.addNormalizerToModel(tempFile, norm);
FileInputStream fis = new FileInputStream(tempFile);
NormalizerStandardize restored = ModelSerializer.restoreNormalizerFromInputStream(fis);
assertNotEquals(null, restored);
DataSet dataSet2 = dataSet.copy();
norm.preProcess(dataSet2);
assertNotEquals(dataSet.getFeatures(), dataSet2.getFeatures());
restored.revert(dataSet2);
assertEquals(dataSet.getFeatures(), dataSet2.getFeatures());
}
@Test
public void testRestoreUnsavedNormalizerFromInputStream() throws Exception {
DataSet dataSet = trivialDataSet();
NormalizerStandardize norm = new NormalizerStandardize();
norm.fit(dataSet);
ComputationGraph cg = simpleComputationGraph();
cg.init();
File tempFile = tempDir;
ModelSerializer.writeModel(cg, tempFile, true);
FileInputStream fis = new FileInputStream(tempFile);
NormalizerStandardize restored = ModelSerializer.restoreNormalizerFromInputStream(fis);
assertNull(restored);
}
@Test
public void testInvalidLoading1() throws Exception {
ComputationGraphConfiguration config = NeuralNetConfiguration.builder()
.graphBuilder().addInputs("in")
.addLayer("dense", DenseLayer.builder().nIn(4).nOut(2).build(), "in")
.addLayer("out",OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(2).nOut(3).build(),
"dense")
.setOutputs("out").build();
ComputationGraph cg = new ComputationGraph(config);
cg.init();
File tempFile = tempDir;
ModelSerializer.writeModel(cg, tempFile, true);
try {
ModelSerializer.restoreMultiLayerNetwork(tempFile);
fail();
} catch (Exception e){
String msg = e.getMessage();
assertTrue(msg.contains("JSON") && msg.contains("restoreComputationGraph"), msg);
}
}
@Test
public void testInvalidLoading2() throws Exception {
int nIn = 5;
int nOut = 6;
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01)
.l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list()
.layer(0, DenseLayer.builder().nIn(nIn).nOut(20).build())
.layer(1, DenseLayer.builder().nIn(20).nOut(30).build()).layer(2, OutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
File tempFile = new File(tempDir, "testInvalidLoading2.bin");
ModelSerializer.writeModel(net, tempFile, true);
try {
ModelSerializer.restoreComputationGraph(tempFile);
fail();
} catch (Exception e){
String msg = e.getMessage();
assertTrue(msg.contains("JSON") && msg.contains("restoreMultiLayerNetwork"), msg);
}
}
@Test
public void testInvalidStreamReuse() throws Exception {
int nIn = 5;
int nOut = 6;
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01)
.list()
.layer(OutputLayer.builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSet dataSet = trivialDataSet();
NormalizerStandardize norm = new NormalizerStandardize();
norm.fit(dataSet);
File tempFile = tempDir;
ModelSerializer.writeModel(net, tempFile, true);
ModelSerializer.addNormalizerToModel(tempFile, norm);
InputStream is = new FileInputStream(tempFile);
ModelSerializer.restoreMultiLayerNetwork(is);
try{
ModelSerializer.restoreNormalizerFromInputStream(is);
fail("Expected exception");
} catch (Exception e){
String msg = e.getMessage();
assertTrue(msg.contains("may have been closed"), msg);
}
try{
ModelSerializer.restoreMultiLayerNetwork(is);
fail("Expected exception");
} catch (Exception e){
String msg = e.getMessage();
assertTrue(msg.contains("may have been closed"), msg);
}
//Also test reading both model and normalizer from stream (correctly)
Pair<MultiLayerNetwork,Normalizer> pair = ModelSerializer.restoreMultiLayerNetworkAndNormalizer(new FileInputStream(tempFile), true);
assertEquals(net.getModelParams(), pair.getFirst().getModelParams());
assertNotNull(pair.getSecond());
}
@Test
public void testInvalidStreamReuseCG() throws Exception {
int nIn = 5;
int nOut = 6;
ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01)
.graphBuilder()
.addInputs("in")
.layer("0", OutputLayer.builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build(), "in")
.setOutputs("0")
.build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
DataSet dataSet = trivialDataSet();
NormalizerStandardize norm = new NormalizerStandardize();
norm.fit(dataSet);
File tempFile = tempDir;
ModelSerializer.writeModel(net, tempFile, true);
ModelSerializer.addNormalizerToModel(tempFile, norm);
InputStream is = new FileInputStream(tempFile);
ModelSerializer.restoreComputationGraph(is);
try{
ModelSerializer.restoreNormalizerFromInputStream(is);
fail("Expected exception");
} catch (Exception e){
String msg = e.getMessage();
assertTrue(msg.contains("may have been closed"), msg);
}
try{
ModelSerializer.restoreComputationGraph(is);
fail("Expected exception");
} catch (Exception e){
String msg = e.getMessage();
assertTrue(msg.contains("may have been closed"), msg);
}
//Also test reading both model and normalizer from stream (correctly)
Pair<ComputationGraph,Normalizer> pair = ModelSerializer.restoreComputationGraphAndNormalizer(new FileInputStream(tempFile), true);
assertEquals(net.getModelParams(), pair.getFirst().getModelParams());
assertNotNull(pair.getSecond());
}
@Test
public void testJavaSerde_1() throws Exception {
int nIn = 5;
int nOut = 6;
ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01)
.graphBuilder()
.addInputs("in")
.layer("0", OutputLayer.builder().nIn(nIn).nOut(nOut).build(), "in")
.setOutputs("0")
.validateOutputLayerConfig(false)
.build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
DataSet dataSet = trivialDataSet();
NormalizerStandardize norm = new NormalizerStandardize();
norm.fit(dataSet);
val b = SerializationUtils.serialize(net);
ComputationGraph restored = SerializationUtils.deserialize(b);
assertEquals(net, restored);
}
@Test
public void testJavaSerde_2() throws Exception {
int nIn = 5;
int nOut = 6;
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01)
.list()
.layer(0, OutputLayer.builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSet dataSet = trivialDataSet();
NormalizerStandardize norm = new NormalizerStandardize();
norm.fit(dataSet);
val b = SerializationUtils.serialize(net);
MultiLayerNetwork restored = SerializationUtils.deserialize(b);
assertEquals(net, restored);
}
@Test
public void testPutGetObject() throws Exception {
int nIn = 5;
int nOut = 6;
ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01)
.graphBuilder()
.addInputs("in")
.layer("0", OutputLayer.builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build(), "in")
.setOutputs("0")
.build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
File tempFile = tempDir;
ModelSerializer.writeModel(net, tempFile, true);
List<String> toWrite = Arrays.asList("zero", "one", "two");
ModelSerializer.addObjectToFile(tempFile, "myLabels", toWrite);
List<String> restored = ModelSerializer.getObjectFromFile(tempFile, "myLabels");
assertEquals(toWrite, restored);
Map<String,Object> someOtherData = new HashMap<>();
someOtherData.put("x", new float[]{0,1,2});
someOtherData.put("y",Nd4j.linspace(1,10,10, Nd4j.dataType()));
ModelSerializer.addObjectToFile(tempFile, "otherData.bin", someOtherData);
Map<String,Object> dataRestored = ModelSerializer.getObjectFromFile(tempFile, "otherData.bin");
assertEquals(someOtherData.keySet(), dataRestored.keySet());
assertArrayEquals((float[])someOtherData.get("x"), (float[])dataRestored.get("x"), 0f);
assertEquals(someOtherData.get("y"), dataRestored.get("y"));
List<String> entries = ModelSerializer.listObjectsInFile(tempFile);
assertEquals(2, entries.size());
System.out.println(entries);
assertTrue(entries.contains("myLabels"));
assertTrue(entries.contains("otherData.bin"));
ComputationGraph restoredNet = ModelSerializer.restoreComputationGraph(tempFile);
assertEquals(net.getModelParams(), restoredNet.getModelParams());
}
}