2021-02-01 14:31:20 +09:00
/ *
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License , Version 2 . 0 which is available at
* * https : //www.apache.org/licenses/LICENSE-2.0.
* *
2021-02-01 17:47:29 +09:00
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership .
2021-02-01 14:31:20 +09:00
* * Unless required by applicable law or agreed to in writing , software
* * distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* * License for the specific language governing permissions and limitations
* * under the License .
* *
* * SPDX - License - Identifier : Apache - 2 . 0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
* /
2019-06-06 15:21:15 +03:00
package org.deeplearning4j.integration ;
import lombok.Data ;
import org.deeplearning4j.nn.api.Model ;
2020-03-07 22:44:41 +11:00
import org.nd4j.autodiff.samediff.SameDiff ;
import org.nd4j.evaluation.IEvaluation ;
2019-06-06 15:21:15 +03:00
import org.nd4j.linalg.api.ndarray.INDArray ;
import org.nd4j.linalg.dataset.api.MultiDataSet ;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator ;
2020-04-29 11:19:26 +10:00
import org.nd4j.common.primitives.Pair ;
2019-06-06 15:21:15 +03:00
import java.io.File ;
import java.util.List ;
2020-03-07 22:44:41 +11:00
import java.util.Map ;
2019-06-06 15:21:15 +03:00
@Data
public abstract class TestCase {
public enum TestType {
PRETRAINED , RANDOM_INIT
}
2020-03-07 22:44:41 +11:00
//See: readme.md for more details
protected String testName ; //Name of the test, for display purposes
protected TestType testType ; //Type of model - from a pretrained model, or a randomly initialized model
protected boolean testPredictions = true ; //If true: check the predictions/output. Requires getPredictionsTestData() to be implemented
protected boolean testGradients = true ; //If true: check the gradients. Requires getGradientsTestData() to be implemented
protected boolean testUnsupervisedTraining = false ; //If true: perform unsupervised training. Only applies to layers like autoencoders, VAEs, etc. Requires getUnsupervisedTrainData() to be implemented
protected boolean testTrainingCurves = true ; //If true: perform training, and compare loss vs. iteration. Requires getTrainingData() method
protected boolean testParamsPostTraining = true ; //If true: perform training, and compare parameters after training. Requires getTrainingData() method
protected boolean testEvaluation = true ; //If true: perform evaluation. Requires getNewEvaluations() and getEvaluationTestData() methods implemented
protected boolean testParallelInference = true ; //If true: run the model through ParallelInference. Requires getPredictionsTestData() method. Only applies to DL4J models, NOT SameDiff models
protected boolean testOverfitting = true ; //If true: perform overfitting, and ensure the predictions match the training data. Requires both getOverfittingData() and getOverfitNumIterations()
2019-06-06 15:21:15 +03:00
protected int [ ] unsupervisedTrainLayersMLN = null ;
protected String [ ] unsupervisedTrainLayersCG = null ;
//Relative errors for this test case:
protected double maxRelativeErrorOutput = 1e - 4 ;
protected double minAbsErrorOutput = 1e - 4 ;
protected double maxRelativeErrorGradients = 1e - 4 ;
protected double minAbsErrorGradients = 1e - 4 ;
protected double maxRelativeErrorPretrainParams = 1e - 5 ;
protected double minAbsErrorPretrainParams = 1e - 5 ;
protected double maxRelativeErrorScores = 1e - 6 ;
protected double minAbsErrorScores = 1e - 5 ;
protected double maxRelativeErrorParamsPostTraining = 1e - 4 ;
protected double minAbsErrorParamsPostTraining = 1e - 4 ;
protected double maxRelativeErrorOverfit = 1e - 2 ;
protected double minAbsErrorOverfit = 1e - 2 ;
2020-03-07 22:44:41 +11:00
public abstract ModelType modelType ( ) ;
2019-06-06 15:21:15 +03:00
/ * *
* Initialize the test case . . . many tests don ' t need this ; others may use it to download or create data
* @param testWorkingDir Working directory to use for test
* /
public void initialize ( File testWorkingDir ) throws Exception {
//No op by default
}
/ * *
* Required if NOT a pretrained model ( testType = = TestType . RANDOM_INIT )
* /
public Object getConfiguration ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
/ * *
* Required for pretrained models ( testType = = TestType . PRETRAINED )
* /
public Model getPretrainedModel ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
/ * *
2020-03-07 22:44:41 +11:00
* Required if testPredictions = = true & & DL4J model ( MultiLayerNetwork or ComputationGraph )
2019-06-06 15:21:15 +03:00
* /
public List < Pair < INDArray [ ] , INDArray [ ] > > getPredictionsTestData ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
/ * *
2020-03-07 22:44:41 +11:00
* Required if testPredictions = = true & & SameDiff model
* /
public List < Map < String , INDArray > > getPredictionsTestDataSameDiff ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
public List < String > getPredictionsNamesSameDiff ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
/ * *
* Required if testGradients = = true & & DL4J model
2019-06-06 15:21:15 +03:00
* /
public MultiDataSet getGradientsTestData ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
2020-03-07 22:44:41 +11:00
/ * *
* Required if testGradients = = true & & SameDiff model
* /
public Map < String , INDArray > getGradientsTestDataSameDiff ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
2019-06-06 15:21:15 +03:00
/ * *
* Required when testUnsupervisedTraining = = true
* /
public MultiDataSetIterator getUnsupervisedTrainData ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
/ * *
* @return Training data - DataSetIterator or MultiDataSetIterator
* /
public MultiDataSetIterator getTrainingData ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
/ * *
* Required if testEvaluation = = true
* /
public IEvaluation [ ] getNewEvaluations ( ) {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
2020-03-07 22:44:41 +11:00
public IEvaluation [ ] doEvaluationSameDiff ( SameDiff sd , MultiDataSetIterator iter , IEvaluation [ ] evaluations ) {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
2019-06-06 15:21:15 +03:00
/ * *
* Required if testEvaluation = = true
* /
public MultiDataSetIterator getEvaluationTestData ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
/ * *
2020-03-07 22:44:41 +11:00
* Required if testOverfitting = = true & & DL4J model
2019-06-06 15:21:15 +03:00
* /
public MultiDataSet getOverfittingData ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
2020-03-07 22:44:41 +11:00
/ * *
* Required if testOverfitting = = true & & SameDiff model
* /
public Map < String , INDArray > getOverfittingDataSameDiff ( ) throws Exception {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
2019-06-06 15:21:15 +03:00
/ * *
* Required if testOverfitting = = true
* /
public int getOverfitNumIterations ( ) {
throw new RuntimeException ( " Implementations must override this method if used " ) ;
}
}