DL4J integrations tests updates + add SameDiff support (#298)

* Revive and start updating DL4J integration tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Add SameDiff support - first pass

Signed-off-by: Alex Black <blacka101@gmail.com>

* SameDiff test case generation

Signed-off-by: Alex Black <blacka101@gmail.com>

* SameDiff integration tests polishing

Signed-off-by: Alex Black <blacka101@gmail.com>

* More SameDiff integration test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Final polish

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small test tweak

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-07 22:44:41 +11:00 committed by GitHub
parent ead5162c97
commit a80fb99a5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 909 additions and 378 deletions

View File

@ -89,12 +89,12 @@ public abstract class BaseDL4JTest {
return getDataType();
}
protected Boolean integrationTest;
protected static Boolean integrationTest;
/**
* @return True if integration tests maven profile is enabled, false otherwise.
*/
public boolean isIntegrationTests(){
public static boolean isIntegrationTests(){
if(integrationTest == null){
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
integrationTest = Boolean.parseBoolean(prop);
@ -107,7 +107,7 @@ public abstract class BaseDL4JTest {
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
* Note that the integration test profile is not enabled by default - "integration-tests" profile
*/
public void skipUnlessIntegrationTests(){
public static void skipUnlessIntegrationTests(){
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
}

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.optimize.listeners;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
@ -32,6 +33,7 @@ import java.io.Serializable;
* @author Alex Black
*/
@Data
@EqualsAndHashCode(callSuper = true)
@Slf4j
public class CollectScoresListener extends BaseTrainingListener implements Serializable {

View File

@ -1,16 +1,15 @@
#DL4J Integration Tests
#DL4J and SameDiff Integration Tests
These tests are designed to check a number of aspects of DL4J:
1. Predictions
These tests are designed to check a number of aspects of DL4J and SameDiff:
1. Predictions (i.e., network output)
2. Training (training curves, parameters, gradient calculation)
3. Evaluation
4. Model serialization
5. Overfitting sanity checks
3. Evaluation (accuracy, etc)
4. Model serialization (saving + loading models)
5. Overfitting sanity checks (make sure we can overfit a single example)
6. Data pipelines
7. Evaluation classes
8. Parallel Wrapper
9. Validating conditions that should always hold (frozen layer params don't change, for example)
7. Parallel Wrapper
8. Validating conditions that should always hold (frozen layer params don't change, for example)
They are designed for the following purposes:
@ -19,32 +18,46 @@ They are designed for the following purposes:
3. Detecting significant differences between CPU and CUDA backends
4. Validating implementation via sanity checks on training - i.e., can we overfit a single example?
5. Checking networks and data pipelines on real-world scale data and nets
6. Operating as fully automated pre-release checks (replacing previously used manual checks)
6. Operating as fully automated pre-release checks (replacing manual sanity checks)
## Types of Tests
## Main Classes
The integration tests are set up to be able to run multiple tests on each network configuration.
Explanation of the main classes:
* **IntegrationTestBaselineGenerator**: Run *manually* to generate and save "expected results" for comparing in the future.
Output goes to dl4j-test-resources, for saving/uploading.
* **IntegrationTestRunner**: Actually runs the tests, and compares the output/result to those generated by the baseline generator
* **TestCase**: integration tests extend this
* **testcases/\*.java**: the actual integration test definitions
* **IntegrationTestsDL4J**: entry point for running the DL4J integration tests
* **IntegrationTestsSameDiff**: entry point for running the SameDiff integration tests
## Types of Test Components
The integration tests are set up to be able to run multiple types of tests on each network configuration.
Networks may be pretrained (from model zoo) or randomly initialized (from specified configuration).
Specifically, test cases can be run with any subset of the following components to be tested, by setting TestCase.XYZ boolean options to true or false:
1. testPredictions: Testing output (predictions) on some specified data vs. saved/known good arrays
2. testGradients: Testing gradients on some specified data vs. saved/known good arrays
3. testPretrain: Test layerwise pretraining parameters and training curves
4. testTrainingCurves: Train, and check score vs. iteration
5. testParamsPostTraining: validate params match post training
6. testEvaluation: test the evaluation performance (post training, if 4 or 5 are true)
7. testParallelInference: validate that single net and parallel inference results match
8. testOverfitting: sanity check - try to overfit a single example
1. **testPredictions**: Testing output (predictions) on some specified data vs. saved/known good arrays
2. **testGradients**: Testing gradients on some specified data vs. saved/known good arrays
3. **testPretrain**: Test layerwise pretraining parameters and training curves
4. **testTrainingCurves**: Train, and check score vs. iteration
5. **testParamsPostTraining**: validate params match post training
6. **testEvaluation**: test the evaluation performance (post training, if 4 or 5 are true)
7. **testParallelInference**: validate that single net and parallel inference results match
8. **testOverfitting**: sanity check - try to overfit a single example
See TestCase.java for more details.
## Adding a New Integration Test
The process to add a new test is simple:
1. Add a method that creates and returns a TestCase object
2. Add it as a unit test to IntegrationTests class
3. Run IntegrationTestBaselineGenerator (if required) to generate and save the "known good" results.
1. Add a method that creates and returns a TestCase object (example: testcases/MLPTestCases.getMLPMnist())
2. Add it as a unit test to IntegrationTests class (example: IntegrationTestsDL4J.testMLPMnist())
3. Run IntegrationTestBaselineGenerator with the new test case, to generate and save the "known good" results.
4. Run the new integration test to make sure it passes, on both CPU and CUDA backends
5. Commit the generated test resources from step 3 to dl4j-test-resources repo
Note that IntegrationTestBaselineGenerator assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo.

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,15 +17,10 @@
package org.deeplearning4j.integration;
import org.nd4j.shade.guava.io.Files;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.integration.testcases.CNN2DTestCases;
import org.deeplearning4j.integration.testcases.MLPTestCases;
import org.deeplearning4j.integration.testcases.RNNTestCases;
import org.deeplearning4j.integration.testcases.UnsupervisedTestCases;
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -32,20 +28,27 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.CollectScoresListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.guava.io.Files;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
/**
* Run this manually to generate - or update - the saved files for a specific test.
* Places results in dl4j-test-resources: assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo.
@ -53,32 +56,31 @@ import java.util.stream.Collectors;
@Slf4j
public class IntegrationTestBaselineGenerator {
public static final File OUTPUT_DIR = new File("../../dl4j-test-resources/src/main/resources/dl4j-integration-tests").getAbsoluteFile();
public static final File OUTPUT_DIR_DL4J = new File("../../dl4j-test-resources/src/main/resources/dl4j-integration-tests").getAbsoluteFile();
public static final File OUTPUT_DIR_SAMEDIFF = new File("../../dl4j-test-resources/src/main/resources/samediff-integration-tests").getAbsoluteFile();
public static void main(String[] args) throws Exception {
if (!OUTPUT_DIR.exists()) {
throw new RuntimeException("output directory (test resources) does not exist!");
if (!OUTPUT_DIR_DL4J.exists() && !OUTPUT_DIR_SAMEDIFF.exists()) {
throw new RuntimeException("output directories in test resources do not exist!");
}
//All integration tests are run with float precision!
Nd4j.setDataType(DataType.FLOAT);
// runGeneration(
// MLPTestCases.getMLPMnist(),
// );
runGeneration(
SameDiffMLPTestCases.getMLPMnist()
);
}
private static void runGeneration(TestCase... testCases) throws Exception {
for( TestCase tc : testCases ) {
final ModelType modelType = tc.modelType();
//Basic validation:
Preconditions.checkState(tc.getTestName() != null, "Test case name is null");
//Run through each test case:
File testBaseDir = new File(OUTPUT_DIR, tc.getTestName());
File testBaseDir = new File(modelType == ModelType.SAMEDIFF ? OUTPUT_DIR_SAMEDIFF : OUTPUT_DIR_DL4J, tc.getTestName());
if (testBaseDir.exists()) {
FileUtils.forceDelete(testBaseDir);
}
@ -109,56 +111,62 @@ public class IntegrationTestBaselineGenerator {
//First: if test is a random init test: generate the config, and save it
MultiLayerNetwork mln = null;
ComputationGraph cg = null;
Model m;
boolean isMLN;
SameDiff sd = null;
Model m = null;
if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) {
Object config = tc.getConfiguration();
String json;
String json = null;
if (config instanceof MultiLayerConfiguration) {
MultiLayerConfiguration mlc = (MultiLayerConfiguration) config;
isMLN = true;
json = mlc.toJson();
mln = new MultiLayerNetwork(mlc);
mln.init();
m = mln;
} else {
} else if (config instanceof ComputationGraphConfiguration){
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
isMLN = false;
json = cgc.toJson();
cg = new ComputationGraph(cgc);
cg.init();
m = cg;
} else {
sd = (SameDiff)config;
}
File configFile = new File(testBaseDir, "config." + (isMLN ? "mlc.json" : "cgc.json"));
FileUtils.writeStringToFile(configFile, json);
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME);
ModelSerializer.writeModel(m, savedModel, true);
if(modelType != ModelType.SAMEDIFF) {
File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json"));
FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8);
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
ModelSerializer.writeModel(m, savedModel, true);
} else {
sd.save(savedModel, true);
}
log.info("RANDOM_INIT test - saved randomly initialized model to: {}", savedModel.getAbsolutePath());
} else {
//Pretrained model
m = tc.getPretrainedModel();
isMLN = (m instanceof MultiLayerNetwork);
if (isMLN) {
if (m instanceof MultiLayerNetwork) {
mln = (MultiLayerNetwork) m;
} else {
} else if(m instanceof ComputationGraph){
cg = (ComputationGraph) m;
} else {
sd = (SameDiff)m;
}
}
//Generate predictions to compare against
if (tc.isTestPredictions()) {
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
// Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
File predictionsTestDir = new File(testBaseDir, "predictions");
predictionsTestDir.mkdirs();
int count = 0;
if (isMLN) {
if (modelType == ModelType.MLN) {
for (Pair<INDArray[], INDArray[]> p : inputs) {
INDArray f = p.getFirst()[0];
INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]);
@ -170,7 +178,7 @@ public class IntegrationTestBaselineGenerator {
Nd4j.write(out, dos);
}
}
} else {
} else if(modelType == ModelType.CG) {
for (Pair<INDArray[], INDArray[]> p : inputs) {
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
@ -182,6 +190,19 @@ public class IntegrationTestBaselineGenerator {
}
}
}
} else {
List<String> outNames = tc.getPredictionsNamesSameDiff();
for( Map<String,INDArray> ph : inputsSd ){
Map<String,INDArray> out = sd.output(ph, outNames);
//Save the output...
for(String s : outNames){
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) {
Nd4j.write(out.get(s), dos);
}
}
}
}
log.info("Saved predictions for {} inputs to disk in directory: {}", tc.getTestName(), predictionsTestDir);
@ -189,32 +210,46 @@ public class IntegrationTestBaselineGenerator {
//Compute and save gradients:
if (tc.isTestGradients()) {
MultiDataSet data = tc.getGradientsTestData();
INDArray gradientFlat;
if (isMLN) {
INDArray gradientFlat = null;
Map<String,INDArray> grad;
if (modelType == ModelType.MLN) {
MultiDataSet data = tc.getGradientsTestData();
mln.setInput(data.getFeatures(0));
mln.setLabels(data.getLabels(0));
mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
mln.computeGradientAndScore();
gradientFlat = mln.getFlattenedGradients();
} else {
grad = m.gradient().gradientForVariable();
} else if(modelType == ModelType.CG) {
MultiDataSet data = tc.getGradientsTestData();
cg.setInputs(data.getFeatures());
cg.setLabels(data.getLabels());
cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
cg.computeGradientAndScore();
gradientFlat = cg.getFlattenedGradients();
grad = m.gradient().gradientForVariable();
} else {
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff();
List<String> allVars = new ArrayList<>();
for(SDVariable v : sd.variables()){
if(v.getVariableType() == VariableType.VARIABLE){
allVars.add(v.name());
}
}
grad = sd.calculateGradients(ph, allVars);
}
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
IntegrationTestRunner.write(gradientFlat, gFlatFile);
if(modelType != ModelType.SAMEDIFF) {
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
IntegrationTestRunner.write(gradientFlat, gFlatFile);
}
//Also save the gradient param table:
Map<String, INDArray> g = m.gradient().gradientForVariable();
File gradientDir = new File(testBaseDir, "gradients");
gradientDir.mkdir();
for (String s : g.keySet()) {
for (String s : grad.keySet()) {
File f = new File(gradientDir, s + ".bin");
IntegrationTestRunner.write(g.get(s), f);
IntegrationTestRunner.write(grad.get(s), f);
}
}
@ -224,7 +259,7 @@ public class IntegrationTestBaselineGenerator {
MultiDataSetIterator iter = tc.getUnsupervisedTrainData();
INDArray paramsPostTraining;
if(isMLN){
if(modelType == ModelType.MLN){
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
@ -233,7 +268,7 @@ public class IntegrationTestBaselineGenerator {
mln.pretrainLayer(i, dsi);
}
paramsPostTraining = mln.params();
} else {
} else if(modelType == ModelType.CG) {
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
@ -241,6 +276,8 @@ public class IntegrationTestBaselineGenerator {
cg.pretrainLayer(i, iter);
}
paramsPostTraining = cg.params();
} else {
throw new UnsupportedOperationException("SameDiff not supported for unsupervised training tests");
}
//Save params
@ -251,23 +288,46 @@ public class IntegrationTestBaselineGenerator {
//Test training curves:
if (tc.isTestTrainingCurves()) {
MultiDataSetIterator trainData = tc.getTrainingData();
CollectScoresListener l = new CollectScoresListener(1);
m.setListeners(l);
if (isMLN) {
CollectScoresListener l = new CollectScoresListener(1);
if(modelType != ModelType.SAMEDIFF)
m.setListeners(l);
History h = null;
if (modelType == ModelType.MLN) {
mln.fit(trainData);
} else {
} else if(modelType == ModelType.CG) {
cg.fit(trainData);
} else {
h = sd.fit(trainData, 1);
}
double[] scores;
if(modelType != ModelType.SAMEDIFF){
scores = l.getListScore().toDoubleArray();
} else {
scores = h.lossCurve().getLossValues().toDoubleVector();
}
double[] scores = l.getListScore().toDoubleArray();
File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME);
List<String> s = Arrays.stream(scores).mapToObj(String::valueOf).collect(Collectors.toList());
FileUtils.writeStringToFile(f, String.join(",", s));
FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8);
if (tc.isTestParamsPostTraining()) {
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
IntegrationTestRunner.write(m.params(), p);
if(modelType == ModelType.SAMEDIFF){
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
p.mkdirs();
for(SDVariable v : sd.variables()){
if(v.getVariableType() == VariableType.VARIABLE){
INDArray arr = v.getArr();
File p2 = new File(p, v.name() + ".bin");
IntegrationTestRunner.write(arr, p2);
}
}
} else {
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
IntegrationTestRunner.write(m.params(), p);
}
}
}
@ -276,11 +336,13 @@ public class IntegrationTestBaselineGenerator {
IEvaluation[] evals = tc.getNewEvaluations();
MultiDataSetIterator iter = tc.getEvaluationTestData();
if (isMLN) {
if (modelType == ModelType.MLN) {
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
mln.doEvaluation(dsi, evals);
} else {
} else if(modelType == ModelType.CG){
cg.doEvaluation(iter, evals);
} else {
evals = tc.doEvaluationSameDiff(sd, iter, evals);
}
File evalDir = new File(testBaseDir, "evaluation");
@ -288,7 +350,7 @@ public class IntegrationTestBaselineGenerator {
for (int i = 0; i < evals.length; i++) {
String json = evals[i].toJson();
File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json");
FileUtils.writeStringToFile(f, json);
FileUtils.writeStringToFile(f, json, StandardCharsets.UTF_8);
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -17,14 +18,12 @@
package org.deeplearning4j.integration;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.*;
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
@ -42,9 +41,16 @@ import org.deeplearning4j.parallelism.ParallelInference;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.util.ModelSerializer;
import org.junit.rules.TemporaryFolder;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.*;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
@ -55,12 +61,15 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.resources.Resources;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import java.io.*;
import java.lang.reflect.Modifier;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
@ -79,6 +88,7 @@ public class IntegrationTestRunner {
public static final String FLAT_GRADIENTS_FILENAME = "flattenedGradients.bin";
public static final String TRAINING_CURVE_FILENAME = "trainingCurve.csv";
public static final String PARAMS_POST_TRAIN_FILENAME = "paramsPostTrain.bin";
public static final String PARAMS_POST_TRAIN_SAMEDIFF_DIR = "paramsPostTrain";
public static final String PARAMS_POST_UNSUPERVISED_FILENAME = "paramsPostUnsupervised.bin";
public static final double MAX_REL_ERROR_SCORES = 1e-4;
@ -148,21 +158,25 @@ public class IntegrationTestRunner {
}
public static void runTest(TestCase tc, TemporaryFolder testDir) throws Exception {
Preconditions.checkState(Nd4j.dataType() == DataType.FLOAT, "Integration tests must be run with float precision!");
log.info("Starting test case: {}", tc.getTestName());
BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
//This could alternatively be done via maven surefire configuration
final ModelType modelType = tc.modelType();
log.info("Starting test case: {} - type = {}", tc.getTestName(), modelType);
long start = System.currentTimeMillis();
File workingDir = testDir.newFolder();
tc.initialize(workingDir);
File testBaseDir = testDir.newFolder();
new ClassPathResource("dl4j-integration-tests/" + tc.getTestName()).copyDirectory(testBaseDir);
// new ClassPathResource("dl4j-integration-tests/" + tc.getTestName()).copyDirectory(testBaseDir);
Resources.copyDirectory((modelType == ModelType.SAMEDIFF ? "samediff-integration-tests/" : "dl4j-integration-tests/") + tc.getTestName(), testBaseDir);
MultiLayerNetwork mln = null;
ComputationGraph cg = null;
Model m;
boolean isMLN;
SameDiff sd = null;
Model m = null;
if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) {
log.info("Checking RANDOM_INIT test case: saved model vs. initialized model");
//Checking randomly initialized model:
@ -173,36 +187,46 @@ public class IntegrationTestRunner {
mln = new MultiLayerNetwork(mlc);
mln.init();
m = mln;
isMLN = true;
MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
assertEquals("Configs not equal", loaded.getLayerWiseConfigurations(), mln.getLayerWiseConfigurations());
assertEquals("Params not equal", loaded.params(), mln.params());
assertEquals("Param table not equal", loaded.paramTable(), mln.paramTable());
} else {
} else if(config instanceof ComputationGraphConfiguration ){
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
cg = new ComputationGraph(cgc);
cg.init();
m = cg;
isMLN = false;
ComputationGraph loaded = ComputationGraph.load(savedModel, true);
assertEquals("Configs not equal", loaded.getConfiguration(), cg.getConfiguration());
assertEquals("Params not equal", loaded.params(), cg.params());
assertEquals("Param table not equal", loaded.paramTable(), cg.paramTable());
} else if(config instanceof SameDiff){
sd = (SameDiff)config;
SameDiff loaded = SameDiff.load(savedModel, true);
assertSameDiffEquals(sd, loaded);
} else {
throw new IllegalStateException("Unknown configuration/model type: " + config.getClass());
}
} else {
m = tc.getPretrainedModel();
isMLN = (m instanceof MultiLayerNetwork);
if (isMLN) {
if (m instanceof MultiLayerNetwork) {
mln = (MultiLayerNetwork) m;
} else {
} else if(m instanceof ComputationGraph) {
cg = (ComputationGraph) m;
} else if(m instanceof SameDiff){
sd = (SameDiff)m;
} else {
throw new IllegalStateException("Unknown model type: " + m.getClass());
}
}
//Collect information for test coverage
collectCoverageInformation(m);
if(modelType != ModelType.SAMEDIFF) {
collectCoverageInformation(m);
}
//Check network output (predictions)
@ -210,15 +234,16 @@ public class IntegrationTestRunner {
log.info("Checking predictions: saved output vs. initialized model");
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
Preconditions.checkState(modelType == ModelType.SAMEDIFF || inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
File predictionsTestDir = new File(testBaseDir, "predictions");
predictionsTestDir.mkdirs();
int count = 0;
if (isMLN) {
if (modelType == ModelType.MLN) {
for (Pair<INDArray[], INDArray[]> p : inputs) {
INDArray f = p.getFirst()[0];
INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]);
@ -231,15 +256,15 @@ public class IntegrationTestRunner {
outSaved = Nd4j.read(dis);
}
INDArray gradExceedsRE = exceedsRelError(outSaved, out, tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
int countExceeds = gradExceedsRE.sumNumber().intValue();
INDArray predictionExceedsRE = exceedsRelError(outSaved, out, tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
int countExceeds = predictionExceedsRE.sumNumber().intValue();
assertEquals("Predictions do not match saved predictions - output", 0, countExceeds);
}
} else {
} else if(modelType == ModelType.CG){
for (Pair<INDArray[], INDArray[]> p : inputs) {
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
//Save the array(s)...
//Load the previously saved arrays
INDArray[] outSaved = new INDArray[out.length];
for (int i = 0; i < out.length; i++) {
File outFile = new File(predictionsTestDir, "output_" + (count++) + "_" + i + ".bin");
@ -249,14 +274,36 @@ public class IntegrationTestRunner {
}
for( int i=0; i<outSaved.length; i++ ){
INDArray gradExceedsRE = exceedsRelError(outSaved[i], out[i], tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
int countExceeds = gradExceedsRE.sumNumber().intValue();
INDArray predictionExceedsRE = exceedsRelError(outSaved[i], out[i], tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
int countExceeds = predictionExceedsRE.sumNumber().intValue();
assertEquals("Predictions do not match saved predictions - output " + i, 0, countExceeds);
}
}
} else {
List<String> outNames = tc.getPredictionsNamesSameDiff();
for( Map<String,INDArray> ph : inputsSd ){
Map<String,INDArray> out = sd.output(ph, outNames);
//Load the previously saved placeholder arrays
Map<String,INDArray> outSaved = new HashMap<>();
for(String s : outNames){
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f))) {
outSaved.put(s, Nd4j.read(dis));
}
}
for(String s : outNames){
INDArray predictionExceedsRE = exceedsRelError(outSaved.get(s), out.get(s), tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
int countExceeds = predictionExceedsRE.sumNumber().intValue();
assertEquals("Predictions do not match saved predictions - output \"" + s + "\"", 0, countExceeds);
}
}
}
checkLayerClearance(m);
if(modelType != ModelType.SAMEDIFF) {
checkLayerClearance(m);
}
}
@ -264,34 +311,49 @@ public class IntegrationTestRunner {
if (tc.isTestGradients()) {
log.info("Checking gradients: saved output vs. initialized model");
MultiDataSet data = tc.getGradientsTestData();
INDArray gradientFlat;
org.deeplearning4j.nn.api.Layer[] layers;
if (isMLN) {
INDArray gradientFlat = null;
org.deeplearning4j.nn.api.Layer[] layers = null;
Map<String,INDArray> grad;
if (modelType == ModelType.MLN) {
MultiDataSet data = tc.getGradientsTestData();
mln.setInput(data.getFeatures(0));
mln.setLabels(data.getLabels(0));
mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
mln.computeGradientAndScore();
gradientFlat = mln.getFlattenedGradients();
layers = mln.getLayers();
} else {
grad = mln.gradient().gradientForVariable();
} else if(modelType == ModelType.CG) {
MultiDataSet data = tc.getGradientsTestData();
cg.setInputs(data.getFeatures());
cg.setLabels(data.getLabels());
cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
cg.computeGradientAndScore();
gradientFlat = cg.getFlattenedGradients();
layers = cg.getLayers();
grad = cg.gradient().gradientForVariable();
} else {
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff();
List<String> allVars = new ArrayList<>();
for(SDVariable v : sd.variables()){
if(v.getVariableType() == VariableType.VARIABLE){
allVars.add(v.name());
}
}
grad = sd.calculateGradients(ph, allVars);
}
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
INDArray gradientFlatSaved = read(gFlatFile);
if(modelType != ModelType.SAMEDIFF) {
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
INDArray gradientFlatSaved = read(gFlatFile);
INDArray gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
int count = gradExceedsRE.sumNumber().intValue();
if(count > 0){
logFailedParams(20, "Gradient", layers, gradExceedsRE, gradientFlatSaved, gradientFlat);
INDArray gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
int count = gradExceedsRE.sumNumber().intValue();
if (count > 0) {
logFailedParams(20, "Gradient", layers, gradExceedsRE, gradientFlatSaved, gradientFlat);
}
assertEquals("Saved flattened gradients: not equal (using relative error)", 0, count);
}
assertEquals("Saved flattened gradients: not equal (using relative error)", 0, count);
//Load the gradient table:
File gradientDir = new File(testBaseDir, "gradients");
@ -302,12 +364,12 @@ public class IntegrationTestRunner {
String key = f.getName();
key = key.substring(0, key.length() - 4); //remove ".bin"
INDArray loaded = read(f);
INDArray now = m.gradient().gradientForVariable().get(key);
INDArray now = grad.get(key);
gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
count = gradExceedsRE.sumNumber().intValue();
assertEquals("Saved flattened gradients: not equal (using relative error) for parameter: " + key, 0, count);
INDArray gradExceedsRE = exceedsRelError(loaded, now, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
int count = gradExceedsRE.sumNumber().intValue();
assertEquals("Gradients: not equal (using relative error) for parameter: " + key, 0, count);
}
}
@ -318,7 +380,7 @@ public class IntegrationTestRunner {
INDArray paramsPostTraining;
org.deeplearning4j.nn.api.Layer[] layers;
if(isMLN){
if(modelType == ModelType.MLN){
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
@ -328,7 +390,7 @@ public class IntegrationTestRunner {
}
paramsPostTraining = mln.params();
layers = mln.getLayers();
} else {
} else if(modelType == ModelType.CG) {
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
@ -337,6 +399,8 @@ public class IntegrationTestRunner {
}
paramsPostTraining = cg.params();
layers = cg.getLayers();
} else {
throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models");
}
File f = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_UNSUPERVISED_FILENAME);
@ -360,53 +424,78 @@ public class IntegrationTestRunner {
MultiDataSetIterator trainData = tc.getTrainingData();
boolean isTbptt;
int tbpttLength;
if(isMLN){
if(modelType == ModelType.MLN){
isTbptt = mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT;
tbpttLength = mln.getLayerWiseConfigurations().getTbpttFwdLength();
} else {
} else if(modelType == ModelType.CG) {
isTbptt = cg.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT;
tbpttLength = cg.getConfiguration().getTbpttFwdLength();
} else {
isTbptt = false;
tbpttLength = 0;
}
CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength);
CollectScoresListener l = new CollectScoresListener(1);
m.setListeners(l);
if(modelType != ModelType.SAMEDIFF) {
m.setListeners(l);
}
int iterBefore;
int epochBefore;
int iterAfter;
int epochAfter;
Map<String,INDArray> frozenParamsBefore = getFrozenLayerParamCopies(m);
org.deeplearning4j.nn.api.Layer[] layers;
if (isMLN) {
Map<String,INDArray> frozenParamsBefore = modelType != ModelType.SAMEDIFF ? getFrozenLayerParamCopies(m) : getConstantCopies(sd);
org.deeplearning4j.nn.api.Layer[] layers = null;
History h = null;
if (modelType == ModelType.MLN) {
iterBefore = mln.getIterationCount();
epochBefore = mln.getEpochCount();
mln.fit(countingIter);
iterAfter = mln.getIterationCount();
epochAfter = mln.getEpochCount();
layers = mln.getLayers();
} else {
} else if(modelType == ModelType.CG){
iterBefore = cg.getConfiguration().getIterationCount();
epochBefore = cg.getConfiguration().getEpochCount();
cg.fit(countingIter);
iterAfter = cg.getConfiguration().getIterationCount();
epochAfter = cg.getConfiguration().getEpochCount();
layers = cg.getLayers();
} else {
iterBefore = sd.getTrainingConfig().getIterationCount();
epochBefore = sd.getTrainingConfig().getEpochCount();
h = sd.fit(countingIter, 1);
iterAfter = sd.getTrainingConfig().getIterationCount();
epochAfter = sd.getTrainingConfig().getEpochCount();
}
//Check that frozen params (if any) haven't changed during training:
checkFrozenParams(frozenParamsBefore, m);
if(modelType == ModelType.SAMEDIFF) {
checkConstants(frozenParamsBefore, sd);
} else {
checkFrozenParams(frozenParamsBefore, m);
}
//Validate the iteration and epoch counts - both for the net, and for the layers
int newIters = countingIter.getCurrIter();
assertEquals(iterBefore + newIters, iterAfter);
assertEquals(epochBefore + 1, epochAfter);
validateLayerIterCounts(m, epochBefore + 1, iterBefore+newIters); //TODO CURRENTLY FAILING
double[] scores = l.getListScore().toDoubleArray();
if(modelType != ModelType.SAMEDIFF) {
validateLayerIterCounts(m, epochBefore + 1, iterBefore + newIters);
}
double[] scores;
if(modelType == ModelType.SAMEDIFF){
scores = h.lossCurve().getLossValues().toDoubleVector();
} else {
scores = l.getListScore().toDoubleArray();
}
File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME);
String[] s = FileUtils.readFileToString(f).split(",");
String[] s = FileUtils.readFileToString(f, StandardCharsets.UTF_8).split(",");
if(tc.isTestTrainingCurves()) {
assertEquals("Different number of scores", s.length, scores.length);
@ -426,17 +515,36 @@ public class IntegrationTestRunner {
}
if (tc.isTestParamsPostTraining()) {
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
INDArray paramsExp = read(p);
INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
int count = z.sumNumber().intValue();
if(count > 0){
logFailedParams(20, "Parameter", layers, z, paramsExp, m.params());
if(modelType != ModelType.SAMEDIFF) {
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
INDArray paramsExp = read(p);
INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
int count = z.sumNumber().intValue();
if (count > 0) {
logFailedParams(20, "Parameter", layers, z, paramsExp, m.params());
}
assertEquals("Number of params exceeded max relative error", 0, count);
} else {
File dir = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
for(SDVariable v : sd.variables()){
if(v.getVariableType() != VariableType.VARIABLE)
continue;
INDArray paramNow = v.getArr();
File paramFile = new File(dir, v.name() + ".bin");
INDArray exp = read(paramFile);
INDArray z = exceedsRelError(paramNow, exp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
int count = z.sumNumber().intValue();
if (count > 0) {
logFailedParams(20, "Parameter: " + v.name(), layers, z, exp, paramNow);
}
assertEquals("Number of params exceeded max relative error for parameter: \"" + v.name() + "\"", 0, count);
}
}
assertEquals("Number of params exceeded max relative error", 0, count);
}
checkLayerClearance(m);
if(modelType != ModelType.SAMEDIFF) {
checkLayerClearance(m);
}
}
//Check evaluation:
@ -445,17 +553,19 @@ public class IntegrationTestRunner {
IEvaluation[] evals = tc.getNewEvaluations();
MultiDataSetIterator iter = tc.getEvaluationTestData();
if (isMLN) {
if (modelType == ModelType.MLN) {
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
mln.doEvaluation(dsi, evals);
} else {
} else if(modelType == ModelType.CG){
cg.doEvaluation(iter, evals);
} else {
evals = tc.doEvaluationSameDiff(sd, iter, evals);
}
File evalDir = new File(testBaseDir, "evaluation");
for (int i = 0; i < evals.length; i++) {
File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json");
String json = FileUtils.readFileToString(f);
String json = FileUtils.readFileToString(f, StandardCharsets.UTF_8);
IEvaluation e;
if (evals[i].getClass() == Evaluation.class) {
e = Evaluation.fromJson(json);
@ -479,7 +589,9 @@ public class IntegrationTestRunner {
//Evaluation coverage information:
evaluationClassesSeen.put(evals[i].getClass(), evaluationClassesSeen.getOrDefault(evals[i].getClass(), 0) + 1);
checkLayerClearance(m);
if(modelType != ModelType.SAMEDIFF) {
checkLayerClearance(m);
}
}
}
@ -490,15 +602,20 @@ public class IntegrationTestRunner {
File f = testDir.newFile();
f.delete();
ModelSerializer.writeModel(m, f, true);
if (isMLN) {
if (modelType == ModelType.MLN) {
ModelSerializer.writeModel(m, f, true);
MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
assertEquals(mln.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
assertEquals(mln.params(), restored.params());
} else {
} else if(modelType == ModelType.CG){
ModelSerializer.writeModel(m, f, true);
ComputationGraph restored = ComputationGraph.load(f, true);
assertEquals(cg.getConfiguration(), restored.getConfiguration());
assertEquals(cg.params(), restored.params());
} else {
sd.save(f, true);
SameDiff restored = SameDiff.load(f, true);
assertSameDiffEquals(sd, restored);
}
System.gc();
@ -506,7 +623,7 @@ public class IntegrationTestRunner {
//Check parallel inference
if (tc.isTestParallelInference()) {
if (modelType != ModelType.SAMEDIFF && tc.isTestParallelInference()) {
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
@ -515,7 +632,7 @@ public class IntegrationTestRunner {
List<INDArray[]> exp = new ArrayList<>();
for(Pair<INDArray[], INDArray[]> p : inputs){
INDArray[] out;
if(isMLN){
if(modelType == ModelType.MLN){
INDArray fm = p.getSecond() == null ? null : p.getSecond()[0];
out = new INDArray[]{mln.output(p.getFirst()[0], false, fm, null)};
} else {
@ -547,37 +664,54 @@ public class IntegrationTestRunner {
MultiDataSet toOverfit = tc.getOverfittingData();
for (int i = 0; i < tc.getOverfitNumIterations(); i++) {
if (isMLN) {
if (modelType == ModelType.MLN) {
mln.fit(toOverfit);
} else {
} else if(modelType == ModelType.CG){
cg.fit(toOverfit);
} else {
sd.fit(toOverfit);
}
}
//Check:
INDArray[] output;
if (isMLN) {
INDArray[] output = null;
Map<String,INDArray> outSd = null;
if (modelType == ModelType.MLN) {
mln.setLayerMaskArrays(toOverfit.getFeaturesMaskArray(0), null);
output = new INDArray[]{mln.output(toOverfit.getFeatures(0))};
} else {
} else if(modelType == ModelType.CG ){
cg.setLayerMaskArrays(toOverfit.getFeaturesMaskArrays(), null);
output = cg.output(toOverfit.getFeatures());
} else {
List<String> l = sd.getTrainingConfig().getDataSetFeatureMapping();
Map<String,INDArray> phMap = new HashMap<>();
int i=0;
for(String s : l){
phMap.put(s, toOverfit.getFeatures(i++));
}
outSd = sd.output(phMap, tc.getPredictionsNamesSameDiff());
}
for (int i = 0; i < output.length; i++) {
INDArray z = exceedsRelError(output[i], toOverfit.getLabels(i), tc.getMaxRelativeErrorOverfit(), tc.getMinAbsErrorOverfit());
int n = modelType == ModelType.SAMEDIFF ? outSd.size() : output.length;
for (int i = 0; i < n; i++) {
INDArray out = modelType == ModelType.SAMEDIFF ? outSd.get(tc.getPredictionsNamesSameDiff().get(i)) : output[i];
INDArray label = toOverfit.getLabels(i);
INDArray z = exceedsRelError(out, label, tc.getMaxRelativeErrorOverfit(), tc.getMinAbsErrorOverfit());
int count = z.sumNumber().intValue();
if (count > 0) {
System.out.println(output[i]);
System.out.println(toOverfit.getLabels(i));
INDArray re = relativeError(output[i], toOverfit.getLabels(i), tc.getMinAbsErrorOverfit());
System.out.println(out);
System.out.println(label);
INDArray re = relativeError(out, label, tc.getMinAbsErrorOverfit());
System.out.println("Relative error:");
System.out.println(re);
}
assertEquals("Number of outputs exceeded max relative error", 0, count);
}
checkLayerClearance(m);
if(modelType != ModelType.SAMEDIFF) {
checkLayerClearance(m);
}
}
long end = System.currentTimeMillis();
@ -709,6 +843,16 @@ public class IntegrationTestRunner {
return out;
}
private static Map<String,INDArray> getConstantCopies(SameDiff sd){
Map<String,INDArray> out = new HashMap<>();
for(SDVariable v : sd.variables()){
if(v.isConstant()){
out.put(v.name(), v.getArr());
}
}
return out;
}
public static void checkFrozenParams(Map<String,INDArray> copiesBeforeTraining, Model m){
for(Map.Entry<String,INDArray> e : copiesBeforeTraining.entrySet()){
INDArray actual = m.getParam(e.getKey());
@ -716,6 +860,13 @@ public class IntegrationTestRunner {
}
}
public static void checkConstants(Map<String,INDArray> copiesBefore, SameDiff sd){
for(Map.Entry<String,INDArray> e : copiesBefore.entrySet()){
INDArray actual = sd.getArrForVarName(e.getKey());
assertEquals(e.getKey(), e.getValue(), actual);
}
}
public static void printCoverageInformation(){
log.info("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||");
@ -918,7 +1069,7 @@ public class IntegrationTestRunner {
}
public static void logFailedParams(int maxNum, String prefix, org.deeplearning4j.nn.api.Layer[] layers, INDArray exceedsRelError, INDArray exp, INDArray act){
public static void logFailedParams(int maxNumToPrintOnFailure, String prefix, org.deeplearning4j.nn.api.Layer[] layers, INDArray exceedsRelError, INDArray exp, INDArray act){
long length = exceedsRelError.length();
int logCount = 0;
for(int i=0; i<length; i++ ){
@ -947,10 +1098,33 @@ public class IntegrationTestRunner {
}
log.info("{} {} ({}) failed: expected {} vs actual {} (RelativeError: {}, AbsError: {})", i, prefix, pName, dExp, dAct, re, ae);
if(++logCount >= maxNum){
if(++logCount >= maxNumToPrintOnFailure){
break;
}
}
}
}
public static void assertSameDiffEquals(SameDiff sd1, SameDiff sd2){
assertEquals(sd1.variableMap().keySet(), sd2.variableMap().keySet());
assertEquals(sd1.getOps().keySet(), sd2.getOps().keySet());
assertEquals(sd1.inputs(), sd2.inputs());
//Check constant and variable arrays:
for(SDVariable v : sd1.variables()){
String n = v.name();
assertEquals(n, v.getVariableType(), sd2.getVariable(n).getVariableType());
if(v.isConstant() || v.getVariableType() == VariableType.VARIABLE){
INDArray a1 = v.getArr();
INDArray a2 = sd2.getVariable(n).getArr();
assertEquals(n, a1, a2);
}
}
//Check ops:
for(SameDiffOp o : sd1.getOps().values()){
SameDiffOp o2 = sd2.getOps().get(o.getName());
assertEquals(o.getOp().getClass(), o2.getOp().getClass());
}
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -17,15 +18,19 @@
package org.deeplearning4j.integration;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.integration.testcases.*;
import org.deeplearning4j.integration.testcases.dl4j.*;
import org.junit.AfterClass;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
public class IntegrationTests extends BaseDL4JTest {
//@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
public class IntegrationTestsDL4J extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 300_000L;
}
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@ -36,79 +41,72 @@ public class IntegrationTests extends BaseDL4JTest {
}
// ***** MLPTestCases *****
@Test(timeout = 20000L)
@Test
public void testMLPMnist() throws Exception {
IntegrationTestRunner.runTest(MLPTestCases.getMLPMnist(), testDir);
}
@Test(timeout = 30000L)
@Test
public void testMlpMoon() throws Exception {
IntegrationTestRunner.runTest(MLPTestCases.getMLPMoon(), testDir);
}
// ***** RNNTestCases *****
@Test(timeout = 30000L)
@Test
public void testRnnSeqClassification1() throws Exception {
IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase1(), testDir);
}
@Test(timeout = 60000L)
@Test
public void testRnnSeqClassification2() throws Exception {
IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase2(), testDir);
}
@Test(timeout = 120000L)
@Test
public void testRnnCharacter() throws Exception {
IntegrationTestRunner.runTest(RNNTestCases.getRnnCharacterTestCase(), testDir);
}
// ***** CNN1DTestCases *****
@Test(timeout = 180000L)
@Test
public void testCnn1dCharacter() throws Exception {
IntegrationTestRunner.runTest(CNN1DTestCases.getCnn1dTestCaseCharRNN(), testDir);
}
// ***** CNN2DTestCases *****
@Test(timeout = 120000L)
@Test
public void testLenetMnist() throws Exception {
IntegrationTestRunner.runTest(CNN2DTestCases.getLenetMnist(), testDir);
}
@Ignore //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6017
@Test(timeout = 180000L)
@Test
public void testYoloHouseNumbers() throws Exception {
IntegrationTestRunner.runTest(CNN2DTestCases.getYoloHouseNumbers(), testDir);
}
@Test(timeout = 120000L)
@Test
public void testCnn2DLenetTransferDropoutRepeatability() throws Exception {
IntegrationTestRunner.runTest(CNN2DTestCases.testLenetTransferDropoutRepeatability(), testDir);
}
// ***** CNN3DTestCases *****
@Test(timeout = 180000L)
@Test
public void testCnn3dSynthetic() throws Exception {
IntegrationTestRunner.runTest(CNN3DTestCases.getCnn3dTestCaseSynthetic(), testDir);
}
// ***** UnsupervisedTestCases *****
@Test(timeout = 120000L)
@Test
public void testVAEMnistAnomaly() throws Exception {
IntegrationTestRunner.runTest(UnsupervisedTestCases.getVAEMnistAnomaly(), testDir);
}
// ***** TransferLearningTestCases *****
@Test(timeout = 360000L)
@Test
public void testVgg16Transfer() throws Exception {
IntegrationTestRunner.runTest(CNN2DTestCases.getVGG16TransferTinyImagenet(), testDir);
}
// ***** KerasImportTestCases *****
//TODO
}

View File

@ -0,0 +1,40 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
public class IntegrationTestsSameDiff extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 300_000L;
}
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testMLPMnist() throws Exception {
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir);
}
}

View File

@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -13,13 +13,8 @@
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration;
package org.deeplearning4j.integration.testcases;
/**
* Integration tests starting from Keras model
*/
public class KerasImportTestCases {
public enum ModelType {
MLN, CG, SAMEDIFF
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -17,8 +18,9 @@
package org.deeplearning4j.integration;
import lombok.Data;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
@ -26,6 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
import java.io.File;
import java.util.List;
import java.util.Map;
/**
* A single test case for integration tests
@ -37,16 +40,17 @@ public abstract class TestCase {
PRETRAINED, RANDOM_INIT
}
protected String testName;
protected TestType testType;
protected boolean testPredictions = true;
protected boolean testGradients = true;
protected boolean testUnsupervisedTraining = false;
protected boolean testTrainingCurves = true;
protected boolean testParamsPostTraining = true;
protected boolean testEvaluation = true;
protected boolean testParallelInference = true;
protected boolean testOverfitting = true;
//See: readme.md for more details
protected String testName; //Name of the test, for display purposes
protected TestType testType; //Type of model - from a pretrained model, or a randomly initialized model
protected boolean testPredictions = true; //If true: check the predictions/output. Requires getPredictionsTestData() to be implemented
protected boolean testGradients = true; //If true: check the gradients. Requires getGradientsTestData() to be implemented
protected boolean testUnsupervisedTraining = false; //If true: perform unsupervised training. Only applies to layers like autoencoders, VAEs, etc. Requires getUnsupervisedTrainData() to be implemented
protected boolean testTrainingCurves = true; //If true: perform training, and compare loss vs. iteration. Requires getTrainingData() method
protected boolean testParamsPostTraining = true; //If true: perform training, and compare parameters after training. Requires getTrainingData() method
protected boolean testEvaluation = true; //If true: perform evaluation. Requires getNewEvaluations() and getEvaluationTestData() methods implemented
protected boolean testParallelInference = true; //If true: run the model through ParallelInference. Requires getPredictionsTestData() method. Only applies to DL4J models, NOT SameDiff models
protected boolean testOverfitting = true; //If true: perform overfitting, and ensure the predictions match the training data. Requires both getOverfittingData() and getOverfitNumIterations()
protected int[] unsupervisedTrainLayersMLN = null;
protected String[] unsupervisedTrainLayersCG = null;
@ -65,6 +69,8 @@ public abstract class TestCase {
protected double maxRelativeErrorOverfit = 1e-2;
protected double minAbsErrorOverfit = 1e-2;
public abstract ModelType modelType();
/**
* Initialize the test case... many tests don't need this; others may use it to download or create data
* @param testWorkingDir Working directory to use for test
@ -88,19 +94,37 @@ public abstract class TestCase {
}
/**
* Required if testPredictions == true
* Required if testPredictions == true && DL4J model (MultiLayerNetwork or ComputationGraph)
*/
public List<Pair<INDArray[],INDArray[]>> getPredictionsTestData() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testGradients == true
* Required if testPredictions == true && SameDiff model
*/
public List<Map<String,INDArray>> getPredictionsTestDataSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
public List<String> getPredictionsNamesSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testGradients == true && DL4J model
*/
public MultiDataSet getGradientsTestData() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testGradients == true && SameDiff model
*/
public Map<String,INDArray> getGradientsTestDataSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required when testUnsupervisedTraining == true
*/
@ -122,6 +146,10 @@ public abstract class TestCase {
throw new RuntimeException("Implementations must override this method if used");
}
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations){
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testEvaluation == true
*/
@ -130,12 +158,19 @@ public abstract class TestCase {
}
/**
* Required if testOverfitting == true
* Required if testOverfitting == true && DL4J model
*/
public MultiDataSet getOverfittingData() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testOverfitting == true && SameDiff model
*/
public Map<String,INDArray> getOverfittingDataSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testOverfitting == true
*/

View File

@ -1,36 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases;
import org.deeplearning4j.integration.TestCase;
public class TransferLearningTestCases {
public static TestCase testPartFrozenResNet50(){
throw new UnsupportedOperationException("Not yet implemented");
}
public static TestCase testPartFrozenNASNET(){
throw new UnsupportedOperationException("Not yet implemented");
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -14,22 +15,24 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases;
package org.deeplearning4j.integration.testcases.dl4j;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.integration.testcases.misc.CharacterIterator;
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -64,12 +67,18 @@ public class CNN1DTestCases {
int miniBatchSize = 16;
int exampleLength = 128;
@Override
public ModelType modelType() {
return ModelType.CG;
}
@Override
public Object getConfiguration() throws Exception {
CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
int nOut = iter.totalOutcomes();
return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(0.01))

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -14,7 +15,7 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases;
package org.deeplearning4j.integration.testcases.dl4j;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
@ -22,16 +23,13 @@ import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
import org.datavec.image.recordreader.objdetect.impl.SvhnLabelProvider;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.fetchers.SvhnDataFetcher;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.datasets.fetchers.DataSetType;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.datasets.iterator.impl.TinyImageNetDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.EvaluationCalibration;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
@ -47,7 +45,12 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.model.TinyYOLO;
import org.deeplearning4j.zoo.model.VGG16;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
@ -82,12 +85,18 @@ public class CNN2DTestCases {
testOverfitting = false;
}
@Override
public ModelType modelType() {
return ModelType.MLN;
}
public Object getConfiguration() throws Exception {
int nChannels = 1; // Number of input channels
int outputNum = 10; // The number of possible outcomes
int seed = 123;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(seed)
.l2(0.0005)
.weightInit(WeightInit.XAVIER)
@ -187,6 +196,11 @@ public class CNN2DTestCases {
testOverfitting = false;
}
@Override
public ModelType modelType() {
return ModelType.CG;
}
@Override
public Model getPretrainedModel() throws Exception {
VGG16 vgg16 = VGG16.builder()
@ -269,6 +283,11 @@ public class CNN2DTestCases {
testOverfitting = false;
}
@Override
public ModelType modelType() {
return ModelType.CG;
}
@Override
public Model getPretrainedModel() throws Exception {
int nClasses = 10;
@ -372,6 +391,11 @@ public class CNN2DTestCases {
testOverfitting = true;
}
@Override
public ModelType modelType() {
return ModelType.CG;
}
@Override
public Model getPretrainedModel() throws Exception {
@ -381,6 +405,7 @@ public class CNN2DTestCases {
lrSchedule.put(3000, 0.001);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345)
.l2(0.0005)
.weightInit(WeightInit.XAVIER)

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -14,35 +15,31 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases;
package org.deeplearning4j.integration.testcases.dl4j;
import org.apache.commons.math3.stat.inference.TestUtils;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@ -66,6 +63,11 @@ public class CNN3DTestCases {
testOverfitting = false;
}
@Override
public ModelType modelType() {
return ModelType.MLN;
}
public Object getConfiguration() throws Exception {
int nChannels = 3; // Number of input channels
int outputNum = 10; // The number of possible outcomes

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -14,8 +15,9 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases;
package org.deeplearning4j.integration.testcases.dl4j;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
@ -24,10 +26,6 @@ import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.EvaluationCalibration;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
@ -35,7 +33,12 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
@ -76,9 +79,15 @@ public class MLPTestCases {
minAbsErrorOverfit = 1e-2;
}
@Override
public ModelType modelType() {
return ModelType.MLN;
}
@Override
public Object getConfiguration() {
return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345)
.updater(new Adam(new MapSchedule.Builder(ScheduleType.ITERATION)
.add(0, 5e-2)
@ -168,6 +177,11 @@ public class MLPTestCases {
testOverfitting = false; //Not much point here: very simple training data
}
@Override
public ModelType modelType() {
return ModelType.MLN;
}
@Override
public Object getConfiguration() {
int seed = 123;
@ -179,6 +193,7 @@ public class MLPTestCases {
//log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(seed)
.updater(new Nesterovs(learningRate, 0.9))
.list()

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -14,22 +15,24 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases;
package org.deeplearning4j.integration.testcases.dl4j;
import org.deeplearning4j.integration.ModelType;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
import org.nd4j.shade.guava.io.Files;
import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.integration.testcases.misc.CharacterIterator;
import org.deeplearning4j.integration.testcases.misc.CompositeMultiDataSetPreProcessor;
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.EvaluationCalibration;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
@ -91,6 +94,11 @@ public class RNNTestCases {
private int exampleLength = 1000;
@Override
public ModelType modelType() {
return ModelType.MLN;
}
@Override
public Object getConfiguration() throws Exception {
@ -101,6 +109,7 @@ public class RNNTestCases {
int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters
return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345)
.l2(0.001)
.weightInit(WeightInit.XAVIER)
@ -175,9 +184,15 @@ public class RNNTestCases {
return normalizer;
}
@Override
public ModelType modelType() {
return ModelType.MLN;
}
@Override
public Object getConfiguration() throws Exception {
return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345)
.updater(new Adam(5e-2))
.l1(1e-3).l2(1e-3)
@ -298,6 +313,7 @@ public class RNNTestCases {
@Override
public Object getConfiguration() throws Exception {
return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345)
.updater(new Adam(5e-2))
.l1(1e-3).l2(1e-3)

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -14,18 +15,20 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases;
package org.deeplearning4j.integration.testcases.dl4j;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
@ -59,9 +62,15 @@ public class UnsupervisedTestCases {
minAbsErrorPretrainParams = 5e-4;
}
@Override
public ModelType modelType() {
return ModelType.MLN;
}
@Override
public Object getConfiguration() {
return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345)
.updater(new Adam(0.05))
.weightInit(WeightInit.XAVIER)

View File

@ -14,7 +14,7 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases.misc;
package org.deeplearning4j.integration.testcases.dl4j.misc;
import org.apache.commons.io.FileUtils;
import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -1,36 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases.misc;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
public class CompositeMultiDataSetPreProcessor implements MultiDataSetPreProcessor {
private MultiDataSetPreProcessor[] preProcessors;
public CompositeMultiDataSetPreProcessor(MultiDataSetPreProcessor... preProcessors){
this.preProcessors = preProcessors;
}
@Override
public void preProcess(MultiDataSet multiDataSet) {
for(MultiDataSetPreProcessor p : preProcessors){
p.preProcess(multiDataSet);
}
}
}

View File

@ -0,0 +1,155 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases.samediff;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import java.util.*;
public class SameDiffMLPTestCases {
public static TestCase getMLPMnist(){
return new TestCase() {
{
testName = "MLPMnistSD";
testType = TestType.RANDOM_INIT;
testPredictions = true;
testTrainingCurves = true;
testGradients = true;
testParamsPostTraining = true;
testEvaluation = true;
testOverfitting = true;
maxRelativeErrorOverfit = 2e-2;
minAbsErrorOverfit = 1e-2;
}
@Override
public ModelType modelType() {
return ModelType.SAMEDIFF;
}
@Override
public Object getConfiguration() throws Exception {
Nd4j.getRandom().setSeed(12345);
//Define the network structure:
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256));
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256));
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10));
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10));
SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0));
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
SDVariable loss = sd.loss.logLoss("loss", label, out);
//Also set the training configuration:
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new Adam(0.01))
.weightDecay(1e-3, true)
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
.build());
return sd;
}
@Override
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
List<Map<String,INDArray>> out = new ArrayList<>();
DataSetIterator iter = new MnistDataSetIterator(1, true, 12345);
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
iter = new MnistDataSetIterator(8, true, 12345);
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
return out;
}
@Override
public List<String> getPredictionsNamesSameDiff() throws Exception {
return Collections.singletonList("out");
}
@Override
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
Map<String,INDArray> map = new HashMap<>();
map.put("in", ds.getFeatures());
map.put("label", ds.getLabels());
return map;
}
@Override
public MultiDataSetIterator getTrainingData() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(8, true, 12345);
iter = new EarlyTerminationDataSetIterator(iter, 32);
return new MultiDataSetIteratorAdapter(iter);
}
@Override
public IEvaluation[] getNewEvaluations() {
return new IEvaluation[]{new Evaluation()};
}
@Override
public MultiDataSetIterator getEvaluationTestData() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(8, false, 12345);
iter = new EarlyTerminationDataSetIterator(iter, 32);
return new MultiDataSetIteratorAdapter(iter);
}
@Override
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
sd.evaluate(iter, "out", 0, evaluations);
return evaluations;
}
@Override
public MultiDataSet getOverfittingData() throws Exception {
return new MnistDataSetIterator(1, true, 12345).next().toMultiDataSet();
}
@Override
public int getOverfitNumIterations() {
return 100;
}
};
}
}

View File

@ -872,6 +872,8 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public native void setLeaksDetector(@Cast("bool") boolean reallyDetect);
public native @Cast("bool") boolean helpersAllowed();
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
public native @Cast("bool") boolean blasFallback();
public native int tadThreshold();
public native void setTadThreshold(int threshold);
@ -4165,15 +4167,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
*/
public native void transposei();
/**
* return array pointing on certain range of this array
* index - the number of array to be returned among set of possible arrays
* dimensions - array of dimensions to point on
*/
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntPointer dimensions);
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntBuffer dimensions);
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector int[] dimensions);
/**
* returns the number of arrays pointing on specified dimension(s)
* dimensions - array of dimensions to point on
@ -6881,9 +6874,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3);
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shape, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shape, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shape, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native void traceNew(int id);
@ -7323,14 +7316,12 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo);
@Namespace("shape") public static native int rank(@Const int[] shapeInfo);
// returns pointer on elementWiseStride
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
/**
* returns pointer on elementWiseStride
*/
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
/**
* Converts a raw int buffer of the layout:
@ -8010,12 +8001,33 @@ public static final int PREALLOC_SIZE = 33554432;
* subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer
* keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
*/
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets);
/**
* processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array
* arguments:
* idx - input argument, intervals of indexes which define the sub-array to point on,
* when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank)
* when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank)
* when (dimStart == dimEnd) then whole range will be used for current dimension
* maxShapeInfo - input argument, shapeInfo of original array
* minShapeInfo - output argument, shapeInfo of sub-array to be deduced
* minOffset - output argument, offset of sub-array buffer offsets from original buffer
* keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
* isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd,
* numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1
*/
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset);
/**
* for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99}
@ -8036,6 +8048,14 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer outShapeInfo);
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] outShapeInfo);
/**
* get stride over contiguous axis (contiguous axis must have stride = 1)
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
*/
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
@ -8908,6 +8928,8 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
@ -9103,6 +9125,10 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// #endif /* SHAPE_H_ */

View File

@ -875,6 +875,8 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public native void setLeaksDetector(@Cast("bool") boolean reallyDetect);
public native @Cast("bool") boolean helpersAllowed();
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
public native @Cast("bool") boolean blasFallback();
public native int tadThreshold();
public native void setTadThreshold(int threshold);
@ -4168,15 +4170,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
*/
public native void transposei();
/**
* return array pointing on certain range of this array
* index - the number of array to be returned among set of possible arrays
* dimensions - array of dimensions to point on
*/
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntPointer dimensions);
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntBuffer dimensions);
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector int[] dimensions);
/**
* returns the number of arrays pointing on specified dimension(s)
* dimensions - array of dimensions to point on
@ -6884,9 +6877,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3);
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shape, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shape, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shape, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native void traceNew(int id);
@ -7326,14 +7319,12 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo);
@Namespace("shape") public static native int rank(@Const int[] shapeInfo);
// returns pointer on elementWiseStride
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
/**
* returns pointer on elementWiseStride
*/
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
/**
* Converts a raw int buffer of the layout:
@ -8013,12 +8004,33 @@ public static final int PREALLOC_SIZE = 33554432;
* subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer
* keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
*/
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets);
/**
* processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array
* arguments:
* idx - input argument, intervals of indexes which define the sub-array to point on,
* when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank)
* when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank)
* when (dimStart == dimEnd) then whole range will be used for current dimension
* maxShapeInfo - input argument, shapeInfo of original array
* minShapeInfo - output argument, shapeInfo of sub-array to be deduced
* minOffset - output argument, offset of sub-array buffer offsets from original buffer
* keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
* isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd,
* numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1
*/
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset);
/**
* for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99}
@ -8039,6 +8051,14 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer outShapeInfo);
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] outShapeInfo);
/**
* get stride over contiguous axis (contiguous axis must have stride = 1)
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
*/
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
@ -8911,6 +8931,8 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
@ -9106,6 +9128,10 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// #endif /* SHAPE_H_ */

View File

@ -102,6 +102,7 @@ public class RngTests extends BaseNd4jTest {
@Test
public void testRandomBinomial() {
Nd4j.getRandom().setSeed(12345);
//silly tests. Just increasing the usage for randomBinomial to stop compiler warnings.
INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3);
assertTrue(x.sum().getDouble(0) > 0.0); //silly test. Just increasing th usage for randomBinomial