DL4J integrations tests updates + add SameDiff support (#298)
* Revive and start updating DL4J integration tests Signed-off-by: Alex Black <blacka101@gmail.com> * Add SameDiff support - first pass Signed-off-by: Alex Black <blacka101@gmail.com> * SameDiff test case generation Signed-off-by: Alex Black <blacka101@gmail.com> * SameDiff integration tests polishing Signed-off-by: Alex Black <blacka101@gmail.com> * More SameDiff integration test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Final polish Signed-off-by: Alex Black <blacka101@gmail.com> * Small test tweak Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
ead5162c97
commit
a80fb99a5f
|
@ -89,12 +89,12 @@ public abstract class BaseDL4JTest {
|
|||
return getDataType();
|
||||
}
|
||||
|
||||
protected Boolean integrationTest;
|
||||
protected static Boolean integrationTest;
|
||||
|
||||
/**
|
||||
* @return True if integration tests maven profile is enabled, false otherwise.
|
||||
*/
|
||||
public boolean isIntegrationTests(){
|
||||
public static boolean isIntegrationTests(){
|
||||
if(integrationTest == null){
|
||||
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
|
||||
integrationTest = Boolean.parseBoolean(prop);
|
||||
|
@ -107,7 +107,7 @@ public abstract class BaseDL4JTest {
|
|||
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
|
||||
* Note that the integration test profile is not enabled by default - "integration-tests" profile
|
||||
*/
|
||||
public void skipUnlessIntegrationTests(){
|
||||
public static void skipUnlessIntegrationTests(){
|
||||
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.optimize.listeners;
|
|||
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
|
||||
import it.unimi.dsi.fastutil.ints.IntArrayList;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
||||
|
@ -32,6 +33,7 @@ import java.io.Serializable;
|
|||
* @author Alex Black
|
||||
*/
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Slf4j
|
||||
public class CollectScoresListener extends BaseTrainingListener implements Serializable {
|
||||
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
|
||||
#DL4J Integration Tests
|
||||
#DL4J and SameDiff Integration Tests
|
||||
|
||||
These tests are designed to check a number of aspects of DL4J:
|
||||
1. Predictions
|
||||
These tests are designed to check a number of aspects of DL4J and SameDiff:
|
||||
1. Predictions (i.e., network output)
|
||||
2. Training (training curves, parameters, gradient calculation)
|
||||
3. Evaluation
|
||||
4. Model serialization
|
||||
5. Overfitting sanity checks
|
||||
3. Evaluation (accuracy, etc)
|
||||
4. Model serialization (saving + loading models)
|
||||
5. Overfitting sanity checks (make sure we can overfit a single example)
|
||||
6. Data pipelines
|
||||
7. Evaluation classes
|
||||
8. Parallel Wrapper
|
||||
9. Validating conditions that should always hold (frozen layer params don't change, for example)
|
||||
7. Parallel Wrapper
|
||||
8. Validating conditions that should always hold (frozen layer params don't change, for example)
|
||||
|
||||
|
||||
They are designed for the following purposes:
|
||||
|
@ -19,32 +18,46 @@ They are designed for the following purposes:
|
|||
3. Detecting significant differences between CPU and CUDA backends
|
||||
4. Validating implementation via sanity checks on training - i.e., can we overfit a single example?
|
||||
5. Checking networks and data pipelines on real-world scale data and nets
|
||||
6. Operating as fully automated pre-release checks (replacing previously used manual checks)
|
||||
6. Operating as fully automated pre-release checks (replacing manual sanity checks)
|
||||
|
||||
## Types of Tests
|
||||
## Main Classes
|
||||
|
||||
The integration tests are set up to be able to run multiple tests on each network configuration.
|
||||
Explanation of the main classes:
|
||||
* **IntegrationTestBaselineGenerator**: Run *manually* to generate and save "expected results" for comparing in the future.
|
||||
Output goes to dl4j-test-resources, for saving/uploading.
|
||||
* **IntegrationTestRunner**: Actually runs the tests, and compares the output/result to those generated by the baseline generator
|
||||
* **TestCase**: integration tests extend this
|
||||
* **testcases/\*.java**: the actual integration test definitions
|
||||
* **IntegrationTestsDL4J**: entry point for running the DL4J integration tests
|
||||
* **IntegrationTestsSameDiff**: entry point for running the SameDiff integration tests
|
||||
|
||||
## Types of Test Components
|
||||
|
||||
The integration tests are set up to be able to run multiple types of tests on each network configuration.
|
||||
|
||||
Networks may be pretrained (from model zoo) or randomly initialized (from specified configuration).
|
||||
|
||||
Specifically, test cases can be run with any subset of the following components to be tested, by setting TestCase.XYZ boolean options to true or false:
|
||||
|
||||
1. testPredictions: Testing output (predictions) on some specified data vs. saved/known good arrays
|
||||
2. testGradients: Testing gradients on some specified data vs. saved/known good arrays
|
||||
3. testPretrain: Test layerwise pretraining parameters and training curves
|
||||
4. testTrainingCurves: Train, and check score vs. iteration
|
||||
5. testParamsPostTraining: validate params match post training
|
||||
6. testEvaluation: test the evaluation performance (post training, if 4 or 5 are true)
|
||||
7. testParallelInference: validate that single net and parallel inference results match
|
||||
8. testOverfitting: sanity check - try to overfit a single example
|
||||
1. **testPredictions**: Testing output (predictions) on some specified data vs. saved/known good arrays
|
||||
2. **testGradients**: Testing gradients on some specified data vs. saved/known good arrays
|
||||
3. **testPretrain**: Test layerwise pretraining parameters and training curves
|
||||
4. **testTrainingCurves**: Train, and check score vs. iteration
|
||||
5. **testParamsPostTraining**: validate params match post training
|
||||
6. **testEvaluation**: test the evaluation performance (post training, if 4 or 5 are true)
|
||||
7. **testParallelInference**: validate that single net and parallel inference results match
|
||||
8. **testOverfitting**: sanity check - try to overfit a single example
|
||||
|
||||
See TestCase.java for more details.
|
||||
|
||||
|
||||
## Adding a New Integration Test
|
||||
|
||||
The process to add a new test is simple:
|
||||
1. Add a method that creates and returns a TestCase object
|
||||
2. Add it as a unit test to IntegrationTests class
|
||||
3. Run IntegrationTestBaselineGenerator (if required) to generate and save the "known good" results.
|
||||
1. Add a method that creates and returns a TestCase object (example: testcases/MLPTestCases.getMLPMnist())
|
||||
2. Add it as a unit test to IntegrationTests class (example: IntegrationTestsDL4J.testMLPMnist())
|
||||
3. Run IntegrationTestBaselineGenerator with the new test case, to generate and save the "known good" results.
|
||||
4. Run the new integration test to make sure it passes, on both CPU and CUDA backends
|
||||
5. Commit the generated test resources from step 3 to dl4j-test-resources repo
|
||||
|
||||
Note that IntegrationTestBaselineGenerator assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo.
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -16,15 +17,10 @@
|
|||
|
||||
package org.deeplearning4j.integration;
|
||||
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||
import org.deeplearning4j.eval.IEvaluation;
|
||||
import org.deeplearning4j.integration.testcases.CNN2DTestCases;
|
||||
import org.deeplearning4j.integration.testcases.MLPTestCases;
|
||||
import org.deeplearning4j.integration.testcases.RNNTestCases;
|
||||
import org.deeplearning4j.integration.testcases.UnsupervisedTestCases;
|
||||
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
|
@ -32,20 +28,27 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
|
|||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.optimize.listeners.CollectScoresListener;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.nd4j.autodiff.listeners.records.History;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
|
||||
import java.io.*;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
/**
|
||||
* Run this manually to generate - or update - the saved files for a specific test.
|
||||
* Places results in dl4j-test-resources: assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo.
|
||||
|
@ -53,32 +56,31 @@ import java.util.stream.Collectors;
|
|||
@Slf4j
|
||||
public class IntegrationTestBaselineGenerator {
|
||||
|
||||
public static final File OUTPUT_DIR = new File("../../dl4j-test-resources/src/main/resources/dl4j-integration-tests").getAbsoluteFile();
|
||||
public static final File OUTPUT_DIR_DL4J = new File("../../dl4j-test-resources/src/main/resources/dl4j-integration-tests").getAbsoluteFile();
|
||||
public static final File OUTPUT_DIR_SAMEDIFF = new File("../../dl4j-test-resources/src/main/resources/samediff-integration-tests").getAbsoluteFile();
|
||||
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (!OUTPUT_DIR.exists()) {
|
||||
throw new RuntimeException("output directory (test resources) does not exist!");
|
||||
if (!OUTPUT_DIR_DL4J.exists() && !OUTPUT_DIR_SAMEDIFF.exists()) {
|
||||
throw new RuntimeException("output directories in test resources do not exist!");
|
||||
}
|
||||
|
||||
//All integration tests are run with float precision!
|
||||
Nd4j.setDataType(DataType.FLOAT);
|
||||
|
||||
// runGeneration(
|
||||
// MLPTestCases.getMLPMnist(),
|
||||
// );
|
||||
runGeneration(
|
||||
SameDiffMLPTestCases.getMLPMnist()
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
private static void runGeneration(TestCase... testCases) throws Exception {
|
||||
|
||||
for( TestCase tc : testCases ) {
|
||||
final ModelType modelType = tc.modelType();
|
||||
|
||||
//Basic validation:
|
||||
Preconditions.checkState(tc.getTestName() != null, "Test case name is null");
|
||||
|
||||
//Run through each test case:
|
||||
File testBaseDir = new File(OUTPUT_DIR, tc.getTestName());
|
||||
File testBaseDir = new File(modelType == ModelType.SAMEDIFF ? OUTPUT_DIR_SAMEDIFF : OUTPUT_DIR_DL4J, tc.getTestName());
|
||||
if (testBaseDir.exists()) {
|
||||
FileUtils.forceDelete(testBaseDir);
|
||||
}
|
||||
|
@ -109,56 +111,62 @@ public class IntegrationTestBaselineGenerator {
|
|||
//First: if test is a random init test: generate the config, and save it
|
||||
MultiLayerNetwork mln = null;
|
||||
ComputationGraph cg = null;
|
||||
Model m;
|
||||
boolean isMLN;
|
||||
SameDiff sd = null;
|
||||
Model m = null;
|
||||
if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) {
|
||||
Object config = tc.getConfiguration();
|
||||
String json;
|
||||
String json = null;
|
||||
if (config instanceof MultiLayerConfiguration) {
|
||||
MultiLayerConfiguration mlc = (MultiLayerConfiguration) config;
|
||||
isMLN = true;
|
||||
json = mlc.toJson();
|
||||
mln = new MultiLayerNetwork(mlc);
|
||||
mln.init();
|
||||
m = mln;
|
||||
} else {
|
||||
} else if (config instanceof ComputationGraphConfiguration){
|
||||
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
||||
isMLN = false;
|
||||
json = cgc.toJson();
|
||||
cg = new ComputationGraph(cgc);
|
||||
cg.init();
|
||||
m = cg;
|
||||
} else {
|
||||
sd = (SameDiff)config;
|
||||
}
|
||||
|
||||
File configFile = new File(testBaseDir, "config." + (isMLN ? "mlc.json" : "cgc.json"));
|
||||
FileUtils.writeStringToFile(configFile, json);
|
||||
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
|
||||
File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME);
|
||||
ModelSerializer.writeModel(m, savedModel, true);
|
||||
if(modelType != ModelType.SAMEDIFF) {
|
||||
File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json"));
|
||||
FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8);
|
||||
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
|
||||
ModelSerializer.writeModel(m, savedModel, true);
|
||||
} else {
|
||||
sd.save(savedModel, true);
|
||||
}
|
||||
log.info("RANDOM_INIT test - saved randomly initialized model to: {}", savedModel.getAbsolutePath());
|
||||
} else {
|
||||
//Pretrained model
|
||||
m = tc.getPretrainedModel();
|
||||
isMLN = (m instanceof MultiLayerNetwork);
|
||||
if (isMLN) {
|
||||
if (m instanceof MultiLayerNetwork) {
|
||||
mln = (MultiLayerNetwork) m;
|
||||
} else {
|
||||
} else if(m instanceof ComputationGraph){
|
||||
cg = (ComputationGraph) m;
|
||||
} else {
|
||||
sd = (SameDiff)m;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//Generate predictions to compare against
|
||||
if (tc.isTestPredictions()) {
|
||||
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
|
||||
Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
||||
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
|
||||
List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
|
||||
// Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
||||
|
||||
|
||||
File predictionsTestDir = new File(testBaseDir, "predictions");
|
||||
predictionsTestDir.mkdirs();
|
||||
|
||||
int count = 0;
|
||||
if (isMLN) {
|
||||
if (modelType == ModelType.MLN) {
|
||||
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
||||
INDArray f = p.getFirst()[0];
|
||||
INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]);
|
||||
|
@ -170,7 +178,7 @@ public class IntegrationTestBaselineGenerator {
|
|||
Nd4j.write(out, dos);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG) {
|
||||
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
||||
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
||||
|
||||
|
@ -182,6 +190,19 @@ public class IntegrationTestBaselineGenerator {
|
|||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
List<String> outNames = tc.getPredictionsNamesSameDiff();
|
||||
for( Map<String,INDArray> ph : inputsSd ){
|
||||
Map<String,INDArray> out = sd.output(ph, outNames);
|
||||
|
||||
//Save the output...
|
||||
for(String s : outNames){
|
||||
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
|
||||
try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) {
|
||||
Nd4j.write(out.get(s), dos);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.info("Saved predictions for {} inputs to disk in directory: {}", tc.getTestName(), predictionsTestDir);
|
||||
|
@ -189,32 +210,46 @@ public class IntegrationTestBaselineGenerator {
|
|||
|
||||
//Compute and save gradients:
|
||||
if (tc.isTestGradients()) {
|
||||
MultiDataSet data = tc.getGradientsTestData();
|
||||
INDArray gradientFlat;
|
||||
if (isMLN) {
|
||||
INDArray gradientFlat = null;
|
||||
Map<String,INDArray> grad;
|
||||
if (modelType == ModelType.MLN) {
|
||||
MultiDataSet data = tc.getGradientsTestData();
|
||||
mln.setInput(data.getFeatures(0));
|
||||
mln.setLabels(data.getLabels(0));
|
||||
mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
|
||||
mln.computeGradientAndScore();
|
||||
gradientFlat = mln.getFlattenedGradients();
|
||||
} else {
|
||||
grad = m.gradient().gradientForVariable();
|
||||
} else if(modelType == ModelType.CG) {
|
||||
MultiDataSet data = tc.getGradientsTestData();
|
||||
cg.setInputs(data.getFeatures());
|
||||
cg.setLabels(data.getLabels());
|
||||
cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
|
||||
cg.computeGradientAndScore();
|
||||
gradientFlat = cg.getFlattenedGradients();
|
||||
grad = m.gradient().gradientForVariable();
|
||||
} else {
|
||||
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff();
|
||||
List<String> allVars = new ArrayList<>();
|
||||
for(SDVariable v : sd.variables()){
|
||||
if(v.getVariableType() == VariableType.VARIABLE){
|
||||
allVars.add(v.name());
|
||||
}
|
||||
}
|
||||
grad = sd.calculateGradients(ph, allVars);
|
||||
}
|
||||
|
||||
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
||||
IntegrationTestRunner.write(gradientFlat, gFlatFile);
|
||||
if(modelType != ModelType.SAMEDIFF) {
|
||||
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
||||
IntegrationTestRunner.write(gradientFlat, gFlatFile);
|
||||
}
|
||||
|
||||
//Also save the gradient param table:
|
||||
Map<String, INDArray> g = m.gradient().gradientForVariable();
|
||||
File gradientDir = new File(testBaseDir, "gradients");
|
||||
gradientDir.mkdir();
|
||||
for (String s : g.keySet()) {
|
||||
for (String s : grad.keySet()) {
|
||||
File f = new File(gradientDir, s + ".bin");
|
||||
IntegrationTestRunner.write(g.get(s), f);
|
||||
IntegrationTestRunner.write(grad.get(s), f);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -224,7 +259,7 @@ public class IntegrationTestBaselineGenerator {
|
|||
MultiDataSetIterator iter = tc.getUnsupervisedTrainData();
|
||||
|
||||
INDArray paramsPostTraining;
|
||||
if(isMLN){
|
||||
if(modelType == ModelType.MLN){
|
||||
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
||||
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||
|
@ -233,7 +268,7 @@ public class IntegrationTestBaselineGenerator {
|
|||
mln.pretrainLayer(i, dsi);
|
||||
}
|
||||
paramsPostTraining = mln.params();
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG) {
|
||||
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
||||
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
||||
|
||||
|
@ -241,6 +276,8 @@ public class IntegrationTestBaselineGenerator {
|
|||
cg.pretrainLayer(i, iter);
|
||||
}
|
||||
paramsPostTraining = cg.params();
|
||||
} else {
|
||||
throw new UnsupportedOperationException("SameDiff not supported for unsupervised training tests");
|
||||
}
|
||||
|
||||
//Save params
|
||||
|
@ -251,23 +288,46 @@ public class IntegrationTestBaselineGenerator {
|
|||
//Test training curves:
|
||||
if (tc.isTestTrainingCurves()) {
|
||||
MultiDataSetIterator trainData = tc.getTrainingData();
|
||||
CollectScoresListener l = new CollectScoresListener(1);
|
||||
m.setListeners(l);
|
||||
|
||||
if (isMLN) {
|
||||
CollectScoresListener l = new CollectScoresListener(1);
|
||||
if(modelType != ModelType.SAMEDIFF)
|
||||
m.setListeners(l);
|
||||
|
||||
History h = null;
|
||||
if (modelType == ModelType.MLN) {
|
||||
mln.fit(trainData);
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG) {
|
||||
cg.fit(trainData);
|
||||
} else {
|
||||
h = sd.fit(trainData, 1);
|
||||
}
|
||||
|
||||
double[] scores;
|
||||
if(modelType != ModelType.SAMEDIFF){
|
||||
scores = l.getListScore().toDoubleArray();
|
||||
} else {
|
||||
scores = h.lossCurve().getLossValues().toDoubleVector();
|
||||
}
|
||||
|
||||
double[] scores = l.getListScore().toDoubleArray();
|
||||
File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME);
|
||||
List<String> s = Arrays.stream(scores).mapToObj(String::valueOf).collect(Collectors.toList());
|
||||
FileUtils.writeStringToFile(f, String.join(",", s));
|
||||
FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8);
|
||||
|
||||
if (tc.isTestParamsPostTraining()) {
|
||||
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
||||
IntegrationTestRunner.write(m.params(), p);
|
||||
if(modelType == ModelType.SAMEDIFF){
|
||||
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
|
||||
p.mkdirs();
|
||||
for(SDVariable v : sd.variables()){
|
||||
if(v.getVariableType() == VariableType.VARIABLE){
|
||||
INDArray arr = v.getArr();
|
||||
File p2 = new File(p, v.name() + ".bin");
|
||||
IntegrationTestRunner.write(arr, p2);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
||||
IntegrationTestRunner.write(m.params(), p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -276,11 +336,13 @@ public class IntegrationTestBaselineGenerator {
|
|||
IEvaluation[] evals = tc.getNewEvaluations();
|
||||
MultiDataSetIterator iter = tc.getEvaluationTestData();
|
||||
|
||||
if (isMLN) {
|
||||
if (modelType == ModelType.MLN) {
|
||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||
mln.doEvaluation(dsi, evals);
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG){
|
||||
cg.doEvaluation(iter, evals);
|
||||
} else {
|
||||
evals = tc.doEvaluationSameDiff(sd, iter, evals);
|
||||
}
|
||||
|
||||
File evalDir = new File(testBaseDir, "evaluation");
|
||||
|
@ -288,7 +350,7 @@ public class IntegrationTestBaselineGenerator {
|
|||
for (int i = 0; i < evals.length; i++) {
|
||||
String json = evals[i].toJson();
|
||||
File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json");
|
||||
FileUtils.writeStringToFile(f, json);
|
||||
FileUtils.writeStringToFile(f, json, StandardCharsets.UTF_8);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -17,14 +18,12 @@
|
|||
package org.deeplearning4j.integration;
|
||||
|
||||
|
||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||
import org.deeplearning4j.eval.*;
|
||||
import org.deeplearning4j.integration.util.CountingMultiDataSetIterator;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
|
@ -42,9 +41,16 @@ import org.deeplearning4j.parallelism.ParallelInference;
|
|||
import org.deeplearning4j.parallelism.inference.InferenceMode;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.autodiff.listeners.records.History;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.classification.*;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||
|
@ -55,12 +61,15 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
|||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.resources.Resources;
|
||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||
|
||||
import java.io.*;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
|
@ -79,6 +88,7 @@ public class IntegrationTestRunner {
|
|||
public static final String FLAT_GRADIENTS_FILENAME = "flattenedGradients.bin";
|
||||
public static final String TRAINING_CURVE_FILENAME = "trainingCurve.csv";
|
||||
public static final String PARAMS_POST_TRAIN_FILENAME = "paramsPostTrain.bin";
|
||||
public static final String PARAMS_POST_TRAIN_SAMEDIFF_DIR = "paramsPostTrain";
|
||||
public static final String PARAMS_POST_UNSUPERVISED_FILENAME = "paramsPostUnsupervised.bin";
|
||||
|
||||
public static final double MAX_REL_ERROR_SCORES = 1e-4;
|
||||
|
@ -148,21 +158,25 @@ public class IntegrationTestRunner {
|
|||
}
|
||||
|
||||
public static void runTest(TestCase tc, TemporaryFolder testDir) throws Exception {
|
||||
Preconditions.checkState(Nd4j.dataType() == DataType.FLOAT, "Integration tests must be run with float precision!");
|
||||
log.info("Starting test case: {}", tc.getTestName());
|
||||
BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
|
||||
//This could alternatively be done via maven surefire configuration
|
||||
|
||||
final ModelType modelType = tc.modelType();
|
||||
log.info("Starting test case: {} - type = {}", tc.getTestName(), modelType);
|
||||
long start = System.currentTimeMillis();
|
||||
|
||||
File workingDir = testDir.newFolder();
|
||||
tc.initialize(workingDir);
|
||||
|
||||
File testBaseDir = testDir.newFolder();
|
||||
new ClassPathResource("dl4j-integration-tests/" + tc.getTestName()).copyDirectory(testBaseDir);
|
||||
// new ClassPathResource("dl4j-integration-tests/" + tc.getTestName()).copyDirectory(testBaseDir);
|
||||
Resources.copyDirectory((modelType == ModelType.SAMEDIFF ? "samediff-integration-tests/" : "dl4j-integration-tests/") + tc.getTestName(), testBaseDir);
|
||||
|
||||
|
||||
MultiLayerNetwork mln = null;
|
||||
ComputationGraph cg = null;
|
||||
Model m;
|
||||
boolean isMLN;
|
||||
SameDiff sd = null;
|
||||
Model m = null;
|
||||
if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) {
|
||||
log.info("Checking RANDOM_INIT test case: saved model vs. initialized model");
|
||||
//Checking randomly initialized model:
|
||||
|
@ -173,36 +187,46 @@ public class IntegrationTestRunner {
|
|||
mln = new MultiLayerNetwork(mlc);
|
||||
mln.init();
|
||||
m = mln;
|
||||
isMLN = true;
|
||||
|
||||
MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
|
||||
assertEquals("Configs not equal", loaded.getLayerWiseConfigurations(), mln.getLayerWiseConfigurations());
|
||||
assertEquals("Params not equal", loaded.params(), mln.params());
|
||||
assertEquals("Param table not equal", loaded.paramTable(), mln.paramTable());
|
||||
} else {
|
||||
} else if(config instanceof ComputationGraphConfiguration ){
|
||||
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
||||
cg = new ComputationGraph(cgc);
|
||||
cg.init();
|
||||
m = cg;
|
||||
isMLN = false;
|
||||
|
||||
ComputationGraph loaded = ComputationGraph.load(savedModel, true);
|
||||
assertEquals("Configs not equal", loaded.getConfiguration(), cg.getConfiguration());
|
||||
assertEquals("Params not equal", loaded.params(), cg.params());
|
||||
assertEquals("Param table not equal", loaded.paramTable(), cg.paramTable());
|
||||
} else if(config instanceof SameDiff){
|
||||
sd = (SameDiff)config;
|
||||
SameDiff loaded = SameDiff.load(savedModel, true);
|
||||
|
||||
assertSameDiffEquals(sd, loaded);
|
||||
} else {
|
||||
throw new IllegalStateException("Unknown configuration/model type: " + config.getClass());
|
||||
}
|
||||
} else {
|
||||
m = tc.getPretrainedModel();
|
||||
isMLN = (m instanceof MultiLayerNetwork);
|
||||
if (isMLN) {
|
||||
if (m instanceof MultiLayerNetwork) {
|
||||
mln = (MultiLayerNetwork) m;
|
||||
} else {
|
||||
} else if(m instanceof ComputationGraph) {
|
||||
cg = (ComputationGraph) m;
|
||||
} else if(m instanceof SameDiff){
|
||||
sd = (SameDiff)m;
|
||||
} else {
|
||||
throw new IllegalStateException("Unknown model type: " + m.getClass());
|
||||
}
|
||||
}
|
||||
|
||||
//Collect information for test coverage
|
||||
collectCoverageInformation(m);
|
||||
if(modelType != ModelType.SAMEDIFF) {
|
||||
collectCoverageInformation(m);
|
||||
}
|
||||
|
||||
|
||||
//Check network output (predictions)
|
||||
|
@ -210,15 +234,16 @@ public class IntegrationTestRunner {
|
|||
log.info("Checking predictions: saved output vs. initialized model");
|
||||
|
||||
|
||||
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
|
||||
Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
||||
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
|
||||
List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
|
||||
Preconditions.checkState(modelType == ModelType.SAMEDIFF || inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
||||
|
||||
|
||||
File predictionsTestDir = new File(testBaseDir, "predictions");
|
||||
predictionsTestDir.mkdirs();
|
||||
|
||||
int count = 0;
|
||||
if (isMLN) {
|
||||
if (modelType == ModelType.MLN) {
|
||||
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
||||
INDArray f = p.getFirst()[0];
|
||||
INDArray fm = (p.getSecond() == null ? null : p.getSecond()[0]);
|
||||
|
@ -231,15 +256,15 @@ public class IntegrationTestRunner {
|
|||
outSaved = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
INDArray gradExceedsRE = exceedsRelError(outSaved, out, tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
||||
int countExceeds = gradExceedsRE.sumNumber().intValue();
|
||||
INDArray predictionExceedsRE = exceedsRelError(outSaved, out, tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
||||
int countExceeds = predictionExceedsRE.sumNumber().intValue();
|
||||
assertEquals("Predictions do not match saved predictions - output", 0, countExceeds);
|
||||
}
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG){
|
||||
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
||||
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
||||
|
||||
//Save the array(s)...
|
||||
//Load the previously saved arrays
|
||||
INDArray[] outSaved = new INDArray[out.length];
|
||||
for (int i = 0; i < out.length; i++) {
|
||||
File outFile = new File(predictionsTestDir, "output_" + (count++) + "_" + i + ".bin");
|
||||
|
@ -249,14 +274,36 @@ public class IntegrationTestRunner {
|
|||
}
|
||||
|
||||
for( int i=0; i<outSaved.length; i++ ){
|
||||
INDArray gradExceedsRE = exceedsRelError(outSaved[i], out[i], tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
||||
int countExceeds = gradExceedsRE.sumNumber().intValue();
|
||||
INDArray predictionExceedsRE = exceedsRelError(outSaved[i], out[i], tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
||||
int countExceeds = predictionExceedsRE.sumNumber().intValue();
|
||||
assertEquals("Predictions do not match saved predictions - output " + i, 0, countExceeds);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
List<String> outNames = tc.getPredictionsNamesSameDiff();
|
||||
for( Map<String,INDArray> ph : inputsSd ){
|
||||
Map<String,INDArray> out = sd.output(ph, outNames);
|
||||
|
||||
//Load the previously saved placeholder arrays
|
||||
Map<String,INDArray> outSaved = new HashMap<>();
|
||||
for(String s : outNames){
|
||||
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
|
||||
try (DataInputStream dis = new DataInputStream(new FileInputStream(f))) {
|
||||
outSaved.put(s, Nd4j.read(dis));
|
||||
}
|
||||
}
|
||||
|
||||
for(String s : outNames){
|
||||
INDArray predictionExceedsRE = exceedsRelError(outSaved.get(s), out.get(s), tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput());
|
||||
int countExceeds = predictionExceedsRE.sumNumber().intValue();
|
||||
assertEquals("Predictions do not match saved predictions - output \"" + s + "\"", 0, countExceeds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
checkLayerClearance(m);
|
||||
if(modelType != ModelType.SAMEDIFF) {
|
||||
checkLayerClearance(m);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -264,34 +311,49 @@ public class IntegrationTestRunner {
|
|||
if (tc.isTestGradients()) {
|
||||
log.info("Checking gradients: saved output vs. initialized model");
|
||||
|
||||
MultiDataSet data = tc.getGradientsTestData();
|
||||
INDArray gradientFlat;
|
||||
org.deeplearning4j.nn.api.Layer[] layers;
|
||||
if (isMLN) {
|
||||
INDArray gradientFlat = null;
|
||||
org.deeplearning4j.nn.api.Layer[] layers = null;
|
||||
Map<String,INDArray> grad;
|
||||
if (modelType == ModelType.MLN) {
|
||||
MultiDataSet data = tc.getGradientsTestData();
|
||||
mln.setInput(data.getFeatures(0));
|
||||
mln.setLabels(data.getLabels(0));
|
||||
mln.setLayerMaskArrays(data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
|
||||
mln.computeGradientAndScore();
|
||||
gradientFlat = mln.getFlattenedGradients();
|
||||
layers = mln.getLayers();
|
||||
} else {
|
||||
grad = mln.gradient().gradientForVariable();
|
||||
} else if(modelType == ModelType.CG) {
|
||||
MultiDataSet data = tc.getGradientsTestData();
|
||||
cg.setInputs(data.getFeatures());
|
||||
cg.setLabels(data.getLabels());
|
||||
cg.setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
|
||||
cg.computeGradientAndScore();
|
||||
gradientFlat = cg.getFlattenedGradients();
|
||||
layers = cg.getLayers();
|
||||
grad = cg.gradient().gradientForVariable();
|
||||
} else {
|
||||
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff();
|
||||
List<String> allVars = new ArrayList<>();
|
||||
for(SDVariable v : sd.variables()){
|
||||
if(v.getVariableType() == VariableType.VARIABLE){
|
||||
allVars.add(v.name());
|
||||
}
|
||||
}
|
||||
grad = sd.calculateGradients(ph, allVars);
|
||||
}
|
||||
|
||||
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
||||
INDArray gradientFlatSaved = read(gFlatFile);
|
||||
if(modelType != ModelType.SAMEDIFF) {
|
||||
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
||||
INDArray gradientFlatSaved = read(gFlatFile);
|
||||
|
||||
INDArray gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
|
||||
int count = gradExceedsRE.sumNumber().intValue();
|
||||
if(count > 0){
|
||||
logFailedParams(20, "Gradient", layers, gradExceedsRE, gradientFlatSaved, gradientFlat);
|
||||
INDArray gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
|
||||
int count = gradExceedsRE.sumNumber().intValue();
|
||||
if (count > 0) {
|
||||
logFailedParams(20, "Gradient", layers, gradExceedsRE, gradientFlatSaved, gradientFlat);
|
||||
}
|
||||
assertEquals("Saved flattened gradients: not equal (using relative error)", 0, count);
|
||||
}
|
||||
assertEquals("Saved flattened gradients: not equal (using relative error)", 0, count);
|
||||
|
||||
//Load the gradient table:
|
||||
File gradientDir = new File(testBaseDir, "gradients");
|
||||
|
@ -302,12 +364,12 @@ public class IntegrationTestRunner {
|
|||
String key = f.getName();
|
||||
key = key.substring(0, key.length() - 4); //remove ".bin"
|
||||
INDArray loaded = read(f);
|
||||
INDArray now = m.gradient().gradientForVariable().get(key);
|
||||
INDArray now = grad.get(key);
|
||||
|
||||
|
||||
gradExceedsRE = exceedsRelError(gradientFlatSaved, gradientFlat, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
|
||||
count = gradExceedsRE.sumNumber().intValue();
|
||||
assertEquals("Saved flattened gradients: not equal (using relative error) for parameter: " + key, 0, count);
|
||||
INDArray gradExceedsRE = exceedsRelError(loaded, now, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients());
|
||||
int count = gradExceedsRE.sumNumber().intValue();
|
||||
assertEquals("Gradients: not equal (using relative error) for parameter: " + key, 0, count);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -318,7 +380,7 @@ public class IntegrationTestRunner {
|
|||
|
||||
INDArray paramsPostTraining;
|
||||
org.deeplearning4j.nn.api.Layer[] layers;
|
||||
if(isMLN){
|
||||
if(modelType == ModelType.MLN){
|
||||
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
||||
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||
|
@ -328,7 +390,7 @@ public class IntegrationTestRunner {
|
|||
}
|
||||
paramsPostTraining = mln.params();
|
||||
layers = mln.getLayers();
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG) {
|
||||
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
||||
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
||||
|
||||
|
@ -337,6 +399,8 @@ public class IntegrationTestRunner {
|
|||
}
|
||||
paramsPostTraining = cg.params();
|
||||
layers = cg.getLayers();
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models");
|
||||
}
|
||||
|
||||
File f = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_UNSUPERVISED_FILENAME);
|
||||
|
@ -360,53 +424,78 @@ public class IntegrationTestRunner {
|
|||
MultiDataSetIterator trainData = tc.getTrainingData();
|
||||
boolean isTbptt;
|
||||
int tbpttLength;
|
||||
if(isMLN){
|
||||
if(modelType == ModelType.MLN){
|
||||
isTbptt = mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT;
|
||||
tbpttLength = mln.getLayerWiseConfigurations().getTbpttFwdLength();
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG) {
|
||||
isTbptt = cg.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT;
|
||||
tbpttLength = cg.getConfiguration().getTbpttFwdLength();
|
||||
} else {
|
||||
isTbptt = false;
|
||||
tbpttLength = 0;
|
||||
}
|
||||
|
||||
CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength);
|
||||
CollectScoresListener l = new CollectScoresListener(1);
|
||||
m.setListeners(l);
|
||||
if(modelType != ModelType.SAMEDIFF) {
|
||||
m.setListeners(l);
|
||||
}
|
||||
|
||||
int iterBefore;
|
||||
int epochBefore;
|
||||
int iterAfter;
|
||||
int epochAfter;
|
||||
|
||||
Map<String,INDArray> frozenParamsBefore = getFrozenLayerParamCopies(m);
|
||||
org.deeplearning4j.nn.api.Layer[] layers;
|
||||
if (isMLN) {
|
||||
Map<String,INDArray> frozenParamsBefore = modelType != ModelType.SAMEDIFF ? getFrozenLayerParamCopies(m) : getConstantCopies(sd);
|
||||
org.deeplearning4j.nn.api.Layer[] layers = null;
|
||||
History h = null;
|
||||
if (modelType == ModelType.MLN) {
|
||||
iterBefore = mln.getIterationCount();
|
||||
epochBefore = mln.getEpochCount();
|
||||
mln.fit(countingIter);
|
||||
iterAfter = mln.getIterationCount();
|
||||
epochAfter = mln.getEpochCount();
|
||||
layers = mln.getLayers();
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG){
|
||||
iterBefore = cg.getConfiguration().getIterationCount();
|
||||
epochBefore = cg.getConfiguration().getEpochCount();
|
||||
cg.fit(countingIter);
|
||||
iterAfter = cg.getConfiguration().getIterationCount();
|
||||
epochAfter = cg.getConfiguration().getEpochCount();
|
||||
layers = cg.getLayers();
|
||||
} else {
|
||||
iterBefore = sd.getTrainingConfig().getIterationCount();
|
||||
epochBefore = sd.getTrainingConfig().getEpochCount();
|
||||
h = sd.fit(countingIter, 1);
|
||||
iterAfter = sd.getTrainingConfig().getIterationCount();
|
||||
epochAfter = sd.getTrainingConfig().getEpochCount();
|
||||
}
|
||||
|
||||
//Check that frozen params (if any) haven't changed during training:
|
||||
checkFrozenParams(frozenParamsBefore, m);
|
||||
if(modelType == ModelType.SAMEDIFF) {
|
||||
checkConstants(frozenParamsBefore, sd);
|
||||
} else {
|
||||
checkFrozenParams(frozenParamsBefore, m);
|
||||
}
|
||||
|
||||
//Validate the iteration and epoch counts - both for the net, and for the layers
|
||||
int newIters = countingIter.getCurrIter();
|
||||
assertEquals(iterBefore + newIters, iterAfter);
|
||||
assertEquals(epochBefore + 1, epochAfter);
|
||||
validateLayerIterCounts(m, epochBefore + 1, iterBefore+newIters); //TODO CURRENTLY FAILING
|
||||
double[] scores = l.getListScore().toDoubleArray();
|
||||
if(modelType != ModelType.SAMEDIFF) {
|
||||
validateLayerIterCounts(m, epochBefore + 1, iterBefore + newIters);
|
||||
}
|
||||
|
||||
|
||||
double[] scores;
|
||||
if(modelType == ModelType.SAMEDIFF){
|
||||
scores = h.lossCurve().getLossValues().toDoubleVector();
|
||||
} else {
|
||||
scores = l.getListScore().toDoubleArray();
|
||||
}
|
||||
|
||||
File f = new File(testBaseDir, IntegrationTestRunner.TRAINING_CURVE_FILENAME);
|
||||
String[] s = FileUtils.readFileToString(f).split(",");
|
||||
String[] s = FileUtils.readFileToString(f, StandardCharsets.UTF_8).split(",");
|
||||
|
||||
if(tc.isTestTrainingCurves()) {
|
||||
assertEquals("Different number of scores", s.length, scores.length);
|
||||
|
@ -426,17 +515,36 @@ public class IntegrationTestRunner {
|
|||
}
|
||||
|
||||
if (tc.isTestParamsPostTraining()) {
|
||||
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
||||
INDArray paramsExp = read(p);
|
||||
INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
|
||||
int count = z.sumNumber().intValue();
|
||||
if(count > 0){
|
||||
logFailedParams(20, "Parameter", layers, z, paramsExp, m.params());
|
||||
if(modelType != ModelType.SAMEDIFF) {
|
||||
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
||||
INDArray paramsExp = read(p);
|
||||
INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
|
||||
int count = z.sumNumber().intValue();
|
||||
if (count > 0) {
|
||||
logFailedParams(20, "Parameter", layers, z, paramsExp, m.params());
|
||||
}
|
||||
assertEquals("Number of params exceeded max relative error", 0, count);
|
||||
} else {
|
||||
File dir = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
|
||||
for(SDVariable v : sd.variables()){
|
||||
if(v.getVariableType() != VariableType.VARIABLE)
|
||||
continue;
|
||||
INDArray paramNow = v.getArr();
|
||||
File paramFile = new File(dir, v.name() + ".bin");
|
||||
INDArray exp = read(paramFile);
|
||||
INDArray z = exceedsRelError(paramNow, exp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
|
||||
int count = z.sumNumber().intValue();
|
||||
if (count > 0) {
|
||||
logFailedParams(20, "Parameter: " + v.name(), layers, z, exp, paramNow);
|
||||
}
|
||||
assertEquals("Number of params exceeded max relative error for parameter: \"" + v.name() + "\"", 0, count);
|
||||
}
|
||||
}
|
||||
assertEquals("Number of params exceeded max relative error", 0, count);
|
||||
}
|
||||
|
||||
checkLayerClearance(m);
|
||||
if(modelType != ModelType.SAMEDIFF) {
|
||||
checkLayerClearance(m);
|
||||
}
|
||||
}
|
||||
|
||||
//Check evaluation:
|
||||
|
@ -445,17 +553,19 @@ public class IntegrationTestRunner {
|
|||
IEvaluation[] evals = tc.getNewEvaluations();
|
||||
MultiDataSetIterator iter = tc.getEvaluationTestData();
|
||||
|
||||
if (isMLN) {
|
||||
if (modelType == ModelType.MLN) {
|
||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||
mln.doEvaluation(dsi, evals);
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG){
|
||||
cg.doEvaluation(iter, evals);
|
||||
} else {
|
||||
evals = tc.doEvaluationSameDiff(sd, iter, evals);
|
||||
}
|
||||
|
||||
File evalDir = new File(testBaseDir, "evaluation");
|
||||
for (int i = 0; i < evals.length; i++) {
|
||||
File f = new File(evalDir, i + "." + evals[i].getClass().getSimpleName() + ".json");
|
||||
String json = FileUtils.readFileToString(f);
|
||||
String json = FileUtils.readFileToString(f, StandardCharsets.UTF_8);
|
||||
IEvaluation e;
|
||||
if (evals[i].getClass() == Evaluation.class) {
|
||||
e = Evaluation.fromJson(json);
|
||||
|
@ -479,7 +589,9 @@ public class IntegrationTestRunner {
|
|||
//Evaluation coverage information:
|
||||
evaluationClassesSeen.put(evals[i].getClass(), evaluationClassesSeen.getOrDefault(evals[i].getClass(), 0) + 1);
|
||||
|
||||
checkLayerClearance(m);
|
||||
if(modelType != ModelType.SAMEDIFF) {
|
||||
checkLayerClearance(m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -490,15 +602,20 @@ public class IntegrationTestRunner {
|
|||
File f = testDir.newFile();
|
||||
f.delete();
|
||||
|
||||
ModelSerializer.writeModel(m, f, true);
|
||||
if (isMLN) {
|
||||
if (modelType == ModelType.MLN) {
|
||||
ModelSerializer.writeModel(m, f, true);
|
||||
MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
|
||||
assertEquals(mln.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
||||
assertEquals(mln.params(), restored.params());
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG){
|
||||
ModelSerializer.writeModel(m, f, true);
|
||||
ComputationGraph restored = ComputationGraph.load(f, true);
|
||||
assertEquals(cg.getConfiguration(), restored.getConfiguration());
|
||||
assertEquals(cg.params(), restored.params());
|
||||
} else {
|
||||
sd.save(f, true);
|
||||
SameDiff restored = SameDiff.load(f, true);
|
||||
assertSameDiffEquals(sd, restored);
|
||||
}
|
||||
|
||||
System.gc();
|
||||
|
@ -506,7 +623,7 @@ public class IntegrationTestRunner {
|
|||
|
||||
|
||||
//Check parallel inference
|
||||
if (tc.isTestParallelInference()) {
|
||||
if (modelType != ModelType.SAMEDIFF && tc.isTestParallelInference()) {
|
||||
|
||||
List<Pair<INDArray[], INDArray[]>> inputs = tc.getPredictionsTestData();
|
||||
|
||||
|
@ -515,7 +632,7 @@ public class IntegrationTestRunner {
|
|||
List<INDArray[]> exp = new ArrayList<>();
|
||||
for(Pair<INDArray[], INDArray[]> p : inputs){
|
||||
INDArray[] out;
|
||||
if(isMLN){
|
||||
if(modelType == ModelType.MLN){
|
||||
INDArray fm = p.getSecond() == null ? null : p.getSecond()[0];
|
||||
out = new INDArray[]{mln.output(p.getFirst()[0], false, fm, null)};
|
||||
} else {
|
||||
|
@ -547,37 +664,54 @@ public class IntegrationTestRunner {
|
|||
|
||||
MultiDataSet toOverfit = tc.getOverfittingData();
|
||||
for (int i = 0; i < tc.getOverfitNumIterations(); i++) {
|
||||
if (isMLN) {
|
||||
if (modelType == ModelType.MLN) {
|
||||
mln.fit(toOverfit);
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG){
|
||||
cg.fit(toOverfit);
|
||||
} else {
|
||||
sd.fit(toOverfit);
|
||||
}
|
||||
}
|
||||
|
||||
//Check:
|
||||
INDArray[] output;
|
||||
if (isMLN) {
|
||||
INDArray[] output = null;
|
||||
Map<String,INDArray> outSd = null;
|
||||
if (modelType == ModelType.MLN) {
|
||||
mln.setLayerMaskArrays(toOverfit.getFeaturesMaskArray(0), null);
|
||||
output = new INDArray[]{mln.output(toOverfit.getFeatures(0))};
|
||||
} else {
|
||||
} else if(modelType == ModelType.CG ){
|
||||
cg.setLayerMaskArrays(toOverfit.getFeaturesMaskArrays(), null);
|
||||
output = cg.output(toOverfit.getFeatures());
|
||||
} else {
|
||||
List<String> l = sd.getTrainingConfig().getDataSetFeatureMapping();
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
int i=0;
|
||||
for(String s : l){
|
||||
phMap.put(s, toOverfit.getFeatures(i++));
|
||||
}
|
||||
outSd = sd.output(phMap, tc.getPredictionsNamesSameDiff());
|
||||
}
|
||||
|
||||
for (int i = 0; i < output.length; i++) {
|
||||
INDArray z = exceedsRelError(output[i], toOverfit.getLabels(i), tc.getMaxRelativeErrorOverfit(), tc.getMinAbsErrorOverfit());
|
||||
int n = modelType == ModelType.SAMEDIFF ? outSd.size() : output.length;
|
||||
for (int i = 0; i < n; i++) {
|
||||
INDArray out = modelType == ModelType.SAMEDIFF ? outSd.get(tc.getPredictionsNamesSameDiff().get(i)) : output[i];
|
||||
INDArray label = toOverfit.getLabels(i);
|
||||
|
||||
INDArray z = exceedsRelError(out, label, tc.getMaxRelativeErrorOverfit(), tc.getMinAbsErrorOverfit());
|
||||
int count = z.sumNumber().intValue();
|
||||
if (count > 0) {
|
||||
System.out.println(output[i]);
|
||||
System.out.println(toOverfit.getLabels(i));
|
||||
INDArray re = relativeError(output[i], toOverfit.getLabels(i), tc.getMinAbsErrorOverfit());
|
||||
System.out.println(out);
|
||||
System.out.println(label);
|
||||
INDArray re = relativeError(out, label, tc.getMinAbsErrorOverfit());
|
||||
System.out.println("Relative error:");
|
||||
System.out.println(re);
|
||||
}
|
||||
assertEquals("Number of outputs exceeded max relative error", 0, count);
|
||||
}
|
||||
|
||||
checkLayerClearance(m);
|
||||
if(modelType != ModelType.SAMEDIFF) {
|
||||
checkLayerClearance(m);
|
||||
}
|
||||
}
|
||||
|
||||
long end = System.currentTimeMillis();
|
||||
|
@ -709,6 +843,16 @@ public class IntegrationTestRunner {
|
|||
return out;
|
||||
}
|
||||
|
||||
private static Map<String,INDArray> getConstantCopies(SameDiff sd){
|
||||
Map<String,INDArray> out = new HashMap<>();
|
||||
for(SDVariable v : sd.variables()){
|
||||
if(v.isConstant()){
|
||||
out.put(v.name(), v.getArr());
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
public static void checkFrozenParams(Map<String,INDArray> copiesBeforeTraining, Model m){
|
||||
for(Map.Entry<String,INDArray> e : copiesBeforeTraining.entrySet()){
|
||||
INDArray actual = m.getParam(e.getKey());
|
||||
|
@ -716,6 +860,13 @@ public class IntegrationTestRunner {
|
|||
}
|
||||
}
|
||||
|
||||
public static void checkConstants(Map<String,INDArray> copiesBefore, SameDiff sd){
|
||||
for(Map.Entry<String,INDArray> e : copiesBefore.entrySet()){
|
||||
INDArray actual = sd.getArrForVarName(e.getKey());
|
||||
assertEquals(e.getKey(), e.getValue(), actual);
|
||||
}
|
||||
}
|
||||
|
||||
public static void printCoverageInformation(){
|
||||
|
||||
log.info("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||");
|
||||
|
@ -918,7 +1069,7 @@ public class IntegrationTestRunner {
|
|||
}
|
||||
|
||||
|
||||
public static void logFailedParams(int maxNum, String prefix, org.deeplearning4j.nn.api.Layer[] layers, INDArray exceedsRelError, INDArray exp, INDArray act){
|
||||
public static void logFailedParams(int maxNumToPrintOnFailure, String prefix, org.deeplearning4j.nn.api.Layer[] layers, INDArray exceedsRelError, INDArray exp, INDArray act){
|
||||
long length = exceedsRelError.length();
|
||||
int logCount = 0;
|
||||
for(int i=0; i<length; i++ ){
|
||||
|
@ -947,10 +1098,33 @@ public class IntegrationTestRunner {
|
|||
}
|
||||
|
||||
log.info("{} {} ({}) failed: expected {} vs actual {} (RelativeError: {}, AbsError: {})", i, prefix, pName, dExp, dAct, re, ae);
|
||||
if(++logCount >= maxNum){
|
||||
if(++logCount >= maxNumToPrintOnFailure){
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void assertSameDiffEquals(SameDiff sd1, SameDiff sd2){
|
||||
assertEquals(sd1.variableMap().keySet(), sd2.variableMap().keySet());
|
||||
assertEquals(sd1.getOps().keySet(), sd2.getOps().keySet());
|
||||
assertEquals(sd1.inputs(), sd2.inputs());
|
||||
|
||||
//Check constant and variable arrays:
|
||||
for(SDVariable v : sd1.variables()){
|
||||
String n = v.name();
|
||||
assertEquals(n, v.getVariableType(), sd2.getVariable(n).getVariableType());
|
||||
if(v.isConstant() || v.getVariableType() == VariableType.VARIABLE){
|
||||
INDArray a1 = v.getArr();
|
||||
INDArray a2 = sd2.getVariable(n).getArr();
|
||||
assertEquals(n, a1, a2);
|
||||
}
|
||||
}
|
||||
|
||||
//Check ops:
|
||||
for(SameDiffOp o : sd1.getOps().values()){
|
||||
SameDiffOp o2 = sd2.getOps().get(o.getName());
|
||||
assertEquals(o.getOp().getClass(), o2.getOp().getClass());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -17,15 +18,19 @@
|
|||
package org.deeplearning4j.integration;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.integration.testcases.*;
|
||||
import org.deeplearning4j.integration.testcases.dl4j.*;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
|
||||
public class IntegrationTests extends BaseDL4JTest {
|
||||
//@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
|
||||
public class IntegrationTestsDL4J extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 300_000L;
|
||||
}
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
@ -36,79 +41,72 @@ public class IntegrationTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
// ***** MLPTestCases *****
|
||||
@Test(timeout = 20000L)
|
||||
@Test
|
||||
public void testMLPMnist() throws Exception {
|
||||
IntegrationTestRunner.runTest(MLPTestCases.getMLPMnist(), testDir);
|
||||
}
|
||||
|
||||
@Test(timeout = 30000L)
|
||||
@Test
|
||||
public void testMlpMoon() throws Exception {
|
||||
IntegrationTestRunner.runTest(MLPTestCases.getMLPMoon(), testDir);
|
||||
}
|
||||
|
||||
// ***** RNNTestCases *****
|
||||
@Test(timeout = 30000L)
|
||||
@Test
|
||||
public void testRnnSeqClassification1() throws Exception {
|
||||
IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase1(), testDir);
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
@Test
|
||||
public void testRnnSeqClassification2() throws Exception {
|
||||
IntegrationTestRunner.runTest(RNNTestCases.getRnnCsvSequenceClassificationTestCase2(), testDir);
|
||||
}
|
||||
|
||||
@Test(timeout = 120000L)
|
||||
@Test
|
||||
public void testRnnCharacter() throws Exception {
|
||||
IntegrationTestRunner.runTest(RNNTestCases.getRnnCharacterTestCase(), testDir);
|
||||
}
|
||||
|
||||
|
||||
// ***** CNN1DTestCases *****
|
||||
@Test(timeout = 180000L)
|
||||
@Test
|
||||
public void testCnn1dCharacter() throws Exception {
|
||||
IntegrationTestRunner.runTest(CNN1DTestCases.getCnn1dTestCaseCharRNN(), testDir);
|
||||
}
|
||||
|
||||
|
||||
// ***** CNN2DTestCases *****
|
||||
@Test(timeout = 120000L)
|
||||
@Test
|
||||
public void testLenetMnist() throws Exception {
|
||||
IntegrationTestRunner.runTest(CNN2DTestCases.getLenetMnist(), testDir);
|
||||
}
|
||||
|
||||
@Ignore //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6017
|
||||
@Test(timeout = 180000L)
|
||||
@Test
|
||||
public void testYoloHouseNumbers() throws Exception {
|
||||
IntegrationTestRunner.runTest(CNN2DTestCases.getYoloHouseNumbers(), testDir);
|
||||
}
|
||||
|
||||
@Test(timeout = 120000L)
|
||||
@Test
|
||||
public void testCnn2DLenetTransferDropoutRepeatability() throws Exception {
|
||||
IntegrationTestRunner.runTest(CNN2DTestCases.testLenetTransferDropoutRepeatability(), testDir);
|
||||
}
|
||||
|
||||
|
||||
// ***** CNN3DTestCases *****
|
||||
@Test(timeout = 180000L)
|
||||
@Test
|
||||
public void testCnn3dSynthetic() throws Exception {
|
||||
IntegrationTestRunner.runTest(CNN3DTestCases.getCnn3dTestCaseSynthetic(), testDir);
|
||||
}
|
||||
|
||||
|
||||
// ***** UnsupervisedTestCases *****
|
||||
@Test(timeout = 120000L)
|
||||
@Test
|
||||
public void testVAEMnistAnomaly() throws Exception {
|
||||
IntegrationTestRunner.runTest(UnsupervisedTestCases.getVAEMnistAnomaly(), testDir);
|
||||
}
|
||||
|
||||
// ***** TransferLearningTestCases *****
|
||||
@Test(timeout = 360000L)
|
||||
@Test
|
||||
public void testVgg16Transfer() throws Exception {
|
||||
IntegrationTestRunner.runTest(CNN2DTestCases.getVGG16TransferTinyImagenet(), testDir);
|
||||
}
|
||||
|
||||
|
||||
// ***** KerasImportTestCases *****
|
||||
//TODO
|
||||
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.integration;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
public class IntegrationTestsSameDiff extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 300_000L;
|
||||
}
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
|
||||
@Test
|
||||
public void testMLPMnist() throws Exception {
|
||||
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -13,13 +13,8 @@
|
|||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.integration;
|
||||
|
||||
package org.deeplearning4j.integration.testcases;
|
||||
|
||||
/**
|
||||
* Integration tests starting from Keras model
|
||||
*/
|
||||
public class KerasImportTestCases {
|
||||
|
||||
|
||||
public enum ModelType {
|
||||
MLN, CG, SAMEDIFF
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -17,8 +18,9 @@
|
|||
package org.deeplearning4j.integration;
|
||||
|
||||
import lombok.Data;
|
||||
import org.deeplearning4j.eval.IEvaluation;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
|
@ -26,6 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
|
||||
import java.io.File;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A single test case for integration tests
|
||||
|
@ -37,16 +40,17 @@ public abstract class TestCase {
|
|||
PRETRAINED, RANDOM_INIT
|
||||
}
|
||||
|
||||
protected String testName;
|
||||
protected TestType testType;
|
||||
protected boolean testPredictions = true;
|
||||
protected boolean testGradients = true;
|
||||
protected boolean testUnsupervisedTraining = false;
|
||||
protected boolean testTrainingCurves = true;
|
||||
protected boolean testParamsPostTraining = true;
|
||||
protected boolean testEvaluation = true;
|
||||
protected boolean testParallelInference = true;
|
||||
protected boolean testOverfitting = true;
|
||||
//See: readme.md for more details
|
||||
protected String testName; //Name of the test, for display purposes
|
||||
protected TestType testType; //Type of model - from a pretrained model, or a randomly initialized model
|
||||
protected boolean testPredictions = true; //If true: check the predictions/output. Requires getPredictionsTestData() to be implemented
|
||||
protected boolean testGradients = true; //If true: check the gradients. Requires getGradientsTestData() to be implemented
|
||||
protected boolean testUnsupervisedTraining = false; //If true: perform unsupervised training. Only applies to layers like autoencoders, VAEs, etc. Requires getUnsupervisedTrainData() to be implemented
|
||||
protected boolean testTrainingCurves = true; //If true: perform training, and compare loss vs. iteration. Requires getTrainingData() method
|
||||
protected boolean testParamsPostTraining = true; //If true: perform training, and compare parameters after training. Requires getTrainingData() method
|
||||
protected boolean testEvaluation = true; //If true: perform evaluation. Requires getNewEvaluations() and getEvaluationTestData() methods implemented
|
||||
protected boolean testParallelInference = true; //If true: run the model through ParallelInference. Requires getPredictionsTestData() method. Only applies to DL4J models, NOT SameDiff models
|
||||
protected boolean testOverfitting = true; //If true: perform overfitting, and ensure the predictions match the training data. Requires both getOverfittingData() and getOverfitNumIterations()
|
||||
|
||||
protected int[] unsupervisedTrainLayersMLN = null;
|
||||
protected String[] unsupervisedTrainLayersCG = null;
|
||||
|
@ -65,6 +69,8 @@ public abstract class TestCase {
|
|||
protected double maxRelativeErrorOverfit = 1e-2;
|
||||
protected double minAbsErrorOverfit = 1e-2;
|
||||
|
||||
public abstract ModelType modelType();
|
||||
|
||||
/**
|
||||
* Initialize the test case... many tests don't need this; others may use it to download or create data
|
||||
* @param testWorkingDir Working directory to use for test
|
||||
|
@ -88,19 +94,37 @@ public abstract class TestCase {
|
|||
}
|
||||
|
||||
/**
|
||||
* Required if testPredictions == true
|
||||
* Required if testPredictions == true && DL4J model (MultiLayerNetwork or ComputationGraph)
|
||||
*/
|
||||
public List<Pair<INDArray[],INDArray[]>> getPredictionsTestData() throws Exception {
|
||||
throw new RuntimeException("Implementations must override this method if used");
|
||||
}
|
||||
|
||||
/**
|
||||
* Required if testGradients == true
|
||||
* Required if testPredictions == true && SameDiff model
|
||||
*/
|
||||
public List<Map<String,INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||
throw new RuntimeException("Implementations must override this method if used");
|
||||
}
|
||||
|
||||
public List<String> getPredictionsNamesSameDiff() throws Exception {
|
||||
throw new RuntimeException("Implementations must override this method if used");
|
||||
}
|
||||
|
||||
/**
|
||||
* Required if testGradients == true && DL4J model
|
||||
*/
|
||||
public MultiDataSet getGradientsTestData() throws Exception {
|
||||
throw new RuntimeException("Implementations must override this method if used");
|
||||
}
|
||||
|
||||
/**
|
||||
* Required if testGradients == true && SameDiff model
|
||||
*/
|
||||
public Map<String,INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||
throw new RuntimeException("Implementations must override this method if used");
|
||||
}
|
||||
|
||||
/**
|
||||
* Required when testUnsupervisedTraining == true
|
||||
*/
|
||||
|
@ -122,6 +146,10 @@ public abstract class TestCase {
|
|||
throw new RuntimeException("Implementations must override this method if used");
|
||||
}
|
||||
|
||||
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations){
|
||||
throw new RuntimeException("Implementations must override this method if used");
|
||||
}
|
||||
|
||||
/**
|
||||
* Required if testEvaluation == true
|
||||
*/
|
||||
|
@ -130,12 +158,19 @@ public abstract class TestCase {
|
|||
}
|
||||
|
||||
/**
|
||||
* Required if testOverfitting == true
|
||||
* Required if testOverfitting == true && DL4J model
|
||||
*/
|
||||
public MultiDataSet getOverfittingData() throws Exception {
|
||||
throw new RuntimeException("Implementations must override this method if used");
|
||||
}
|
||||
|
||||
/**
|
||||
* Required if testOverfitting == true && SameDiff model
|
||||
*/
|
||||
public Map<String,INDArray> getOverfittingDataSameDiff() throws Exception {
|
||||
throw new RuntimeException("Implementations must override this method if used");
|
||||
}
|
||||
|
||||
/**
|
||||
* Required if testOverfitting == true
|
||||
*/
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.integration.testcases;
|
||||
|
||||
import org.deeplearning4j.integration.TestCase;
|
||||
|
||||
public class TransferLearningTestCases {
|
||||
|
||||
public static TestCase testPartFrozenResNet50(){
|
||||
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
|
||||
public static TestCase testPartFrozenNASNET(){
|
||||
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -14,22 +15,24 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.integration.testcases;
|
||||
package org.deeplearning4j.integration.testcases.dl4j;
|
||||
|
||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
import org.deeplearning4j.eval.IEvaluation;
|
||||
import org.deeplearning4j.eval.ROCMultiClass;
|
||||
import org.deeplearning4j.integration.ModelType;
|
||||
import org.deeplearning4j.integration.TestCase;
|
||||
import org.deeplearning4j.integration.testcases.misc.CharacterIterator;
|
||||
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
@ -64,12 +67,18 @@ public class CNN1DTestCases {
|
|||
int miniBatchSize = 16;
|
||||
int exampleLength = 128;
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.CG;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getConfiguration() throws Exception {
|
||||
CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
|
||||
int nOut = iter.totalOutcomes();
|
||||
|
||||
return new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.seed(12345)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.updater(new Adam(0.01))
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -14,7 +15,7 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.integration.testcases;
|
||||
package org.deeplearning4j.integration.testcases.dl4j;
|
||||
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
|
@ -22,16 +23,13 @@ import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
|
|||
import org.datavec.image.recordreader.objdetect.impl.SvhnLabelProvider;
|
||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||
import org.deeplearning4j.datasets.fetchers.SvhnDataFetcher;
|
||||
import org.deeplearning4j.integration.ModelType;
|
||||
import org.deeplearning4j.integration.TestCase;
|
||||
import org.deeplearning4j.datasets.fetchers.DataSetType;
|
||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||
import org.deeplearning4j.datasets.iterator.impl.TinyImageNetDataSetIterator;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
import org.deeplearning4j.eval.EvaluationCalibration;
|
||||
import org.deeplearning4j.eval.IEvaluation;
|
||||
import org.deeplearning4j.eval.ROCMultiClass;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.*;
|
||||
|
@ -47,7 +45,12 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
|||
import org.deeplearning4j.zoo.PretrainedType;
|
||||
import org.deeplearning4j.zoo.model.TinyYOLO;
|
||||
import org.deeplearning4j.zoo.model.VGG16;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
@ -82,12 +85,18 @@ public class CNN2DTestCases {
|
|||
testOverfitting = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.MLN;
|
||||
}
|
||||
|
||||
public Object getConfiguration() throws Exception {
|
||||
int nChannels = 1; // Number of input channels
|
||||
int outputNum = 10; // The number of possible outcomes
|
||||
int seed = 123;
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.seed(seed)
|
||||
.l2(0.0005)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
|
@ -187,6 +196,11 @@ public class CNN2DTestCases {
|
|||
testOverfitting = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.CG;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Model getPretrainedModel() throws Exception {
|
||||
VGG16 vgg16 = VGG16.builder()
|
||||
|
@ -269,6 +283,11 @@ public class CNN2DTestCases {
|
|||
testOverfitting = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.CG;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Model getPretrainedModel() throws Exception {
|
||||
int nClasses = 10;
|
||||
|
@ -372,6 +391,11 @@ public class CNN2DTestCases {
|
|||
testOverfitting = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.CG;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Model getPretrainedModel() throws Exception {
|
||||
|
||||
|
@ -381,6 +405,7 @@ public class CNN2DTestCases {
|
|||
lrSchedule.put(3000, 0.001);
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.seed(12345)
|
||||
.l2(0.0005)
|
||||
.weightInit(WeightInit.XAVIER)
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -14,35 +15,31 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.integration.testcases;
|
||||
package org.deeplearning4j.integration.testcases.dl4j;
|
||||
|
||||
import org.apache.commons.math3.stat.inference.TestUtils;
|
||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
import org.deeplearning4j.eval.IEvaluation;
|
||||
import org.deeplearning4j.eval.ROCMultiClass;
|
||||
import org.deeplearning4j.integration.ModelType;
|
||||
import org.deeplearning4j.integration.TestCase;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
||||
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
|
@ -66,6 +63,11 @@ public class CNN3DTestCases {
|
|||
testOverfitting = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.MLN;
|
||||
}
|
||||
|
||||
public Object getConfiguration() throws Exception {
|
||||
int nChannels = 3; // Number of input channels
|
||||
int outputNum = 10; // The number of possible outcomes
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -14,8 +15,9 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.integration.testcases;
|
||||
package org.deeplearning4j.integration.testcases.dl4j;
|
||||
|
||||
import org.deeplearning4j.integration.ModelType;
|
||||
import org.deeplearning4j.integration.TestCase;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
|
@ -24,10 +26,6 @@ import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
|||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
import org.deeplearning4j.eval.EvaluationCalibration;
|
||||
import org.deeplearning4j.eval.IEvaluation;
|
||||
import org.deeplearning4j.eval.ROCMultiClass;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
|
@ -35,7 +33,12 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
@ -76,9 +79,15 @@ public class MLPTestCases {
|
|||
minAbsErrorOverfit = 1e-2;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.MLN;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getConfiguration() {
|
||||
return new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.seed(12345)
|
||||
.updater(new Adam(new MapSchedule.Builder(ScheduleType.ITERATION)
|
||||
.add(0, 5e-2)
|
||||
|
@ -168,6 +177,11 @@ public class MLPTestCases {
|
|||
testOverfitting = false; //Not much point here: very simple training data
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.MLN;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getConfiguration() {
|
||||
int seed = 123;
|
||||
|
@ -179,6 +193,7 @@ public class MLPTestCases {
|
|||
|
||||
//log.info("Build model....");
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.seed(seed)
|
||||
.updater(new Nesterovs(learningRate, 0.9))
|
||||
.list()
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -14,22 +15,24 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.integration.testcases;
|
||||
package org.deeplearning4j.integration.testcases.dl4j;
|
||||
|
||||
import org.deeplearning4j.integration.ModelType;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.deeplearning4j.integration.TestCase;
|
||||
import org.deeplearning4j.integration.testcases.misc.CharacterIterator;
|
||||
import org.deeplearning4j.integration.testcases.misc.CompositeMultiDataSetPreProcessor;
|
||||
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
|
||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||
import org.datavec.api.split.NumberedFileInputSplit;
|
||||
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
import org.deeplearning4j.eval.EvaluationCalibration;
|
||||
import org.deeplearning4j.eval.IEvaluation;
|
||||
import org.deeplearning4j.eval.ROCMultiClass;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
|
@ -91,6 +94,11 @@ public class RNNTestCases {
|
|||
private int exampleLength = 1000;
|
||||
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.MLN;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getConfiguration() throws Exception {
|
||||
|
||||
|
@ -101,6 +109,7 @@ public class RNNTestCases {
|
|||
int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters
|
||||
|
||||
return new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.seed(12345)
|
||||
.l2(0.001)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
|
@ -175,9 +184,15 @@ public class RNNTestCases {
|
|||
return normalizer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.MLN;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getConfiguration() throws Exception {
|
||||
return new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.seed(12345)
|
||||
.updater(new Adam(5e-2))
|
||||
.l1(1e-3).l2(1e-3)
|
||||
|
@ -298,6 +313,7 @@ public class RNNTestCases {
|
|||
@Override
|
||||
public Object getConfiguration() throws Exception {
|
||||
return new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.seed(12345)
|
||||
.updater(new Adam(5e-2))
|
||||
.l1(1e-3).l2(1e-3)
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -14,18 +15,20 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.integration.testcases;
|
||||
package org.deeplearning4j.integration.testcases.dl4j;
|
||||
|
||||
|
||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||
import org.deeplearning4j.integration.ModelType;
|
||||
import org.deeplearning4j.integration.TestCase;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
|
||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
|
@ -59,9 +62,15 @@ public class UnsupervisedTestCases {
|
|||
minAbsErrorPretrainParams = 5e-4;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.MLN;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getConfiguration() {
|
||||
return new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.seed(12345)
|
||||
.updater(new Adam(0.05))
|
||||
.weightInit(WeightInit.XAVIER)
|
|
@ -14,7 +14,7 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.integration.testcases.misc;
|
||||
package org.deeplearning4j.integration.testcases.dl4j.misc;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
@ -1,36 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.integration.testcases.misc;
|
||||
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
||||
|
||||
public class CompositeMultiDataSetPreProcessor implements MultiDataSetPreProcessor {
|
||||
|
||||
private MultiDataSetPreProcessor[] preProcessors;
|
||||
|
||||
public CompositeMultiDataSetPreProcessor(MultiDataSetPreProcessor... preProcessors){
|
||||
this.preProcessors = preProcessors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void preProcess(MultiDataSet multiDataSet) {
|
||||
for(MultiDataSetPreProcessor p : preProcessors){
|
||||
p.preProcess(multiDataSet);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,155 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.integration.testcases.samediff;
|
||||
|
||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||
import org.deeplearning4j.integration.ModelType;
|
||||
import org.deeplearning4j.integration.TestCase;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class SameDiffMLPTestCases {
|
||||
|
||||
|
||||
public static TestCase getMLPMnist(){
|
||||
return new TestCase() {
|
||||
{
|
||||
testName = "MLPMnistSD";
|
||||
testType = TestType.RANDOM_INIT;
|
||||
testPredictions = true;
|
||||
testTrainingCurves = true;
|
||||
testGradients = true;
|
||||
testParamsPostTraining = true;
|
||||
testEvaluation = true;
|
||||
testOverfitting = true;
|
||||
maxRelativeErrorOverfit = 2e-2;
|
||||
minAbsErrorOverfit = 1e-2;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelType modelType() {
|
||||
return ModelType.SAMEDIFF;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getConfiguration() throws Exception {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
//Define the network structure:
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
|
||||
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
|
||||
|
||||
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256));
|
||||
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256));
|
||||
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10));
|
||||
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10));
|
||||
|
||||
SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0));
|
||||
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
|
||||
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||
|
||||
//Also set the training configuration:
|
||||
sd.setTrainingConfig(TrainingConfig.builder()
|
||||
.updater(new Adam(0.01))
|
||||
.weightDecay(1e-3, true)
|
||||
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
|
||||
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
|
||||
.build());
|
||||
|
||||
return sd;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||
List<Map<String,INDArray>> out = new ArrayList<>();
|
||||
|
||||
DataSetIterator iter = new MnistDataSetIterator(1, true, 12345);
|
||||
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
|
||||
|
||||
iter = new MnistDataSetIterator(8, true, 12345);
|
||||
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getPredictionsNamesSameDiff() throws Exception {
|
||||
return Collections.singletonList("out");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
|
||||
Map<String,INDArray> map = new HashMap<>();
|
||||
map.put("in", ds.getFeatures());
|
||||
map.put("label", ds.getLabels());
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||
DataSetIterator iter = new MnistDataSetIterator(8, true, 12345);
|
||||
iter = new EarlyTerminationDataSetIterator(iter, 32);
|
||||
return new MultiDataSetIteratorAdapter(iter);
|
||||
}
|
||||
|
||||
@Override
|
||||
public IEvaluation[] getNewEvaluations() {
|
||||
return new IEvaluation[]{new Evaluation()};
|
||||
}
|
||||
|
||||
@Override
|
||||
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||||
DataSetIterator iter = new MnistDataSetIterator(8, false, 12345);
|
||||
iter = new EarlyTerminationDataSetIterator(iter, 32);
|
||||
return new MultiDataSetIteratorAdapter(iter);
|
||||
}
|
||||
|
||||
@Override
|
||||
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
|
||||
sd.evaluate(iter, "out", 0, evaluations);
|
||||
return evaluations;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MultiDataSet getOverfittingData() throws Exception {
|
||||
return new MnistDataSetIterator(1, true, 12345).next().toMultiDataSet();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOverfitNumIterations() {
|
||||
return 100;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
}
|
|
@ -872,6 +872,8 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
|||
public native void setLeaksDetector(@Cast("bool") boolean reallyDetect);
|
||||
public native @Cast("bool") boolean helpersAllowed();
|
||||
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
|
||||
|
||||
public native @Cast("bool") boolean blasFallback();
|
||||
|
||||
public native int tadThreshold();
|
||||
public native void setTadThreshold(int threshold);
|
||||
|
@ -4165,15 +4167,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
|||
*/
|
||||
public native void transposei();
|
||||
|
||||
/**
|
||||
* return array pointing on certain range of this array
|
||||
* index - the number of array to be returned among set of possible arrays
|
||||
* dimensions - array of dimensions to point on
|
||||
*/
|
||||
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntPointer dimensions);
|
||||
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntBuffer dimensions);
|
||||
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector int[] dimensions);
|
||||
|
||||
/**
|
||||
* returns the number of arrays pointing on specified dimension(s)
|
||||
* dimensions - array of dimensions to point on
|
||||
|
@ -6881,9 +6874,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3);
|
||||
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3);
|
||||
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shape, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shape, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shape, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
||||
|
||||
@Namespace("shape") public static native void traceNew(int id);
|
||||
|
||||
|
@ -7323,14 +7316,12 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo);
|
||||
@Namespace("shape") public static native int rank(@Const int[] shapeInfo);
|
||||
|
||||
// returns pointer on elementWiseStride
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
|
||||
|
||||
/**
|
||||
* returns pointer on elementWiseStride
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
|
||||
|
||||
/**
|
||||
* Converts a raw int buffer of the layout:
|
||||
|
@ -8010,12 +8001,33 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
* subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer
|
||||
* keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
|
||||
*/
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets);
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets);
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets);
|
||||
|
||||
/**
|
||||
* processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array
|
||||
* arguments:
|
||||
* idx - input argument, intervals of indexes which define the sub-array to point on,
|
||||
* when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank)
|
||||
* when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank)
|
||||
* when (dimStart == dimEnd) then whole range will be used for current dimension
|
||||
* maxShapeInfo - input argument, shapeInfo of original array
|
||||
* minShapeInfo - output argument, shapeInfo of sub-array to be deduced
|
||||
* minOffset - output argument, offset of sub-array buffer offsets from original buffer
|
||||
* keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
|
||||
* isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd,
|
||||
* numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1
|
||||
*/
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset);
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset);
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset);
|
||||
|
||||
/**
|
||||
* for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99}
|
||||
|
@ -8036,6 +8048,14 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer outShapeInfo);
|
||||
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] outShapeInfo);
|
||||
|
||||
/**
|
||||
* get stride over contiguous axis (contiguous axis must have stride = 1)
|
||||
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -8908,6 +8928,8 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
||||
|
||||
|
@ -9103,6 +9125,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// #endif /* SHAPE_H_ */
|
||||
|
|
|
@ -875,6 +875,8 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
|||
public native void setLeaksDetector(@Cast("bool") boolean reallyDetect);
|
||||
public native @Cast("bool") boolean helpersAllowed();
|
||||
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
|
||||
|
||||
public native @Cast("bool") boolean blasFallback();
|
||||
|
||||
public native int tadThreshold();
|
||||
public native void setTadThreshold(int threshold);
|
||||
|
@ -4168,15 +4170,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
|||
*/
|
||||
public native void transposei();
|
||||
|
||||
/**
|
||||
* return array pointing on certain range of this array
|
||||
* index - the number of array to be returned among set of possible arrays
|
||||
* dimensions - array of dimensions to point on
|
||||
*/
|
||||
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntPointer dimensions);
|
||||
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector IntBuffer dimensions);
|
||||
public native @ByVal NDArray tensorAlongDimension(@Cast("Nd4jLong") long index, @StdVector int[] dimensions);
|
||||
|
||||
/**
|
||||
* returns the number of arrays pointing on specified dimension(s)
|
||||
* dimensions - array of dimensions to point on
|
||||
|
@ -6884,9 +6877,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3);
|
||||
@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3);
|
||||
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shape, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shape, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shape, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
||||
|
||||
@Namespace("shape") public static native void traceNew(int id);
|
||||
|
||||
|
@ -7326,14 +7319,12 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo);
|
||||
@Namespace("shape") public static native int rank(@Const int[] shapeInfo);
|
||||
|
||||
// returns pointer on elementWiseStride
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
|
||||
|
||||
/**
|
||||
* returns pointer on elementWiseStride
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo);
|
||||
|
||||
/**
|
||||
* Converts a raw int buffer of the layout:
|
||||
|
@ -8013,12 +8004,33 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
* subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer
|
||||
* keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
|
||||
*/
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets);
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets);
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/);
|
||||
@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets);
|
||||
|
||||
/**
|
||||
* processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array
|
||||
* arguments:
|
||||
* idx - input argument, intervals of indexes which define the sub-array to point on,
|
||||
* when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank)
|
||||
* when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank)
|
||||
* when (dimStart == dimEnd) then whole range will be used for current dimension
|
||||
* maxShapeInfo - input argument, shapeInfo of original array
|
||||
* minShapeInfo - output argument, shapeInfo of sub-array to be deduced
|
||||
* minOffset - output argument, offset of sub-array buffer offsets from original buffer
|
||||
* keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
|
||||
* isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd,
|
||||
* numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1
|
||||
*/
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset);
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset);
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/);
|
||||
@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset);
|
||||
|
||||
/**
|
||||
* for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99}
|
||||
|
@ -8039,6 +8051,14 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer outShapeInfo);
|
||||
@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] outShapeInfo);
|
||||
|
||||
/**
|
||||
* get stride over contiguous axis (contiguous axis must have stride = 1)
|
||||
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -8911,6 +8931,8 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
||||
|
||||
|
@ -9106,6 +9128,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// #endif /* SHAPE_H_ */
|
||||
|
|
|
@ -102,6 +102,7 @@ public class RngTests extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
public void testRandomBinomial() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
//silly tests. Just increasing the usage for randomBinomial to stop compiler warnings.
|
||||
INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3);
|
||||
assertTrue(x.sum().getDouble(0) > 0.0); //silly test. Just increasing th usage for randomBinomial
|
||||
|
|
Loading…
Reference in New Issue