DL4J integrations tests updates + add SameDiff support (#298)

* Revive and start updating DL4J integration tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Add SameDiff support - first pass

Signed-off-by: Alex Black <blacka101@gmail.com>

* SameDiff test case generation

Signed-off-by: Alex Black <blacka101@gmail.com>

* SameDiff integration tests polishing

Signed-off-by: Alex Black <blacka101@gmail.com>

* More SameDiff integration test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Final polish

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small test tweak

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-07 22:44:41 +11:00 committed by GitHub
parent ead5162c97
commit a80fb99a5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 909 additions and 378 deletions

View File

@ -89,12 +89,12 @@ public abstract class BaseDL4JTest {
return getDataType(); return getDataType();
} }
protected Boolean integrationTest; protected static Boolean integrationTest;
/** /**
* @return True if integration tests maven profile is enabled, false otherwise. * @return True if integration tests maven profile is enabled, false otherwise.
*/ */
public boolean isIntegrationTests(){ public static boolean isIntegrationTests(){
if(integrationTest == null){ if(integrationTest == null){
String prop = System.getenv("DL4J_INTEGRATION_TESTS"); String prop = System.getenv("DL4J_INTEGRATION_TESTS");
integrationTest = Boolean.parseBoolean(prop); integrationTest = Boolean.parseBoolean(prop);
@ -107,7 +107,7 @@ public abstract class BaseDL4JTest {
* This can be used to dynamically skip integration tests when the integration test profile is not enabled. * This can be used to dynamically skip integration tests when the integration test profile is not enabled.
* Note that the integration test profile is not enabled by default - "integration-tests" profile * Note that the integration test profile is not enabled by default - "integration-tests" profile
*/ */
public void skipUnlessIntegrationTests(){ public static void skipUnlessIntegrationTests(){
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests()); assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
} }

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.optimize.listeners;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList; import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntArrayList;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.deeplearning4j.optimize.api.BaseTrainingListener;
@ -32,6 +33,7 @@ import java.io.Serializable;
* @author Alex Black * @author Alex Black
*/ */
@Data @Data
@EqualsAndHashCode(callSuper = true)
@Slf4j @Slf4j
public class CollectScoresListener extends BaseTrainingListener implements Serializable { public class CollectScoresListener extends BaseTrainingListener implements Serializable {

View File

@ -1,16 +1,15 @@
#DL4J Integration Tests #DL4J and SameDiff Integration Tests
These tests are designed to check a number of aspects of DL4J: These tests are designed to check a number of aspects of DL4J and SameDiff:
1. Predictions 1. Predictions (i.e., network output)
2. Training (training curves, parameters, gradient calculation) 2. Training (training curves, parameters, gradient calculation)
3. Evaluation 3. Evaluation (accuracy, etc)
4. Model serialization 4. Model serialization (saving + loading models)
5. Overfitting sanity checks 5. Overfitting sanity checks (make sure we can overfit a single example)
6. Data pipelines 6. Data pipelines
7. Evaluation classes 7. Parallel Wrapper
8. Parallel Wrapper 8. Validating conditions that should always hold (frozen layer params don't change, for example)
9. Validating conditions that should always hold (frozen layer params don't change, for example)
They are designed for the following purposes: They are designed for the following purposes:
@ -19,32 +18,46 @@ They are designed for the following purposes:
3. Detecting significant differences between CPU and CUDA backends 3. Detecting significant differences between CPU and CUDA backends
4. Validating implementation via sanity checks on training - i.e., can we overfit a single example? 4. Validating implementation via sanity checks on training - i.e., can we overfit a single example?
5. Checking networks and data pipelines on real-world scale data and nets 5. Checking networks and data pipelines on real-world scale data and nets
6. Operating as fully automated pre-release checks (replacing previously used manual checks) 6. Operating as fully automated pre-release checks (replacing manual sanity checks)
## Types of Tests ## Main Classes
The integration tests are set up to be able to run multiple tests on each network configuration. Explanation of the main classes:
* **IntegrationTestBaselineGenerator**: Run *manually* to generate and save "expected results" for comparing in the future.
Output goes to dl4j-test-resources, for saving/uploading.
* **IntegrationTestRunner**: Actually runs the tests, and compares the output/result to those generated by the baseline generator
* **TestCase**: integration tests extend this
* **testcases/\*.java**: the actual integration test definitions
* **IntegrationTestsDL4J**: entry point for running the DL4J integration tests
* **IntegrationTestsSameDiff**: entry point for running the SameDiff integration tests
## Types of Test Components
The integration tests are set up to be able to run multiple types of tests on each network configuration.
Networks may be pretrained (from model zoo) or randomly initialized (from specified configuration). Networks may be pretrained (from model zoo) or randomly initialized (from specified configuration).
Specifically, test cases can be run with any subset of the following components to be tested, by setting TestCase.XYZ boolean options to true or false: Specifically, test cases can be run with any subset of the following components to be tested, by setting TestCase.XYZ boolean options to true or false:
1. testPredictions: Testing output (predictions) on some specified data vs. saved/known good arrays 1. **testPredictions**: Testing output (predictions) on some specified data vs. saved/known good arrays
2. testGradients: Testing gradients on some specified data vs. saved/known good arrays 2. **testGradients**: Testing gradients on some specified data vs. saved/known good arrays
3. testPretrain: Test layerwise pretraining parameters and training curves 3. **testPretrain**: Test layerwise pretraining parameters and training curves
4. testTrainingCurves: Train, and check score vs. iteration 4. **testTrainingCurves**: Train, and check score vs. iteration
5. testParamsPostTraining: validate params match post training 5. **testParamsPostTraining**: validate params match post training
6. testEvaluation: test the evaluation performance (post training, if 4 or 5 are true) 6. **testEvaluation**: test the evaluation performance (post training, if 4 or 5 are true)
7. testParallelInference: validate that single net and parallel inference results match 7. **testParallelInference**: validate that single net and parallel inference results match
8. testOverfitting: sanity check - try to overfit a single example 8. **testOverfitting**: sanity check - try to overfit a single example
See TestCase.java for more details.
## Adding a New Integration Test ## Adding a New Integration Test
The process to add a new test is simple: The process to add a new test is simple:
1. Add a method that creates and returns a TestCase object 1. Add a method that creates and returns a TestCase object (example: testcases/MLPTestCases.getMLPMnist())
2. Add it as a unit test to IntegrationTests class 2. Add it as a unit test to IntegrationTests class (example: IntegrationTestsDL4J.testMLPMnist())
3. Run IntegrationTestBaselineGenerator (if required) to generate and save the "known good" results. 3. Run IntegrationTestBaselineGenerator with the new test case, to generate and save the "known good" results.
4. Run the new integration test to make sure it passes, on both CPU and CUDA backends
5. Commit the generated test resources from step 3 to dl4j-test-resources repo
Note that IntegrationTestBaselineGenerator assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo. Note that IntegrationTestBaselineGenerator assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo.

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,15 +17,10 @@
package org.deeplearning4j.integration; package org.deeplearning4j.integration;
import org.nd4j.shade.guava.io.Files;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.IEvaluation; import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
import org.deeplearning4j.integration.testcases.CNN2DTestCases;
import org.deeplearning4j.integration.testcases.MLPTestCases;
import org.deeplearning4j.integration.testcases.RNNTestCases;
import org.deeplearning4j.integration.testcases.UnsupervisedTestCases;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -32,20 +28,27 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.CollectScoresListener; import org.deeplearning4j.optimize.listeners.CollectScoresListener;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.guava.io.Files;
import java.io.*; import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
/** /**
* Run this manually to generate - or update - the saved files for a specific test. * Run this manually to generate - or update - the saved files for a specific test.
* Places results in dl4j-test-resources: assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo. * Places results in dl4j-test-resources: assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo.
@ -53,32 +56,31 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
public class IntegrationTestBaselineGenerator { public class IntegrationTestBaselineGenerator {
public static final File OUTPUT_DIR = new File("../../dl4j-test-resources/src/main/resources/dl4j-integration-tests").getAbsoluteFile(); public static final File OUTPUT_DIR_DL4J = new File("../../dl4j-test-resources/src/main/resources/dl4j-integration-tests").getAbsoluteFile();
public static final File OUTPUT_DIR_SAMEDIFF = new File("../../dl4j-test-resources/src/main/resources/samediff-integration-tests").getAbsoluteFile();
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
if (!OUTPUT_DIR.exists()) { if (!OUTPUT_DIR_DL4J.exists() && !OUTPUT_DIR_SAMEDIFF.exists()) {
throw new RuntimeException("output directory (test resources) does not exist!"); throw new RuntimeException("output directories in test resources do not exist!");
} }
//All integration tests are run with float precision! runGeneration(
Nd4j.setDataType(DataType.FLOAT); SameDiffMLPTestCases.getMLPMnist()
);
// runGeneration(
// MLPTestCases.getMLPMnist(),
// );
} }
private static void runGeneration(TestCase... testCases) throws Exception { private static void runGeneration(TestCase... testCases) throws Exception {
for( TestCase tc : testCases ) { for( TestCase tc : testCases ) {
final ModelType modelType = tc.modelType();
//Basic validation: //Basic validation:
Preconditions.checkState(tc.getTestName() != null, "Test case name is null"); Preconditions.checkState(tc.getTestName() != null, "Test case name is null");
//Run through each test case: //Run through each test case:
File testBaseDir = new File(OUTPUT_DIR, tc.getTestName()); File testBaseDir = new File(modelType == ModelType.SAMEDIFF ? OUTPUT_DIR_SAMEDIFF : OUTPUT_DIR_DL4J, tc.getTestName());
if (testBaseDir.exists()) { if (testBaseDir.exists()) {
FileUtils.forceDelete(testBaseDir); FileUtils.forceDelete(testBaseDir);
} }
@ -109,56 +111,62 @@ public class IntegrationTestBaselineGenerator {
//First: if test is a random init test: generate the config, and save it //First: if test is a random init test: generate the config, and save it
MultiLayerNetwork mln = null; MultiLayerNetwork mln = null;
ComputationGraph cg = null; ComputationGraph cg = null;
Model m; SameDiff sd = null;
boolean isMLN; Model m = null;
if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) { if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) {
Object config = tc.getConfiguration(); Object config = tc.getConfiguration();
String json; String json = null;
if (config instanceof MultiLayerConfiguration) { if (config instanceof MultiLayerConfiguration) {
MultiLayerConfiguration mlc = (MultiLayerConfiguration) config; MultiLayerConfiguration mlc = (MultiLayerConfiguration) config;
isMLN = true;
json = mlc.toJson(); json = mlc.toJson();
mln = new MultiLayerNetwork(mlc); mln = new MultiLayerNetwork(mlc);
mln.init(); mln.init();
m = mln; m = mln;
} else { } else if (config instanceof ComputationGraphConfiguration){
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config; ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
isMLN = false;
json = cgc.toJson(); json = cgc.toJson();
cg = new ComputationGraph(cgc); cg = new ComputationGraph(cgc);
cg.init(); cg.init();
m = cg; m = cg;
} else {
sd = (SameDiff)config;
} }
File configFile = new File(testBaseDir, "config." + (isMLN ? "mlc.json" : "cgc.json"));
FileUtils.writeStringToFile(configFile, json);
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME); File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME);
ModelSerializer.writeModel(m, savedModel, true); if(modelType != ModelType.SAMEDIFF) {
File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json"));
FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8);
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
ModelSerializer.writeModel(m, savedModel, true);
} else {
sd.save(savedModel, true);
}
log.info("RANDOM_INIT test - saved randomly initialized model to: {}", savedModel.getAbsolutePath()); log.info("RANDOM_INIT test - saved randomly initialized model to: {}", savedModel.getAbsolutePath());
} else { } else {
//Pretrained model //Pretrained model
m = tc.getPretrainedModel(); m = tc.getPretrainedModel();
isMLN = (m instanceof MultiLayerNetwork); if (m instanceof MultiLayerNetwork) {
if (isMLN) {
mln = (MultiLayerNetwork) m; mln = (MultiLayerNetwork) m;
} else { } else if(m instanceof ComputationGraph){
cg = (ComputationGraph) m; cg = (ComputationGraph) m;
} else {
sd = (SameDiff)m;
} }
} }
//Generate predictions to compare against //Generate predictions to compare against
if (tc.isTestPredictions()) { if (tc.isTestPredictions()) {
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData(); List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName()); List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
// Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
File predictionsTestDir = new File(testBaseDir, "predictions"); File predictionsTestDir = new File(testBaseDir, "predictions");
predictionsTestDir.mkdirs(); predictionsTestDir.mkdirs();
int count = 0; int count = 0;
if (isMLN) { if (modelType == ModelType.MLN) {
for (Pair<INDArray[], INDArray[]> p : inputs) { for (Pair<INDArray[], INDArray[]> p : inputs) {
INDArray f = p.getFirst()[0]; INDArray f = p.getFirst()[0];
INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]); INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]);
@ -170,7 +178,7 @@ public class IntegrationTestBaselineGenerator {
Nd4j.write(out, dos); Nd4j.write(out, dos);
} }
} }
} else { } else if(modelType == ModelType.CG) {
for (Pair<INDArray[], INDArray[]> p : inputs) { for (Pair<INDArray[], INDArray[]> p : inputs) {
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null); INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
@ -182,6 +190,19 @@ public class IntegrationTestBaselineGenerator {
} }
} }
} }
} else {
List<String> outNames = tc.getPredictionsNamesSameDiff();
for( Map<String,INDArray> ph : inputsSd ){
Map<String,INDArray> out = sd.output(ph, outNames);
//Save the output...
for(String s : outNames){
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) {
Nd4j.write(out.get(s), dos);
}
}
}
} }
log.info("Saved predictions for {} inputs to disk in directory: {}", tc.getTestName(), predictionsTestDir); log.info("Saved predictions for {} inputs to disk in directory: {}", tc.getTestName(), predictionsTestDir);
@ -189,32 +210,46 @@ public class IntegrationTestBaselineGenerator {
//Compute and save gradients: //Compute and save gradients:
if (tc.isTestGradients()) { if (tc.isTestGradients()) {
MultiDataSet data = tc.getGradientsTestData(); INDArray gradientFlat = null;
INDArray gradientFlat; Map<String,INDArray> grad;
if (isMLN) { if (modelType == ModelType.MLN) {
MultiDataSet data = tc.getGradientsTestData();
mln.setInput(data.getFeatures(0)); mln.setInput(data.getFeatures(0));
mln.setLabels(data.getLabels(0)); mln.setLabels(data.getLabels(0));
mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0)); mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
mln.computeGradientAndScore(); mln.computeGradientAndScore();
gradientFlat = mln.getFlattenedGradients(); gradientFlat = mln.getFlattenedGradients();
} else { grad = m.gradient().gradientForVariable();
} else if(modelType == ModelType.CG) {
MultiDataSet data = tc.getGradientsTestData();
cg.setInputs(data.getFeatures()); cg.setInputs(data.getFeatures());
cg.setLabels(data.getLabels()); cg.setLabels(data.getLabels());
cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays()); cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
cg.computeGradientAndScore(); cg.computeGradientAndScore();
gradientFlat = cg.getFlattenedGradients(); gradientFlat = cg.getFlattenedGradients();
grad = m.gradient().gradientForVariable();
} else {
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff();
List<String> allVars = new ArrayList<>();
for(SDVariable v : sd.variables()){
if(v.getVariableType() == VariableType.VARIABLE){
allVars.add(v.name());
}
}
grad = sd.calculateGradients(ph, allVars);
} }
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME); if(modelType != ModelType.SAMEDIFF) {
IntegrationTestRunner.write(gradientFlat, gFlatFile); File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
IntegrationTestRunner.write(gradientFlat, gFlatFile);
}
//Also save the gradient param table: //Also save the gradient param table:
Map<String, INDArray> g = m.gradient().gradientForVariable();
File gradientDir = new File(testBaseDir, "gradients"); File gradientDir = new File(testBaseDir, "gradients");
gradientDir.mkdir(); gradientDir.mkdir();
for (String s : g.keySet()) { for (String s : grad.keySet()) {
File f = new File(gradientDir, s + ".bin"); File f = new File(gradientDir, s + ".bin");
IntegrationTestRunner.write(g.get(s), f); IntegrationTestRunner.write(grad.get(s), f);
} }
} }
@ -224,7 +259,7 @@ public class IntegrationTestBaselineGenerator {
MultiDataSetIterator iter = tc.getUnsupervisedTrainData(); MultiDataSetIterator iter = tc.getUnsupervisedTrainData();
INDArray paramsPostTraining; INDArray paramsPostTraining;
if(isMLN){ if(modelType == ModelType.MLN){
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN(); int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null"); Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
@ -233,7 +268,7 @@ public class IntegrationTestBaselineGenerator {
mln.pretrainLayer(i, dsi); mln.pretrainLayer(i, dsi);
} }
paramsPostTraining = mln.params(); paramsPostTraining = mln.params();
} else { } else if(modelType == ModelType.CG) {
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG(); String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
Preconditions.checkState(layersToTrain != null, "Layer names must not be null"); Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
@ -241,6 +276,8 @@ public class IntegrationTestBaselineGenerator {
cg.pretrainLayer(i, iter); cg.pretrainLayer(i, iter);
} }
paramsPostTraining = cg.params(); paramsPostTraining = cg.params();
} else {
throw new UnsupportedOperationException("SameDiff not supported for unsupervised training tests");
} }
//Save params //Save params
@ -251,23 +288,46 @@ public class IntegrationTestBaselineGenerator {
//Test training curves: //Test training curves:
if (tc.isTestTrainingCurves()) { if (tc.isTestTrainingCurves()) {
MultiDataSetIterator trainData = tc.getTrainingData(); MultiDataSetIterator trainData = tc.getTrainingData();
CollectScoresListener l = new CollectScoresListener(1);
m.setListeners(l);
if (isMLN) { CollectScoresListener l = new CollectScoresListener(1);
if(modelType != ModelType.SAMEDIFF)
m.setListeners(l);
History h = null;
if (modelType == ModelType.MLN) {
mln.fit(trainData); mln.fit(trainData);
} else { } else if(modelType == ModelType.CG) {
cg.fit(trainData); cg.fit(trainData);
} else {
h = sd.fit(trainData, 1);
}
double[] scores;
if(modelType != ModelType.SAMEDIFF){
scores = l.getListScore().toDoubleArray();
} else {
scores = h.lossCurve().getLossValues().toDoubleVector();
} }
double[] scores = l.getListScore().toDoubleArray();
File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME); File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME);
List<String> s = Arrays.stream(scores).mapToObj(String::valueOf).collect(Collectors.toList()); List<String> s = Arrays.stream(scores).mapToObj(String::valueOf).collect(Collectors.toList());
FileUtils.writeStringToFile(f, String.join(",", s)); FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8);
if (tc.isTestParamsPostTraining()) { if (tc.isTestParamsPostTraining()) {
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME); if(modelType == ModelType.SAMEDIFF){
IntegrationTestRunner.write(m.params(), p); File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
p.mkdirs();
for(SDVariable v : sd.variables()){
if(v.getVariableType() == VariableType.VARIABLE){
INDArray arr = v.getArr();
File p2 = new File(p, v.name() + ".bin");
IntegrationTestRunner.write(arr, p2);
}
}
} else {
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
IntegrationTestRunner.write(m.params(), p);
}
} }
} }
@ -276,11 +336,13 @@ public class IntegrationTestBaselineGenerator {
IEvaluation[] evals = tc.getNewEvaluations(); IEvaluation[] evals = tc.getNewEvaluations();
MultiDataSetIterator iter = tc.getEvaluationTestData(); MultiDataSetIterator iter = tc.getEvaluationTestData();
if (isMLN) { if (modelType == ModelType.MLN) {
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
mln.doEvaluation(dsi, evals); mln.doEvaluation(dsi, evals);
} else { } else if(modelType == ModelType.CG){
cg.doEvaluation(iter, evals); cg.doEvaluation(iter, evals);
} else {
evals = tc.doEvaluationSameDiff(sd, iter, evals);
} }
File evalDir = new File(testBaseDir, "evaluation"); File evalDir = new File(testBaseDir, "evaluation");
@ -288,7 +350,7 @@ public class IntegrationTestBaselineGenerator {
for (int i = 0; i < evals.length; i++) { for (int i = 0; i < evals.length; i++) {
String json = evals[i].toJson(); String json = evals[i].toJson();
File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json"); File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json");
FileUtils.writeStringToFile(f, json); FileUtils.writeStringToFile(f, json, StandardCharsets.UTF_8);
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -17,14 +18,12 @@
package org.deeplearning4j.integration; package org.deeplearning4j.integration;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.*; import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
@ -42,9 +41,16 @@ import org.deeplearning4j.parallelism.ParallelInference;
import org.deeplearning4j.parallelism.inference.InferenceMode; import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.*;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
@ -55,12 +61,15 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.resources.Resources;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import java.io.*; import java.io.*;
import java.lang.reflect.Modifier; import java.lang.reflect.Modifier;
import java.nio.charset.StandardCharsets;
import java.util.*; import java.util.*;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -79,6 +88,7 @@ public class IntegrationTestRunner {
public static final String FLAT_GRADIENTS_FILENAME = "flattenedGradients.bin"; public static final String FLAT_GRADIENTS_FILENAME = "flattenedGradients.bin";
public static final String TRAINING_CURVE_FILENAME = "trainingCurve.csv"; public static final String TRAINING_CURVE_FILENAME = "trainingCurve.csv";
public static final String PARAMS_POST_TRAIN_FILENAME = "paramsPostTrain.bin"; public static final String PARAMS_POST_TRAIN_FILENAME = "paramsPostTrain.bin";
public static final String PARAMS_POST_TRAIN_SAMEDIFF_DIR = "paramsPostTrain";
public static final String PARAMS_POST_UNSUPERVISED_FILENAME = "paramsPostUnsupervised.bin"; public static final String PARAMS_POST_UNSUPERVISED_FILENAME = "paramsPostUnsupervised.bin";
public static final double MAX_REL_ERROR_SCORES = 1e-4; public static final double MAX_REL_ERROR_SCORES = 1e-4;
@ -148,21 +158,25 @@ public class IntegrationTestRunner {
} }
public static void runTest(TestCase tc, TemporaryFolder testDir) throws Exception { public static void runTest(TestCase tc, TemporaryFolder testDir) throws Exception {
Preconditions.checkState(Nd4j.dataType() == DataType.FLOAT, "Integration tests must be run with float precision!"); BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
log.info("Starting test case: {}", tc.getTestName()); //This could alternatively be done via maven surefire configuration
final ModelType modelType = tc.modelType();
log.info("Starting test case: {} - type = {}", tc.getTestName(), modelType);
long start = System.currentTimeMillis(); long start = System.currentTimeMillis();
File workingDir = testDir.newFolder(); File workingDir = testDir.newFolder();
tc.initialize(workingDir); tc.initialize(workingDir);
File testBaseDir = testDir.newFolder(); File testBaseDir = testDir.newFolder();
new ClassPathResource("dl4j-integration-tests/" + tc.getTestName()).copyDirectory(testBaseDir); // new ClassPathResource("dl4j-integration-tests/" + tc.getTestName()).copyDirectory(testBaseDir);
Resources.copyDirectory((modelType == ModelType.SAMEDIFF ? "samediff-integration-tests/" : "dl4j-integration-tests/") + tc.getTestName(), testBaseDir);
MultiLayerNetwork mln = null; MultiLayerNetwork mln = null;
ComputationGraph cg = null; ComputationGraph cg = null;
Model m; SameDiff sd = null;
boolean isMLN; Model m = null;
if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) { if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) {
log.info("Checking RANDOM_INIT test case: saved model vs. initialized model"); log.info("Checking RANDOM_INIT test case: saved model vs. initialized model");
//Checking randomly initialized model: //Checking randomly initialized model:
@ -173,36 +187,46 @@ public class IntegrationTestRunner {
mln = new MultiLayerNetwork(mlc); mln = new MultiLayerNetwork(mlc);
mln.init(); mln.init();
m = mln; m = mln;
isMLN = true;
MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true); MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
assertEquals("Configs not equal", loaded.getLayerWiseConfigurations(), mln.getLayerWiseConfigurations()); assertEquals("Configs not equal", loaded.getLayerWiseConfigurations(), mln.getLayerWiseConfigurations());
assertEquals("Params not equal", loaded.params(), mln.params()); assertEquals("Params not equal", loaded.params(), mln.params());
assertEquals("Param table not equal", loaded.paramTable(), mln.paramTable()); assertEquals("Param table not equal", loaded.paramTable(), mln.paramTable());
} else { } else if(config instanceof ComputationGraphConfiguration ){
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config; ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
cg = new ComputationGraph(cgc); cg = new ComputationGraph(cgc);
cg.init(); cg.init();
m = cg; m = cg;
isMLN = false;
ComputationGraph loaded = ComputationGraph.load(savedModel, true); ComputationGraph loaded = ComputationGraph.load(savedModel, true);
assertEquals("Configs not equal", loaded.getConfiguration(), cg.getConfiguration()); assertEquals("Configs not equal", loaded.getConfiguration(), cg.getConfiguration());
assertEquals("Params not equal", loaded.params(), cg.params()); assertEquals("Params not equal", loaded.params(), cg.params());
assertEquals("Param table not equal", loaded.paramTable(), cg.paramTable()); assertEquals("Param table not equal", loaded.paramTable(), cg.paramTable());
} else if(config instanceof SameDiff){
sd = (SameDiff)config;
SameDiff loaded = SameDiff.load(savedModel, true);
assertSameDiffEquals(sd, loaded);
} else {
throw new IllegalStateException("Unknown configuration/model type: " + config.getClass());
} }
} else { } else {
m = tc.getPretrainedModel(); m = tc.getPretrainedModel();
isMLN = (m instanceof MultiLayerNetwork); if (m instanceof MultiLayerNetwork) {
if (isMLN) {
mln = (MultiLayerNetwork) m; mln = (MultiLayerNetwork) m;
} else { } else if(m instanceof ComputationGraph) {
cg = (ComputationGraph) m; cg = (ComputationGraph) m;
} else if(m instanceof SameDiff){
sd = (SameDiff)m;
} else {
throw new IllegalStateException("Unknown model type: " + m.getClass());
} }
} }
//Collect information for test coverage //Collect information for test coverage
collectCoverageInformation(m); if(modelType != ModelType.SAMEDIFF) {
collectCoverageInformation(m);
}
//Check network output (predictions) //Check network output (predictions)
@ -210,15 +234,16 @@ public class IntegrationTestRunner {
log.info("Checking predictions: saved output vs. initialized model"); log.info("Checking predictions: saved output vs. initialized model");
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData(); List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName()); List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
Preconditions.checkState(modelType == ModelType.SAMEDIFF || inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
File predictionsTestDir = new File(testBaseDir, "predictions"); File predictionsTestDir = new File(testBaseDir, "predictions");
predictionsTestDir.mkdirs(); predictionsTestDir.mkdirs();
int count = 0; int count = 0;
if (isMLN) { if (modelType == ModelType.MLN) {
for (Pair<INDArray[], INDArray[]> p : inputs) { for (Pair<INDArray[], INDArray[]> p : inputs) {
INDArray f = p.getFirst()[0]; INDArray f = p.getFirst()[0];
INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]); INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]);
@ -231,15 +256,15 @@ public class IntegrationTestRunner {
outSaved = Nd4j.read(dis); outSaved = Nd4j.read(dis);
} }
INDArray gradExceedsRE = exceedsRelError(outSaved, out, tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput()); INDArray predictionExceedsRE = exceedsRelError(outSaved, out, tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
int countExceeds = gradExceedsRE.sumNumber().intValue(); int countExceeds = predictionExceedsRE.sumNumber().intValue();
assertEquals("Predictions do not match saved predictions - output", 0, countExceeds); assertEquals("Predictions do not match saved predictions - output", 0, countExceeds);
} }
} else { } else if(modelType == ModelType.CG){
for (Pair<INDArray[], INDArray[]> p : inputs) { for (Pair<INDArray[], INDArray[]> p : inputs) {
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null); INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
//Save the array(s)... //Load the previously saved arrays
INDArray[] outSaved = new INDArray[out.length]; INDArray[] outSaved = new INDArray[out.length];
for (int i = 0; i < out.length; i++) { for (int i = 0; i < out.length; i++) {
File outFile = new File(predictionsTestDir, "output_" + (count++) + "_" + i + ".bin"); File outFile = new File(predictionsTestDir, "output_" + (count++) + "_" + i + ".bin");
@ -249,14 +274,36 @@ public class IntegrationTestRunner {
} }
for( int i=0; i<outSaved.length; i++ ){ for( int i=0; i<outSaved.length; i++ ){
INDArray gradExceedsRE = exceedsRelError(outSaved[i], out[i], tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput()); INDArray predictionExceedsRE = exceedsRelError(outSaved[i], out[i], tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
int countExceeds = gradExceedsRE.sumNumber().intValue(); int countExceeds = predictionExceedsRE.sumNumber().intValue();
assertEquals("Predictions do not match saved predictions - output " + i, 0, countExceeds); assertEquals("Predictions do not match saved predictions - output " + i, 0, countExceeds);
} }
} }
} else {
List<String> outNames = tc.getPredictionsNamesSameDiff();
for( Map<String,INDArray> ph : inputsSd ){
Map<String,INDArray> out = sd.output(ph, outNames);
//Load the previously saved placeholder arrays
Map<String,INDArray> outSaved = new HashMap<>();
for(String s : outNames){
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
try (DataInputStream dis = new DataInputStream(new FileInputStream(f))) {
outSaved.put(s, Nd4j.read(dis));
}
}
for(String s : outNames){
INDArray predictionExceedsRE = exceedsRelError(outSaved.get(s), out.get(s), tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
int countExceeds = predictionExceedsRE.sumNumber().intValue();
assertEquals("Predictions do not match saved predictions - output \"" + s + "\"", 0, countExceeds);
}
}
} }
checkLayerClearance(m); if(modelType != ModelType.SAMEDIFF) {
checkLayerClearance(m);
}
} }
@ -264,34 +311,49 @@ public class IntegrationTestRunner {
if (tc.isTestGradients()) { if (tc.isTestGradients()) {
log.info("Checking gradients: saved output vs. initialized model"); log.info("Checking gradients: saved output vs. initialized model");
MultiDataSet data = tc.getGradientsTestData(); INDArray gradientFlat = null;
INDArray gradientFlat; org.deeplearning4j.nn.api.Layer[] layers = null;
org.deeplearning4j.nn.api.Layer[] layers; Map<String,INDArray> grad;
if (isMLN) { if (modelType == ModelType.MLN) {
MultiDataSet data = tc.getGradientsTestData();
mln.setInput(data.getFeatures(0)); mln.setInput(data.getFeatures(0));
mln.setLabels(data.getLabels(0)); mln.setLabels(data.getLabels(0));
mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0)); mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
mln.computeGradientAndScore(); mln.computeGradientAndScore();
gradientFlat = mln.getFlattenedGradients(); gradientFlat = mln.getFlattenedGradients();
layers = mln.getLayers(); layers = mln.getLayers();
} else { grad = mln.gradient().gradientForVariable();
} else if(modelType == ModelType.CG) {
MultiDataSet data = tc.getGradientsTestData();
cg.setInputs(data.getFeatures()); cg.setInputs(data.getFeatures());
cg.setLabels(data.getLabels()); cg.setLabels(data.getLabels());
cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays()); cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
cg.computeGradientAndScore(); cg.computeGradientAndScore();
gradientFlat = cg.getFlattenedGradients(); gradientFlat = cg.getFlattenedGradients();
layers = cg.getLayers(); layers = cg.getLayers();
grad = cg.gradient().gradientForVariable();
} else {
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff();
List<String> allVars = new ArrayList<>();
for(SDVariable v : sd.variables()){
if(v.getVariableType() == VariableType.VARIABLE){
allVars.add(v.name());
}
}
grad = sd.calculateGradients(ph, allVars);
} }
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME); if(modelType != ModelType.SAMEDIFF) {
INDArray gradientFlatSaved = read(gFlatFile); File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
INDArray gradientFlatSaved = read(gFlatFile);
INDArray gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients()); INDArray gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
int count = gradExceedsRE.sumNumber().intValue(); int count = gradExceedsRE.sumNumber().intValue();
if(count > 0){ if (count > 0) {
logFailedParams(20, "Gradient", layers, gradExceedsRE, gradientFlatSaved, gradientFlat); logFailedParams(20, "Gradient", layers, gradExceedsRE, gradientFlatSaved, gradientFlat);
}
assertEquals("Saved flattened gradients: not equal (using relative error)", 0, count);
} }
assertEquals("Saved flattened gradients: not equal (using relative error)", 0, count);
//Load the gradient table: //Load the gradient table:
File gradientDir = new File(testBaseDir, "gradients"); File gradientDir = new File(testBaseDir, "gradients");
@ -302,12 +364,12 @@ public class IntegrationTestRunner {
String key = f.getName(); String key = f.getName();
key = key.substring(0, key.length() - 4); //remove ".bin" key = key.substring(0, key.length() - 4); //remove ".bin"
INDArray loaded = read(f); INDArray loaded = read(f);
INDArray now = m.gradient().gradientForVariable().get(key); INDArray now = grad.get(key);
gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients()); INDArray gradExceedsRE = exceedsRelError(loaded, now, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
count = gradExceedsRE.sumNumber().intValue(); int count = gradExceedsRE.sumNumber().intValue();
assertEquals("Saved flattened gradients: not equal (using relative error) for parameter: " + key, 0, count); assertEquals("Gradients: not equal (using relative error) for parameter: " + key, 0, count);
} }
} }
@ -318,7 +380,7 @@ public class IntegrationTestRunner {
INDArray paramsPostTraining; INDArray paramsPostTraining;
org.deeplearning4j.nn.api.Layer[] layers; org.deeplearning4j.nn.api.Layer[] layers;
if(isMLN){ if(modelType == ModelType.MLN){
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN(); int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null"); Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
@ -328,7 +390,7 @@ public class IntegrationTestRunner {
} }
paramsPostTraining = mln.params(); paramsPostTraining = mln.params();
layers = mln.getLayers(); layers = mln.getLayers();
} else { } else if(modelType == ModelType.CG) {
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG(); String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
Preconditions.checkState(layersToTrain != null, "Layer names must not be null"); Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
@ -337,6 +399,8 @@ public class IntegrationTestRunner {
} }
paramsPostTraining = cg.params(); paramsPostTraining = cg.params();
layers = cg.getLayers(); layers = cg.getLayers();
} else {
throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models");
} }
File f = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_UNSUPERVISED_FILENAME); File f = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_UNSUPERVISED_FILENAME);
@ -360,53 +424,78 @@ public class IntegrationTestRunner {
MultiDataSetIterator trainData = tc.getTrainingData(); MultiDataSetIterator trainData = tc.getTrainingData();
boolean isTbptt; boolean isTbptt;
int tbpttLength; int tbpttLength;
if(isMLN){ if(modelType == ModelType.MLN){
isTbptt = mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT; isTbptt = mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT;
tbpttLength = mln.getLayerWiseConfigurations().getTbpttFwdLength(); tbpttLength = mln.getLayerWiseConfigurations().getTbpttFwdLength();
} else { } else if(modelType == ModelType.CG) {
isTbptt = cg.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; isTbptt = cg.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT;
tbpttLength = cg.getConfiguration().getTbpttFwdLength(); tbpttLength = cg.getConfiguration().getTbpttFwdLength();
} else {
isTbptt = false;
tbpttLength = 0;
} }
CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength); CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength);
CollectScoresListener l = new CollectScoresListener(1); CollectScoresListener l = new CollectScoresListener(1);
m.setListeners(l); if(modelType != ModelType.SAMEDIFF) {
m.setListeners(l);
}
int iterBefore; int iterBefore;
int epochBefore; int epochBefore;
int iterAfter; int iterAfter;
int epochAfter; int epochAfter;
Map<String,INDArray> frozenParamsBefore = getFrozenLayerParamCopies(m); Map<String,INDArray> frozenParamsBefore = modelType != ModelType.SAMEDIFF ? getFrozenLayerParamCopies(m) : getConstantCopies(sd);
org.deeplearning4j.nn.api.Layer[] layers; org.deeplearning4j.nn.api.Layer[] layers = null;
if (isMLN) { History h = null;
if (modelType == ModelType.MLN) {
iterBefore = mln.getIterationCount(); iterBefore = mln.getIterationCount();
epochBefore = mln.getEpochCount(); epochBefore = mln.getEpochCount();
mln.fit(countingIter); mln.fit(countingIter);
iterAfter = mln.getIterationCount(); iterAfter = mln.getIterationCount();
epochAfter = mln.getEpochCount(); epochAfter = mln.getEpochCount();
layers = mln.getLayers(); layers = mln.getLayers();
} else { } else if(modelType == ModelType.CG){
iterBefore = cg.getConfiguration().getIterationCount(); iterBefore = cg.getConfiguration().getIterationCount();
epochBefore = cg.getConfiguration().getEpochCount(); epochBefore = cg.getConfiguration().getEpochCount();
cg.fit(countingIter); cg.fit(countingIter);
iterAfter = cg.getConfiguration().getIterationCount(); iterAfter = cg.getConfiguration().getIterationCount();
epochAfter = cg.getConfiguration().getEpochCount(); epochAfter = cg.getConfiguration().getEpochCount();
layers = cg.getLayers(); layers = cg.getLayers();
} else {
iterBefore = sd.getTrainingConfig().getIterationCount();
epochBefore = sd.getTrainingConfig().getEpochCount();
h = sd.fit(countingIter, 1);
iterAfter = sd.getTrainingConfig().getIterationCount();
epochAfter = sd.getTrainingConfig().getEpochCount();
} }
//Check that frozen params (if any) haven't changed during training: //Check that frozen params (if any) haven't changed during training:
checkFrozenParams(frozenParamsBefore, m); if(modelType == ModelType.SAMEDIFF) {
checkConstants(frozenParamsBefore, sd);
} else {
checkFrozenParams(frozenParamsBefore, m);
}
//Validate the iteration and epoch counts - both for the net, and for the layers //Validate the iteration and epoch counts - both for the net, and for the layers
int newIters = countingIter.getCurrIter(); int newIters = countingIter.getCurrIter();
assertEquals(iterBefore + newIters, iterAfter); assertEquals(iterBefore + newIters, iterAfter);
assertEquals(epochBefore + 1, epochAfter); assertEquals(epochBefore + 1, epochAfter);
validateLayerIterCounts(m, epochBefore + 1, iterBefore+newIters); //TODO CURRENTLY FAILING if(modelType != ModelType.SAMEDIFF) {
double[] scores = l.getListScore().toDoubleArray(); validateLayerIterCounts(m, epochBefore + 1, iterBefore + newIters);
}
double[] scores;
if(modelType == ModelType.SAMEDIFF){
scores = h.lossCurve().getLossValues().toDoubleVector();
} else {
scores = l.getListScore().toDoubleArray();
}
File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME); File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME);
String[] s = FileUtils.readFileToString(f).split(","); String[] s = FileUtils.readFileToString(f, StandardCharsets.UTF_8).split(",");
if(tc.isTestTrainingCurves()) { if(tc.isTestTrainingCurves()) {
assertEquals("Different number of scores", s.length, scores.length); assertEquals("Different number of scores", s.length, scores.length);
@ -426,17 +515,36 @@ public class IntegrationTestRunner {
} }
if (tc.isTestParamsPostTraining()) { if (tc.isTestParamsPostTraining()) {
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME); if(modelType != ModelType.SAMEDIFF) {
INDArray paramsExp = read(p); File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining()); INDArray paramsExp = read(p);
int count = z.sumNumber().intValue(); INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
if(count > 0){ int count = z.sumNumber().intValue();
logFailedParams(20, "Parameter", layers, z, paramsExp, m.params()); if (count > 0) {
logFailedParams(20, "Parameter", layers, z, paramsExp, m.params());
}
assertEquals("Number of params exceeded max relative error", 0, count);
} else {
File dir = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
for(SDVariable v : sd.variables()){
if(v.getVariableType() != VariableType.VARIABLE)
continue;
INDArray paramNow = v.getArr();
File paramFile = new File(dir, v.name() + ".bin");
INDArray exp = read(paramFile);
INDArray z = exceedsRelError(paramNow, exp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
int count = z.sumNumber().intValue();
if (count > 0) {
logFailedParams(20, "Parameter: " + v.name(), layers, z, exp, paramNow);
}
assertEquals("Number of params exceeded max relative error for parameter: \"" + v.name() + "\"", 0, count);
}
} }
assertEquals("Number of params exceeded max relative error", 0, count);
} }
checkLayerClearance(m); if(modelType != ModelType.SAMEDIFF) {
checkLayerClearance(m);
}
} }
//Check evaluation: //Check evaluation:
@ -445,17 +553,19 @@ public class IntegrationTestRunner {
IEvaluation[] evals = tc.getNewEvaluations(); IEvaluation[] evals = tc.getNewEvaluations();
MultiDataSetIterator iter = tc.getEvaluationTestData(); MultiDataSetIterator iter = tc.getEvaluationTestData();
if (isMLN) { if (modelType == ModelType.MLN) {
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
mln.doEvaluation(dsi, evals); mln.doEvaluation(dsi, evals);
} else { } else if(modelType == ModelType.CG){
cg.doEvaluation(iter, evals); cg.doEvaluation(iter, evals);
} else {
evals = tc.doEvaluationSameDiff(sd, iter, evals);
} }
File evalDir = new File(testBaseDir, "evaluation"); File evalDir = new File(testBaseDir, "evaluation");
for (int i = 0; i < evals.length; i++) { for (int i = 0; i < evals.length; i++) {
File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json"); File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json");
String json = FileUtils.readFileToString(f); String json = FileUtils.readFileToString(f, StandardCharsets.UTF_8);
IEvaluation e; IEvaluation e;
if (evals[i].getClass() == Evaluation.class) { if (evals[i].getClass() == Evaluation.class) {
e = Evaluation.fromJson(json); e = Evaluation.fromJson(json);
@ -479,7 +589,9 @@ public class IntegrationTestRunner {
//Evaluation coverage information: //Evaluation coverage information:
evaluationClassesSeen.put(evals[i].getClass(), evaluationClassesSeen.getOrDefault(evals[i].getClass(), 0) + 1); evaluationClassesSeen.put(evals[i].getClass(), evaluationClassesSeen.getOrDefault(evals[i].getClass(), 0) + 1);
checkLayerClearance(m); if(modelType != ModelType.SAMEDIFF) {
checkLayerClearance(m);
}
} }
} }
@ -490,15 +602,20 @@ public class IntegrationTestRunner {
File f = testDir.newFile(); File f = testDir.newFile();
f.delete(); f.delete();
ModelSerializer.writeModel(m, f, true); if (modelType == ModelType.MLN) {
if (isMLN) { ModelSerializer.writeModel(m, f, true);
MultiLayerNetwork restored = MultiLayerNetwork.load(f, true); MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
assertEquals(mln.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); assertEquals(mln.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
assertEquals(mln.params(), restored.params()); assertEquals(mln.params(), restored.params());
} else { } else if(modelType == ModelType.CG){
ModelSerializer.writeModel(m, f, true);
ComputationGraph restored = ComputationGraph.load(f, true); ComputationGraph restored = ComputationGraph.load(f, true);
assertEquals(cg.getConfiguration(), restored.getConfiguration()); assertEquals(cg.getConfiguration(), restored.getConfiguration());
assertEquals(cg.params(), restored.params()); assertEquals(cg.params(), restored.params());
} else {
sd.save(f, true);
SameDiff restored = SameDiff.load(f, true);
assertSameDiffEquals(sd, restored);
} }
System.gc(); System.gc();
@ -506,7 +623,7 @@ public class IntegrationTestRunner {
//Check parallel inference //Check parallel inference
if (tc.isTestParallelInference()) { if (modelType != ModelType.SAMEDIFF && tc.isTestParallelInference()) {
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData(); List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
@ -515,7 +632,7 @@ public class IntegrationTestRunner {
List<INDArray[]> exp = new ArrayList<>(); List<INDArray[]> exp = new ArrayList<>();
for(Pair<INDArray[], INDArray[]> p : inputs){ for(Pair<INDArray[], INDArray[]> p : inputs){
INDArray[] out; INDArray[] out;
if(isMLN){ if(modelType == ModelType.MLN){
INDArray fm = p.getSecond() == null ? null : p.getSecond()[0]; INDArray fm = p.getSecond() == null ? null : p.getSecond()[0];
out = new INDArray[]{mln.output(p.getFirst()[0], false, fm, null)}; out = new INDArray[]{mln.output(p.getFirst()[0], false, fm, null)};
} else { } else {
@ -547,37 +664,54 @@ public class IntegrationTestRunner {
MultiDataSet toOverfit = tc.getOverfittingData(); MultiDataSet toOverfit = tc.getOverfittingData();
for (int i = 0; i < tc.getOverfitNumIterations(); i++) { for (int i = 0; i < tc.getOverfitNumIterations(); i++) {
if (isMLN) { if (modelType == ModelType.MLN) {
mln.fit(toOverfit); mln.fit(toOverfit);
} else { } else if(modelType == ModelType.CG){
cg.fit(toOverfit); cg.fit(toOverfit);
} else {
sd.fit(toOverfit);
} }
} }
//Check: //Check:
INDArray[] output; INDArray[] output = null;
if (isMLN) { Map<String,INDArray> outSd = null;
if (modelType == ModelType.MLN) {
mln.setLayerMaskArrays(toOverfit.getFeaturesMaskArray(0), null); mln.setLayerMaskArrays(toOverfit.getFeaturesMaskArray(0), null);
output = new INDArray[]{mln.output(toOverfit.getFeatures(0))}; output = new INDArray[]{mln.output(toOverfit.getFeatures(0))};
} else { } else if(modelType == ModelType.CG ){
cg.setLayerMaskArrays(toOverfit.getFeaturesMaskArrays(), null); cg.setLayerMaskArrays(toOverfit.getFeaturesMaskArrays(), null);
output = cg.output(toOverfit.getFeatures()); output = cg.output(toOverfit.getFeatures());
} else {
List<String> l = sd.getTrainingConfig().getDataSetFeatureMapping();
Map<String,INDArray> phMap = new HashMap<>();
int i=0;
for(String s : l){
phMap.put(s, toOverfit.getFeatures(i++));
}
outSd = sd.output(phMap, tc.getPredictionsNamesSameDiff());
} }
for (int i = 0; i < output.length; i++) { int n = modelType == ModelType.SAMEDIFF ? outSd.size() : output.length;
INDArray z = exceedsRelError(output[i], toOverfit.getLabels(i), tc.getMaxRelativeErrorOverfit(), tc.getMinAbsErrorOverfit()); for (int i = 0; i < n; i++) {
INDArray out = modelType == ModelType.SAMEDIFF ? outSd.get(tc.getPredictionsNamesSameDiff().get(i)) : output[i];
INDArray label = toOverfit.getLabels(i);
INDArray z = exceedsRelError(out, label, tc.getMaxRelativeErrorOverfit(), tc.getMinAbsErrorOverfit());
int count = z.sumNumber().intValue(); int count = z.sumNumber().intValue();
if (count > 0) { if (count > 0) {
System.out.println(output[i]); System.out.println(out);
System.out.println(toOverfit.getLabels(i)); System.out.println(label);
INDArray re = relativeError(output[i], toOverfit.getLabels(i), tc.getMinAbsErrorOverfit()); INDArray re = relativeError(out, label, tc.getMinAbsErrorOverfit());
System.out.println("Relative error:"); System.out.println("Relative error:");
System.out.println(re); System.out.println(re);
} }
assertEquals("Number of outputs exceeded max relative error", 0, count); assertEquals("Number of outputs exceeded max relative error", 0, count);
} }
checkLayerClearance(m); if(modelType != ModelType.SAMEDIFF) {
checkLayerClearance(m);
}
} }
long end = System.currentTimeMillis(); long end = System.currentTimeMillis();
@ -709,6 +843,16 @@ public class IntegrationTestRunner {
return out; return out;
} }
private static Map<String,INDArray> getConstantCopies(SameDiff sd){
Map<String,INDArray> out = new HashMap<>();
for(SDVariable v : sd.variables()){
if(v.isConstant()){
out.put(v.name(), v.getArr());
}
}
return out;
}
public static void checkFrozenParams(Map<String,INDArray> copiesBeforeTraining, Model m){ public static void checkFrozenParams(Map<String,INDArray> copiesBeforeTraining, Model m){
for(Map.Entry<String,INDArray> e : copiesBeforeTraining.entrySet()){ for(Map.Entry<String,INDArray> e : copiesBeforeTraining.entrySet()){
INDArray actual = m.getParam(e.getKey()); INDArray actual = m.getParam(e.getKey());
@ -716,6 +860,13 @@ public class IntegrationTestRunner {
} }
} }
public static void checkConstants(Map<String,INDArray> copiesBefore, SameDiff sd){
for(Map.Entry<String,INDArray> e : copiesBefore.entrySet()){
INDArray actual = sd.getArrForVarName(e.getKey());
assertEquals(e.getKey(), e.getValue(), actual);
}
}
public static void printCoverageInformation(){ public static void printCoverageInformation(){
log.info("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"); log.info("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||");
@ -918,7 +1069,7 @@ public class IntegrationTestRunner {
} }
public static void logFailedParams(int maxNum, String prefix, org.deeplearning4j.nn.api.Layer[] layers, INDArray exceedsRelError, INDArray exp, INDArray act){ public static void logFailedParams(int maxNumToPrintOnFailure, String prefix, org.deeplearning4j.nn.api.Layer[] layers, INDArray exceedsRelError, INDArray exp, INDArray act){
long length = exceedsRelError.length(); long length = exceedsRelError.length();
int logCount = 0; int logCount = 0;
for(int i=0; i<length; i++ ){ for(int i=0; i<length; i++ ){
@ -947,10 +1098,33 @@ public class IntegrationTestRunner {
} }
log.info("{} {} ({}) failed: expected {} vs actual {} (RelativeError: {}, AbsError: {})", i, prefix, pName, dExp, dAct, re, ae); log.info("{} {} ({}) failed: expected {} vs actual {} (RelativeError: {}, AbsError: {})", i, prefix, pName, dExp, dAct, re, ae);
if(++logCount >= maxNum){ if(++logCount >= maxNumToPrintOnFailure){
break; break;
} }
} }
} }
} }
public static void assertSameDiffEquals(SameDiff sd1, SameDiff sd2){
assertEquals(sd1.variableMap().keySet(), sd2.variableMap().keySet());
assertEquals(sd1.getOps().keySet(), sd2.getOps().keySet());
assertEquals(sd1.inputs(), sd2.inputs());
//Check constant and variable arrays:
for(SDVariable v : sd1.variables()){
String n = v.name();
assertEquals(n, v.getVariableType(), sd2.getVariable(n).getVariableType());
if(v.isConstant() || v.getVariableType() == VariableType.VARIABLE){
INDArray a1 = v.getArr();
INDArray a2 = sd2.getVariable(n).getArr();
assertEquals(n, a1, a2);
}
}
//Check ops:
for(SameDiffOp o : sd1.getOps().values()){
SameDiffOp o2 = sd2.getOps().get(o.getName());
assertEquals(o.getOp().getClass(), o2.getOp().getClass());
}
}
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -17,15 +18,19 @@
package org.deeplearning4j.integration; package org.deeplearning4j.integration;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.integration.testcases.*; import org.deeplearning4j.integration.testcases.dl4j.*;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.Ignore;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
@Ignore("AB - 2019/05/27 - Integration tests need to be updated") //@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
public class IntegrationTests extends BaseDL4JTest { public class IntegrationTestsDL4J extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 300_000L;
}
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();
@ -36,79 +41,72 @@ public class IntegrationTests extends BaseDL4JTest {
} }
// ***** MLPTestCases ***** // ***** MLPTestCases *****
@Test(timeout = 20000L) @Test
public void testMLPMnist() throws Exception { public void testMLPMnist() throws Exception {
IntegrationTestRunner.runTest(MLPTestCases.getMLPMnist(), testDir); IntegrationTestRunner.runTest(MLPTestCases.getMLPMnist(), testDir);
} }
@Test(timeout = 30000L) @Test
public void testMlpMoon() throws Exception { public void testMlpMoon() throws Exception {
IntegrationTestRunner.runTest(MLPTestCases.getMLPMoon(), testDir); IntegrationTestRunner.runTest(MLPTestCases.getMLPMoon(), testDir);
} }
// ***** RNNTestCases ***** // ***** RNNTestCases *****
@Test(timeout = 30000L) @Test
public void testRnnSeqClassification1() throws Exception { public void testRnnSeqClassification1() throws Exception {
IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase1(), testDir); IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase1(), testDir);
} }
@Test(timeout = 60000L) @Test
public void testRnnSeqClassification2() throws Exception { public void testRnnSeqClassification2() throws Exception {
IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase2(), testDir); IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase2(), testDir);
} }
@Test(timeout = 120000L) @Test
public void testRnnCharacter() throws Exception { public void testRnnCharacter() throws Exception {
IntegrationTestRunner.runTest(RNNTestCases.getRnnCharacterTestCase(), testDir); IntegrationTestRunner.runTest(RNNTestCases.getRnnCharacterTestCase(), testDir);
} }
// ***** CNN1DTestCases ***** // ***** CNN1DTestCases *****
@Test(timeout = 180000L) @Test
public void testCnn1dCharacter() throws Exception { public void testCnn1dCharacter() throws Exception {
IntegrationTestRunner.runTest(CNN1DTestCases.getCnn1dTestCaseCharRNN(), testDir); IntegrationTestRunner.runTest(CNN1DTestCases.getCnn1dTestCaseCharRNN(), testDir);
} }
// ***** CNN2DTestCases ***** // ***** CNN2DTestCases *****
@Test(timeout = 120000L) @Test
public void testLenetMnist() throws Exception { public void testLenetMnist() throws Exception {
IntegrationTestRunner.runTest(CNN2DTestCases.getLenetMnist(), testDir); IntegrationTestRunner.runTest(CNN2DTestCases.getLenetMnist(), testDir);
} }
@Ignore //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6017 @Test
@Test(timeout = 180000L)
public void testYoloHouseNumbers() throws Exception { public void testYoloHouseNumbers() throws Exception {
IntegrationTestRunner.runTest(CNN2DTestCases.getYoloHouseNumbers(), testDir); IntegrationTestRunner.runTest(CNN2DTestCases.getYoloHouseNumbers(), testDir);
} }
@Test(timeout = 120000L) @Test
public void testCnn2DLenetTransferDropoutRepeatability() throws Exception { public void testCnn2DLenetTransferDropoutRepeatability() throws Exception {
IntegrationTestRunner.runTest(CNN2DTestCases.testLenetTransferDropoutRepeatability(), testDir); IntegrationTestRunner.runTest(CNN2DTestCases.testLenetTransferDropoutRepeatability(), testDir);
} }
// ***** CNN3DTestCases ***** // ***** CNN3DTestCases *****
@Test(timeout = 180000L) @Test
public void testCnn3dSynthetic() throws Exception { public void testCnn3dSynthetic() throws Exception {
IntegrationTestRunner.runTest(CNN3DTestCases.getCnn3dTestCaseSynthetic(), testDir); IntegrationTestRunner.runTest(CNN3DTestCases.getCnn3dTestCaseSynthetic(), testDir);
} }
// ***** UnsupervisedTestCases ***** // ***** UnsupervisedTestCases *****
@Test(timeout = 120000L) @Test
public void testVAEMnistAnomaly() throws Exception { public void testVAEMnistAnomaly() throws Exception {
IntegrationTestRunner.runTest(UnsupervisedTestCases.getVAEMnistAnomaly(), testDir); IntegrationTestRunner.runTest(UnsupervisedTestCases.getVAEMnistAnomaly(), testDir);
} }
// ***** TransferLearningTestCases ***** // ***** TransferLearningTestCases *****
@Test(timeout = 360000L) @Test
public void testVgg16Transfer() throws Exception { public void testVgg16Transfer() throws Exception {
IntegrationTestRunner.runTest(CNN2DTestCases.getVGG16TransferTinyImagenet(), testDir); IntegrationTestRunner.runTest(CNN2DTestCases.getVGG16TransferTinyImagenet(), testDir);
} }
// ***** KerasImportTestCases *****
//TODO
} }

View File

@ -0,0 +1,40 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
public class IntegrationTestsSameDiff extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 300_000L;
}
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testMLPMnist() throws Exception {
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir);
}
}

View File

@ -1,5 +1,5 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -13,13 +13,8 @@
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.integration;
package org.deeplearning4j.integration.testcases; public enum ModelType {
MLN, CG, SAMEDIFF
/**
* Integration tests starting from Keras model
*/
public class KerasImportTestCases {
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -17,8 +18,9 @@
package org.deeplearning4j.integration; package org.deeplearning4j.integration;
import lombok.Data; import lombok.Data;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
@ -26,6 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
import java.io.File; import java.io.File;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* A single test case for integration tests * A single test case for integration tests
@ -37,16 +40,17 @@ public abstract class TestCase {
PRETRAINED, RANDOM_INIT PRETRAINED, RANDOM_INIT
} }
protected String testName; //See: readme.md for more details
protected TestType testType; protected String testName; //Name of the test, for display purposes
protected boolean testPredictions = true; protected TestType testType; //Type of model - from a pretrained model, or a randomly initialized model
protected boolean testGradients = true; protected boolean testPredictions = true; //If true: check the predictions/output. Requires getPredictionsTestData() to be implemented
protected boolean testUnsupervisedTraining = false; protected boolean testGradients = true; //If true: check the gradients. Requires getGradientsTestData() to be implemented
protected boolean testTrainingCurves = true; protected boolean testUnsupervisedTraining = false; //If true: perform unsupervised training. Only applies to layers like autoencoders, VAEs, etc. Requires getUnsupervisedTrainData() to be implemented
protected boolean testParamsPostTraining = true; protected boolean testTrainingCurves = true; //If true: perform training, and compare loss vs. iteration. Requires getTrainingData() method
protected boolean testEvaluation = true; protected boolean testParamsPostTraining = true; //If true: perform training, and compare parameters after training. Requires getTrainingData() method
protected boolean testParallelInference = true; protected boolean testEvaluation = true; //If true: perform evaluation. Requires getNewEvaluations() and getEvaluationTestData() methods implemented
protected boolean testOverfitting = true; protected boolean testParallelInference = true; //If true: run the model through ParallelInference. Requires getPredictionsTestData() method. Only applies to DL4J models, NOT SameDiff models
protected boolean testOverfitting = true; //If true: perform overfitting, and ensure the predictions match the training data. Requires both getOverfittingData() and getOverfitNumIterations()
protected int[] unsupervisedTrainLayersMLN = null; protected int[] unsupervisedTrainLayersMLN = null;
protected String[] unsupervisedTrainLayersCG = null; protected String[] unsupervisedTrainLayersCG = null;
@ -65,6 +69,8 @@ public abstract class TestCase {
protected double maxRelativeErrorOverfit = 1e-2; protected double maxRelativeErrorOverfit = 1e-2;
protected double minAbsErrorOverfit = 1e-2; protected double minAbsErrorOverfit = 1e-2;
public abstract ModelType modelType();
/** /**
* Initialize the test case... many tests don't need this; others may use it to download or create data * Initialize the test case... many tests don't need this; others may use it to download or create data
* @param testWorkingDir Working directory to use for test * @param testWorkingDir Working directory to use for test
@ -88,19 +94,37 @@ public abstract class TestCase {
} }
/** /**
* Required if testPredictions == true * Required if testPredictions == true && DL4J model (MultiLayerNetwork or ComputationGraph)
*/ */
public List<Pair<INDArray[],INDArray[]>> getPredictionsTestData() throws Exception { public List<Pair<INDArray[],INDArray[]>> getPredictionsTestData() throws Exception {
throw new RuntimeException("Implementations must override this method if used"); throw new RuntimeException("Implementations must override this method if used");
} }
/** /**
* Required if testGradients == true * Required if testPredictions == true && SameDiff model
*/
public List<Map<String,INDArray>> getPredictionsTestDataSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
public List<String> getPredictionsNamesSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/**
* Required if testGradients == true && DL4J model
*/ */
public MultiDataSet getGradientsTestData() throws Exception { public MultiDataSet getGradientsTestData() throws Exception {
throw new RuntimeException("Implementations must override this method if used"); throw new RuntimeException("Implementations must override this method if used");
} }
/**
* Required if testGradients == true && SameDiff model
*/
public Map<String,INDArray> getGradientsTestDataSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/** /**
* Required when testUnsupervisedTraining == true * Required when testUnsupervisedTraining == true
*/ */
@ -122,6 +146,10 @@ public abstract class TestCase {
throw new RuntimeException("Implementations must override this method if used"); throw new RuntimeException("Implementations must override this method if used");
} }
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations){
throw new RuntimeException("Implementations must override this method if used");
}
/** /**
* Required if testEvaluation == true * Required if testEvaluation == true
*/ */
@ -130,12 +158,19 @@ public abstract class TestCase {
} }
/** /**
* Required if testOverfitting == true * Required if testOverfitting == true && DL4J model
*/ */
public MultiDataSet getOverfittingData() throws Exception { public MultiDataSet getOverfittingData() throws Exception {
throw new RuntimeException("Implementations must override this method if used"); throw new RuntimeException("Implementations must override this method if used");
} }
/**
* Required if testOverfitting == true && SameDiff model
*/
public Map<String,INDArray> getOverfittingDataSameDiff() throws Exception {
throw new RuntimeException("Implementations must override this method if used");
}
/** /**
* Required if testOverfitting == true * Required if testOverfitting == true
*/ */

View File

@ -1,36 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases;
import org.deeplearning4j.integration.TestCase;
public class TransferLearningTestCases {
public static TestCase testPartFrozenResNet50(){
throw new UnsupportedOperationException("Not yet implemented");
}
public static TestCase testPartFrozenNASNET(){
throw new UnsupportedOperationException("Not yet implemented");
}
}

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,22 +15,24 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.integration.testcases; package org.deeplearning4j.integration.testcases.dl4j;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.integration.testcases.misc.CharacterIterator; import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -64,12 +67,18 @@ public class CNN1DTestCases {
int miniBatchSize = 16; int miniBatchSize = 16;
int exampleLength = 128; int exampleLength = 128;
@Override
public ModelType modelType() {
return ModelType.CG;
}
@Override @Override
public Object getConfiguration() throws Exception { public Object getConfiguration() throws Exception {
CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength); CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
int nOut = iter.totalOutcomes(); int nOut = iter.totalOutcomes();
return new NeuralNetConfiguration.Builder() return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345) .seed(12345)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.updater(new Adam(0.01)) .updater(new Adam(0.01))

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,7 +15,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.integration.testcases; package org.deeplearning4j.integration.testcases.dl4j;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.loader.NativeImageLoader;
@ -22,16 +23,13 @@ import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
import org.datavec.image.recordreader.objdetect.impl.SvhnLabelProvider; import org.datavec.image.recordreader.objdetect.impl.SvhnLabelProvider;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.fetchers.SvhnDataFetcher; import org.deeplearning4j.datasets.fetchers.SvhnDataFetcher;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.datasets.fetchers.DataSetType; import org.deeplearning4j.datasets.fetchers.DataSetType;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.datasets.iterator.impl.TinyImageNetDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.TinyImageNetDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.EvaluationCalibration;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.*;
@ -47,7 +45,12 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.zoo.PretrainedType; import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.model.TinyYOLO; import org.deeplearning4j.zoo.model.TinyYOLO;
import org.deeplearning4j.zoo.model.VGG16; import org.deeplearning4j.zoo.model.VGG16;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
@ -82,12 +85,18 @@ public class CNN2DTestCases {
testOverfitting = false; testOverfitting = false;
} }
@Override
public ModelType modelType() {
return ModelType.MLN;
}
public Object getConfiguration() throws Exception { public Object getConfiguration() throws Exception {
int nChannels = 1; // Number of input channels int nChannels = 1; // Number of input channels
int outputNum = 10; // The number of possible outcomes int outputNum = 10; // The number of possible outcomes
int seed = 123; int seed = 123;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(seed) .seed(seed)
.l2(0.0005) .l2(0.0005)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
@ -187,6 +196,11 @@ public class CNN2DTestCases {
testOverfitting = false; testOverfitting = false;
} }
@Override
public ModelType modelType() {
return ModelType.CG;
}
@Override @Override
public Model getPretrainedModel() throws Exception { public Model getPretrainedModel() throws Exception {
VGG16 vgg16 = VGG16.builder() VGG16 vgg16 = VGG16.builder()
@ -269,6 +283,11 @@ public class CNN2DTestCases {
testOverfitting = false; testOverfitting = false;
} }
@Override
public ModelType modelType() {
return ModelType.CG;
}
@Override @Override
public Model getPretrainedModel() throws Exception { public Model getPretrainedModel() throws Exception {
int nClasses = 10; int nClasses = 10;
@ -372,6 +391,11 @@ public class CNN2DTestCases {
testOverfitting = true; testOverfitting = true;
} }
@Override
public ModelType modelType() {
return ModelType.CG;
}
@Override @Override
public Model getPretrainedModel() throws Exception { public Model getPretrainedModel() throws Exception {
@ -381,6 +405,7 @@ public class CNN2DTestCases {
lrSchedule.put(3000, 0.001); lrSchedule.put(3000, 0.001);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345) .seed(12345)
.l2(0.0005) .l2(0.0005)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,35 +15,31 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.integration.testcases; package org.deeplearning4j.integration.testcases.dl4j;
import org.apache.commons.math3.stat.inference.TestUtils;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -66,6 +63,11 @@ public class CNN3DTestCases {
testOverfitting = false; testOverfitting = false;
} }
@Override
public ModelType modelType() {
return ModelType.MLN;
}
public Object getConfiguration() throws Exception { public Object getConfiguration() throws Exception {
int nChannels = 3; // Number of input channels int nChannels = 3; // Number of input channels
int outputNum = 10; // The number of possible outcomes int outputNum = 10; // The number of possible outcomes

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,8 +15,9 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.integration.testcases; package org.deeplearning4j.integration.testcases.dl4j;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.TestCase;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
@ -24,10 +26,6 @@ import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.EvaluationCalibration;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
@ -35,7 +33,12 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil; import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
@ -76,9 +79,15 @@ public class MLPTestCases {
minAbsErrorOverfit = 1e-2; minAbsErrorOverfit = 1e-2;
} }
@Override
public ModelType modelType() {
return ModelType.MLN;
}
@Override @Override
public Object getConfiguration() { public Object getConfiguration() {
return new NeuralNetConfiguration.Builder() return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345) .seed(12345)
.updater(new Adam(new MapSchedule.Builder(ScheduleType.ITERATION) .updater(new Adam(new MapSchedule.Builder(ScheduleType.ITERATION)
.add(0, 5e-2) .add(0, 5e-2)
@ -168,6 +177,11 @@ public class MLPTestCases {
testOverfitting = false; //Not much point here: very simple training data testOverfitting = false; //Not much point here: very simple training data
} }
@Override
public ModelType modelType() {
return ModelType.MLN;
}
@Override @Override
public Object getConfiguration() { public Object getConfiguration() {
int seed = 123; int seed = 123;
@ -179,6 +193,7 @@ public class MLPTestCases {
//log.info("Build model...."); //log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(seed) .seed(seed)
.updater(new Nesterovs(learningRate, 0.9)) .updater(new Nesterovs(learningRate, 0.9))
.list() .list()

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,22 +15,24 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.integration.testcases; package org.deeplearning4j.integration.testcases.dl4j;
import org.deeplearning4j.integration.ModelType;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.integration.testcases.misc.CharacterIterator; import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
import org.deeplearning4j.integration.testcases.misc.CompositeMultiDataSetPreProcessor;
import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.EvaluationCalibration;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
@ -91,6 +94,11 @@ public class RNNTestCases {
private int exampleLength = 1000; private int exampleLength = 1000;
@Override
public ModelType modelType() {
return ModelType.MLN;
}
@Override @Override
public Object getConfiguration() throws Exception { public Object getConfiguration() throws Exception {
@ -101,6 +109,7 @@ public class RNNTestCases {
int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters
return new NeuralNetConfiguration.Builder() return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345) .seed(12345)
.l2(0.001) .l2(0.001)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
@ -175,9 +184,15 @@ public class RNNTestCases {
return normalizer; return normalizer;
} }
@Override
public ModelType modelType() {
return ModelType.MLN;
}
@Override @Override
public Object getConfiguration() throws Exception { public Object getConfiguration() throws Exception {
return new NeuralNetConfiguration.Builder() return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345) .seed(12345)
.updater(new Adam(5e-2)) .updater(new Adam(5e-2))
.l1(1e-3).l2(1e-3) .l1(1e-3).l2(1e-3)
@ -298,6 +313,7 @@ public class RNNTestCases {
@Override @Override
public Object getConfiguration() throws Exception { public Object getConfiguration() throws Exception {
return new NeuralNetConfiguration.Builder() return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345) .seed(12345)
.updater(new Adam(5e-2)) .updater(new Adam(5e-2))
.l1(1e-3).l2(1e-3) .l1(1e-3).l2(1e-3)

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,18 +15,20 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.integration.testcases; package org.deeplearning4j.integration.testcases.dl4j;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution; import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
@ -59,9 +62,15 @@ public class UnsupervisedTestCases {
minAbsErrorPretrainParams = 5e-4; minAbsErrorPretrainParams = 5e-4;
} }
@Override
public ModelType modelType() {
return ModelType.MLN;
}
@Override @Override
public Object getConfiguration() { public Object getConfiguration() {
return new NeuralNetConfiguration.Builder() return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345) .seed(12345)
.updater(new Adam(0.05)) .updater(new Adam(0.05))
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)

View File

@ -14,7 +14,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.integration.testcases.misc; package org.deeplearning4j.integration.testcases.dl4j.misc;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -1,36 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases.misc;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
public class CompositeMultiDataSetPreProcessor implements MultiDataSetPreProcessor {
private MultiDataSetPreProcessor[] preProcessors;
public CompositeMultiDataSetPreProcessor(MultiDataSetPreProcessor... preProcessors){
this.preProcessors = preProcessors;
}
@Override
public void preProcess(MultiDataSet multiDataSet) {
for(MultiDataSetPreProcessor p : preProcessors){
p.preProcess(multiDataSet);
}
}
}

View File

@ -0,0 +1,155 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases.samediff;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import java.util.*;
public class SameDiffMLPTestCases {
public static TestCase getMLPMnist(){
return new TestCase() {
{
testName = "MLPMnistSD";
testType = TestType.RANDOM_INIT;
testPredictions = true;
testTrainingCurves = true;
testGradients = true;
testParamsPostTraining = true;
testEvaluation = true;
testOverfitting = true;
maxRelativeErrorOverfit = 2e-2;
minAbsErrorOverfit = 1e-2;
}
@Override
public ModelType modelType() {
return ModelType.SAMEDIFF;
}
@Override
public Object getConfiguration() throws Exception {
Nd4j.getRandom().setSeed(12345);
//Define the network structure:
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256));
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256));
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10));
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10));
SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0));
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
SDVariable loss = sd.loss.logLoss("loss", label, out);
//Also set the training configuration:
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new Adam(0.01))
.weightDecay(1e-3, true)
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
.build());
return sd;
}
@Override
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
List<Map<String,INDArray>> out = new ArrayList<>();
DataSetIterator iter = new MnistDataSetIterator(1, true, 12345);
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
iter = new MnistDataSetIterator(8, true, 12345);
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
return out;
}
@Override
public List<String> getPredictionsNamesSameDiff() throws Exception {
return Collections.singletonList("out");
}
@Override
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
Map<String,INDArray> map = new HashMap<>();
map.put("in", ds.getFeatures());
map.put("label", ds.getLabels());
return map;
}
@Override
public MultiDataSetIterator getTrainingData() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(8, true, 12345);
iter = new EarlyTerminationDataSetIterator(iter, 32);
return new MultiDataSetIteratorAdapter(iter);
}
@Override
public IEvaluation[] getNewEvaluations() {
return new IEvaluation[]{new Evaluation()};
}
@Override
public MultiDataSetIterator getEvaluationTestData() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(8, false, 12345);
iter = new EarlyTerminationDataSetIterator(iter, 32);
return new MultiDataSetIteratorAdapter(iter);
}
@Override
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
sd.evaluate(iter, "out", 0, evaluations);
return evaluations;
}
@Override
public MultiDataSet getOverfittingData() throws Exception {
return new MnistDataSetIterator(1, true, 12345).next().toMultiDataSet();
}
@Override
public int getOverfitNumIterations() {
return 100;
}
};
}
}

View File

@ -872,6 +872,8 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public native void setLeaksDetector(@Cast("bool") boolean reallyDetect); public native void setLeaksDetector(@Cast("bool") boolean reallyDetect);
public native @Cast("bool") boolean helpersAllowed(); public native @Cast("bool") boolean helpersAllowed();
public native void allowHelpers(@Cast("bool") boolean reallyAllow); public native void allowHelpers(@Cast("bool") boolean reallyAllow);
public native @Cast("bool") boolean blasFallback();
public native int tadThreshold(); public native int tadThreshold();
public native void setTadThreshold(int threshold); public native void setTadThreshold(int threshold);
@ -4165,15 +4167,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
*/ */
public native void transposei(); public native void transposei();
/**
* return array pointing on certain range of this array
* index - the number of array to be returned among set of possible arrays
* dimensions - array of dimensions to point on
*/
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntPointer dimensions);
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntBuffer dimensions);
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector int[] dimensions);
/** /**
* returns the number of arrays pointing on specified dimension(s) * returns the number of arrays pointing on specified dimension(s)
* dimensions - array of dimensions to point on * dimensions - array of dimensions to point on
@ -6881,9 +6874,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3); @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3);
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3); @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shape, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shape, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shape, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native void traceNew(int id); @Namespace("shape") public static native void traceNew(int id);
@ -7323,14 +7316,12 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo); @Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo);
@Namespace("shape") public static native int rank(@Const int[] shapeInfo); @Namespace("shape") public static native int rank(@Const int[] shapeInfo);
// returns pointer on elementWiseStride
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
/** /**
* returns pointer on elementWiseStride * returns pointer on elementWiseStride
*/ */
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
/** /**
* Converts a raw int buffer of the layout: * Converts a raw int buffer of the layout:
@ -8010,12 +8001,33 @@ public static final int PREALLOC_SIZE = 33554432;
* subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer * subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer
* keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} * keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
*/ */
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets);
/**
* processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array
* arguments:
* idx - input argument, intervals of indexes which define the sub-array to point on,
* when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank)
* when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank)
* when (dimStart == dimEnd) then whole range will be used for current dimension
* maxShapeInfo - input argument, shapeInfo of original array
* minShapeInfo - output argument, shapeInfo of sub-array to be deduced
* minOffset - output argument, offset of sub-array buffer offsets from original buffer
* keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
* isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd,
* numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1
*/
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset);
/** /**
* for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99} * for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99}
@ -8036,6 +8048,14 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer outShapeInfo); @Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer outShapeInfo);
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] outShapeInfo); @Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] outShapeInfo);
/**
* get stride over contiguous axis (contiguous axis must have stride = 1)
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
*/
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
@ -8908,6 +8928,8 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { // INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
@ -9103,6 +9125,10 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// #endif /* SHAPE_H_ */ // #endif /* SHAPE_H_ */

View File

@ -875,6 +875,8 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public native void setLeaksDetector(@Cast("bool") boolean reallyDetect); public native void setLeaksDetector(@Cast("bool") boolean reallyDetect);
public native @Cast("bool") boolean helpersAllowed(); public native @Cast("bool") boolean helpersAllowed();
public native void allowHelpers(@Cast("bool") boolean reallyAllow); public native void allowHelpers(@Cast("bool") boolean reallyAllow);
public native @Cast("bool") boolean blasFallback();
public native int tadThreshold(); public native int tadThreshold();
public native void setTadThreshold(int threshold); public native void setTadThreshold(int threshold);
@ -4168,15 +4170,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
*/ */
public native void transposei(); public native void transposei();
/**
* return array pointing on certain range of this array
* index - the number of array to be returned among set of possible arrays
* dimensions - array of dimensions to point on
*/
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntPointer dimensions);
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntBuffer dimensions);
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector int[] dimensions);
/** /**
* returns the number of arrays pointing on specified dimension(s) * returns the number of arrays pointing on specified dimension(s)
* dimensions - array of dimensions to point on * dimensions - array of dimensions to point on
@ -6884,9 +6877,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3); @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3);
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3); @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shape, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shape, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shape, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native void traceNew(int id); @Namespace("shape") public static native void traceNew(int id);
@ -7326,14 +7319,12 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo); @Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo);
@Namespace("shape") public static native int rank(@Const int[] shapeInfo); @Namespace("shape") public static native int rank(@Const int[] shapeInfo);
// returns pointer on elementWiseStride
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
/** /**
* returns pointer on elementWiseStride * returns pointer on elementWiseStride
*/ */
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
/** /**
* Converts a raw int buffer of the layout: * Converts a raw int buffer of the layout:
@ -8013,12 +8004,33 @@ public static final int PREALLOC_SIZE = 33554432;
* subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer * subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer
* keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} * keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
*/ */
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets); @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets);
/**
* processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array
* arguments:
* idx - input argument, intervals of indexes which define the sub-array to point on,
* when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank)
* when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank)
* when (dimStart == dimEnd) then whole range will be used for current dimension
* maxShapeInfo - input argument, shapeInfo of original array
* minShapeInfo - output argument, shapeInfo of sub-array to be deduced
* minOffset - output argument, offset of sub-array buffer offsets from original buffer
* keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
* isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd,
* numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1
*/
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset);
/** /**
* for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99} * for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99}
@ -8039,6 +8051,14 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer outShapeInfo); @Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer outShapeInfo);
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] outShapeInfo); @Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] outShapeInfo);
/**
* get stride over contiguous axis (contiguous axis must have stride = 1)
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
*/
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
@ -8911,6 +8931,8 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { // INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
@ -9106,6 +9128,10 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// #endif /* SHAPE_H_ */ // #endif /* SHAPE_H_ */

View File

@ -102,6 +102,7 @@ public class RngTests extends BaseNd4jTest {
@Test @Test
public void testRandomBinomial() { public void testRandomBinomial() {
Nd4j.getRandom().setSeed(12345);
//silly tests. Just increasing the usage for randomBinomial to stop compiler warnings. //silly tests. Just increasing the usage for randomBinomial to stop compiler warnings.
INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3); INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3);
assertTrue(x.sum().getDouble(0) > 0.0); //silly test. Just increasing th usage for randomBinomial assertTrue(x.sum().getDouble(0) > 0.0); //silly test. Just increasing th usage for randomBinomial