cavis/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestCase.java

183 lines
7.6 KiB
Java
Raw Normal View History

2021-02-01 06:31:20 +01:00
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
2021-02-01 09:47:29 +01:00
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
2021-02-01 06:31:20 +01:00
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
2019-06-06 14:21:15 +02:00
package org.deeplearning4j.integration;
import lombok.Data;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.evaluation.IEvaluation;
2019-06-06 14:21:15 +02:00
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
Refactor packages to fix split package issues (#411) * Refactor nd4j-common: org.nd4j.* -> org.nd4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * Fix CUDA (missed nd4j-common package refactoring changes) Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-kryo: org.nd4j -> org.nd4j.kryo Signed-off-by: Alex Black <blacka101@gmail.com> * Fix nd4j-common for deeplearning4j-cuda Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-grppc-client: org.nd4j.graph -> org.nd4j.remote.grpc Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-common: org.deeplearning4.* -> org.deeplearning4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-core: org.deeplearning4j.* -> org.deeplearning.core.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-cuda: org.deeplearning4j.nn.layers.* -> org.deeplearning4j.cuda.* Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-nlp-*: org.deeplearning4.text.* -> org.deeplearning4j.nlp.(language).* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-ui-model: org.deeplearning4j.ui -> org.deeplearning4j.ui.model Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-spark-inference-{server/model/client}: org.datavec.spark.transform -> org.datavec.spark.inference.{server/model/client} Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-jdbc: org.datavec.api -> org.datavec.jdbc Signed-off-by: Alex Black <blacka101@gmail.com> * Delete org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter in favor of (essentially identical) org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter Signed-off-by: Alex Black <blacka101@gmail.com> * ND4S fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-common-tests: org.nd4j.* -> org.nd4j.common.tests Signed-off-by: Alex Black <blacka101@gmail.com> * Trigger CI Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * #8878 Ignore CUDA tests on modules with 'nd4j-native under cuda' issue Signed-off-by: Alex Black <blacka101@gmail.com> * Fix bad imports in tests Signed-off-by: Alex Black <blacka101@gmail.com> * Add ignore on test (already failing) due to #8882 Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Additional import fixes Signed-off-by: Alex Black <blacka101@gmail.com>
2020-04-29 03:19:26 +02:00
import org.nd4j.common.primitives.Pair;
2019-06-06 14:21:15 +02:00
import java.io.File;
import java.util.List;
import java.util.Map;
2019-06-06 14:21:15 +02:00
@Data
public abstract class TestCase {
public enum TestType {
PRETRAINED, RANDOM_INIT
}
//See: readme.md for more details
protected String testName; //Name of the test, for display purposes
protected TestType testType; //Type of model - from a pretrained model, or a randomly initialized model
protected boolean testPredictions = true; //If true: check the predictions/output. Requires getPredictionsTestData() to be implemented
protected boolean testGradients = true; //If true: check the gradients. Requires getGradientsTestData() to be implemented
protected boolean testUnsupervisedTraining = false; //If true: perform unsupervised training. Only applies to layers like autoencoders, VAEs, etc. Requires getUnsupervisedTrainData() to be implemented
protected boolean testTrainingCurves = true; //If true: perform training, and compare loss vs. iteration. Requires getTrainingData() method
protected boolean testParamsPostTraining = true; //If true: perform training, and compare parameters after training. Requires getTrainingData() method
protected boolean testEvaluation = true; //If true: perform evaluation. Requires getNewEvaluations() and getEvaluationTestData() methods implemented
protected boolean testParallelInference = true; //If true: run the model through ParallelInference. Requires getPredictionsTestData() method. Only applies to DL4J models, NOT SameDiff models
protected boolean testOverfitting = true; //If true: perform overfitting, and ensure the predictions match the training data. Requires both getOverfittingData() and getOverfitNumIterations()
2019-06-06 14:21:15 +02:00
protected int[] unsupervisedTrainLayersMLN = null;
protected String[] unsupervisedTrainLayersCG = null;
//Relative errors for this test case:
protected double maxRelativeErrorOutput = 1e-4;
protected double minAbsErrorOutput = 1e-4;
protected double maxRelativeErrorGradients = 1e-4;
protected double minAbsErrorGradients = 1e-4;
protected double maxRelativeErrorPretrainParams = 1e-5;
protected double minAbsErrorPretrainParams = 1e-5;
protected double maxRelativeErrorScores = 1e-6;
protected double minAbsErrorScores = 1e-5;
protected double maxRelativeErrorParamsPostTraining = 1e-4;
protected double minAbsErrorParamsPostTraining = 1e-4;
protected double maxRelativeErrorOverfit = 1e-2;
protected double minAbsErrorOverfit = 1e-2;
public abstract ModelType modelType();
2019-06-06 14:21:15 +02:00
/**
* Initialize the test case... many tests don't need this; others may use it to download or create data
* @param testWorkingDir Working directory to use for test
*/
public void initialize(File testWorkingDir) throws Exception {
//No op by default
}
/**
* Required if NOT a pretrained model (testType == TestType.RANDOM_INIT)
*/
public Object getConfiguration() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required for pretrained models (testType == TestType.PRETRAINED)
*/
public Model getPretrainedModel() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testPredictions == true && DL4J model (MultiLayerNetwork or ComputationGraph)
2019-06-06 14:21:15 +02:00
*/
public List<Pair<INDArray[],INDArray[]>> getPredictionsTestData() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testPredictions == true && SameDiff model
*/
public List<Map<String,INDArray>> getPredictionsTestDataSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
public List<String> getPredictionsNamesSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testGradients == true && DL4J model
2019-06-06 14:21:15 +02:00
*/
public MultiDataSet getGradientsTestData() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testGradients == true && SameDiff model
*/
public Map<String,INDArray> getGradientsTestDataSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
2019-06-06 14:21:15 +02:00
/**
* Required when testUnsupervisedTraining == true
*/
public MultiDataSetIterator getUnsupervisedTrainData() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* @return Training data - DataSetIterator or MultiDataSetIterator
*/
public MultiDataSetIterator getTrainingData() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testEvaluation == true
*/
public IEvaluation[] getNewEvaluations() {
throw new RuntimeException("Implementations must override this method if used");
}
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations){
throw new RuntimeException("Implementations must override this method if used");
}
2019-06-06 14:21:15 +02:00
/**
* Required if testEvaluation == true
*/
public MultiDataSetIterator getEvaluationTestData() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testOverfitting == true && DL4J model
2019-06-06 14:21:15 +02:00
*/
public MultiDataSet getOverfittingData() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testOverfitting == true && SameDiff model
*/
public Map<String,INDArray> getOverfittingDataSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
2019-06-06 14:21:15 +02:00
/**
* Required if testOverfitting == true
*/
public int getOverfitNumIterations() {
throw new RuntimeException("Implementations must override this method if used");
}
}