317 lines
14 KiB
Java
317 lines
14 KiB
Java
/*
|
|
* ******************************************************************************
|
|
* *
|
|
* *
|
|
* * This program and the accompanying materials are made available under the
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
* *
|
|
* * See the NOTICE file distributed with this work for additional
|
|
* * information regarding copyright ownership.
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* * License for the specific language governing permissions and limitations
|
|
* * under the License.
|
|
* *
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
* *****************************************************************************
|
|
*/
|
|
|
|
package org.deeplearning4j.util;
|
|
|
|
import org.apache.commons.io.FileUtils;
|
|
import org.apache.commons.io.IOUtils;
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.junit.jupiter.api.Test;
|
|
import org.junit.jupiter.api.io.TempDir;
|
|
import org.nd4j.linalg.learning.config.Adam;
|
|
import org.nd4j.common.validation.ValidationResult;
|
|
|
|
import java.io.BufferedOutputStream;
|
|
import java.io.File;
|
|
import java.io.FileOutputStream;
|
|
import java.net.URI;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.nio.file.FileSystem;
|
|
import java.nio.file.FileSystems;
|
|
import java.nio.file.Files;
|
|
import java.nio.file.Path;
|
|
import java.util.Collections;
|
|
import java.util.Enumeration;
|
|
import java.util.zip.ZipEntry;
|
|
import java.util.zip.ZipFile;
|
|
import java.util.zip.ZipInputStream;
|
|
import java.util.zip.ZipOutputStream;
|
|
|
|
import static org.junit.jupiter.api.Assertions.*;
|
|
|
|
public class ModelValidatorTests extends BaseDL4JTest {
|
|
|
|
@TempDir
|
|
public File testDir;
|
|
|
|
@Test
|
|
public void testMultiLayerNetworkValidation() throws Exception {
|
|
File f = testDir;
|
|
|
|
//Test non-existent file
|
|
File f0 = new File(f, "doesntExist.bin");
|
|
ValidationResult vr0 = DL4JModelValidator.validateMultiLayerNetwork(f0);
|
|
assertFalse(vr0.isValid());
|
|
assertTrue(vr0.getIssues().get(0).contains("exist"));
|
|
assertEquals("MultiLayerNetwork", vr0.getFormatType());
|
|
assertEquals(MultiLayerNetwork.class, vr0.getFormatClass());
|
|
assertNull(vr0.getException());
|
|
// System.out.println(vr0.toString());
|
|
|
|
//Test empty file
|
|
File f1 = new File(f, "empty.bin");
|
|
f1.createNewFile();
|
|
assertTrue(f1.exists());
|
|
ValidationResult vr1 = DL4JModelValidator.validateMultiLayerNetwork(f1);
|
|
assertFalse(vr1.isValid());
|
|
assertTrue(vr1.getIssues().get(0).contains("empty"));
|
|
assertEquals("MultiLayerNetwork", vr1.getFormatType());
|
|
assertEquals(MultiLayerNetwork.class, vr1.getFormatClass());
|
|
assertNull(vr1.getException());
|
|
// System.out.println(vr1.toString());
|
|
|
|
//Test invalid zip file
|
|
File f2 = new File(f, "notReallyZip.zip");
|
|
FileUtils.writeStringToFile(f2, "This isn't actually a zip file", StandardCharsets.UTF_8);
|
|
ValidationResult vr2 = DL4JModelValidator.validateMultiLayerNetwork(f2);
|
|
assertFalse(vr2.isValid());
|
|
String s = vr2.getIssues().get(0);
|
|
assertTrue(s.contains("zip") && s.contains("corrupt"), s);
|
|
assertEquals("MultiLayerNetwork", vr2.getFormatType());
|
|
assertEquals(MultiLayerNetwork.class, vr2.getFormatClass());
|
|
assertNotNull(vr2.getException());
|
|
// System.out.println(vr2.toString());
|
|
|
|
//Test valid zip, but missing configuration
|
|
File f3 = new File(f, "modelNoConfig.zip");
|
|
getSimpleNet().save(f3);
|
|
try (FileSystem zipfs = FileSystems.newFileSystem(URI.create("jar:" + f3.toURI().toString()), Collections.singletonMap("create", "false"))) {
|
|
Path p = zipfs.getPath(ModelSerializer.CONFIGURATION_JSON);
|
|
Files.delete(p);
|
|
}
|
|
ValidationResult vr3 = DL4JModelValidator.validateMultiLayerNetwork(f3);
|
|
assertFalse(vr3.isValid());
|
|
s = vr3.getIssues().get(0);
|
|
assertEquals(1, vr3.getIssues().size());
|
|
assertTrue(s.contains("missing") && s.contains("configuration"), s);
|
|
assertEquals("MultiLayerNetwork", vr3.getFormatType());
|
|
assertEquals(MultiLayerNetwork.class, vr3.getFormatClass());
|
|
assertNull(vr3.getException());
|
|
// System.out.println(vr3.toString());
|
|
|
|
|
|
//Test valid sip, but missing params
|
|
File f4 = new File(f, "modelNoParams.zip");
|
|
getSimpleNet().save(f4);
|
|
try (FileSystem zipfs = FileSystems.newFileSystem(URI.create("jar:" + f4.toURI().toString()), Collections.singletonMap("create", "false"))) {
|
|
Path p = zipfs.getPath(ModelSerializer.COEFFICIENTS_BIN);
|
|
Files.delete(p);
|
|
}
|
|
ValidationResult vr4 = DL4JModelValidator.validateMultiLayerNetwork(f4);
|
|
assertFalse(vr4.isValid());
|
|
s = vr4.getIssues().get(0);
|
|
assertEquals(1, vr4.getIssues().size());
|
|
assertTrue(s.contains("missing") && s.contains("coefficients"), s);
|
|
assertEquals("MultiLayerNetwork", vr4.getFormatType());
|
|
assertEquals(MultiLayerNetwork.class, vr4.getFormatClass());
|
|
assertNull(vr4.getException());
|
|
// System.out.println(vr4.toString());
|
|
|
|
|
|
//Test valid model
|
|
File f5 = new File(f, "modelValid.zip");
|
|
getSimpleNet().save(f5);
|
|
ValidationResult vr5 = DL4JModelValidator.validateMultiLayerNetwork(f5);
|
|
assertTrue(vr5.isValid());
|
|
assertNull(vr5.getIssues());
|
|
assertEquals("MultiLayerNetwork", vr5.getFormatType());
|
|
assertEquals(MultiLayerNetwork.class, vr5.getFormatClass());
|
|
assertNull(vr5.getException());
|
|
// System.out.println(vr5.toString());
|
|
|
|
|
|
//Test valid model with corrupted JSON
|
|
File f6 = new File(f, "modelBadJson.zip");
|
|
getSimpleNet().save(f6);
|
|
try(ZipFile zf = new ZipFile(f5); ZipOutputStream zo = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(f6)))){
|
|
Enumeration<? extends ZipEntry> e = zf.entries();
|
|
while(e.hasMoreElements()){
|
|
ZipEntry ze = e.nextElement();
|
|
zo.putNextEntry(new ZipEntry(ze.getName()));
|
|
if(ze.getName().equals(ModelSerializer.CONFIGURATION_JSON)){
|
|
zo.write("totally not valid json! - {}".getBytes(StandardCharsets.UTF_8));
|
|
} else {
|
|
byte[] bytes;
|
|
try(ZipInputStream zis = new ZipInputStream(zf.getInputStream(ze))){
|
|
bytes = IOUtils.toByteArray(zis);
|
|
}
|
|
zo.write(bytes);
|
|
// System.out.println("WROTE: " + ze.getName());
|
|
}
|
|
}
|
|
}
|
|
ValidationResult vr6 = DL4JModelValidator.validateMultiLayerNetwork(f6);
|
|
assertFalse(vr6.isValid());
|
|
s = vr6.getIssues().get(0);
|
|
assertEquals(1, vr6.getIssues().size());
|
|
assertTrue(s.contains("JSON") && s.contains("valid") && s.contains("NeuralNetConfiguration"), s);
|
|
assertEquals("MultiLayerNetwork", vr6.getFormatType());
|
|
assertEquals(MultiLayerNetwork.class, vr6.getFormatClass());
|
|
assertNotNull(vr6.getException());
|
|
// System.out.println(vr6.toString());
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testComputationGraphNetworkValidation() throws Exception {
|
|
File f = testDir;
|
|
|
|
//Test non-existent file
|
|
File f0 = new File(f, "doesntExist.bin");
|
|
ValidationResult vr0 = DL4JModelValidator.validateComputationGraph(f0);
|
|
assertFalse(vr0.isValid());
|
|
assertTrue(vr0.getIssues().get(0).contains("exist"));
|
|
assertEquals("ComputationGraph", vr0.getFormatType());
|
|
assertEquals(ComputationGraph.class, vr0.getFormatClass());
|
|
assertNull(vr0.getException());
|
|
// System.out.println(vr0.toString());
|
|
|
|
//Test empty file
|
|
File f1 = new File(f, "empty.bin");
|
|
f1.createNewFile();
|
|
assertTrue(f1.exists());
|
|
ValidationResult vr1 = DL4JModelValidator.validateComputationGraph(f1);
|
|
assertFalse(vr1.isValid());
|
|
assertTrue(vr1.getIssues().get(0).contains("empty"));
|
|
assertEquals("ComputationGraph", vr1.getFormatType());
|
|
assertEquals(ComputationGraph.class, vr1.getFormatClass());
|
|
assertNull(vr1.getException());
|
|
// System.out.println(vr1.toString());
|
|
|
|
//Test invalid zip file
|
|
File f2 = new File(f, "notReallyZip.zip");
|
|
FileUtils.writeStringToFile(f2, "This isn't actually a zip file", StandardCharsets.UTF_8);
|
|
ValidationResult vr2 = DL4JModelValidator.validateComputationGraph(f2);
|
|
assertFalse(vr2.isValid());
|
|
String s = vr2.getIssues().get(0);
|
|
assertTrue(s.contains("zip") && s.contains("corrupt"), s);
|
|
assertEquals("ComputationGraph", vr2.getFormatType());
|
|
assertEquals(ComputationGraph.class, vr2.getFormatClass());
|
|
assertNotNull(vr2.getException());
|
|
// System.out.println(vr2.toString());
|
|
|
|
//Test valid zip, but missing configuration
|
|
File f3 = new File(f, "modelNoConfig.zip");
|
|
getSimpleNet().save(f3);
|
|
try (FileSystem zipfs = FileSystems.newFileSystem(URI.create("jar:" + f3.toURI().toString()), Collections.singletonMap("create", "false"))) {
|
|
Path p = zipfs.getPath(ModelSerializer.CONFIGURATION_JSON);
|
|
Files.delete(p);
|
|
}
|
|
ValidationResult vr3 = DL4JModelValidator.validateComputationGraph(f3);
|
|
assertFalse(vr3.isValid());
|
|
s = vr3.getIssues().get(0);
|
|
assertEquals(1, vr3.getIssues().size());
|
|
assertTrue(s.contains("missing") && s.contains("configuration"), s);
|
|
assertEquals("ComputationGraph", vr3.getFormatType());
|
|
assertEquals(ComputationGraph.class, vr3.getFormatClass());
|
|
assertNull(vr3.getException());
|
|
// System.out.println(vr3.toString());
|
|
|
|
|
|
//Test valid sip, but missing params
|
|
File f4 = new File(f, "modelNoParams.zip");
|
|
getSimpleNet().save(f4);
|
|
try (FileSystem zipfs = FileSystems.newFileSystem(URI.create("jar:" + f4.toURI().toString()), Collections.singletonMap("create", "false"))) {
|
|
Path p = zipfs.getPath(ModelSerializer.COEFFICIENTS_BIN);
|
|
Files.delete(p);
|
|
}
|
|
ValidationResult vr4 = DL4JModelValidator.validateComputationGraph(f4);
|
|
assertFalse(vr4.isValid());
|
|
s = vr4.getIssues().get(0);
|
|
assertEquals(1, vr4.getIssues().size());
|
|
assertTrue(s.contains("missing") && s.contains("coefficients"), s);
|
|
assertEquals("ComputationGraph", vr4.getFormatType());
|
|
assertEquals(ComputationGraph.class, vr4.getFormatClass());
|
|
assertNull(vr4.getException());
|
|
// System.out.println(vr4.toString());
|
|
|
|
|
|
//Test valid model
|
|
File f5 = new File(f, "modelValid.zip");
|
|
getSimpleNet().save(f5);
|
|
ValidationResult vr5 = DL4JModelValidator.validateComputationGraph(f5);
|
|
assertTrue(vr5.isValid());
|
|
assertNull(vr5.getIssues());
|
|
assertEquals("ComputationGraph", vr5.getFormatType());
|
|
assertEquals(ComputationGraph.class, vr5.getFormatClass());
|
|
assertNull(vr5.getException());
|
|
// System.out.println(vr5.toString());
|
|
|
|
|
|
//Test valid model with corrupted JSON
|
|
File f6 = new File(f, "modelBadJson.zip");
|
|
getSimpleNet().save(f6);
|
|
try(ZipFile zf = new ZipFile(f5); ZipOutputStream zo = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(f6)))){
|
|
Enumeration<? extends ZipEntry> e = zf.entries();
|
|
while(e.hasMoreElements()){
|
|
ZipEntry ze = e.nextElement();
|
|
zo.putNextEntry(new ZipEntry(ze.getName()));
|
|
if(ze.getName().equals(ModelSerializer.CONFIGURATION_JSON)){
|
|
zo.write("totally not valid json! - {}".getBytes(StandardCharsets.UTF_8));
|
|
} else {
|
|
byte[] bytes;
|
|
try(ZipInputStream zis = new ZipInputStream(zf.getInputStream(ze))){
|
|
bytes = IOUtils.toByteArray(zis);
|
|
}
|
|
zo.write(bytes);
|
|
// System.out.println("WROTE: " + ze.getName());
|
|
}
|
|
}
|
|
}
|
|
ValidationResult vr6 = DL4JModelValidator.validateComputationGraph(f6);
|
|
assertFalse(vr6.isValid());
|
|
s = vr6.getIssues().get(0);
|
|
assertEquals(1, vr6.getIssues().size());
|
|
assertTrue(s.contains("JSON") && s.contains("valid") && s.contains("ComputationGraphConfiguration"), s);
|
|
assertEquals("ComputationGraph", vr6.getFormatType());
|
|
assertEquals(ComputationGraph.class, vr6.getFormatClass());
|
|
assertNotNull(vr6.getException());
|
|
// System.out.println(vr6.toString());
|
|
}
|
|
|
|
|
|
|
|
public static MultiLayerNetwork getSimpleNet(){
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
|
.seed(12345)
|
|
.updater(new Adam(0.01))
|
|
.list()
|
|
.layer(DenseLayer.builder().nIn(10).nOut(10).build())
|
|
.layer(DenseLayer.builder().nIn(10).nOut(10).build())
|
|
.layer(OutputLayer.builder().nIn(10).nOut(10).build())
|
|
.build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
return net;
|
|
}
|
|
|
|
public static ComputationGraph getSimpleCG(){
|
|
return getSimpleNet().toComputationGraph();
|
|
}
|
|
}
|