162 lines
7.0 KiB
Java
Raw Normal View History

2021-02-01 14:31:20 +09:00
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
2021-02-01 17:47:29 +09:00
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
2021-02-01 14:31:20 +09:00
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
2019-06-06 15:21:15 +03:00
package org.deeplearning4j.util;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
2019-06-06 15:21:15 +03:00
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
Refactor packages to fix split package issues (#411) * Refactor nd4j-common: org.nd4j.* -> org.nd4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * Fix CUDA (missed nd4j-common package refactoring changes) Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-kryo: org.nd4j -> org.nd4j.kryo Signed-off-by: Alex Black <blacka101@gmail.com> * Fix nd4j-common for deeplearning4j-cuda Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-grppc-client: org.nd4j.graph -> org.nd4j.remote.grpc Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-common: org.deeplearning4.* -> org.deeplearning4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-core: org.deeplearning4j.* -> org.deeplearning.core.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-cuda: org.deeplearning4j.nn.layers.* -> org.deeplearning4j.cuda.* Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-nlp-*: org.deeplearning4.text.* -> org.deeplearning4j.nlp.(language).* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-ui-model: org.deeplearning4j.ui -> org.deeplearning4j.ui.model Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-spark-inference-{server/model/client}: org.datavec.spark.transform -> org.datavec.spark.inference.{server/model/client} Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-jdbc: org.datavec.api -> org.datavec.jdbc Signed-off-by: Alex Black <blacka101@gmail.com> * Delete org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter in favor of (essentially identical) org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter Signed-off-by: Alex Black <blacka101@gmail.com> * ND4S fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-common-tests: org.nd4j.* -> org.nd4j.common.tests Signed-off-by: Alex Black <blacka101@gmail.com> * Trigger CI Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * #8878 Ignore CUDA tests on modules with 'nd4j-native under cuda' issue Signed-off-by: Alex Black <blacka101@gmail.com> * Fix bad imports in tests Signed-off-by: Alex Black <blacka101@gmail.com> * Add ignore on test (already failing) due to #8882 Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Additional import fixes Signed-off-by: Alex Black <blacka101@gmail.com>
2020-04-29 11:19:26 +10:00
import org.nd4j.common.validation.Nd4jCommonValidator;
import org.nd4j.common.validation.ValidationResult;
2019-06-06 15:21:15 +03:00
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
public class DL4JModelValidator {
private DL4JModelValidator(){ }
/**
* Validate whether the file represents a valid MultiLayerNetwork saved previously with {@link MultiLayerNetwork#save(File)}
* or {@link ModelSerializer#writeModel(IModel, File, boolean)}, to be read with {@link MultiLayerNetwork#load(File, boolean)}
2019-06-06 15:21:15 +03:00
*
* @param f File that should represent an saved MultiLayerNetwork
* @return Result of validation
*/
public static ValidationResult validateMultiLayerNetwork(@NonNull File f){
List<String> requiredEntries = Arrays.asList(ModelSerializer.CONFIGURATION_JSON, ModelSerializer.COEFFICIENTS_BIN); //TODO no-params models... might be OK to have no params, but basically useless in practice
ValidationResult vr = Nd4jCommonValidator.isValidZipFile(f, false, requiredEntries);
if(vr != null && !vr.isValid()) {
vr.setFormatClass(MultiLayerNetwork.class);
vr.setFormatType("MultiLayerNetwork");
return vr;
}
//Check that configuration (JSON) can actually be deserialized correctly...
String config;
try(ZipFile zf = new ZipFile(f)){
ZipEntry ze = zf.getEntry(ModelSerializer.CONFIGURATION_JSON);
config = IOUtils.toString(new BufferedReader(new InputStreamReader(zf.getInputStream(ze), StandardCharsets.UTF_8)));
} catch (IOException e){
return ValidationResult.builder()
.formatType("MultiLayerNetwork")
.formatClass(MultiLayerNetwork.class)
.valid(false)
.path(Nd4jCommonValidator.getPath(f))
.issues(Collections.singletonList("Unable to read configuration from model zip file"))
.exception(e)
.build();
}
try{
NeuralNetConfiguration.fromJson(config);
2019-06-06 15:21:15 +03:00
} catch (Throwable t){
return ValidationResult.builder()
.formatType("MultiLayerNetwork")
.formatClass(MultiLayerNetwork.class)
.valid(false)
.path(Nd4jCommonValidator.getPath(f))
.issues(Collections.singletonList("Zip file JSON model configuration does not appear to represent a valid NeuralNetConfiguration"))
2019-06-06 15:21:15 +03:00
.exception(t)
.build();
}
//TODO should we check params too?
return ValidationResult.builder()
.formatType("MultiLayerNetwork")
.formatClass(MultiLayerNetwork.class)
.valid(true)
.path(Nd4jCommonValidator.getPath(f))
.build();
}
/**
* Validate whether the file represents a valid ComputationGraph saved previously with {@link ComputationGraph#save(File)}
* or {@link ModelSerializer#writeModel(IModel, File, boolean)}, to be read with {@link ComputationGraph#load(File, boolean)}
2019-06-06 15:21:15 +03:00
*
* @param f File that should represent an saved MultiLayerNetwork
* @return Result of validation
*/
public static ValidationResult validateComputationGraph(@NonNull File f){
List<String> requiredEntries = Arrays.asList(ModelSerializer.CONFIGURATION_JSON, ModelSerializer.COEFFICIENTS_BIN); //TODO no-params models... might be OK to have no params, but basically useless in practice
ValidationResult vr = Nd4jCommonValidator.isValidZipFile(f, false, requiredEntries);
if(vr != null && !vr.isValid()) {
vr.setFormatClass(ComputationGraph.class);
vr.setFormatType("ComputationGraph");
return vr;
}
//Check that configuration (JSON) can actually be deserialized correctly...
String config;
try(ZipFile zf = new ZipFile(f)){
ZipEntry ze = zf.getEntry(ModelSerializer.CONFIGURATION_JSON);
config = IOUtils.toString(new BufferedReader(new InputStreamReader(zf.getInputStream(ze), StandardCharsets.UTF_8)));
} catch (IOException e){
return ValidationResult.builder()
.formatType("ComputationGraph")
.formatClass(ComputationGraph.class)
.valid(false)
.path(Nd4jCommonValidator.getPath(f))
.issues(Collections.singletonList("Unable to read configuration from model zip file"))
.exception(e)
.build();
}
try{
ComputationGraphConfiguration.fromJson(config);
2019-06-06 15:21:15 +03:00
} catch (Throwable t){
return ValidationResult.builder()
.formatType("ComputationGraph")
.formatClass(ComputationGraph.class)
.valid(false)
.path(Nd4jCommonValidator.getPath(f))
.issues(Collections.singletonList("Zip file JSON model configuration does not appear to represent a valid ComputationGraphConfiguration"))
.exception(t)
.build();
}
//TODO should we check params too? (a) that it can be read, and (b) that it matches config (number of parameters, etc)
return ValidationResult.builder()
.formatType("ComputationGraph")
.formatClass(ComputationGraph.class)
.valid(true)
.path(Nd4jCommonValidator.getPath(f))
.build();
}
}