1132 lines
49 KiB
Java
1132 lines
49 KiB
Java
/*
|
|
* ******************************************************************************
|
|
* *
|
|
* *
|
|
* * This program and the accompanying materials are made available under the
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
* *
|
|
* * See the NOTICE file distributed with this work for additional
|
|
* * information regarding copyright ownership.
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* * License for the specific language governing permissions and limitations
|
|
* * under the License.
|
|
* *
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
* *****************************************************************************
|
|
*/
|
|
|
|
package org.deeplearning4j.integration;
|
|
|
|
|
|
import com.google.common.collect.ImmutableSet;
|
|
import com.google.common.reflect.ClassPath;
|
|
import lombok.NonNull;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import org.apache.commons.io.FileUtils;
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
|
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
|
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
|
|
import org.deeplearning4j.nn.api.Model;
|
|
import org.deeplearning4j.nn.conf.BackpropType;
|
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
|
|
import org.deeplearning4j.nn.layers.BaseOutputLayer;
|
|
import org.deeplearning4j.nn.layers.FrozenLayer;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.deeplearning4j.optimize.listeners.CollectScoresListener;
|
|
import org.deeplearning4j.parallelism.ParallelInference;
|
|
import org.deeplearning4j.parallelism.inference.InferenceMode;
|
|
import org.deeplearning4j.util.ModelSerializer;
|
|
import org.nd4j.autodiff.listeners.records.History;
|
|
import org.nd4j.autodiff.samediff.SDVariable;
|
|
import org.nd4j.autodiff.samediff.SameDiff;
|
|
import org.nd4j.autodiff.samediff.VariableType;
|
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
|
import org.nd4j.common.base.Preconditions;
|
|
import org.nd4j.common.primitives.Pair;
|
|
import org.nd4j.common.resources.Resources;
|
|
import org.nd4j.evaluation.IEvaluation;
|
|
import org.nd4j.evaluation.classification.*;
|
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
|
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.api.ops.Op;
|
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.RelativeError;
|
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
|
|
|
import java.io.*;
|
|
import java.lang.reflect.Modifier;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.nio.file.Path;
|
|
import java.util.*;
|
|
import java.util.concurrent.atomic.AtomicInteger;
|
|
|
|
import static org.junit.jupiter.api.Assertions.*;
|
|
|
|
@Slf4j
|
|
public class IntegrationTestRunner {
|
|
|
|
public static final String RANDOM_INIT_UNTRAINED_MODEL_FILENAME = "Model_RANDOM_INIT_UNTRAINED.zip";
|
|
public static final String FLAT_GRADIENTS_FILENAME = "flattenedGradients.bin";
|
|
public static final String TRAINING_CURVE_FILENAME = "trainingCurve.csv";
|
|
public static final String PARAMS_POST_TRAIN_FILENAME = "paramsPostTrain.bin";
|
|
public static final String PARAMS_POST_TRAIN_SAMEDIFF_DIR = "paramsPostTrain";
|
|
public static final String PARAMS_POST_UNSUPERVISED_FILENAME = "paramsPostUnsupervised.bin";
|
|
|
|
public static final double MAX_REL_ERROR_SCORES = 1e-4;
|
|
|
|
private static final List<Class<?>> layerClasses = new ArrayList<>();
|
|
private static final List<Class<?>> preprocClasses = new ArrayList<>();
|
|
private static final List<Class<?>> graphVertexClasses = new ArrayList<>();
|
|
private static final List<Class<?>> evaluationClasses = new ArrayList<>();
|
|
|
|
private static Map<Class<?>, Integer> layerConfClassesSeen = new HashMap<>();
|
|
private static Map<Class<?>, Integer> preprocessorConfClassesSeen = new HashMap<>();
|
|
private static Map<Class<?>, Integer> vertexConfClassesSeen = new HashMap<>();
|
|
private static Map<Class<?>, Integer> evaluationClassesSeen = new HashMap<>();
|
|
|
|
static {
|
|
try {
|
|
setup();
|
|
} catch (Exception e){
|
|
throw new RuntimeException(e);
|
|
}
|
|
}
|
|
|
|
|
|
public static void setup() throws Exception {
|
|
|
|
//First: discover all layers, preprocessors, etc
|
|
|
|
ImmutableSet<ClassPath.ClassInfo> info;
|
|
try {
|
|
//Dependency note: this ClassPath class was added in Guava 14
|
|
info = ClassPath.from(DifferentialFunctionClassHolder.class.getClassLoader())
|
|
.getTopLevelClassesRecursive("org.deeplearning4j");
|
|
} catch (IOException e) {
|
|
//Should never happen
|
|
throw new RuntimeException(e);
|
|
}
|
|
|
|
for (ClassPath.ClassInfo c : info) {
|
|
Class<?> clazz = DL4JClassLoading.loadClassByName(c.getName());
|
|
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface())
|
|
continue;
|
|
|
|
if (isLayerConfig(clazz)) {
|
|
layerClasses.add(clazz);
|
|
} else if (isPreprocessorConfig(clazz)) {
|
|
preprocClasses.add(clazz);
|
|
} else if (isGraphVertexConfig(clazz)) {
|
|
graphVertexClasses.add(clazz);
|
|
} else if (isEvaluationClass(clazz)) {
|
|
evaluationClasses.add(clazz);
|
|
}
|
|
}
|
|
|
|
layerClasses.sort(Comparator.comparing(Class::getName));
|
|
preprocClasses.sort(Comparator.comparing(Class::getName));
|
|
graphVertexClasses.sort(Comparator.comparing(Class::getName));
|
|
|
|
log.info("Found {} layers", layerClasses.size());
|
|
log.info("Found {} preprocessors", preprocClasses.size());
|
|
log.info("Found {} graph vertices", graphVertexClasses.size());
|
|
log.info("Found {} IEvaluation classes", evaluationClasses.size());
|
|
|
|
layerConfClassesSeen = new HashMap<>();
|
|
preprocessorConfClassesSeen = new HashMap<>();
|
|
vertexConfClassesSeen = new HashMap<>();
|
|
evaluationClassesSeen = new HashMap<>();
|
|
}
|
|
|
|
public static void runTest(TestCase tc, Path testDir) throws Exception {
|
|
runTest(tc, testDir.toFile());
|
|
}
|
|
public static void runTest(TestCase tc, File testDir) throws Exception {
|
|
BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
|
|
//This could alternatively be done via maven surefire configuration
|
|
|
|
final ModelType modelType = tc.modelType();
|
|
log.info("Starting test case: {} - type = {}", tc.getTestName(), modelType);
|
|
long start = System.currentTimeMillis();
|
|
|
|
File workingDir = testDir;
|
|
tc.initialize(workingDir);
|
|
|
|
File testBaseDir = testDir;
|
|
// new ClassPathResource("dl4j-integration-tests/" + tc.getTestName()).copyDirectory(testBaseDir);
|
|
Resources.copyDirectory((modelType == ModelType.SAMEDIFF ? "samediff-integration-tests/" : "dl4j-integration-tests/") + tc.getTestName(), testBaseDir);
|
|
|
|
|
|
MultiLayerNetwork mln = null;
|
|
ComputationGraph cg = null;
|
|
SameDiff sd = null;
|
|
Model m = null;
|
|
if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) {
|
|
log.info("Checking RANDOM_INIT test case: saved model vs. initialized model");
|
|
//Checking randomly initialized model:
|
|
File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME);
|
|
Object config = tc.getConfiguration();
|
|
if (config instanceof MultiLayerConfiguration) {
|
|
MultiLayerConfiguration mlc = (MultiLayerConfiguration) config;
|
|
mln = new MultiLayerNetwork(mlc);
|
|
mln.init();
|
|
m = mln;
|
|
|
|
MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
|
|
assertEquals(loaded.getLayerWiseConfigurations(), mln.getLayerWiseConfigurations(), "Configs not equal");
|
|
assertEquals( loaded.params(), mln.params(), "Params not equal");
|
|
assertEquals( loaded.paramTable(), mln.paramTable(), "Param table not equal");
|
|
} else if(config instanceof ComputationGraphConfiguration ){
|
|
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
|
cg = new ComputationGraph(cgc);
|
|
cg.init();
|
|
m = cg;
|
|
|
|
ComputationGraph loaded = ComputationGraph.load(savedModel, true);
|
|
assertEquals(loaded.getConfiguration(), cg.getConfiguration(), "Configs not equal" );
|
|
assertEquals( loaded.params(), cg.params(), "Params not equal");
|
|
assertEquals(loaded.paramTable(), cg.paramTable(), "Param table not equal");
|
|
} else if(config instanceof SameDiff){
|
|
sd = (SameDiff)config;
|
|
SameDiff loaded = SameDiff.load(savedModel, true);
|
|
|
|
assertSameDiffEquals(sd, loaded);
|
|
} else {
|
|
throw new IllegalStateException("Unknown configuration/model type: " + config.getClass());
|
|
}
|
|
} else {
|
|
m = tc.getPretrainedModel();
|
|
if (m instanceof MultiLayerNetwork) {
|
|
mln = (MultiLayerNetwork) m;
|
|
} else if(m instanceof ComputationGraph) {
|
|
cg = (ComputationGraph) m;
|
|
} else if(m instanceof SameDiff){
|
|
sd = (SameDiff)m;
|
|
} else {
|
|
throw new IllegalStateException("Unknown model type: " + m.getClass());
|
|
}
|
|
}
|
|
|
|
//Collect information for test coverage
|
|
if(modelType != ModelType.SAMEDIFF) {
|
|
collectCoverageInformation(m);
|
|
}
|
|
|
|
|
|
//Check network output (predictions)
|
|
if (tc.isTestPredictions()) {
|
|
log.info("Checking predictions: saved output vs. initialized model");
|
|
|
|
|
|
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
|
|
List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
|
|
Preconditions.checkState(modelType == ModelType.SAMEDIFF || inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
|
|
|
|
|
File predictionsTestDir = new File(testBaseDir, "predictions");
|
|
predictionsTestDir.mkdirs();
|
|
|
|
int count = 0;
|
|
if (modelType == ModelType.MLN) {
|
|
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
|
INDArray f = p.getFirst()[0];
|
|
INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]);
|
|
INDArray out = mln.output(f, false, fm, null);
|
|
|
|
//Load the previously saved array
|
|
File outFile = new File(predictionsTestDir, "output_" + (count++) + "_0.bin");
|
|
INDArray outSaved;
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(outFile))) {
|
|
outSaved = Nd4j.read(dis);
|
|
}
|
|
|
|
INDArray predictionExceedsRE = exceedsRelError(outSaved, out, tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
|
int countExceeds = predictionExceedsRE.sumNumber().intValue();
|
|
assertEquals( 0, countExceeds, "Predictions do not match saved predictions - output");
|
|
}
|
|
} else if(modelType == ModelType.CG){
|
|
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
|
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
|
|
|
//Load the previously saved arrays
|
|
INDArray[] outSaved = new INDArray[out.length];
|
|
for (int i = 0; i < out.length; i++) {
|
|
File outFile = new File(predictionsTestDir, "output_" + (count++) + "_" + i + ".bin");
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(outFile))) {
|
|
outSaved[i] = Nd4j.read(dis);
|
|
}
|
|
}
|
|
|
|
for( int i=0; i<outSaved.length; i++ ){
|
|
INDArray predictionExceedsRE = exceedsRelError(outSaved[i], out[i], tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
|
int countExceeds = predictionExceedsRE.sumNumber().intValue();
|
|
assertEquals( 0, countExceeds, "Predictions do not match saved predictions - output " + i);
|
|
}
|
|
}
|
|
} else {
|
|
List<String> outNames = tc.getPredictionsNamesSameDiff();
|
|
for( Map<String,INDArray> ph : inputsSd ){
|
|
Map<String,INDArray> out = sd.output(ph, outNames);
|
|
|
|
//Load the previously saved placeholder arrays
|
|
Map<String,INDArray> outSaved = new HashMap<>();
|
|
for(String s : outNames){
|
|
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f))) {
|
|
outSaved.put(s, Nd4j.read(dis));
|
|
}
|
|
}
|
|
|
|
for(String s : outNames){
|
|
INDArray predictionExceedsRE = exceedsRelError(outSaved.get(s), out.get(s), tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
|
int countExceeds = predictionExceedsRE.sumNumber().intValue();
|
|
assertEquals( 0, countExceeds, "Predictions do not match saved predictions - output \"" + s + "\"");
|
|
}
|
|
}
|
|
}
|
|
|
|
if(modelType != ModelType.SAMEDIFF) {
|
|
checkLayerClearance(m);
|
|
}
|
|
}
|
|
|
|
|
|
//Test gradients
|
|
if (tc.isTestGradients()) {
|
|
log.info("Checking gradients: saved output vs. initialized model");
|
|
|
|
INDArray gradientFlat = null;
|
|
org.deeplearning4j.nn.api.Layer[] layers = null;
|
|
Map<String,INDArray> grad;
|
|
if (modelType == ModelType.MLN) {
|
|
MultiDataSet data = tc.getGradientsTestData();
|
|
mln.setInput(data.getFeatures(0));
|
|
mln.setLabels(data.getLabels(0));
|
|
mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
|
|
mln.computeGradientAndScore();
|
|
gradientFlat = mln.getFlattenedGradients();
|
|
layers = mln.getLayers();
|
|
grad = mln.gradient().gradientForVariable();
|
|
} else if(modelType == ModelType.CG) {
|
|
MultiDataSet data = tc.getGradientsTestData();
|
|
cg.setInputs(data.getFeatures());
|
|
cg.setLabels(data.getLabels());
|
|
cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
|
|
cg.computeGradientAndScore();
|
|
gradientFlat = cg.getFlattenedGradients();
|
|
layers = cg.getLayers();
|
|
grad = cg.gradient().gradientForVariable();
|
|
} else {
|
|
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff();
|
|
List<String> allVars = new ArrayList<>();
|
|
for(SDVariable v : sd.variables()){
|
|
if(v.getVariableType() == VariableType.VARIABLE){
|
|
allVars.add(v.name());
|
|
}
|
|
}
|
|
grad = sd.calculateGradients(ph, allVars);
|
|
}
|
|
|
|
if(modelType != ModelType.SAMEDIFF) {
|
|
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
|
INDArray gradientFlatSaved = read(gFlatFile);
|
|
|
|
INDArray gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
|
|
int count = gradExceedsRE.sumNumber().intValue();
|
|
if (count > 0) {
|
|
logFailedParams(20, "Gradient", layers, gradExceedsRE, gradientFlatSaved, gradientFlat);
|
|
}
|
|
assertEquals( 0, count, "Saved flattened gradients: not equal (using relative error)");
|
|
}
|
|
|
|
//Load the gradient table:
|
|
File gradientDir = new File(testBaseDir, "gradients");
|
|
for (File f : gradientDir.listFiles()) {
|
|
if (!f.isFile()) {
|
|
continue;
|
|
}
|
|
String key = f.getName();
|
|
key = key.substring(0, key.length() - 4); //remove ".bin"
|
|
INDArray loaded = read(f);
|
|
INDArray now = grad.get(key);
|
|
|
|
|
|
INDArray gradExceedsRE = exceedsRelError(loaded, now, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
|
|
int count = gradExceedsRE.sumNumber().intValue();
|
|
assertEquals( 0, count, "Gradients: not equal (using relative error) for parameter: " + key);
|
|
}
|
|
}
|
|
|
|
//Test layerwise pretraining
|
|
if(tc.isTestUnsupervisedTraining()){
|
|
log.info("Performing layerwise pretraining");
|
|
MultiDataSetIterator iter = tc.getUnsupervisedTrainData();
|
|
|
|
INDArray paramsPostTraining;
|
|
org.deeplearning4j.nn.api.Layer[] layers;
|
|
if(modelType == ModelType.MLN){
|
|
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
|
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
|
|
|
for( int i : layersToTrain){
|
|
mln.pretrainLayer(i, dsi);
|
|
}
|
|
paramsPostTraining = mln.params();
|
|
layers = mln.getLayers();
|
|
} else if(modelType == ModelType.CG) {
|
|
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
|
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
|
|
|
for( String i : layersToTrain){
|
|
cg.pretrainLayer(i, iter);
|
|
}
|
|
paramsPostTraining = cg.params();
|
|
layers = cg.getLayers();
|
|
} else {
|
|
throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models");
|
|
}
|
|
|
|
File f = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_UNSUPERVISED_FILENAME);
|
|
INDArray expParams = read(f);
|
|
|
|
INDArray exceedsRelError = exceedsRelError(expParams, paramsPostTraining, tc.getMaxRelativeErrorPretrainParams(),
|
|
tc.getMinAbsErrorPretrainParams());
|
|
int count = exceedsRelError.sumNumber().intValue();
|
|
if(count > 0){
|
|
logFailedParams(20, "Parameter", layers, exceedsRelError, expParams, paramsPostTraining);
|
|
}
|
|
assertEquals( 0, count, "Number of parameters exceeding relative error");
|
|
|
|
//Set params to saved ones - to avoid accumulation of roundoff errors causing later failures...
|
|
m.setParams(expParams);
|
|
}
|
|
|
|
|
|
//Test training curves:
|
|
if (tc.isTestTrainingCurves() || tc.isTestParamsPostTraining()) {
|
|
MultiDataSetIterator trainData = tc.getTrainingData();
|
|
boolean isTbptt;
|
|
int tbpttLength;
|
|
if(modelType == ModelType.MLN){
|
|
isTbptt = mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT;
|
|
tbpttLength = mln.getLayerWiseConfigurations().getTbpttFwdLength();
|
|
} else if(modelType == ModelType.CG) {
|
|
isTbptt = cg.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT;
|
|
tbpttLength = cg.getConfiguration().getTbpttFwdLength();
|
|
} else {
|
|
isTbptt = false;
|
|
tbpttLength = 0;
|
|
}
|
|
|
|
CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength);
|
|
CollectScoresListener l = new CollectScoresListener(1);
|
|
if(modelType != ModelType.SAMEDIFF) {
|
|
m.setListeners(l);
|
|
}
|
|
|
|
int iterBefore;
|
|
int epochBefore;
|
|
int iterAfter;
|
|
int epochAfter;
|
|
|
|
Map<String,INDArray> frozenParamsBefore = modelType != ModelType.SAMEDIFF ? getFrozenLayerParamCopies(m) : getConstantCopies(sd);
|
|
org.deeplearning4j.nn.api.Layer[] layers = null;
|
|
History h = null;
|
|
if (modelType == ModelType.MLN) {
|
|
iterBefore = mln.getIterationCount();
|
|
epochBefore = mln.getEpochCount();
|
|
mln.fit(countingIter);
|
|
iterAfter = mln.getIterationCount();
|
|
epochAfter = mln.getEpochCount();
|
|
layers = mln.getLayers();
|
|
} else if(modelType == ModelType.CG){
|
|
iterBefore = cg.getConfiguration().getIterationCount();
|
|
epochBefore = cg.getConfiguration().getEpochCount();
|
|
cg.fit(countingIter);
|
|
iterAfter = cg.getConfiguration().getIterationCount();
|
|
epochAfter = cg.getConfiguration().getEpochCount();
|
|
layers = cg.getLayers();
|
|
} else {
|
|
iterBefore = sd.getTrainingConfig().getIterationCount();
|
|
epochBefore = sd.getTrainingConfig().getEpochCount();
|
|
h = sd.fit(countingIter, 1);
|
|
iterAfter = sd.getTrainingConfig().getIterationCount();
|
|
epochAfter = sd.getTrainingConfig().getEpochCount();
|
|
}
|
|
|
|
//Check that frozen params (if any) haven't changed during training:
|
|
if(modelType == ModelType.SAMEDIFF) {
|
|
checkConstants(frozenParamsBefore, sd);
|
|
} else {
|
|
checkFrozenParams(frozenParamsBefore, m);
|
|
}
|
|
|
|
//Validate the iteration and epoch counts - both for the net, and for the layers
|
|
int newIters = countingIter.getCurrIter();
|
|
assertEquals(iterBefore + newIters, iterAfter);
|
|
assertEquals(epochBefore + 1, epochAfter);
|
|
if(modelType != ModelType.SAMEDIFF) {
|
|
validateLayerIterCounts(m, epochBefore + 1, iterBefore + newIters);
|
|
}
|
|
|
|
|
|
double[] scores;
|
|
if(modelType == ModelType.SAMEDIFF){
|
|
scores = h.lossCurve().getLossValues().toDoubleVector();
|
|
} else {
|
|
scores = l.getListScore().toDoubleArray();
|
|
}
|
|
|
|
File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME);
|
|
String[] s = FileUtils.readFileToString(f, StandardCharsets.UTF_8).split(",");
|
|
|
|
if(tc.isTestTrainingCurves()) {
|
|
assertEquals( s.length, scores.length, "Different number of scores");
|
|
|
|
boolean pass = true;
|
|
for (int i = 0; i < s.length; i++) {
|
|
double exp = Double.parseDouble(s[i]);
|
|
double re = relError(exp, scores[i]);
|
|
if (re > MAX_REL_ERROR_SCORES) {
|
|
pass = false;
|
|
break;
|
|
}
|
|
}
|
|
if (!pass) {
|
|
fail("Scores differ: expected/saved: " + Arrays.toString(s) + "\nActual: " + Arrays.toString(scores));
|
|
}
|
|
}
|
|
|
|
if (tc.isTestParamsPostTraining()) {
|
|
if(modelType != ModelType.SAMEDIFF) {
|
|
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
|
INDArray paramsExp = read(p);
|
|
INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
|
|
int count = z.sumNumber().intValue();
|
|
if (count > 0) {
|
|
logFailedParams(20, "Parameter", layers, z, paramsExp, m.params());
|
|
}
|
|
assertEquals( 0, count, "Number of params exceeded max relative error");
|
|
} else {
|
|
File dir = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
|
|
for(SDVariable v : sd.variables()){
|
|
if(v.getVariableType() != VariableType.VARIABLE)
|
|
continue;
|
|
INDArray paramNow = v.getArr();
|
|
File paramFile = new File(dir, v.name() + ".bin");
|
|
INDArray exp = read(paramFile);
|
|
INDArray z = exceedsRelError(paramNow, exp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
|
|
int count = z.sumNumber().intValue();
|
|
if (count > 0) {
|
|
logFailedParams(20, "Parameter: " + v.name(), layers, z, exp, paramNow);
|
|
}
|
|
assertEquals( 0, count, "Number of params exceeded max relative error for parameter: \"" + v.name() + "\"");
|
|
}
|
|
}
|
|
}
|
|
|
|
if(modelType != ModelType.SAMEDIFF) {
|
|
checkLayerClearance(m);
|
|
}
|
|
}
|
|
|
|
//Check evaluation:
|
|
if (tc.isTestEvaluation()) {
|
|
log.info("Testing evaluation");
|
|
IEvaluation[] evals = tc.getNewEvaluations();
|
|
MultiDataSetIterator iter = tc.getEvaluationTestData();
|
|
|
|
if (modelType == ModelType.MLN) {
|
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
|
mln.doEvaluation(dsi, evals);
|
|
} else if(modelType == ModelType.CG){
|
|
cg.doEvaluation(iter, evals);
|
|
} else {
|
|
evals = tc.doEvaluationSameDiff(sd, iter, evals);
|
|
}
|
|
|
|
File evalDir = new File(testBaseDir, "evaluation");
|
|
for (int i = 0; i < evals.length; i++) {
|
|
File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json");
|
|
String json = FileUtils.readFileToString(f, StandardCharsets.UTF_8);
|
|
IEvaluation e;
|
|
if (evals[i].getClass() == Evaluation.class) {
|
|
e = Evaluation.fromJson(json);
|
|
} else if (evals[i].getClass() == RegressionEvaluation.class) {
|
|
e = RegressionEvaluation.fromJson(json, RegressionEvaluation.class);
|
|
} else if (evals[i].getClass() == ROC.class) {
|
|
e = ROC.fromJson(json, ROC.class);
|
|
} else if (evals[i].getClass() == ROCBinary.class) {
|
|
e = ROCBinary.fromJson(json, ROCBinary.class);
|
|
} else if (evals[i].getClass() == ROCMultiClass.class) {
|
|
e = ROCMultiClass.fromJson(json, ROCMultiClass.class);
|
|
} else if (evals[i].getClass() == EvaluationCalibration.class) {
|
|
e = EvaluationCalibration.fromJson(json, EvaluationCalibration.class);
|
|
} else {
|
|
throw new RuntimeException("Unknown/not implemented evaluation type: " + evals[i].getClass());
|
|
}
|
|
|
|
|
|
assertEquals( e, evals[i], "Evaluation not equal: " + evals[i].getClass());
|
|
|
|
//Evaluation coverage information:
|
|
evaluationClassesSeen.put(evals[i].getClass(), evaluationClassesSeen.getOrDefault(evals[i].getClass(), 0) + 1);
|
|
|
|
if(modelType != ModelType.SAMEDIFF) {
|
|
checkLayerClearance(m);
|
|
}
|
|
}
|
|
}
|
|
|
|
//Check model serialization
|
|
{
|
|
log.info("Testing model serialization");
|
|
|
|
File f = new File(testDir, UUID.randomUUID().toString());
|
|
f.delete();
|
|
|
|
if (modelType == ModelType.MLN) {
|
|
ModelSerializer.writeModel(m, f, true);
|
|
MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
|
|
assertEquals(mln.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
|
assertEquals(mln.params(), restored.params());
|
|
} else if(modelType == ModelType.CG){
|
|
ModelSerializer.writeModel(m, f, true);
|
|
ComputationGraph restored = ComputationGraph.load(f, true);
|
|
assertEquals(cg.getConfiguration(), restored.getConfiguration());
|
|
assertEquals(cg.params(), restored.params());
|
|
} else {
|
|
sd.save(f, true);
|
|
SameDiff restored = SameDiff.load(f, true);
|
|
assertSameDiffEquals(sd, restored);
|
|
}
|
|
|
|
System.gc();
|
|
}
|
|
|
|
|
|
//Check parallel inference
|
|
if (modelType != ModelType.SAMEDIFF && tc.isTestParallelInference()) {
|
|
|
|
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
|
|
|
|
int numThreads = 2; //TODO allow customization of this?
|
|
|
|
List<INDArray[]> exp = new ArrayList<>();
|
|
for(Pair<INDArray[], INDArray[]> p : inputs){
|
|
INDArray[] out;
|
|
if(modelType == ModelType.MLN){
|
|
INDArray fm = p.getSecond() == null ? null : p.getSecond()[0];
|
|
out = new INDArray[]{mln.output(p.getFirst()[0], false, fm, null)};
|
|
} else {
|
|
out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
|
}
|
|
exp.add(out);
|
|
}
|
|
|
|
ParallelInference inf =
|
|
new ParallelInference.Builder(m)
|
|
.inferenceMode(InferenceMode.BATCHED)
|
|
.batchLimit(3)
|
|
.queueLimit(8)
|
|
.workers(numThreads)
|
|
.build();
|
|
|
|
|
|
testParallelInference(inf, inputs, exp);
|
|
|
|
inf.shutdown();
|
|
inf = null;
|
|
System.gc();
|
|
}
|
|
|
|
|
|
//Test overfitting single example
|
|
if (tc.isTestOverfitting()) {
|
|
log.info("Testing overfitting on single example");
|
|
|
|
MultiDataSet toOverfit = tc.getOverfittingData();
|
|
for (int i = 0; i < tc.getOverfitNumIterations(); i++) {
|
|
if (modelType == ModelType.MLN) {
|
|
mln.fit(toOverfit);
|
|
} else if(modelType == ModelType.CG){
|
|
cg.fit(toOverfit);
|
|
} else {
|
|
sd.fit(toOverfit);
|
|
}
|
|
}
|
|
|
|
//Check:
|
|
INDArray[] output = null;
|
|
Map<String,INDArray> outSd = null;
|
|
if (modelType == ModelType.MLN) {
|
|
mln.setLayerMaskArrays(toOverfit.getFeaturesMaskArray(0), null);
|
|
output = new INDArray[]{mln.output(toOverfit.getFeatures(0))};
|
|
} else if(modelType == ModelType.CG ){
|
|
cg.setLayerMaskArrays(toOverfit.getFeaturesMaskArrays(), null);
|
|
output = cg.output(toOverfit.getFeatures());
|
|
} else {
|
|
List<String> l = sd.getTrainingConfig().getDataSetFeatureMapping();
|
|
Map<String,INDArray> phMap = new HashMap<>();
|
|
int i=0;
|
|
for(String s : l){
|
|
phMap.put(s, toOverfit.getFeatures(i++));
|
|
}
|
|
outSd = sd.output(phMap, tc.getPredictionsNamesSameDiff());
|
|
}
|
|
|
|
int n = modelType == ModelType.SAMEDIFF ? outSd.size() : output.length;
|
|
for (int i = 0; i < n; i++) {
|
|
INDArray out = modelType == ModelType.SAMEDIFF ? outSd.get(tc.getPredictionsNamesSameDiff().get(i)) : output[i];
|
|
INDArray label = toOverfit.getLabels(i);
|
|
|
|
INDArray z = exceedsRelError(out, label, tc.getMaxRelativeErrorOverfit(), tc.getMinAbsErrorOverfit());
|
|
int count = z.sumNumber().intValue();
|
|
if (count > 0) {
|
|
System.out.println(out);
|
|
System.out.println(label);
|
|
INDArray re = relativeError(out, label, tc.getMinAbsErrorOverfit());
|
|
System.out.println("Relative error:");
|
|
System.out.println(re);
|
|
}
|
|
assertEquals(0, count,"Number of outputs exceeded max relative error");
|
|
}
|
|
|
|
if(modelType != ModelType.SAMEDIFF) {
|
|
checkLayerClearance(m);
|
|
}
|
|
}
|
|
|
|
long end = System.currentTimeMillis();
|
|
|
|
|
|
log.info("Completed test case {} in {} sec", tc.getTestName(), (end - start) / 1000L);
|
|
}
|
|
|
|
//Work out which layers, vertices etc we have seen - so we can (at the end of all tests) log our integration test coverage
|
|
private static void collectCoverageInformation(Model m){
|
|
boolean isMLN = (m instanceof MultiLayerNetwork);
|
|
MultiLayerNetwork mln = (isMLN ? (MultiLayerNetwork)m : null);
|
|
ComputationGraph cg = (!isMLN ? (ComputationGraph)m : null);
|
|
|
|
//Collect layer coverage information:
|
|
org.deeplearning4j.nn.api.Layer[] layers;
|
|
if (isMLN) {
|
|
layers = mln.getLayers();
|
|
} else {
|
|
layers = cg.getLayers();
|
|
}
|
|
for (org.deeplearning4j.nn.api.Layer l : layers) {
|
|
Layer lConf = l.conf().getLayer();
|
|
layerConfClassesSeen.put(lConf.getClass(), layerConfClassesSeen.getOrDefault(lConf.getClass(), 0) + 1);
|
|
}
|
|
|
|
//Collect preprocessor coverage information:
|
|
Collection<InputPreProcessor> preProcessors;
|
|
if (isMLN) {
|
|
preProcessors = mln.getLayerWiseConfigurations().getInputPreProcessors().values();
|
|
} else {
|
|
preProcessors = new ArrayList<>();
|
|
for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getConfiguration().getVertices().values()) {
|
|
if (gv instanceof LayerVertex) {
|
|
InputPreProcessor pp = ((LayerVertex) gv).getPreProcessor();
|
|
if (pp != null) {
|
|
preProcessors.add(pp);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
for (InputPreProcessor ipp : preProcessors) {
|
|
preprocessorConfClassesSeen.put(ipp.getClass(), preprocessorConfClassesSeen.getOrDefault(ipp.getClass(), 0) + 1);
|
|
}
|
|
|
|
//Collect vertex coverage information
|
|
if (!isMLN) {
|
|
for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getConfiguration().getVertices().values()) {
|
|
vertexConfClassesSeen.put(gv.getClass(), vertexConfClassesSeen.getOrDefault(gv.getClass(), 0) + 1);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
private static void checkLayerClearance(Model m) {
|
|
//Check that the input fields for all layers have been cleared
|
|
org.deeplearning4j.nn.api.Layer[] layers;
|
|
if (m instanceof MultiLayerNetwork) {
|
|
layers = ((MultiLayerNetwork) m).getLayers();
|
|
} else {
|
|
layers = ((ComputationGraph) m).getLayers();
|
|
}
|
|
|
|
for (org.deeplearning4j.nn.api.Layer l : layers) {
|
|
assertNull(l.input());
|
|
assertNull(l.getMaskArray());
|
|
if (l instanceof BaseOutputLayer) {
|
|
BaseOutputLayer b = (BaseOutputLayer) l;
|
|
assertNull(b.getLabels());
|
|
}
|
|
}
|
|
|
|
|
|
if (m instanceof ComputationGraph) {
|
|
//Also check the vertices:
|
|
GraphVertex[] vertices = ((ComputationGraph) m).getVertices();
|
|
for (GraphVertex v : vertices) {
|
|
int numInputs = v.getNumInputArrays();
|
|
INDArray[] arr = v.getInputs();
|
|
if (arr != null) {
|
|
for (int i = 0; i < numInputs; i++) {
|
|
assertNull(arr[i]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
private static void validateLayerIterCounts(Model m, int expEpoch, int expIter){
|
|
//Check that the iteration and epoch counts - on the layers - are synced
|
|
org.deeplearning4j.nn.api.Layer[] layers;
|
|
if (m instanceof MultiLayerNetwork) {
|
|
layers = ((MultiLayerNetwork) m).getLayers();
|
|
} else {
|
|
layers = ((ComputationGraph) m).getLayers();
|
|
}
|
|
|
|
for(org.deeplearning4j.nn.api.Layer l : layers){
|
|
assertEquals( expEpoch, l.getEpochCount(), "Epoch count");
|
|
assertEquals( expIter, l.getIterationCount(), "Iteration count");
|
|
}
|
|
}
|
|
|
|
|
|
private static Map<String,INDArray> getFrozenLayerParamCopies(Model m){
|
|
Map<String,INDArray> out = new LinkedHashMap<>();
|
|
org.deeplearning4j.nn.api.Layer[] layers;
|
|
if (m instanceof MultiLayerNetwork) {
|
|
layers = ((MultiLayerNetwork) m).getLayers();
|
|
} else {
|
|
layers = ((ComputationGraph) m).getLayers();
|
|
}
|
|
|
|
for(org.deeplearning4j.nn.api.Layer l : layers){
|
|
if(l instanceof FrozenLayer){
|
|
String paramPrefix;
|
|
if(m instanceof MultiLayerNetwork){
|
|
paramPrefix = l.getIndex() + "_";
|
|
} else {
|
|
paramPrefix = l.conf().getLayer().getLayerName() + "_";
|
|
}
|
|
Map<String,INDArray> paramTable = l.paramTable();
|
|
for(Map.Entry<String,INDArray> e : paramTable.entrySet()){
|
|
out.put(paramPrefix + e.getKey(), e.getValue().dup());
|
|
}
|
|
}
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
private static Map<String,INDArray> getConstantCopies(SameDiff sd){
|
|
Map<String,INDArray> out = new HashMap<>();
|
|
for(SDVariable v : sd.variables()){
|
|
if(v.isConstant()){
|
|
out.put(v.name(), v.getArr());
|
|
}
|
|
}
|
|
return out;
|
|
}
|
|
|
|
public static void checkFrozenParams(Map<String,INDArray> copiesBeforeTraining, Model m){
|
|
for(Map.Entry<String,INDArray> e : copiesBeforeTraining.entrySet()){
|
|
INDArray actual = m.getParam(e.getKey());
|
|
assertEquals(e.getValue(), actual, e.getKey());
|
|
}
|
|
}
|
|
|
|
public static void checkConstants(Map<String,INDArray> copiesBefore, SameDiff sd){
|
|
for(Map.Entry<String,INDArray> e : copiesBefore.entrySet()){
|
|
INDArray actual = sd.getArrForVarName(e.getKey());
|
|
assertEquals(e.getValue(), actual, e.getKey());
|
|
}
|
|
}
|
|
|
|
public static void printCoverageInformation(){
|
|
|
|
log.info("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||");
|
|
|
|
log.info("Layer coverage - classes seen:");
|
|
for (Class<?> c : layerClasses) {
|
|
if (layerConfClassesSeen.containsKey(c)) {
|
|
log.info("Class seen {} times in tests: {}", layerConfClassesSeen.get(c), c.getName());
|
|
}
|
|
}
|
|
|
|
log.info("Layer classes NOT seen in any tests:");
|
|
for (Class<?> c : layerClasses) {
|
|
if (!layerConfClassesSeen.containsKey(c)) {
|
|
log.info("Class NOT seen in any tests: {}", c.getName());
|
|
}
|
|
}
|
|
|
|
log.info("----------------------------------------------------------------------------------------------------");
|
|
|
|
log.info("GraphVertex coverage - classes seen:");
|
|
for (Class<?> c : graphVertexClasses) {
|
|
if (vertexConfClassesSeen.containsKey(c)) {
|
|
log.info("Preprocessor seen {} times in tests: {}", preprocessorConfClassesSeen.get(c), c.getName());
|
|
}
|
|
}
|
|
|
|
log.info("GraphVertexcoverage - classes NOT seen:");
|
|
for (Class<?> c : graphVertexClasses) {
|
|
if (!vertexConfClassesSeen.containsKey(c)) {
|
|
log.info("Preprocessor NOT seen in any tests: {}", c.getName());
|
|
}
|
|
}
|
|
|
|
log.info("----------------------------------------------------------------------------------------------------");
|
|
|
|
log.info("Preprocessor coverage - classes seen:");
|
|
for (Class<?> c : preprocClasses) {
|
|
if (preprocessorConfClassesSeen.containsKey(c)) {
|
|
log.info("Preprocessor seen {} times in tests: {}", preprocessorConfClassesSeen.get(c), c.getName());
|
|
}
|
|
}
|
|
|
|
log.info("Preprocessor coverage - classes NOT seen:");
|
|
for (Class<?> c : preprocClasses) {
|
|
if (!preprocessorConfClassesSeen.containsKey(c)) {
|
|
log.info("Preprocessor NOT seen in any tests: {}", c.getName());
|
|
}
|
|
}
|
|
|
|
log.info("----------------------------------------------------------------------------------------------------");
|
|
|
|
|
|
log.info("Evaluation coverage - classes seen:");
|
|
for (Class<?> c : evaluationClasses) {
|
|
if (evaluationClassesSeen.containsKey(c)) {
|
|
log.info("Evaluation class seen {} times in tests: {}", evaluationClassesSeen.get(c), c.getName());
|
|
}
|
|
}
|
|
|
|
log.info("Evaluation coverage - classes NOT seen:");
|
|
for (Class<?> c : evaluationClasses) {
|
|
if (!evaluationClassesSeen.containsKey(c)) {
|
|
log.info("Evaluation class NOT seen in any tests: {}", c.getName());
|
|
}
|
|
}
|
|
|
|
log.info("----------------------------------------------------------------------------------------------------");
|
|
}
|
|
|
|
private static boolean isLayerConfig(Class<?> c) {
|
|
return Layer.class.isAssignableFrom(c);
|
|
}
|
|
|
|
private static boolean isPreprocessorConfig(Class<?> c) {
|
|
return InputPreProcessor.class.isAssignableFrom(c);
|
|
}
|
|
|
|
private static boolean isGraphVertexConfig(Class<?> c) {
|
|
return GraphVertex.class.isAssignableFrom(c);
|
|
}
|
|
|
|
private static boolean isEvaluationClass(Class<?> c) {
|
|
return IEvaluation.class.isAssignableFrom(c);
|
|
}
|
|
|
|
private static INDArray read(File f) {
|
|
try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(f)))) {
|
|
return Nd4j.read(dis);
|
|
} catch (IOException e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
}
|
|
|
|
public static void write(INDArray arr, File f) {
|
|
try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(f)))) {
|
|
Nd4j.write(arr, dos);
|
|
} catch (IOException e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
}
|
|
|
|
private static double relError(double d1, double d2) {
|
|
Preconditions.checkState(!Double.isNaN(d1), "d1 is NaN");
|
|
Preconditions.checkState(!Double.isNaN(d2), "d2 is NaN");
|
|
if (d1 == 0.0 && d2 == 0.0) {
|
|
return 0.0;
|
|
}
|
|
|
|
return Math.abs(d1 - d2) / (Math.abs(d1) + Math.abs(d2));
|
|
}
|
|
|
|
private static INDArray exceedsRelError(INDArray first, INDArray second, double maxRel, double minAbs) {
|
|
// INDArray z = Nd4j.createUninitialized(first.shape());
|
|
// Op op = new BinaryMinimalRelativeError(first, second, z, maxRel, minAbs);
|
|
// Nd4j.getExecutioner().exec(op);
|
|
// return z;
|
|
INDArray z = relativeError(first, second, minAbs);
|
|
BooleanIndexing.replaceWhere(z, 0.0, Conditions.lessThan(maxRel));
|
|
BooleanIndexing.replaceWhere(z, 1.0, Conditions.greaterThan(0.0));
|
|
return z;
|
|
}
|
|
|
|
private static INDArray relativeError(INDArray first, INDArray second) {
|
|
INDArray z = Nd4j.createUninitialized(first.shape());
|
|
Op op = new RelativeError(first, second, z);
|
|
Nd4j.getExecutioner().exec(op);
|
|
return z;
|
|
}
|
|
|
|
private static INDArray relativeError(@NonNull INDArray a1, @NonNull INDArray a2, double minAbsError) {
|
|
long numNaN1 = Nd4j.getExecutioner().exec(new MatchCondition(a1, Conditions.isNan(), Integer.MAX_VALUE)).getInt(0);
|
|
long numNaN2 = Nd4j.getExecutioner().exec(new MatchCondition(a2, Conditions.isNan(), Integer.MAX_VALUE)).getInt(0);
|
|
Preconditions.checkState(numNaN1 == 0, "Array 1 has NaNs");
|
|
Preconditions.checkState(numNaN2 == 0, "Array 2 has NaNs");
|
|
|
|
|
|
// INDArray isZero1 = a1.eq(0.0);
|
|
// INDArray isZero2 = a2.eq(0.0);
|
|
// INDArray bothZero = isZero1.muli(isZero2);
|
|
|
|
INDArray abs1 = Transforms.abs(a1, true);
|
|
INDArray abs2 = Transforms.abs(a2, true);
|
|
INDArray absDiff = Transforms.abs(a1.sub(a2), false);
|
|
|
|
//abs(a1-a2) < minAbsError ? 1 : 0
|
|
INDArray greaterThanMinAbs = Transforms.abs(a1.sub(a2), false);
|
|
BooleanIndexing.replaceWhere(greaterThanMinAbs, 0.0, Conditions.lessThan(minAbsError));
|
|
BooleanIndexing.replaceWhere(greaterThanMinAbs, 1.0, Conditions.greaterThan(0.0));
|
|
|
|
INDArray result = absDiff.divi(abs1.add(abs2));
|
|
//Only way to have NaNs given there weren't any in original : both 0s
|
|
BooleanIndexing.replaceWhere(result, 0.0, Conditions.isNan());
|
|
//Finally, set to 0 if less than min abs error, or unchanged otherwise
|
|
result.muli(greaterThanMinAbs);
|
|
|
|
// double maxRE = result.maxNumber().doubleValue();
|
|
// if(maxRE > MAX_REL_ERROR){
|
|
// System.out.println();
|
|
// }
|
|
return result;
|
|
}
|
|
|
|
public static void testParallelInference(@NonNull ParallelInference inf, List<Pair<INDArray[],INDArray[]>> in, List<INDArray[]> exp) throws Exception {
|
|
final INDArray[][] act = new INDArray[in.size()][0];
|
|
final AtomicInteger counter = new AtomicInteger(0);
|
|
final AtomicInteger failedCount = new AtomicInteger(0);
|
|
|
|
for( int i=0; i<in.size(); i++ ){
|
|
final int j=i;
|
|
new Thread(new Runnable() {
|
|
@Override
|
|
public void run() {
|
|
try{
|
|
INDArray[] inMask = in.get(j).getSecond();
|
|
act[j] = inf.output(in.get(j).getFirst(), inMask);
|
|
counter.incrementAndGet();
|
|
} catch (Exception e){
|
|
log.error("",e);
|
|
failedCount.incrementAndGet();
|
|
}
|
|
}
|
|
}).start();
|
|
}
|
|
|
|
long start = System.currentTimeMillis();
|
|
long current = System.currentTimeMillis();
|
|
while(current < start + 20000 && failedCount.get() == 0 && counter.get() < in.size()){
|
|
Thread.sleep(1000L);
|
|
}
|
|
|
|
assertEquals(0, failedCount.get());
|
|
assertEquals(in.size(), counter.get());
|
|
for( int i=0; i<in.size(); i++ ){
|
|
INDArray[] e = exp.get(i);
|
|
INDArray[] a = act[i];
|
|
|
|
assertArrayEquals(e, a);
|
|
}
|
|
}
|
|
|
|
|
|
public static void logFailedParams(int maxNumToPrintOnFailure, String prefix, org.deeplearning4j.nn.api.Layer[] layers, INDArray exceedsRelError, INDArray exp, INDArray act){
|
|
long length = exceedsRelError.length();
|
|
int logCount = 0;
|
|
for(int i=0; i<length; i++ ){
|
|
if(exceedsRelError.getDouble(i) > 0){
|
|
double dExp = exp.getDouble(i);
|
|
double dAct = act.getDouble(i);
|
|
double re = relError(dExp, dAct);
|
|
double ae = Math.abs(dExp - dAct);
|
|
|
|
//Work out parameter key:
|
|
long pSoFar = 0;
|
|
String pName = null;
|
|
for(org.deeplearning4j.nn.api.Layer l : layers){
|
|
long n = l.numParams();
|
|
if(pSoFar + n < i){
|
|
pSoFar += n;
|
|
} else {
|
|
for(Map.Entry<String,INDArray> e : l.paramTable().entrySet()){
|
|
pSoFar += e.getValue().length();
|
|
if(pSoFar >= i){
|
|
pName = e.getKey();
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
log.info("{} {} ({}) failed: expected {} vs actual {} (RelativeError: {}, AbsError: {})", i, prefix, pName, dExp, dAct, re, ae);
|
|
if(++logCount >= maxNumToPrintOnFailure){
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
public static void assertSameDiffEquals(SameDiff sd1, SameDiff sd2){
|
|
assertEquals(sd1.variableMap().keySet(), sd2.variableMap().keySet());
|
|
assertEquals(sd1.getOps().keySet(), sd2.getOps().keySet());
|
|
assertEquals(sd1.inputs(), sd2.inputs());
|
|
|
|
//Check constant and variable arrays:
|
|
for(SDVariable v : sd1.variables()){
|
|
String n = v.name();
|
|
assertEquals(v.getVariableType(), sd2.getVariable(n).getVariableType(), n);
|
|
if(v.isConstant() || v.getVariableType() == VariableType.VARIABLE){
|
|
INDArray a1 = v.getArr();
|
|
INDArray a2 = sd2.getVariable(n).getArr();
|
|
assertEquals(a1, a2, n);
|
|
}
|
|
}
|
|
|
|
//Check ops:
|
|
for(SameDiffOp o : sd1.getOps().values()){
|
|
SameDiffOp o2 = sd2.getOps().get(o.getName());
|
|
assertEquals(o.getOp().getClass(), o2.getOp().getClass());
|
|
}
|
|
}
|
|
}
|