DL4J and SameDiff integration tests + LSTMLayer java op class (#353)

* init in this branch

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* Lenetet Mnist workflow

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* small fix for calculations

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* for Alex to check placeholder null pointer issue

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* CNN3D workflow

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* state for launching on dxg to regenterate dl4j examples

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* SD RNN test case workflow

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* small fixes

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* checkpoint at lstmBlock: Input array 1 (x) rank must be got input with rank 2 issue

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* Fix LSTMLayer inputs order

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* lstm mismatch with c++ op issue

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* LSTMLayer config draft

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* LSTMLayer config draft v2

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* have doubt I had to do this

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* NDRNN generated by codegen

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* LSTMLayerTestCases draft

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* minor fixes again

* added LSTMLayer testcases to nd4j-tests + setted Preconditions in LSTMLayer constructors

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* added lost SDCNNtestcases

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* overrided getNumOutputs from DynamicCustomOp in LSTMLayer and reorganized LSTMLayerOutputs according to cpp op

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* finished with LSTMLayerOutputs

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* Fix MKLDNN platform checks (i.e., when MKLDNN can be used vs. not)

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix LSTMLayerWeights input order

Signed-off-by: Alex Black <blacka101@gmail.com>

* More fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* minor fixes

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* fixed LSTMLayer testcases

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* finished SameDiffRNNTestCase

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* finished all testcases + minor fixes

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* Multiple generation-related fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix multiple issues

Signed-off-by: Alex Black <blacka101@gmail.com>

* More fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* LSTM fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Regenerate ND4J namespaces and fix multiple issues

Signed-off-by: Alex Black <blacka101@gmail.com>

* changed SameDiffRNNTestCase

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* Small fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* added  Nd4j.getRandom().setSeed(12345) where needed

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* #8828 Fix ND4J profiler NaN/Inf checks when using OpContext

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8828 Fix ND4J profiler NaN/Inf checks when using OpContext

Signed-off-by: Alex Black <blacka101@gmail.com>

* Tweak to weight init for SameDiff CNN test case

Signed-off-by: Alex Black <blacka101@gmail.com>

* Tweaks for test cases

Signed-off-by: Alex Black <blacka101@gmail.com>

* Ignore failing tests until fixed

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Andrii T 2020-04-08 17:20:48 +03:00 committed by GitHub
parent ab083b9167
commit d86dd5b131
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
72 changed files with 8063 additions and 3997 deletions

View File

@ -25,7 +25,7 @@ import org.nd4j.shade.jackson.annotation.JsonProperty;
*/ */
@Deprecated @Deprecated
@Getter @Getter
@EqualsAndHashCode @EqualsAndHashCode(callSuper = true)
public class EvaluationCalibration extends org.nd4j.evaluation.classification.EvaluationCalibration implements org.deeplearning4j.eval.IEvaluation<org.nd4j.evaluation.classification.EvaluationCalibration> { public class EvaluationCalibration extends org.nd4j.evaluation.classification.EvaluationCalibration implements org.deeplearning4j.eval.IEvaluation<org.nd4j.evaluation.classification.EvaluationCalibration> {
/** /**

View File

@ -185,7 +185,9 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
final val R = paramTable.get(RECURRENT_WEIGHT_KEY); final val R = paramTable.get(RECURRENT_WEIGHT_KEY);
final val b = paramTable.get(BIAS_KEY); final val b = paramTable.get(BIAS_KEY);
SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2); long[] shape = layerInput.getShape();
Preconditions.checkState(shape != null, "Null shape for input placeholder");
SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2, (int)shape[2]);
this.timeSteps = inputSlices.length; this.timeSteps = inputSlices.length;
SDVariable[] outputSlices = new SDVariable[timeSteps]; SDVariable[] outputSlices = new SDVariable[timeSteps];
SDVariable prev = null; SDVariable prev = null;

View File

@ -20,7 +20,10 @@ package org.deeplearning4j.integration;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.integration.testcases.dl4j.*;
import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases;
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases; import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
import org.deeplearning4j.integration.testcases.samediff.SameDiffRNNTestCases;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -66,14 +69,36 @@ public class IntegrationTestBaselineGenerator {
} }
runGeneration( runGeneration(
SameDiffMLPTestCases.getMLPMnist()
// DL4J integration test cases.
// CNN1DTestCases.getCnn1dTestCaseCharRNN(),
// CNN2DTestCases.testLenetTransferDropoutRepeatability(),
//// CNN2DTestCases.getCnn2DSynthetic(),
// CNN2DTestCases.getLenetMnist(),
// CNN2DTestCases.getVGG16TransferTinyImagenet(),
// CNN2DTestCases.getYoloHouseNumbers(),
// CNN3DTestCases.getCnn3dTestCaseSynthetic(),
// MLPTestCases.getMLPMnist(),
// MLPTestCases.getMLPMoon(),
// RNNTestCases.getRnnCharacterTestCase(),
// RNNTestCases.getRnnCsvSequenceClassificationTestCase1(),
// RNNTestCases.getRnnCsvSequenceClassificationTestCase2(),
// UnsupervisedTestCases.getVAEMnistAnomaly(),
// Samediff test cases done
SameDiffMLPTestCases.getMLPMnist(),
SameDiffMLPTestCases.getMLPMoon(),
SameDiffCNNCases.getLenetMnist(),
SameDiffCNNCases.getCnn3dSynthetic(),
SameDiffRNNTestCases.getRnnCsvSequenceClassificationTestCase1()
); );
} }
private static void runGeneration(TestCase... testCases) throws Exception { private static void runGeneration(TestCase... testCases) throws Exception {
for( TestCase tc : testCases ) { for (TestCase tc : testCases) {
final ModelType modelType = tc.modelType(); final ModelType modelType = tc.modelType();
//Basic validation: //Basic validation:
@ -122,18 +147,18 @@ public class IntegrationTestBaselineGenerator {
mln = new MultiLayerNetwork(mlc); mln = new MultiLayerNetwork(mlc);
mln.init(); mln.init();
m = mln; m = mln;
} else if (config instanceof ComputationGraphConfiguration){ } else if (config instanceof ComputationGraphConfiguration) {
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config; ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
json = cgc.toJson(); json = cgc.toJson();
cg = new ComputationGraph(cgc); cg = new ComputationGraph(cgc);
cg.init(); cg.init();
m = cg; m = cg;
} else { } else {
sd = (SameDiff)config; sd = (SameDiff) config;
} }
File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME); File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME);
if(modelType != ModelType.SAMEDIFF) { if (modelType != ModelType.SAMEDIFF) {
File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json")); File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json"));
FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8); FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8);
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath()); log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
@ -147,10 +172,10 @@ public class IntegrationTestBaselineGenerator {
m = tc.getPretrainedModel(); m = tc.getPretrainedModel();
if (m instanceof MultiLayerNetwork) { if (m instanceof MultiLayerNetwork) {
mln = (MultiLayerNetwork) m; mln = (MultiLayerNetwork) m;
} else if(m instanceof ComputationGraph){ } else if (m instanceof ComputationGraph) {
cg = (ComputationGraph) m; cg = (ComputationGraph) m;
} else { } else {
sd = (SameDiff)m; sd = (SameDiff) m;
} }
} }
@ -158,7 +183,7 @@ public class IntegrationTestBaselineGenerator {
//Generate predictions to compare against //Generate predictions to compare against
if (tc.isTestPredictions()) { if (tc.isTestPredictions()) {
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null; List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null; List<Map<String, INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
// Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName()); // Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
@ -178,7 +203,7 @@ public class IntegrationTestBaselineGenerator {
Nd4j.write(out, dos); Nd4j.write(out, dos);
} }
} }
} else if(modelType == ModelType.CG) { } else if (modelType == ModelType.CG) {
for (Pair<INDArray[], INDArray[]> p : inputs) { for (Pair<INDArray[], INDArray[]> p : inputs) {
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null); INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
@ -192,11 +217,11 @@ public class IntegrationTestBaselineGenerator {
} }
} else { } else {
List<String> outNames = tc.getPredictionsNamesSameDiff(); List<String> outNames = tc.getPredictionsNamesSameDiff();
for( Map<String,INDArray> ph : inputsSd ){ for (Map<String, INDArray> ph : inputsSd) {
Map<String,INDArray> out = sd.output(ph, outNames); Map<String, INDArray> out = sd.output(ph, outNames);
//Save the output... //Save the output...
for(String s : outNames){ for (String s : outNames) {
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin"); File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) { try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) {
Nd4j.write(out.get(s), dos); Nd4j.write(out.get(s), dos);
@ -211,7 +236,7 @@ public class IntegrationTestBaselineGenerator {
//Compute and save gradients: //Compute and save gradients:
if (tc.isTestGradients()) { if (tc.isTestGradients()) {
INDArray gradientFlat = null; INDArray gradientFlat = null;
Map<String,INDArray> grad; Map<String, INDArray> grad;
if (modelType == ModelType.MLN) { if (modelType == ModelType.MLN) {
MultiDataSet data = tc.getGradientsTestData(); MultiDataSet data = tc.getGradientsTestData();
mln.setInput(data.getFeatures(0)); mln.setInput(data.getFeatures(0));
@ -220,7 +245,7 @@ public class IntegrationTestBaselineGenerator {
mln.computeGradientAndScore(); mln.computeGradientAndScore();
gradientFlat = mln.getFlattenedGradients(); gradientFlat = mln.getFlattenedGradients();
grad = m.gradient().gradientForVariable(); grad = m.gradient().gradientForVariable();
} else if(modelType == ModelType.CG) { } else if (modelType == ModelType.CG) {
MultiDataSet data = tc.getGradientsTestData(); MultiDataSet data = tc.getGradientsTestData();
cg.setInputs(data.getFeatures()); cg.setInputs(data.getFeatures());
cg.setLabels(data.getLabels()); cg.setLabels(data.getLabels());
@ -229,17 +254,17 @@ public class IntegrationTestBaselineGenerator {
gradientFlat = cg.getFlattenedGradients(); gradientFlat = cg.getFlattenedGradients();
grad = m.gradient().gradientForVariable(); grad = m.gradient().gradientForVariable();
} else { } else {
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff(); Map<String, INDArray> ph = tc.getGradientsTestDataSameDiff();
List<String> allVars = new ArrayList<>(); List<String> allVars = new ArrayList<>();
for(SDVariable v : sd.variables()){ for (SDVariable v : sd.variables()) {
if(v.getVariableType() == VariableType.VARIABLE){ if (v.getVariableType() == VariableType.VARIABLE) {
allVars.add(v.name()); allVars.add(v.name());
} }
} }
grad = sd.calculateGradients(ph, allVars); grad = sd.calculateGradients(ph, allVars);
} }
if(modelType != ModelType.SAMEDIFF) { if (modelType != ModelType.SAMEDIFF) {
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME); File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
IntegrationTestRunner.write(gradientFlat, gFlatFile); IntegrationTestRunner.write(gradientFlat, gFlatFile);
} }
@ -254,25 +279,25 @@ public class IntegrationTestBaselineGenerator {
} }
//Test pretraining //Test pretraining
if(tc.isTestUnsupervisedTraining()){ if (tc.isTestUnsupervisedTraining()) {
log.info("Performing layerwise pretraining"); log.info("Performing layerwise pretraining");
MultiDataSetIterator iter = tc.getUnsupervisedTrainData(); MultiDataSetIterator iter = tc.getUnsupervisedTrainData();
INDArray paramsPostTraining; INDArray paramsPostTraining;
if(modelType == ModelType.MLN){ if (modelType == ModelType.MLN) {
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN(); int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null"); Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
for( int i : layersToTrain){ for (int i : layersToTrain) {
mln.pretrainLayer(i, dsi); mln.pretrainLayer(i, dsi);
} }
paramsPostTraining = mln.params(); paramsPostTraining = mln.params();
} else if(modelType == ModelType.CG) { } else if (modelType == ModelType.CG) {
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG(); String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
Preconditions.checkState(layersToTrain != null, "Layer names must not be null"); Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
for( String i : layersToTrain){ for (String i : layersToTrain) {
cg.pretrainLayer(i, iter); cg.pretrainLayer(i, iter);
} }
paramsPostTraining = cg.params(); paramsPostTraining = cg.params();
@ -290,20 +315,20 @@ public class IntegrationTestBaselineGenerator {
MultiDataSetIterator trainData = tc.getTrainingData(); MultiDataSetIterator trainData = tc.getTrainingData();
CollectScoresListener l = new CollectScoresListener(1); CollectScoresListener l = new CollectScoresListener(1);
if(modelType != ModelType.SAMEDIFF) if (modelType != ModelType.SAMEDIFF)
m.setListeners(l); m.setListeners(l);
History h = null; History h = null;
if (modelType == ModelType.MLN) { if (modelType == ModelType.MLN) {
mln.fit(trainData); mln.fit(trainData);
} else if(modelType == ModelType.CG) { } else if (modelType == ModelType.CG) {
cg.fit(trainData); cg.fit(trainData);
} else { } else {
h = sd.fit(trainData, 1); h = sd.fit(trainData, 1);
} }
double[] scores; double[] scores;
if(modelType != ModelType.SAMEDIFF){ if (modelType != ModelType.SAMEDIFF) {
scores = l.getListScore().toDoubleArray(); scores = l.getListScore().toDoubleArray();
} else { } else {
scores = h.lossCurve().getLossValues().toDoubleVector(); scores = h.lossCurve().getLossValues().toDoubleVector();
@ -314,11 +339,11 @@ public class IntegrationTestBaselineGenerator {
FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8); FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8);
if (tc.isTestParamsPostTraining()) { if (tc.isTestParamsPostTraining()) {
if(modelType == ModelType.SAMEDIFF){ if (modelType == ModelType.SAMEDIFF) {
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR); File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
p.mkdirs(); p.mkdirs();
for(SDVariable v : sd.variables()){ for (SDVariable v : sd.variables()) {
if(v.getVariableType() == VariableType.VARIABLE){ if (v.getVariableType() == VariableType.VARIABLE) {
INDArray arr = v.getArr(); INDArray arr = v.getArr();
File p2 = new File(p, v.name() + ".bin"); File p2 = new File(p, v.name() + ".bin");
IntegrationTestRunner.write(arr, p2); IntegrationTestRunner.write(arr, p2);
@ -331,7 +356,6 @@ public class IntegrationTestBaselineGenerator {
} }
} }
if (tc.isTestEvaluation()) { if (tc.isTestEvaluation()) {
IEvaluation[] evals = tc.getNewEvaluations(); IEvaluation[] evals = tc.getNewEvaluations();
MultiDataSetIterator iter = tc.getEvaluationTestData(); MultiDataSetIterator iter = tc.getEvaluationTestData();
@ -339,7 +363,7 @@ public class IntegrationTestBaselineGenerator {
if (modelType == ModelType.MLN) { if (modelType == ModelType.MLN) {
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
mln.doEvaluation(dsi, evals); mln.doEvaluation(dsi, evals);
} else if(modelType == ModelType.CG){ } else if (modelType == ModelType.CG) {
cg.doEvaluation(iter, evals); cg.doEvaluation(iter, evals);
} else { } else {
evals = tc.doEvaluationSameDiff(sd, iter, evals); evals = tc.doEvaluationSameDiff(sd, iter, evals);

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.integration; package org.deeplearning4j.integration;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases;
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases; import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
@ -37,4 +38,20 @@ public class IntegrationTestsSameDiff extends BaseDL4JTest {
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir); IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir);
} }
@Test
public void testMLPMoon() throws Exception {
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMoon(), testDir);
}
@Test
public void testLenetMnist() throws Exception {
IntegrationTestRunner.runTest(SameDiffCNNCases.getLenetMnist(), testDir);
}
@Test
public void testCnn3dSynthetic() throws Exception {
IntegrationTestRunner.runTest(SameDiffCNNCases.getCnn3dSynthetic(), testDir);
}
} }

View File

@ -194,6 +194,8 @@ public class CNN2DTestCases {
testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb) testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb)
testEvaluation = false; testEvaluation = false;
testOverfitting = false; testOverfitting = false;
maxRelativeErrorOutput = 0.2;
minAbsErrorOutput = 0.05; //Max value is around 0.22
} }
@Override @Override
@ -314,6 +316,7 @@ public class CNN2DTestCases {
ComputationGraph model = new TransferLearning.GraphBuilder(pretrained) ComputationGraph model = new TransferLearning.GraphBuilder(pretrained)
.fineTuneConfiguration(fineTuneConf) .fineTuneConfiguration(fineTuneConf)
.removeVertexKeepConnections("conv2d_9") .removeVertexKeepConnections("conv2d_9")
.removeVertexAndConnections("outputs")
.addLayer("convolution2d_9", .addLayer("convolution2d_9",
new ConvolutionLayer.Builder(1,1) new ConvolutionLayer.Builder(1,1)
.nIn(1024) .nIn(1024)
@ -393,7 +396,7 @@ public class CNN2DTestCases {
@Override @Override
public ModelType modelType() { public ModelType modelType() {
return ModelType.CG; return ModelType.MLN;
} }
@Override @Override

View File

@ -77,6 +77,10 @@ public class MLPTestCases {
testOverfitting = true; testOverfitting = true;
maxRelativeErrorOverfit = 2e-2; maxRelativeErrorOverfit = 2e-2;
minAbsErrorOverfit = 1e-2; minAbsErrorOverfit = 1e-2;
maxRelativeErrorGradients = 0.01;
minAbsErrorGradients = 0.05;
maxRelativeErrorParamsPostTraining = 0.01;
minAbsErrorParamsPostTraining = 0.05;
} }
@Override @Override
@ -135,8 +139,7 @@ public class MLPTestCases {
public IEvaluation[] getNewEvaluations(){ public IEvaluation[] getNewEvaluations(){
return new IEvaluation[]{ return new IEvaluation[]{
new Evaluation(), new Evaluation(),
new ROCMultiClass(), new ROCMultiClass()
new EvaluationCalibration()
}; };
} }

View File

@ -24,6 +24,7 @@ import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor; import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.TestCase;
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator; import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
@ -91,7 +92,7 @@ public class RNNTestCases {
} }
private int miniBatchSize = 32; private int miniBatchSize = 32;
private int exampleLength = 1000; private int exampleLength = 200;
@Override @Override
@ -101,6 +102,7 @@ public class RNNTestCases {
@Override @Override
public Object getConfiguration() throws Exception { public Object getConfiguration() throws Exception {
Nd4j.getRandom().setSeed(12345);
CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength); CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
int nOut = iter.totalOutcomes(); int nOut = iter.totalOutcomes();
@ -113,7 +115,7 @@ public class RNNTestCases {
.seed(12345) .seed(12345)
.l2(0.001) .l2(0.001)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.updater(new RmsProp(0.1)) .updater(new Adam(1e-3))
.list() .list()
.layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize) .layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
.activation(Activation.TANH).build()) .activation(Activation.TANH).build())
@ -140,7 +142,7 @@ public class RNNTestCases {
@Override @Override
public MultiDataSetIterator getTrainingData() throws Exception { public MultiDataSetIterator getTrainingData() throws Exception {
DataSetIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength); DataSetIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
iter = new EarlyTerminationDataSetIterator(iter, 2); //3 minibatches, 1000/200 = 5 updates per minibatch iter = new EarlyTerminationDataSetIterator(iter, 2); //2 minibatches, 200/50 = 4 updates per minibatch
return new MultiDataSetIteratorAdapter(iter); return new MultiDataSetIteratorAdapter(iter);
} }

View File

@ -72,12 +72,12 @@ public class UnsupervisedTestCases {
return new NeuralNetConfiguration.Builder() return new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT) .dataType(DataType.FLOAT)
.seed(12345) .seed(12345)
.updater(new Adam(0.05)) .updater(new Adam(1e-3))
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.l2(1e-4) .l2(1e-4)
.list() .list()
.layer(0, new VariationalAutoencoder.Builder() .layer(0, new VariationalAutoencoder.Builder()
.activation(Activation.LEAKYRELU) .activation(Activation.TANH)
.encoderLayerSizes(256, 256) //2 encoder layers, each of size 256 .encoderLayerSizes(256, 256) //2 encoder layers, each of size 256
.decoderLayerSizes(256, 256) //2 decoder layers, each of size 256 .decoderLayerSizes(256, 256) //2 decoder layers, each of size 256
.pzxActivationFunction(Activation.IDENTITY) //p(z|data) activation function .pzxActivationFunction(Activation.IDENTITY) //p(z|data) activation function

View File

@ -0,0 +1,398 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases.samediff;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import java.util.*;
public class SameDiffCNNCases {
public static TestCase getLenetMnist() {
return new TestCase() {
{
testName = "LenetMnistSD";
testType = TestType.RANDOM_INIT;
testPredictions = true;
testTrainingCurves = true;
testGradients = true;
testParamsPostTraining = true;
testEvaluation = true;
testOverfitting = false;
}
@Override
public ModelType modelType() {
return ModelType.SAMEDIFF;
}
public Object getConfiguration() throws Exception {
Nd4j.getRandom().setSeed(12345);
int nChannels = 1; // Number of input channels
int outputNum = 10; // The number of possible outcomes
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, outputNum);
//input [minibatch, channels=1, Height = 28, Width = 28]
SDVariable in4d = in.reshape(-1, nChannels, 28, 28);
int kernelHeight = 5;
int kernelWidth = 5;
// w0 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 1, outputChannels = 20]
// b0 [20]
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, nChannels, 20).muli(0.01));
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 20).muli(0.01));
SDVariable layer0 = sd.nn.relu(sd.cnn.conv2d("layer0", in4d, w0, b0, Conv2DConfig.builder()
.kH(kernelHeight)
.kW(kernelWidth)
.sH(1)
.sW(1)
.dataFormat("NCHW")
.build()), 0);
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W) = ( 28 - 5 + 2*0 ) / 1 + 1 = 24
// [minibatch,20,24,24]
SDVariable layer1 = sd.cnn.maxPooling2d("layer1", layer0, Pooling2DConfig.builder()
.kH(2).kW(2)
.sH(2).sW(2)
.isNHWC(false)
.build());
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W) = ( 24 - 2 + 2*0 ) / 2 + 1 = 12
// [minibatch,12,12,20]
// w2 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 20, outputChannels = 50]
// b0 [50]
SDVariable w2 = sd.var("w2", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, 20, 50).muli(0.01));
SDVariable b2 = sd.var("b2", Nd4j.rand(DataType.FLOAT, 50).muli(0.01));
SDVariable layer2 = sd.nn.relu(sd.cnn.conv2d("layer2", layer1, w2, b2, Conv2DConfig.builder()
.kH(kernelHeight)
.kW(kernelWidth)
.sH(1)
.sW(1)
.dataFormat("NCHW")
.build()), 0);
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W) = ( 12 - 5 + 2*0 ) / 1 + 1 = 8
// [minibatch,8,8,50]
SDVariable layer3 = sd.cnn.maxPooling2d("layer3", layer2, Pooling2DConfig.builder()
.kH(2).kW(2)
.sH(2).sW(2)
.isNHWC(false)
.build());
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W) = ( 8 - 2 + 2*0 ) / 2 + 1 = 4
// [minibatch,4,4,50]
int channels_height_width = 4 * 4 * 50;
SDVariable layer3_reshaped = layer3.reshape(-1, channels_height_width);
SDVariable w4 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width, 500).muli(0.01));
SDVariable b4 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 500).muli(0.01));
SDVariable layer4 = sd.nn.relu("layer4", layer3_reshaped.mmul(w4).add(b4), 0);
SDVariable w5 = sd.var("w5", Nd4j.rand(DataType.FLOAT, 500, outputNum));
SDVariable b5 = sd.var("b5", Nd4j.rand(DataType.FLOAT, outputNum));
SDVariable out = sd.nn.softmax("out", layer4.mmul(w5).add(b5));
SDVariable loss = sd.loss.logLoss("loss", label, out);
//Also set the training configuration:
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new Adam(1e-3))
.l2(1e-3)
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
.build());
return sd;
}
@Override
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
Map<String, INDArray> map = new HashMap<>();
map.put("in", ds.getFeatures());
map.put("label", ds.getLabels());
return map;
}
@Override
public MultiDataSetIterator getTrainingData() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
iter = new EarlyTerminationDataSetIterator(iter, 60);
return new MultiDataSetIteratorAdapter(iter);
}
@Override
public MultiDataSetIterator getEvaluationTestData() throws Exception {
return new MultiDataSetIteratorAdapter(new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10));
}
@Override
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(8, true, 12345);
List<Map<String, INDArray>> list = new ArrayList<>();
org.nd4j.linalg.dataset.DataSet ds = iter.next();
ds = ds.asList().get(0);
list.add(Collections.singletonMap("in", ds.getFeatures()));
ds = iter.next();
list.add(Collections.singletonMap("in", ds.getFeatures()));
return list;
}
@Override
public List<String> getPredictionsNamesSameDiff() {
return Collections.singletonList("out");
}
@Override
public IEvaluation[] getNewEvaluations() {
return new IEvaluation[]{
new Evaluation(),
new ROCMultiClass(),
new EvaluationCalibration()};
}
@Override
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
sd.evaluate(iter, "out", 0, evaluations);
return evaluations;
}
};
}
public static TestCase getCnn3dSynthetic() {
return new TestCase() {
{
testName = "Cnn3dSynthetic";
testType = TestType.RANDOM_INIT;
testPredictions = true;
testTrainingCurves = true;
testGradients = true;
testParamsPostTraining = true;
testEvaluation = true;
testOverfitting = false;
}
@Override
public ModelType modelType() {
return ModelType.SAMEDIFF;
}
public Object getConfiguration() throws Exception {
Nd4j.getRandom().setSeed(12345);
int nChannels = 3; // Number of input channels
int outputNum = 10; // The number of possible outcomes
SameDiff sd = SameDiff.create();
//input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8]
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, nChannels, 8, 8, 8);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, nChannels, outputNum);
//input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8]
// Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]
// [kernelDepth = 3, kernelHeight = 3, kernelWidth = 3, inputChannels = 3, outputChannels = 8]
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 3, 3, 3, nChannels, 8));
// Optional 1D bias array with shape [outputChannels]. May be null.
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 8));
SDVariable layer0 = sd.nn.relu(sd.cnn.conv3d("layer0", in, w0, b0, Conv3DConfig.builder()
.kH(3)
.kW(3)
.kD(3)
.sH(2)
.sW(2)
.sD(2)
.dataFormat("NCDHW")
.build()), 0);
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W)(D) = (8 - 3 + 2*0 ) / 2 + 1 = 3
// [minibatch,8,3,3,3]
SDVariable layer1 = sd.cnn.maxPooling3d("layer1", layer0, Pooling3DConfig.builder()
.kH(2).kW(2).kD(2)
.sH(2).sW(2).sD(2)
.isNCDHW(true)
.build());
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
// outputsize_H(W)(D) = ( 3 - 2 + 2*0 ) / 2 + 1 = 1
// [minibatch,8,1,1,1]
int channels_height_width_depth = 8 * 1 * 1 * 1;
SDVariable layer1_reshaped = layer1.reshape(-1, channels_height_width_depth);
SDVariable w1 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width_depth, 10));
SDVariable b1 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 10));
SDVariable out = sd.nn.softmax("out", layer1_reshaped.mmul(w1).add(b1));
SDVariable loss = sd.loss.logLoss("loss", label, out);
//Also set the training configuration:
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new Nesterovs(0.01, 0.9))
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
.build());
return sd;
}
@Override
public Map<String,INDArray> getGradientsTestDataSameDiff() throws Exception {
Nd4j.getRandom().setSeed(12345);
//NCDHW format
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10);
Map<String, INDArray> map = new HashMap<>();
map.put("in", arr);
map.put("label", labels);
return map;
}
@Override
public List<String> getPredictionsNamesSameDiff() {
return Collections.singletonList("out");
}
@Override
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
Nd4j.getRandom().setSeed(12345);
List<Map<String, INDArray>> list = new ArrayList<>();
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
list.add(Collections.singletonMap("in", arr));
return list;
}
@Override
public MultiDataSet getGradientsTestData() throws Exception {
Nd4j.getRandom().setSeed(12345);
//NCDHW format
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10);
return new org.nd4j.linalg.dataset.MultiDataSet(arr, labels);
}
@Override
public MultiDataSetIterator getTrainingData() throws Exception {
return new SingletonMultiDataSetIterator(getGradientsTestData());
}
@Override
public MultiDataSetIterator getEvaluationTestData() throws Exception {
return getTrainingData();
}
@Override
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations){
sd.evaluate(iter, "out", 0, evaluations);
return evaluations;
}
@Override
public IEvaluation[] getNewEvaluations(){
return new IEvaluation[]{new Evaluation()};
}
};
}
}

View File

@ -15,9 +15,14 @@
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.integration.testcases.samediff; package org.deeplearning4j.integration.testcases.samediff;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.integration.ModelType; import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.TestCase;
import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.loss.LossReduce;
@ -26,21 +31,34 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.*; import java.util.*;
import static org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig.*;
public class SameDiffMLPTestCases { public class SameDiffMLPTestCases {
public static TestCase getMLPMnist(){ public static TestCase getMLPMnist() {
return new TestCase() { return new TestCase() {
{ {
testName = "MLPMnistSD"; testName = "MLPMnistSD";
@ -69,10 +87,10 @@ public class SameDiffMLPTestCases {
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10); SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256)); SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256).muli(0.1));
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256)); SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256).muli(0.1));
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10)); SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10).muli(0.1));
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10)); SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10).muli(0.1));
SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0)); SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0));
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1)); SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
@ -91,7 +109,7 @@ public class SameDiffMLPTestCases {
@Override @Override
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception { public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
List<Map<String,INDArray>> out = new ArrayList<>(); List<Map<String, INDArray>> out = new ArrayList<>();
DataSetIterator iter = new MnistDataSetIterator(1, true, 12345); DataSetIterator iter = new MnistDataSetIterator(1, true, 12345);
out.add(Collections.singletonMap("in", iter.next().getFeatures())); out.add(Collections.singletonMap("in", iter.next().getFeatures()));
@ -110,7 +128,7 @@ public class SameDiffMLPTestCases {
@Override @Override
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception { public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
DataSet ds = new MnistDataSetIterator(8, true, 12345).next(); DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
Map<String,INDArray> map = new HashMap<>(); Map<String, INDArray> map = new HashMap<>();
map.put("in", ds.getFeatures()); map.put("in", ds.getFeatures());
map.put("label", ds.getLabels()); map.put("label", ds.getLabels());
return map; return map;
@ -153,4 +171,160 @@ public class SameDiffMLPTestCases {
}; };
} }
public static TestCase getMLPMoon() {
return new TestCase() {
{
testName = "MLPMoonSD";
testType = TestType.RANDOM_INIT;
testPredictions = true;
testTrainingCurves = true;
testGradients = true;
testParamsPostTraining = true;
testEvaluation = true;
testOverfitting = true;
maxRelativeErrorOverfit = 2e-2;
minAbsErrorOverfit = 1e-2;
}
@Override
public ModelType modelType() {
return ModelType.SAMEDIFF;
}
@Override
public Object getConfiguration() throws Exception {
int numInputs = 2;
int numOutputs = 2;
int numHiddenNodes = 20;
double learningRate = 0.005;
Nd4j.getRandom().setSeed(12345);
//Define the network structure:
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, numInputs);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, numOutputs);
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, numInputs, numHiddenNodes));
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, numHiddenNodes));
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numHiddenNodes, numOutputs));
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numOutputs));
SDVariable a0 = sd.nn.relu(in.mmul(w0).add(b0), 0);
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
SDVariable loss = sd.loss.logLoss("loss", label, out);
//Also set the training configuration:
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new Nesterovs(learningRate, 0.9))
.weightDecay(1e-3, true)
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
.build());
return sd;
}
@Override
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
List<Map<String, INDArray>> out = new ArrayList<>();
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(f));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 0, 2);
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
return out;
}
@Override
public List<String> getPredictionsNamesSameDiff() throws Exception {
return Collections.singletonList("out");
}
@Override
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(f));
org.nd4j.linalg.dataset.DataSet ds = new RecordReaderDataSetIterator(rr, 5, 0, 2).next();
Map<String, INDArray> map = new HashMap<>();
map.put("in", ds.getFeatures());
map.put("label", ds.getLabels());
return map;
}
@Override
public MultiDataSetIterator getTrainingData() throws Exception {
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_train.csv");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(f));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2);
iter = new EarlyTerminationDataSetIterator(iter, 32);
return new MultiDataSetIteratorAdapter(iter);
}
@Override
public IEvaluation[] getNewEvaluations() {
return new IEvaluation[]{
new Evaluation(),
new ROCMultiClass(),
new EvaluationCalibration()};
}
@Override
public MultiDataSetIterator getEvaluationTestData() throws Exception {
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(f));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2);
return new MultiDataSetIteratorAdapter(iter);
}
@Override
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
sd.evaluate(iter, "out", 0, evaluations);
return evaluations;
}
@Override
public MultiDataSet getOverfittingData() throws Exception {
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(f));
return new RecordReaderDataSetIterator(rr, 1, 0, 2).next().toMultiDataSet();
}
@Override
public int getOverfitNumIterations() {
return 200;
}
};
}
} }

View File

@ -0,0 +1,289 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.integration.testcases.samediff;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.resources.Resources;
import org.nd4j.shade.guava.io.Files;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class SameDiffRNNTestCases {
public static TestCase getRnnCsvSequenceClassificationTestCase1() {
return new SameDiffRNNTestCases.RnnCsvSequenceClassificationTestCase1();
}
protected static class RnnCsvSequenceClassificationTestCase1 extends TestCase {
protected RnnCsvSequenceClassificationTestCase1() {
testName = "RnnCsvSequenceClassification1";
testType = TestType.RANDOM_INIT;
testPredictions = true;
testTrainingCurves = false;
testGradients = false;
testParamsPostTraining = false;
testEvaluation = true;
testOverfitting = false; //Not much point on this one - it already fits very well...
}
protected MultiDataNormalization normalizer;
protected MultiDataNormalization getNormalizer() throws Exception {
if (normalizer != null) {
return normalizer;
}
normalizer = new MultiNormalizerStandardize();
normalizer.fit(getTrainingDataUnnormalized());
return normalizer;
}
@Override
public ModelType modelType() {
return ModelType.SAMEDIFF;
}
@Override
public Object getConfiguration() throws Exception {
Nd4j.getRandom().setSeed(12345);
int miniBatchSize = 10;
int numLabelClasses = 6;
int nIn = 60;
int numUnits = 7;
int timeSteps = 3;
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, miniBatchSize, timeSteps, nIn);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, miniBatchSize, numLabelClasses);
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits));
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits));
LSTMLayerConfig c = LSTMLayerConfig.builder()
.lstmdataformat(LSTMDataFormat.NTS)
.directionMode(LSTMDirectionMode.FWD)
.gateAct(LSTMActivations.SIGMOID)
.cellAct(LSTMActivations.TANH)
.outAct(LSTMActivations.TANH)
.retFullSequence(true)
.retLastC(true)
.retLastH(true)
.build();
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
in, cLast, yLast, null,
LSTMLayerWeights.builder()
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
.peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits)))
.bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits)))
.build(),
c), c);
// Behaviour with default settings: 3d (time series) input with shape
// [miniBatchSize, vectorSize, timeSeriesLength] -> 2d output [miniBatchSize, vectorSize]
SDVariable layer0 = outputs.getOutput();
SDVariable layer1 = layer0.mean(1);
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numUnits, numLabelClasses));
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numLabelClasses));
SDVariable out = sd.nn.softmax("out", layer1.mmul(w1).add(b1));
SDVariable loss = sd.loss.logLoss("loss", label, out);
//Also set the training configuration:
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new Adam(5e-2))
.l1(1e-3).l2(1e-3)
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
.build());
return sd;
}
@Override
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
MultiDataSet mds = getTrainingData().next();
List<Map<String, INDArray>> list = new ArrayList<>();
list.add(Collections.singletonMap("in", mds.getFeatures()[0].reshape(10, 1, 60)));
//[batchsize, insize]
return list;
}
@Override
public List<String> getPredictionsNamesSameDiff() throws Exception {
return Collections.singletonList("out");
}
@Override
public MultiDataSetIterator getTrainingData() throws Exception {
MultiDataSetIterator iter = getTrainingDataUnnormalized();
MultiDataSetPreProcessor pp = multiDataSet -> {
INDArray l = multiDataSet.getLabels(0);
l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
multiDataSet.setLabels(0, l);
multiDataSet.setLabelsMaskArray(0, null);
};
iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));
return iter;
}
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception {
int miniBatchSize = 10;
int numLabelClasses = 6;
File featuresDirTrain = Files.createTempDir();
File labelsDirTrain = Files.createTempDir();
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain);
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain);
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData);
return iter;
}
@Override
public IEvaluation[] getNewEvaluations() {
return new IEvaluation[]{
new Evaluation(),
new ROCMultiClass(),
new EvaluationCalibration()
};
}
@Override
public MultiDataSetIterator getEvaluationTestData() throws Exception {
int miniBatchSize = 10;
int numLabelClasses = 6;
// File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
// File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
File featuresDirTest = Files.createTempDir();
File labelsDirTest = Files.createTempDir();
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest);
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest);
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);
MultiDataSetPreProcessor pp = multiDataSet -> {
INDArray l = multiDataSet.getLabels(0);
l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
multiDataSet.setLabels(0, l);
multiDataSet.setLabelsMaskArray(0, null);
};
iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));
return iter;
}
@Override
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
sd.evaluate(iter, "out", 0, evaluations);
return evaluations;
}
}
}

View File

@ -368,7 +368,7 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !"); REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !");
REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!");
REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !");
REQUIRE_TRUE((retLastH && retLastC) || (!retLastH && !retLastC), 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !");
count = 0; count = 0;
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
@ -464,13 +464,21 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
} }
PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX)
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
const auto x = INPUT_VARIABLE(0); // input const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights const auto Wr = INPUT_VARIABLE(2); // recurrent weights
@ -495,7 +503,15 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
DataType hLType = hL != nullptr ? hL->dataType() : xType; DataType hLType = hL != nullptr ? hL->dataType() : xType;
DataType cLType = cL != nullptr ? cL->dataType() : xType; DataType cLType = cL != nullptr ? cL->dataType() : xType;
return block.isUseMKLDNN() && ( auto featuresSupported = (cellClip == 0) //Cell clipping not supported
&& retFullSeq //Always return full sequence in case of MKL DNN
&& !hasPH //Peephole connections not supported in MKL DNN
&& !hasSeqLen //Sequence length array not supported in MKL DNN
&& dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn]
&& directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat
&& retLastH == retLastC; //Return both lastH and lastC, or return neither (not just 1 or other)
return block.isUseMKLDNN() && featuresSupported && (
(xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) || (xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) ||
(xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) || (xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) ||
(xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8)) (xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8))

View File

@ -2148,7 +2148,7 @@ public class DifferentialFunctionFactory {
public SDVariable gatherNd(SDVariable df, SDVariable indices) { public SDVariable gatherNd(SDVariable df, SDVariable indices) {
validateDifferentialFunctionsameDiff(df); validateDifferentialFunctionsameDiff(df);
return new GatherNd(sameDiff(), df, indices, false).outputVariable(); return new GatherNd(sameDiff(), df, indices).outputVariable();
} }
public SDVariable trace(SDVariable in){ public SDVariable trace(SDVariable in){

View File

@ -26,6 +26,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.weightinit.WeightInitScheme; import org.nd4j.weightinit.WeightInitScheme;
import java.io.Serializable; import java.io.Serializable;
@ -244,7 +245,7 @@ public class SDVariable implements Serializable {
* @return new variable * @return new variable
*/ */
public SDVariable assign(Number value){ public SDVariable assign(Number value){
return sameDiff.scalarSet(this, value); return sameDiff.scalarSet(this, value.doubleValue());
} }
/** /**
@ -538,7 +539,7 @@ public class SDVariable implements Serializable {
* @return Output variable (result of mmul) * @return Output variable (result of mmul)
*/ */
public SDVariable mmul(String name, SDVariable other, @NonNull MMulTranspose mMulTranspose) { public SDVariable mmul(String name, SDVariable other, @NonNull MMulTranspose mMulTranspose) {
return sameDiff.mmul(name, this, other, mMulTranspose); return sameDiff.mmul(name, this, other, mMulTranspose.isTransposeA(), mMulTranspose.isTransposeB(), mMulTranspose.isTransposeResult());
} }
@ -1403,7 +1404,7 @@ public class SDVariable implements Serializable {
* @return Output variable * @return Output variable
*/ */
public SDVariable reshape(int... newShape){ public SDVariable reshape(int... newShape){
return sameDiff.reshape(this, newShape); return sameDiff.reshape(this, ArrayUtil.toLongArray(newShape));
} }
/** /**

View File

@ -53,6 +53,7 @@ import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
@ -78,6 +79,7 @@ import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.ND4JFileUtils; import org.nd4j.linalg.util.ND4JFileUtils;
import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Sets;
import org.nd4j.shade.guava.collect.Table; import org.nd4j.shade.guava.collect.Table;
import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.weightinit.WeightInitScheme; import org.nd4j.weightinit.WeightInitScheme;
@ -104,7 +106,6 @@ import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
* <p> * <p>
* In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)} * In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)}
*/ */
@AllArgsConstructor
@Slf4j @Slf4j
public class SameDiff extends SDBaseOps { public class SameDiff extends SDBaseOps {
protected static final String GRAD_FN_KEY = "grad"; protected static final String GRAD_FN_KEY = "grad";
@ -914,6 +915,8 @@ public class SameDiff extends SDBaseOps {
} }
private SameDiff() { private SameDiff() {
super(null);
super.sd = this;
functionFactory = new DifferentialFunctionFactory(this); functionFactory = new DifferentialFunctionFactory(this);
sameDiffFunctionInstances = new LinkedHashMap<>(); sameDiffFunctionInstances = new LinkedHashMap<>();
fieldVariableResolutionMapping = HashBasedTable.create(); fieldVariableResolutionMapping = HashBasedTable.create();
@ -4544,7 +4547,7 @@ public class SameDiff extends SDBaseOps {
} }
//Also exclude assert etc ops - doesn't make sense to return these "outputs" to user //Also exclude assert etc ops - doesn't make sense to return these "outputs" to user
if (v.getOutputOfOp() != null) { if (v.getOutputOfOp() != null && v.getVariable().dataType().isFPType()) {
String opName = v.getOutputOfOp(); String opName = v.getOutputOfOp();
SameDiffOp o = ops.get(opName); SameDiffOp o = ops.get(opName);
if (o.getOp() instanceof Assert) { if (o.getOp() instanceof Assert) {
@ -4621,12 +4624,6 @@ public class SameDiff extends SDBaseOps {
return varToUpdate; return varToUpdate;
} }
@Override
protected SameDiff sd() {
//Helper method for SDBaseOps etc
return this;
}
/** /**
* Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable. * Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable.
@ -5840,7 +5837,6 @@ public class SameDiff extends SDBaseOps {
* See {@link #generateNewVarName(String, int, boolean)} * See {@link #generateNewVarName(String, int, boolean)}
* existingOp is true. * existingOp is true.
*/ */
@Override
public String generateNewVarName(String base, int argIndex) { public String generateNewVarName(String base, int argIndex) {
return generateNewVarName(base, argIndex, true); return generateNewVarName(base, argIndex, true);
} }
@ -5868,4 +5864,261 @@ public class SameDiff extends SDBaseOps {
public String toString(){ public String toString(){
return "SameDiff(nVars=" + variables.size() + ",nOps=" + ops.size() + ")"; return "SameDiff(nVars=" + variables.size() + ",nOps=" + ops.size() + ")";
} }
/**
* See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
*/
public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda cond,
@NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
return ifCond(null, null, cond, trueBody, falseBody);
}
/**
* See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
*/
public SDVariable ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond,
@NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
return ifCond(null, ifName, cond, trueBody, falseBody);
}
/**
* Constructs a If statement using the tensorflow style control flow operations (Switch and Merge)
*
* If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody
*
* Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate.
*
* See <a href="http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf">Tensorflow Control Flow Implementation</a>
*
* @param outputName Name to give the output variable. If null, doesn't rename
* @param ifName The name of the if block. If null, uses "if"
* @param cond A lambda evaluating to the if condition
* @param trueBody A lambda to be executed if cond is true (the if block)
* @param falseBody A lambda to be executed if cond is false (the else block)
* @return The value of trueBody if cond is true, or falseBody if it isn't
*/
public SDVariable ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond,
@NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
ifName = newBlockName(ifName == null ? "if" : ifName);
NameScope ifScope = sd.withNameScope(ifName);
NameScope condScope = withNameScope("cond");
final SDVariable pred = cond.define(this);
condScope.close();
if (pred.dataType() != DataType.BOOL) {
//cleanup partially added block
for(SDVariable v : getVariablesInScope(ifScope))
this.getVariables().remove(v.name());
for(SameDiffOp op : this.getOpsInScope(ifScope)) {
for(String in : op.getInputsToOp()){
this.removeArgFromOp(in, op.getOp());
}
this.getOps().remove(op.getName());
}
throw new IllegalStateException("Can not use " + pred.name()
+ " as the condition of an If statement, the condition must be a boolean.");
}
final Map<String, SDVariable[]> switches = new HashMap<>();
final Set<String> declared = Sets.newHashSet(this.variableMap().keySet());
this.addArgumentInterceptor(new ArgumentInterceptor() {
@Override
public SDVariable intercept(SDVariable argument) {
// if its declared in the if, we don't care acout it
if(!declared.contains(argument.name()))
return argument;
// if we've already added a switch, move on
if(switches.containsKey(argument.name()))
return switches.get(argument.name())[1];
SDVariable[] s = f().switchOp(argument, pred);
switches.put(argument.name(), s);
return s[1];
}
});
NameScope trueScope = this.withNameScope("trueBody");
SDVariable trueOut = trueBody.define(this);
this.removeArgumentInterceptor();
if(declared.contains(trueOut.name())) {
SDVariable[] s = f().switchOp(trueOut, pred);
switches.put(trueOut.name(), s);
trueOut = s[1];
}
trueScope.close();
final Set<String> declared2 = Sets.newHashSet(variableMap().keySet());
sd.addArgumentInterceptor(new ArgumentInterceptor() {
@Override
public SDVariable intercept(SDVariable argument) {
// if its declared in the if, we don't care acout it
if(!declared2.contains(argument.name()))
return argument;
// if we've already added a switch, move on
if(switches.containsKey(argument.name()))
return switches.get(argument.name())[0];
SDVariable[] s = f().switchOp(argument, pred);
switches.put(argument.name(), s);
return s[0];
}
});
NameScope falseScope = this.withNameScope("falseBody");
SDVariable falseOut = falseBody.define(this);
this.removeArgumentInterceptor();
if(declared2.contains(falseOut.name())) {
SDVariable[] s = f().switchOp(falseOut, pred);
switches.put(falseOut.name(), s);
falseOut = s[0];
}
falseScope.close();
SDVariable output = f().merge(trueOut, falseOut);
ifScope.close();
return updateVariableNameAndReference(output, outputName);
}
/**
* See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
*/
public SDVariable[] whileLoop(@NonNull SDVariable[] loopVars,
@NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
return whileLoop(null, null, loopVars, cond, body);
}
/**
* See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
*/
public SDVariable[] whileLoop(String loopName, @NonNull SDVariable[] loopVars,
@NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
return whileLoop(null, loopName, loopVars, cond, body);
}
/**
* Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration)
*
* Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false
*
* Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations.
*
* See <a href="http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf">Tensorflow Control Flow Implementation</a>
*
* @param outputNames Names to give the output variables. If null, doesn't rename
* @param loopName The name of the loop block and frame (must be unique). If null, uses "if"
* @param loopVars Loop variables' inputs
* @param cond A lambda evaluating to the loop condition
* @param body A lambda doing the loop operation and returning the new loop variable values
* @return The values of the loop variables once condition is false
*/
public SDVariable[] whileLoop(String[] outputNames, final String loopName, @NonNull SDVariable[] loopVars,
@NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
final String frameName = this.newBlockName(loopName == null ? "while" : loopName);
NameScope loopScope = this.withNameScope(frameName);
//SDVariable counter = SD.scalar(SD.generateNewVarName("counter", 0), 0);
SDVariable[] entered = new SDVariable[loopVars.length];
for(int i = 0 ; i < loopVars.length ; i++){
entered[i] = f().enter(loopVars[i], frameName);
}
//counter = SD.f().enter(counter, frameName);
SDVariable[] merged = new SDVariable[loopVars.length];
Merge[] mergeOps = new Merge[loopVars.length];
for(int i = 0 ; i < loopVars.length ; i++){
// the second arg will later be replaced with the output of NextIteration
// but that isn't available yet (and can't be, as it depends on this)
mergeOps[i] = new Merge(this, entered[i], entered[i]);
merged[i] = mergeOps[i].outputVariable();
}
//Merge counterMerge = new Merge(SD, counter, counter);
//counter = counterMerge.outputVariable();
NameScope condScope = this.withNameScope("cond");
SDVariable cond_result = cond.define(this, merged);
condScope.close();
if (cond_result.dataType() != DataType.BOOL)
throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean.");
final Set<String> alreadyEntered = Sets.newHashSet();
SDVariable[] trueSwitches = new SDVariable[loopVars.length];
SDVariable[] exits = new SDVariable[loopVars.length];
for(int i = 0 ; i < loopVars.length ; i++){
SDVariable[] s = f().switchOp(merged[i], cond_result);
trueSwitches[i] = s[1];
alreadyEntered.add(s[1].name());
exits[i] = f().exit(s[0]);
}
//SDVariable[] cs = SD.f().switchOp(counter, cond_result);
//SDVariable counterExit = SD.f().exit(cs[0]);
//counter = cs[1];
final Set<String> declared = Sets.newHashSet(this.variableMap().keySet());
final Map<String, SDVariable> done = new HashMap<>();
this.addArgumentInterceptor(new ArgumentInterceptor() {
@Override
public SDVariable intercept(SDVariable argument) {
if(!declared.contains(argument.name()))
return argument;
if(alreadyEntered.contains(argument.name()))
return argument;
if(done.containsKey(argument.name()))
return done.get(argument.name());
SDVariable e = f().enter(argument, frameName, true);
done.put(argument.name(), e);
return e;
}
});
NameScope bodyScope = this.withNameScope("body");
SDVariable[] outs = body.define(this, trueSwitches);
bodyScope.close();
this.removeArgumentInterceptor();
//counter.add(1);
for(int i = 0 ; i < loopVars.length ; i++){
SDVariable n = f().nextIteration(outs[i]);
mergeOps[i].replaceArg(1,n);
}
//counterMerge.replaceArg(1, counter);
loopScope.close();
return updateVariableNamesAndReferences(exits, outputNames);
}
} }

View File

@ -23,8 +23,8 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String; import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.enums.DataFormat;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.enums.DataFormat;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
@ -753,6 +753,33 @@ public class SDCNN extends SDOps {
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
/**
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
*
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
*/
public SDVariable[] maxPoolWithArgmax(SDVariable input, Pooling2DConfig Pooling2DConfig) {
SDValidation.validateNumerical("maxPoolWithArgmax", "input", input);
return new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(sd,input, Pooling2DConfig).outputVariables();
}
/**
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
*
* @param names names May be null. Arrays of names for the output variables.
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
*/
public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable input,
Pooling2DConfig Pooling2DConfig) {
SDValidation.validateNumerical("maxPoolWithArgmax", "input", input);
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(sd,input, Pooling2DConfig).outputVariables();
return sd.updateVariableNamesAndReferences(out, names);
}
/** /**
* 2D Convolution layer operation - max pooling 2d <br> * 2D Convolution layer operation - max pooling 2d <br>
* *

View File

@ -2205,7 +2205,7 @@ public class SDMath extends SDOps {
* @param inputs Input variables (NUMERIC type) * @param inputs Input variables (NUMERIC type)
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public SDVariable mergeAdd(SDVariable[] inputs) { public SDVariable mergeAdd(SDVariable... inputs) {
SDValidation.validateNumerical("mergeAdd", "inputs", inputs); SDValidation.validateNumerical("mergeAdd", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable();
@ -2219,7 +2219,7 @@ public class SDMath extends SDOps {
* @param inputs Input variables (NUMERIC type) * @param inputs Input variables (NUMERIC type)
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public SDVariable mergeAdd(String name, SDVariable[] inputs) { public SDVariable mergeAdd(String name, SDVariable... inputs) {
SDValidation.validateNumerical("mergeAdd", "inputs", inputs); SDValidation.validateNumerical("mergeAdd", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable();
@ -2233,7 +2233,7 @@ public class SDMath extends SDOps {
* @param inputs Input variables (NUMERIC type) * @param inputs Input variables (NUMERIC type)
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public SDVariable mergeAvg(SDVariable[] inputs) { public SDVariable mergeAvg(SDVariable... inputs) {
SDValidation.validateNumerical("mergeAvg", "inputs", inputs); SDValidation.validateNumerical("mergeAvg", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable();
@ -2247,7 +2247,7 @@ public class SDMath extends SDOps {
* @param inputs Input variables (NUMERIC type) * @param inputs Input variables (NUMERIC type)
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public SDVariable mergeAvg(String name, SDVariable[] inputs) { public SDVariable mergeAvg(String name, SDVariable... inputs) {
SDValidation.validateNumerical("mergeAvg", "inputs", inputs); SDValidation.validateNumerical("mergeAvg", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable();
@ -2261,7 +2261,7 @@ public class SDMath extends SDOps {
* @param inputs Input variables (NUMERIC type) * @param inputs Input variables (NUMERIC type)
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public SDVariable mergeMax(SDVariable[] inputs) { public SDVariable mergeMax(SDVariable... inputs) {
SDValidation.validateNumerical("mergeMax", "inputs", inputs); SDValidation.validateNumerical("mergeMax", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable();
@ -2275,7 +2275,7 @@ public class SDMath extends SDOps {
* @param inputs Input variables (NUMERIC type) * @param inputs Input variables (NUMERIC type)
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public SDVariable mergeMax(String name, SDVariable[] inputs) { public SDVariable mergeMax(String name, SDVariable... inputs) {
SDValidation.validateNumerical("mergeMax", "inputs", inputs); SDValidation.validateNumerical("mergeMax", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable();

View File

@ -18,17 +18,15 @@
package org.nd4j.autodiff.samediff.ops; package org.nd4j.autodiff.samediff.ops;
import java.lang.String; import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import lombok.NonNull; import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
@ -43,28 +41,26 @@ public class SDRNN extends SDOps {
* @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
* @param GRUWeights Configuration Object * @param GRUWeights Configuration Object
* @return output The cell's outputs. (NUMERIC type)
*/ */
public SDVariable gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { public SDVariable[] gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) {
SDValidation.validateNumerical("gru", "x", x); SDValidation.validateNumerical("gru", "x", x);
SDValidation.validateNumerical("gru", "hLast", hLast); SDValidation.validateNumerical("gru", "hLast", hLast);
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariable(); return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables();
} }
/** /**
* The GRU cell. Does a single time step operation<br> * The GRU cell. Does a single time step operation<br>
* *
* @param name name May be null. Name for the output variable * @param names names May be null. Arrays of names for the output variables.
* @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
* @param GRUWeights Configuration Object * @param GRUWeights Configuration Object
* @return output The cell's outputs. (NUMERIC type)
*/ */
public GRUCellOutputs gru(String name, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { public SDVariable[] gru(String[] names, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) {
SDValidation.validateNumerical("gru", "x", x); SDValidation.validateNumerical("gru", "x", x);
SDValidation.validateNumerical("gru", "hLast", hLast); SDValidation.validateNumerical("gru", "hLast", hLast);
GRUCell c = new GRUCell(sd,x, hLast, GRUWeights); SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables();
return new GRUCellOutputs(c.outputVariables(name)); return sd.updateVariableNamesAndReferences(out, names);
} }
/** /**
@ -75,39 +71,172 @@ public class SDRNN extends SDOps {
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
* @param LSTMWeights Configuration Object * @param LSTMWeights Configuration Object
* @param LSTMConfiguration Configuration Object * @param LSTMConfiguration Configuration Object
* @return output The cell's outputs (NUMERIC type)
*/ */
public LSTMCellOutputs lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast, public SDVariable[] lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast,
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
SDValidation.validateNumerical("lstmCell", "x", x); SDValidation.validateNumerical("lstmCell", "x", x);
SDValidation.validateNumerical("lstmCell", "cLast", cLast); SDValidation.validateNumerical("lstmCell", "cLast", cLast);
SDValidation.validateNumerical("lstmCell", "yLast", yLast); SDValidation.validateNumerical("lstmCell", "yLast", yLast);
LSTMBlockCell c = new LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration); return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariables();
return new LSTMCellOutputs(c.outputVariables());
} }
/** /**
* The LSTM cell. Does a single time step operation.<br> * The LSTM cell. Does a single time step operation.<br>
* *
* @param name name May be null. Name for the output variable * @param names names May be null. Arrays of names for the output variables.
* @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type) * @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type)
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
* @param LSTMWeights Configuration Object * @param LSTMWeights Configuration Object
* @param LSTMConfiguration Configuration Object * @param LSTMConfiguration Configuration Object
* @return output The cell's outputs (NUMERIC type)
*/ */
public LSTMCellOutputs lstmCell(String name, SDVariable x, SDVariable cLast, SDVariable yLast, public SDVariable[] lstmCell(String[] names, SDVariable x, SDVariable cLast, SDVariable yLast,
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
SDValidation.validateNumerical("lstmCell", "x", x); SDValidation.validateNumerical("lstmCell", "x", x);
SDValidation.validateNumerical("lstmCell", "cLast", cLast); SDValidation.validateNumerical("lstmCell", "cLast", cLast);
SDValidation.validateNumerical("lstmCell", "yLast", yLast); SDValidation.validateNumerical("lstmCell", "yLast", yLast);
LSTMBlockCell c = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration); SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariables();
return new LSTMCellOutputs(c.outputVariables(name)); return sd.updateVariableNamesAndReferences(out, names);
} }
/** /**
* The LSTM layer. Does multiple time steps.<br> * Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type)
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type)
* @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type)
* @param LSTMLayerWeights Configuration Object
* @param LSTMLayerConfig Configuration Object
*/
public SDVariable[] lstmLayer(SDVariable x, SDVariable cLast, SDVariable yLast,
SDVariable maxTSLength, LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) {
SDValidation.validateNumerical("lstmLayer", "x", x);
SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig).outputVariables();
}
/**
* Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param names names May be null. Arrays of names for the output variables.
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type)
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type)
* @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type)
* @param LSTMLayerWeights Configuration Object
* @param LSTMLayerConfig Configuration Object
*/
public SDVariable[] lstmLayer(String[] names, SDVariable x, SDVariable cLast, SDVariable yLast,
SDVariable maxTSLength, LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) {
SDValidation.validateNumerical("lstmLayer", "x", x);
SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig).outputVariables();
return sd.updateVariableNamesAndReferences(out, names);
}
/**
* Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
* @param LSTMLayerWeights Configuration Object
* @param LSTMLayerConfig Configuration Object
*/
public SDVariable[] lstmLayer(SDVariable x, LSTMLayerWeights LSTMLayerWeights,
LSTMLayerConfig LSTMLayerConfig) {
SDValidation.validateNumerical("lstmLayer", "x", x);
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, null, null, null, LSTMLayerWeights, LSTMLayerConfig).outputVariables();
}
/**
* Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param names names May be null. Arrays of names for the output variables.
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
* @param LSTMLayerWeights Configuration Object
* @param LSTMLayerConfig Configuration Object
*/
public SDVariable[] lstmLayer(String[] names, SDVariable x, LSTMLayerWeights LSTMLayerWeights,
LSTMLayerConfig LSTMLayerConfig) {
SDValidation.validateNumerical("lstmLayer", "x", x);
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, null, null, null, LSTMLayerWeights, LSTMLayerConfig).outputVariables();
return sd.updateVariableNamesAndReferences(out, names);
}
/**
* The LSTM block<br>
* *
* @param maxTSLength (NUMERIC type) * @param maxTSLength (NUMERIC type)
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type) * @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
@ -117,17 +246,17 @@ public class SDRNN extends SDOps {
* @param LSTMConfiguration Configuration Object * @param LSTMConfiguration Configuration Object
* @return output The layer's outputs. (NUMERIC type) * @return output The layer's outputs. (NUMERIC type)
*/ */
public SDVariable lstmLayer(SDVariable maxTSLength, SDVariable x, SDVariable cLast, public SDVariable lstmblock(SDVariable maxTSLength, SDVariable x, SDVariable cLast,
SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); SDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength);
SDValidation.validateNumerical("lstmLayer", "x", x); SDValidation.validateNumerical("lstmblock", "x", x);
SDValidation.validateNumerical("lstmLayer", "cLast", cLast); SDValidation.validateNumerical("lstmblock", "cLast", cLast);
SDValidation.validateNumerical("lstmLayer", "yLast", yLast); SDValidation.validateNumerical("lstmblock", "yLast", yLast);
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable();
} }
/** /**
* The LSTM layer. Does multiple time steps.<br> * The LSTM block<br>
* *
* @param name name May be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param maxTSLength (NUMERIC type) * @param maxTSLength (NUMERIC type)
@ -138,13 +267,43 @@ public class SDRNN extends SDOps {
* @param LSTMConfiguration Configuration Object * @param LSTMConfiguration Configuration Object
* @return output The layer's outputs. (NUMERIC type) * @return output The layer's outputs. (NUMERIC type)
*/ */
public SDVariable lstmLayer(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast, public SDVariable lstmblock(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast,
SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); SDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength);
SDValidation.validateNumerical("lstmLayer", "x", x); SDValidation.validateNumerical("lstmblock", "x", x);
SDValidation.validateNumerical("lstmLayer", "cLast", cLast); SDValidation.validateNumerical("lstmblock", "cLast", cLast);
SDValidation.validateNumerical("lstmLayer", "yLast", yLast); SDValidation.validateNumerical("lstmblock", "yLast", yLast);
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* The LSTM block<br>
*
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
* @param LSTMWeights Configuration Object
* @param LSTMConfiguration Configuration Object
* @return output The layer's outputs. (NUMERIC type)
*/
public SDVariable lstmblock(SDVariable x, LSTMWeights LSTMWeights,
LSTMConfiguration LSTMConfiguration) {
SDValidation.validateNumerical("lstmblock", "x", x);
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,null, x, null, null, LSTMWeights, LSTMConfiguration).outputVariable();
}
/**
* The LSTM block<br>
*
* @param name name May be null. Name for the output variable
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
* @param LSTMWeights Configuration Object
* @param LSTMConfiguration Configuration Object
* @return output The layer's outputs. (NUMERIC type)
*/
public SDVariable lstmblock(String name, SDVariable x, LSTMWeights LSTMWeights,
LSTMConfiguration LSTMConfiguration) {
SDValidation.validateNumerical("lstmblock", "x", x);
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,null, x, null, null, LSTMWeights, LSTMConfiguration).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }

View File

@ -0,0 +1,45 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* Activations */
public enum CellAct {
TANH,
RELU,
SIGMOID,
AFFINE,
LEAKY_RELU,
THRESHHOLD_RELU,
SCALED_TAHN,
HARD_SIGMOID,
ELU,
SOFTSIGN,
SOFTPLUS
}

View File

@ -0,0 +1,45 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* Activations */
public enum GateAct {
TANH,
RELU,
SIGMOID,
AFFINE,
LEAKY_RELU,
THRESHHOLD_RELU,
SCALED_TAHN,
HARD_SIGMOID,
ELU,
SOFTSIGN,
SOFTPLUS
}

View File

@ -0,0 +1,36 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* for unidirectional:
* TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
* NST: shape [numExamples, inOutSize, timeLength]<br>
* NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout<br>
* for bidirectional:
* T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) */
public enum LSTMDataFormat {
TNS,
NST,
NTS,
T2NS
}

View File

@ -0,0 +1,38 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* direction <br>
* FWD: 0 = fwd
* BWD: 1 = bwd
* BIDIR_SUM: 2 = bidirectional sum
* BIDIR_CONCAT: 3 = bidirectional concat
* BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */
public enum LSTMDirectionMode {
FWD,
BWD,
BIDIR_SUM,
BIDIR_CONCAT,
BIDIR_EXTRA_DIM
}

View File

@ -0,0 +1,45 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* Activations */
public enum OutAct {
TANH,
RELU,
SIGMOID,
AFFINE,
LEAKY_RELU,
THRESHHOLD_RELU,
SCALED_TAHN,
HARD_SIGMOID,
ELU,
SOFTSIGN,
SOFTPLUS
}

View File

@ -0,0 +1,32 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* The data format of the input. Input shape depends on data format (in config):<br>
* TNS -> [timeSteps, batchSize, inSize]<br>
* NST -> [batchSize, inSize, timeSteps]<br>
* NTS -> [batchSize, timeSteps, inSize]<br> */
public enum RnnDataFormat {
TNS,
NST,
NTS
}

View File

@ -146,6 +146,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class,
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class,
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer.class,
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock.class,
org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU.class,
org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell.class,
org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss.class, org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss.class,

View File

@ -301,24 +301,27 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
} }
} }
protected void checkForWorkspaces(CustomOp op) { protected void checkForWorkspaces(CustomOp op, OpContext oc) {
for (val input: op.inputArguments()) List<INDArray> inArgs = oc != null ? oc.getInputArrays() : op.inputArguments();
List<INDArray> outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments();
for (val input: inArgs)
checkWorkspace(op.opName(), input); checkWorkspace(op.opName(), input);
for (val output: op.outputArguments()) for (val output: outArgs)
checkWorkspace(op.opName(), output); checkWorkspace(op.opName(), output);
} }
protected void checkForWorkspaces(Op op) { protected void checkForWorkspaces(Op op, OpContext oc) {
val x = op.x(); val x = oc != null ? oc.getInputArray(0) : op.x();
if (x != null) if (x != null)
checkWorkspace(op.opName(), x); checkWorkspace(op.opName(), x);
val y = op.y(); val y = oc != null && oc.getInputArrays().size() > 1 ? oc.getInputArray(1) : op.y();
if (y != null) if (y != null)
checkWorkspace(op.opName(), y); checkWorkspace(op.opName(), y);
val z = op.z(); val z = oc != null ? oc.getOutputArray(0) : op.z();
if (z != null) if (z != null)
checkWorkspace(op.opName(), z); checkWorkspace(op.opName(), z);
} }
@ -346,7 +349,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
OpProfiler.getInstance().processOpCall(op, tadBuffers); OpProfiler.getInstance().processOpCall(op, tadBuffers);
break; break;
case SCOPE_PANIC: case SCOPE_PANIC:
checkForWorkspaces(op); checkForWorkspaces(op, null);
return 0L; return 0L;
case DISABLED: case DISABLED:
default: default:
@ -357,7 +360,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
} }
@Deprecated @Deprecated
public long profilingHookIn(CustomOp op) { public long profilingHookIn(CustomOp op, OpContext oc) {
switch (profilingMode) { switch (profilingMode) {
case ALL: case ALL:
OpProfiler.getInstance().processOpCall(op); OpProfiler.getInstance().processOpCall(op);
@ -368,7 +371,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
OpProfiler.getInstance().processOpCall(op); OpProfiler.getInstance().processOpCall(op);
break; break;
case SCOPE_PANIC: case SCOPE_PANIC:
checkForWorkspaces(op); checkForWorkspaces(op, oc);
return 0L; return 0L;
case DISABLED: case DISABLED:
default: default:
@ -379,7 +382,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
} }
@Deprecated @Deprecated
public void profilingHookOut(Op op, long timeStart) { public void profilingHookOut(Op op, OpContext oc, long timeStart) {
switch (profilingMode) { switch (profilingMode) {
case ALL: case ALL:
OpProfiler.getInstance().processStackCall(op, timeStart); OpProfiler.getInstance().processStackCall(op, timeStart);
@ -392,14 +395,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
OpProfiler.getInstance().timeOpCall(op, timeStart); OpProfiler.getInstance().timeOpCall(op, timeStart);
break; break;
case NAN_PANIC: case NAN_PANIC:
OpExecutionerUtil.checkForNaN(op); OpExecutionerUtil.checkForNaN(op, oc);
break; break;
case INF_PANIC: case INF_PANIC:
OpExecutionerUtil.checkForInf(op); OpExecutionerUtil.checkForInf(op, oc);
break; break;
case ANY_PANIC: case ANY_PANIC:
OpExecutionerUtil.checkForNaN(op); OpExecutionerUtil.checkForNaN(op, oc);
OpExecutionerUtil.checkForInf(op); OpExecutionerUtil.checkForInf(op, oc);
break; break;
case DISABLED: case DISABLED:
default: default:
@ -413,7 +416,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
} }
@Deprecated @Deprecated
public void profilingHookOut(CustomOp op, long timeStart) { public void profilingHookOut(CustomOp op, OpContext oc, long timeStart) {
switch (profilingMode) { switch (profilingMode) {
case ALL: case ALL:
OpProfiler.getInstance().processStackCall(op, timeStart); OpProfiler.getInstance().processStackCall(op, timeStart);
@ -426,14 +429,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
OpProfiler.getInstance().timeOpCall(op, timeStart); OpProfiler.getInstance().timeOpCall(op, timeStart);
break; break;
case NAN_PANIC: case NAN_PANIC:
OpExecutionerUtil.checkForNaN(op); OpExecutionerUtil.checkForNaN(op, oc);
break; break;
case INF_PANIC: case INF_PANIC:
OpExecutionerUtil.checkForInf(op); OpExecutionerUtil.checkForInf(op, oc);
break; break;
case ANY_PANIC: case ANY_PANIC:
OpExecutionerUtil.checkForNaN(op); OpExecutionerUtil.checkForNaN(op, oc);
OpExecutionerUtil.checkForInf(op); OpExecutionerUtil.checkForInf(op, oc);
break; break;
case DISABLED: case DISABLED:
default: default:
@ -442,12 +445,15 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
} }
public long profilingConfigurableHookIn(CustomOp op) { public long profilingConfigurableHookIn(CustomOp op, OpContext oc) {
for (val arr: op.inputArguments()) List<INDArray> inArgs = oc != null ? oc.getInputArrays() : op.inputArguments();
List<INDArray> outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments();
for (val arr: inArgs)
if (arr.wasClosed()) if (arr.wasClosed())
throw new IllegalStateException("One of Input arguments was closed before call"); throw new IllegalStateException("One of Input arguments was closed before call");
for (val arr: op.outputArguments()) for (val arr: outArgs)
if (arr.wasClosed()) if (arr.wasClosed())
throw new IllegalStateException("One of Output arguments was closed before call"); throw new IllegalStateException("One of Output arguments was closed before call");
@ -460,7 +466,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
} }
if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) { if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) {
checkForWorkspaces(op); checkForWorkspaces(op, oc);
} }
return System.nanoTime(); return System.nanoTime();
@ -491,14 +497,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
OpProfiler.getInstance().processOpCall(op, tadBuffers); OpProfiler.getInstance().processOpCall(op, tadBuffers);
} }
if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) { if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) {
checkForWorkspaces(op); checkForWorkspaces(op, null);
} }
return System.nanoTime(); return System.nanoTime();
} }
public void profilingConfigurableHookOut(Op op, long timeStart) { public void profilingConfigurableHookOut(Op op, OpContext oc, long timeStart) {
if (OpProfiler.getInstance().getConfig() == null) if (OpProfiler.getInstance().getConfig() == null)
return; return;
@ -509,10 +515,10 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
OpProfiler.getInstance().timeOpCall(op, timeStart); OpProfiler.getInstance().timeOpCall(op, timeStart);
} }
if (OpProfiler.getInstance().getConfig().isCheckForNAN()) { if (OpProfiler.getInstance().getConfig().isCheckForNAN()) {
OpExecutionerUtil.checkForNaN(op); OpExecutionerUtil.checkForNaN(op, oc);
} }
if (OpProfiler.getInstance().getConfig().isCheckForINF()) { if (OpProfiler.getInstance().getConfig().isCheckForINF()) {
OpExecutionerUtil.checkForInf(op); OpExecutionerUtil.checkForInf(op, oc);
} }
if (OpProfiler.getInstance().getConfig().isNativeStatistics()) { if (OpProfiler.getInstance().getConfig().isNativeStatistics()) {
if (op.z() != null) { if (op.z() != null) {
@ -531,7 +537,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
} }
} }
public void profilingConfigurableHookOut(CustomOp op, long timeStart) { public void profilingConfigurableHookOut(CustomOp op, OpContext oc, long timeStart) {
if (OpProfiler.getInstance().getConfig() == null) if (OpProfiler.getInstance().getConfig() == null)
return; return;
@ -542,10 +548,10 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
OpProfiler.getInstance().timeOpCall(op, timeStart); OpProfiler.getInstance().timeOpCall(op, timeStart);
} }
if (OpProfiler.getInstance().getConfig().isCheckForNAN()) { if (OpProfiler.getInstance().getConfig().isCheckForNAN()) {
OpExecutionerUtil.checkForNaN(op); OpExecutionerUtil.checkForNaN(op, oc);
} }
if (OpProfiler.getInstance().getConfig().isCheckForINF()) { if (OpProfiler.getInstance().getConfig().isCheckForINF()) {
OpExecutionerUtil.checkForInf(op); OpExecutionerUtil.checkForInf(op, oc);
} }
} }

View File

@ -22,12 +22,15 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.profiler.OpProfiler; import org.nd4j.linalg.profiler.OpProfiler;
import java.util.List;
/**Utility functions for the DefaultOpExecutioner /**Utility functions for the DefaultOpExecutioner
* @author Alex Black * @author Alex Black
*/ */
@ -58,7 +61,7 @@ public class OpExecutionerUtil {
} }
if (match > 0) if (match > 0)
throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): "); throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s)");
} }
public static void checkForAny(INDArray z) { public static void checkForAny(INDArray z) {
@ -92,44 +95,52 @@ public class OpExecutionerUtil {
} }
public static void checkForNaN(Op op) { public static void checkForNaN(Op op, OpContext oc) {
if (!OpProfiler.getInstance().getConfig().isCheckForNAN()) if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
return; return;
if (op.z() != null && !(op instanceof MatchCondition)) { INDArray z = oc != null ? oc.getOutputArray(0) : op.z();
checkForNaN(op.z()); if (z != null && !(op instanceof MatchCondition)) {
checkForNaN(z);
} }
} }
public static void checkForInf(Op op) { public static void checkForInf(Op op, OpContext oc) {
if (!OpProfiler.getInstance().getConfig().isCheckForINF()) if (!OpProfiler.getInstance().getConfig().isCheckForINF())
return; return;
if (op.z() != null && !(op instanceof MatchCondition)) { INDArray z = oc != null ? oc.getOutputArray(0) : op.z();
checkForInf(op.z()); if (z != null && !(op instanceof MatchCondition)) {
checkForInf(z);
} }
} }
public static void checkForInf(CustomOp op) { public static void checkForInf(CustomOp op, OpContext oc) {
if (!OpProfiler.getInstance().getConfig().isCheckForINF()) if (!OpProfiler.getInstance().getConfig().isCheckForINF())
return; return;
for (val input: op.inputArguments()) List<INDArray> inArgs = oc != null ? oc.getInputArrays() : op.inputArguments();
List<INDArray> outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments();
for (val input: inArgs)
checkForInf(input); checkForInf(input);
for (val output: op.outputArguments()) for (val output: outArgs)
checkForInf(output); checkForInf(output);
} }
public static void checkForNaN(CustomOp op) { public static void checkForNaN(CustomOp op, OpContext oc) {
if (!OpProfiler.getInstance().getConfig().isCheckForNAN()) if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
return; return;
for (val input: op.inputArguments()) List<INDArray> inArgs = oc != null ? oc.getInputArrays() : op.inputArguments();
List<INDArray> outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments();
for (val input: inArgs)
checkForNaN(input); checkForNaN(input);
for (val output: op.outputArguments()) for (val output: outArgs)
checkForNaN(output); checkForNaN(output);
} }
} }

View File

@ -57,8 +57,12 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
addArgs(); addArgs();
} }
public MaxPoolWithArgmax(INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){ public MaxPoolWithArgmax(@NonNull INDArray input, @NonNull Pooling2DConfig config){
super(null, new INDArray[]{input}, new INDArray[]{output, outArgMax}); this(input, null, null, config);
}
public MaxPoolWithArgmax(@NonNull INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){
super(null, new INDArray[]{input}, wrapFilterNull(output, outArgMax));
config.setType(Pooling2D.Pooling2DType.MAX); config.setType(Pooling2D.Pooling2DType.MAX);
this.config = config; this.config = config;

View File

@ -45,7 +45,7 @@ public class SConv2D extends Conv2D {
} }
public SConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, public SConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights,
@NonNull SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) {
this(sameDiff, wrapFilterNull(layerInput, depthWeights, pointWeights, bias), conv2DConfig); this(sameDiff, wrapFilterNull(layerInput, depthWeights, pointWeights, bias), conv2DConfig);
} }

View File

@ -0,0 +1,144 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import lombok.Getter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
/**
* LSTM layer implemented as a single operation.
* Implementation of operation for LSTM layer with optional peep hole connections.<br>
* S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation and <a href="https://research.google.com/pubs/archive/43905.pdf">https://research.google.com/pubs/archive/43905.pdf</a><br>
* Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.<br>
* See also: <a href="https://arxiv.org/pdf/1503.04069.pdf">https://arxiv.org/pdf/1503.04069.pdf</a><br>
* <p>
* See also {@link LSTMBlockCell} - lstmBlockCell op is used internally at C++ level for computation.<br>
* <br>
* Input arrays:<br>
* 0: max sequence length; long/int64 scalar<br>
* 1: input [seqLength, bS, inSize] at time t<br>
* 2: previous/initial cell state [bS, numUnits]<br>
* 3: previous/initial output [bS, numUnits]<br>
* 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]<br>
* 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits]<br>
* 6: weights - cell peephole (t-1) connections to forget gate, [numUnits]<br>
* 7: weights - cell peephole (t) connections to output gate, [numUnits]<br>
* 8: biases, shape [4*numUnits]<br>
* <br>
* Input integer arguments: set via {@link LSTMConfiguration}<br>
* 0: if not zero, provide peephole connections<br>
* 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size]<br>
* <br>
* Input float arguments: set via {@link LSTMConfiguration}<br>
* 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training<br>
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped<br>
* <p>
* Output arrays:<br>
* 0: i - Input modulation gate activations, rank 3, shape as per dataFormat<br>
* 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat<br>
* 2: f - Output - forget gate activations, rank 3, shape as per dataFormat<br>
* 3: o - Output - output gate activations, rank 3, shape as per dataFormat<br>
* 4: z (ci) - Output - block input, rank 3, shape as per dataFormat<br>
* 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat<br>
* 6: y (h) - Current cell output, rank 3, shape as per dataFormat<br>
*
* @author Alex Black
*/
public class LSTMBlock extends DynamicCustomOp {
private LSTMConfiguration configuration;
@Getter
private LSTMWeights weights;
public LSTMBlock() {
}
public LSTMBlock(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast));
this.configuration = configuration;
this.weights = weights;
addIArgument(configuration.iArgs(true));
addTArgument(configuration.tArgs());
}
public LSTMBlock(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) {
super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast));
this.configuration = lstmConfiguration;
this.weights = lstmWeights;
addIArgument(configuration.iArgs(true));
addTArgument(configuration.tArgs());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 9, "Expected exactly 9 inputs to LSTMBlock, got %s", inputDataTypes);
//7 outputs, all of same type as input. Note that input 0 is max sequence length (int64), input 1 is actual input
DataType dt = inputDataTypes.get(1);
Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt);
return Arrays.asList(dt, dt, dt, dt, dt, dt, dt);
}
@Override
public List<SDVariable> doDiff(List<SDVariable> grads) {
throw new UnsupportedOperationException("Not yet implemented");
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
configuration = LSTMConfiguration.builder()
.forgetBias(attributesForNode.get("forget_bias").getF())
.clippingCellValue(attributesForNode.get("cell_clip").getF())
.peepHole(attributesForNode.get("use_peephole").getB())
.dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM
.build();
addIArgument(configuration.iArgs(true));
addTArgument(configuration.tArgs());
}
@Override
public String opName() {
return "lstmBlock";
}
@Override
public Map<String, Object> propertiesForFunction() {
return configuration.toProperties(true);
}
@Override
public String tensorflowName() {
return "BlockLSTM";
}
}

View File

@ -1,5 +1,5 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -13,7 +13,6 @@
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import lombok.Getter; import lombok.Getter;
@ -24,89 +23,103 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.shade.guava.primitives.Booleans;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import javax.xml.crypto.Data;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
/** /**
* LSTM layer implemented as a single operation. * LSTM layer implemented as a single operation.
* Implementation of operation for LSTM layer with optional peep hole connections.<br> * Implementation of operation for LSTM layer with optional peep hole connections.<br>
* S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation and <a href="https://research.google.com/pubs/archive/43905.pdf">https://research.google.com/pubs/archive/43905.pdf</a><br> * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation and <a href="https://research.google.com/pubs/archive/43905.pdf">https://research.google.com/pubs/archive/43905.pdf</a><br>
* Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.<br> * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.<br>
* See also: <a href="https://arxiv.org/pdf/1503.04069.pdf">https://arxiv.org/pdf/1503.04069.pdf</a><br> * See also: <a href="https://arxiv.org/pdf/1503.04069.pdf">https://arxiv.org/pdf/1503.04069.pdf</a><br>
* <p>
* See also {@link LSTMBlockCell} - lstmBlockCell op is used internally at C++ level for computation.<br>
* <br>
* Input arrays:<br> * Input arrays:<br>
* 0: max sequence length; long/int64 scalar<br> * 0: input <br>
* 1: input [seqLength, bS, inSize] at time t<br> * [sL, bS, nIn] when dataFormat - TNS <br>
* 2: previous/initial cell state [bS, numUnits]<br> * [bS, sL, nIn] when dataFormat - NST <br>
* 3: previous/initial output [bS, numUnits]<br> * [bS, nIn, sL] when dataFormat - NST <br>
* 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]<br> * 1: previous/initial cell state<br>
* 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits]<br> * shapes [nIn, 4*nOut] for FWD, BWD Direction Mode <br>
* 6: weights - cell peephole (t-1) connections to forget gate, [numUnits]<br> * shapes [2, nIn, 4*nOut] BIDIR_SUM, BIDIR_CONCAT and BIDIR_EXTRA_DIM Direction Mode <br>
* 7: weights - cell peephole (t) connections to output gate, [numUnits]<br> * 2: previous/initial output [bS, numUnits]<br>
* 8: biases, shape [4*numUnits]<br> * * shapes [nIn, 4*nOut] for FWD, BWD Direction Mode <br>
* <br> * * shapes [2, nIn, 4*nOut] BIDIR_SUM, BIDIR_CONCAT and BIDIR_EXTRA_DIM Direction Mode <br>
* Input integer arguments: set via {@link LSTMConfiguration}<br> * 3 max sequence length [bS] <br>
* 0: if not zero, provide peephole connections<br> * 4: LSTMLayerWeights - {@link LSTMLayerWeights} <br>
* 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size]<br> * 5: LSTMLayerConfig - {@link LSTMLayerConfig}<br>
* <br>
* Input float arguments: set via {@link LSTMConfiguration}<br>
* 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training<br>
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped<br>
* <p> * <p>
* Output arrays:<br> * Output arrays:<br>
* 0: i - Input modulation gate activations, rank 3, shape as per dataFormat<br> * 0: output h - rank 3 or 4, depends on DirectionMode and dataFormat<br>
* 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat<br> * 1: output at last step hL - rank 3 or 4, depends on DirectionMode and dataFormat<<br>
* 2: f - Output - forget gate activations, rank 3, shape as per dataFormat<br> * 2: cell state at last step cL - same shape as in hL<br>
* 3: o - Output - output gate activations, rank 3, shape as per dataFormat<br>
* 4: z (ci) - Output - block input, rank 3, shape as per dataFormat<br>
* 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat<br>
* 6: y (h) - Current cell output, rank 3, shape as per dataFormat<br>
*
* @author Alex Black
*/ */
public class LSTMLayer extends DynamicCustomOp { public class LSTMLayer extends DynamicCustomOp {
private LSTMConfiguration configuration; @Getter
private LSTMLayerConfig configuration;
@Getter @Getter
private LSTMWeights weights; private LSTMLayerWeights weights;
public LSTMLayer() { public LSTMLayer() {
} }
public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) { public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights weights, LSTMLayerConfig configuration) {
super(null, sameDiff, weights.argsWithInputs(maxTSLength, x, cLast, yLast)); super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast));
this.configuration = configuration; this.configuration = configuration;
this.weights = weights; this.weights = weights;
addIArgument(configuration.iArgs(true)); addIArgument(iArgs());
addTArgument(configuration.tArgs()); addTArgument(tArgs());
addBArgument(bArgs(weights, maxTSLength, yLast, cLast));
Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(),
"You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them");
} }
public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) { public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMLayerWeights lstmWeights, LSTMLayerConfig LSTMLayerConfig) {
super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast)); super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast));
this.configuration = lstmConfiguration; this.configuration = LSTMLayerConfig;
this.weights = lstmWeights; this.weights = lstmWeights;
addIArgument(configuration.iArgs(true)); addIArgument(iArgs());
addTArgument(configuration.tArgs()); addTArgument(tArgs());
addBArgument(bArgs(weights, maxTSLength, yLast, cLast));
Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(),
"You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them");
} }
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) { public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 9, "Expected exactly 9 inputs to LSTMLayer, got %s", inputDataTypes); Preconditions.checkState(inputDataTypes != null && 3 <= inputDataTypes.size() && inputDataTypes.size() <= 8, "Expected amount of inputs to LSTMLayer between 3 inputs minimum (input, Wx, Wr only) or 8 maximum, got %s", inputDataTypes);
//7 outputs, all of same type as input. Note that input 0 is max sequence length (int64), input 1 is actual input //7 outputs, all of same type as input. Note that input 0 is max sequence length (int64), input 1 is actual input
DataType dt = inputDataTypes.get(1); DataType dt = inputDataTypes.get(1);
ArrayList<DataType> list = new ArrayList<>();
if (configuration.isRetFullSequence()) {
list.add(dt);
}
if (configuration.isRetLastC()) {
list.add(dt);
}
if (configuration.isRetLastH()){
list.add(dt);
}
Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt); Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt);
return Arrays.asList(dt, dt, dt, dt, dt, dt, dt); return list;
} }
@Override @Override
@ -114,31 +127,61 @@ public class LSTMLayer extends DynamicCustomOp {
throw new UnsupportedOperationException("Not yet implemented"); throw new UnsupportedOperationException("Not yet implemented");
} }
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
configuration = LSTMConfiguration.builder()
.forgetBias(attributesForNode.get("forget_bias").getF())
.clippingCellValue(attributesForNode.get("cell_clip").getF())
.peepHole(attributesForNode.get("use_peephole").getB())
.dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM
.build();
addIArgument(configuration.iArgs(true));
addTArgument(configuration.tArgs());
}
@Override @Override
public String opName() { public String opName() {
return "lstmBlock"; return "lstmLayer";
} }
@Override @Override
public Map<String, Object> propertiesForFunction() { public Map<String, Object> propertiesForFunction() {
return configuration.toProperties(true); return configuration.toProperties(true, true);
}
public long[] iArgs() {
return new long[]{
configuration.getLstmdataformat().ordinal(),// INT_ARG(0)
configuration.getDirectionMode().ordinal(), // INT_ARG(1)
configuration.getGateAct().ordinal(), // INT_ARG(2)
configuration.getOutAct().ordinal(), // INT_ARG(3)
configuration.getCellAct().ordinal() // INT_ARG(4)
};
}
public double[] tArgs() {
return new double[]{this.configuration.getCellClip()}; // T_ARG(0)
}
public <T> boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) {
return new boolean[]{
weights.hasBias(), // hasBiases: B_ARG(0)
maxTSLength != null, // hasSeqLen: B_ARG(1)
yLast != null, // hasInitH: B_ARG(2)
cLast != null, // hasInitC: B_ARG(3)
weights.hasPH(), // hasPH: B_ARG(4)
configuration.isRetFullSequence(), //retFullSequence: B_ARG(5)
configuration.isRetLastH(), // retLastH: B_ARG(6)
configuration.isRetLastC() // retLastC: B_ARG(7)
};
} }
@Override @Override
public String tensorflowName() { public int getNumOutputs(){
return "BlockLSTM";
return Booleans.countTrue(
configuration.isRetFullSequence(), //retFullSequence: B_ARG(5)
configuration.isRetLastH(), // retLastH: B_ARG(6)
configuration.isRetLastC() // retLastC: B_ARG(7)
);
} }
} }

View File

@ -0,0 +1,48 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
/**
* integer numbers corresponding to activations:
* 0=tanh,
* 1=relu,
* 2=sigmoid,
* 3=affine,
* 4=leaky relu,
* 5= thresholded relu,
* 6=scaled tanh,
* 7=hard sigmoid,
* 8=ELU,
* 9=softsign,
* 10=softplus
*/
public enum LSTMActivations {
//Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end
TANH,
RELU,
SIGMOID,
AFFINE,
LEAKY_RELU,
THRESHHOLD_RELU,
SCALED_TAHN,
HARD_SIGMOID,
ELU,
SOFTSIGN,
SOFTPLUS
}

View File

@ -0,0 +1,41 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
/**
* notations <br>
* for unidirectional:
* TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
* NST: shape [numExamples, inOutSize, timeLength]<br>
* NTS: shape [numExamples, timeLength, inOutSize]<br>
* for bidirectional:
* T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
*/
public enum LSTMDataFormat {
//Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end
TNS,
NTS,
NST,
T2NS
}

View File

@ -0,0 +1,38 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
/**
* direction <br>
* FWD: 0 = fwd
* BWD: 1 = bwd
* BIDIR_SUM: 2 = bidirectional sum
* BIDIR_CONCAT: 3 = bidirectional concat
* BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */
// const auto directionMode = INT_ARG(1); // direction:
public enum LSTMDirectionMode {
//Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end
FWD,
BWD,
BIDIR_SUM,
BIDIR_CONCAT,
BIDIR_EXTRA_DIM
}

View File

@ -0,0 +1,119 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
import lombok.Builder;
import lombok.Data;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import java.util.LinkedHashMap;
import java.util.Map;
@Builder
@Data
public class LSTMLayerConfig {
/**
* notations <br>
* for unidirectional:
* TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
* NST: shape [numExamples, inOutSize, timeLength]<br>
* NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout<br>
* for bidirectional:
* T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
*/
@Builder.Default
private LSTMDataFormat lstmdataformat = LSTMDataFormat.TNS; //INT_ARG(0)
/**
* direction <br>
* FWD: 0 = fwd
* BWD: 1 = bwd
* BS: 2 = bidirectional sum
* BC: 3 = bidirectional concat
* BE: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
*/
@Builder.Default
private LSTMDirectionMode directionMode = LSTMDirectionMode.FWD; //INT_ARG(1)
/**
* Activation for input (i), forget (f) and output (o) gates
*/
@Builder.Default
private LSTMActivations gateAct = LSTMActivations.SIGMOID; // INT_ARG(2)
@Builder.Default
private LSTMActivations cellAct = LSTMActivations.TANH; // INT_ARG(3)
@Builder.Default
private LSTMActivations outAct = LSTMActivations.TANH; // INT_ARG(4)
/**
* indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
*/
@Builder.Default
private boolean retFullSequence = true; //B_ARG(5)
/**
* indicates whether to return output at last time step only,
* in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
*/
private boolean retLastH; //B_ARG(6)
/**
* indicates whether to return cells state at last time step only,
* in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
*/
private boolean retLastC; // B_ARG(7)
/**
* Cell clipping value, if it = 0 then do not apply clipping
*/
@Builder.Default
private double cellClip; //T_ARG(0)
public Map<String, Object> toProperties(boolean includeLSTMDataFormat, boolean includeLSTMDirectionMode) {
Map<String, Object> ret = new LinkedHashMap<>();
ret.put("gateAct", gateAct.ordinal());
ret.put("outAct", outAct.ordinal());
ret.put("cellAct", cellAct.ordinal());
ret.put("retFullSequence", retFullSequence);
ret.put("retLastH", retLastH);
ret.put("retLastC", retLastC);
ret.put("cellClip", cellClip);
if (includeLSTMDataFormat)
ret.put("LSTMDataFormat", lstmdataformat.ordinal());
if (includeLSTMDirectionMode)
ret.put("LSTMDirectionMode", directionMode.ordinal());
return ret;
}
}

View File

@ -2,13 +2,18 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import lombok.AccessLevel; import lombok.AccessLevel;
import lombok.Getter; import lombok.Getter;
import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
import org.nd4j.shade.guava.primitives.Booleans;
/** /**
* The outputs of a LSTM layer ({@link LSTMLayer}. * The outputs of a LSTM layer ({@link LSTMLayer}.
@ -16,165 +21,78 @@ import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
@Getter @Getter
public class LSTMLayerOutputs { public class LSTMLayerOutputs {
private RnnDataFormat dataFormat; /**
* The LSTM layer data format ({@link LSTMDataFormat}.
*/
private LSTMDataFormat dataFormat;
/** /**
* Output - input modulation gate activations. * output h:
* Shape depends on data format (in layer config):<br> * [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
* TNS -> [timeSteps, batchSize, numUnits]<br> * [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
* NST -> [batchSize, numUnits, timeSteps]<br> * [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
* NTS -> [batchSize, timeSteps, numUnits]<br> * [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
* [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1
* [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2
* [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
* numbers mean index in corresponding enums {@link LSTMDataFormat} and {@link LSTMDirectionMode}
*/ */
private SDVariable i; private SDVariable timeSeriesOutput;
/** /**
* Activations, cell state (pre tanh). * cell state at last step cL:
* Shape depends on data format (in layer config):<br> * [bS, nOut] when directionMode FWD or BWD
* TNS -> [timeSteps, batchSize, numUnits]<br> * 2, bS, nOut] when directionMode BIDIR_SUM, BIDIR_CONCAT or BIDIR_EXTRA_DIM
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/ */
private SDVariable c; private SDVariable lastCellStateOutput;
/** /**
* Output - forget gate activations. * output at last step hL:
* Shape depends on data format (in layer config):<br> * [bS, nOut] when directionMode FWD or BWD
* TNS -> [timeSteps, batchSize, numUnits]<br> * 2, bS, nOut] when directionMode BIDIR_SUM, BIDIR_CONCAT or BIDIR_EXTRA_DIM
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/ */
private SDVariable f; private SDVariable lastTimeStepOutput;
/**
* Output - output gate activations.
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
private SDVariable o;
/** public LSTMLayerOutputs(SDVariable[] outputs, LSTMLayerConfig lstmLayerConfig) {
* Output - input gate activations. Preconditions.checkArgument(outputs.length > 0 && outputs.length <= 3,
* Shape depends on data format (in layer config):<br> "Must have from 1 to 3 LSTM layer outputs, got %s", outputs.length);
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
private SDVariable z;
/** int i = 0;
* Cell state, post tanh. timeSeriesOutput = lstmLayerConfig.isRetFullSequence() ? outputs[i++] : null;
* Shape depends on data format (in layer config):<br> lastTimeStepOutput = lstmLayerConfig.isRetLastH() ? outputs[i++] : null;
* TNS -> [timeSteps, batchSize, numUnits]<br> lastCellStateOutput = lstmLayerConfig.isRetLastC() ? outputs[i++] : null;
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
private SDVariable h;
/**
* Current cell output.
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
private SDVariable y;
public LSTMLayerOutputs(SDVariable[] outputs, RnnDataFormat dataFormat){ this.dataFormat = lstmLayerConfig.getLstmdataformat();
Preconditions.checkArgument(outputs.length == 7,
"Must have 7 LSTM layer outputs, got %s", outputs.length);
i = outputs[0];
c = outputs[1];
f = outputs[2];
o = outputs[3];
z = outputs[4];
h = outputs[5];
y = outputs[6];
this.dataFormat = dataFormat;
} }
/**
* Get all outputs returned by the cell.
*/
public List<SDVariable> getAllOutputs(){
return Arrays.asList(i, c, f, o, z, h, y);
}
/** /**
* Get y, the output of the cell for all time steps. * Get h, the output of the cell for all time steps.
* * <p>
* Shape depends on data format (in layer config):<br> * Shape depends on data format defined in {@link LSTMLayerConfig }:<br>
* TNS -> [timeSteps, batchSize, numUnits]<br> * for unidirectional:
* NST -> [batchSize, numUnits, timeSteps]<br> * TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
* NTS -> [batchSize, timeSteps, numUnits]<br> * NST: shape [numExamples, inOutSize, timeLength]<br>
* NTS: shape [numExamples, timeLength, inOutSize] <br>
* for bidirectional:
* T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
*/ */
public SDVariable getOutput(){ public SDVariable getOutput() {
return y; Preconditions.checkArgument(timeSeriesOutput != null, "retFullSequence was setted as false in LSTMLayerConfig");
return timeSeriesOutput;
} }
/** public SDVariable getLastState() {
* Get c, the cell's state for all time steps. Preconditions.checkArgument(lastCellStateOutput != null, "retLastC was setted as false in LSTMLayerConfig");
* return lastCellStateOutput;
* Shape depends on data format (in layer config):<br>
* TNS -> [timeSteps, batchSize, numUnits]<br>
* NST -> [batchSize, numUnits, timeSteps]<br>
* NTS -> [batchSize, timeSteps, numUnits]<br>
*/
public SDVariable getState(){
return c;
} }
private SDVariable lastOutput = null; public SDVariable getLastOutput() {
Preconditions.checkArgument(lastTimeStepOutput != null, "retLastH was setted as false in LSTMLayerConfig");
/** return lastTimeStepOutput;
* Get y, the output of the cell, for the last time step.
*
* Has shape [batchSize, numUnits].
*/
public SDVariable getLastOutput(){
if(lastOutput != null)
return lastOutput;
switch (dataFormat){
case TNS:
lastOutput = getOutput().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
break;
case NST:
lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
break;
case NTS:
lastOutput = getOutput().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
break;
} }
return lastOutput;
}
private SDVariable lastState = null;
/**
* Get c, the state of the cell, for the last time step.
*
* Has shape [batchSize, numUnits].
*/
public SDVariable getLastState(){
if(lastState != null)
return lastState;
switch (dataFormat){
case TNS:
lastState = getState().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
break;
case NST:
lastState = getState().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
break;
case NTS:
lastState = getState().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
break;
}
return lastState;
}
} }

View File

@ -0,0 +1,99 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.util.ArrayUtil;
/**
* The weight configuration of a LSTMLayer. For {@link LSTMLayer}
* @author Alex Black
*/
@EqualsAndHashCode(callSuper = true)
@Data
@Builder
public class LSTMLayerWeights extends RNNWeights {
/**
* Input to hidden weights with a shape of [inSize, 4*numUnits].
*
* Input to hidden and hidden to hidden are concatenated in dimension 0,
* so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :].
*/
private SDVariable weights;
private INDArray iWeights;
/**
* hidden to hidden weights (aka "recurrent weights", with a shape of [numUnits, 4*numUnits].
*
*/
private SDVariable rWeights;
private INDArray irWeights;
/**
* Peephole weights, with a shape of [3*numUnits].
*/
private SDVariable peepholeWeights;
private INDArray iPeepholeWeights;
/**
* Input to hidden and hidden to hidden biases, with shape [4*numUnits].
*/
private SDVariable bias;
private INDArray iBias;
@Override
public SDVariable[] args() {
return filterNonNull(weights, rWeights, peepholeWeights, bias);
}
@Override
public INDArray[] arrayArgs() {
return filterNonNull(iWeights, irWeights, iPeepholeWeights, iBias);
}
@Override
public SDVariable[] argsWithInputs(SDVariable... inputs){
Preconditions.checkArgument(inputs.length == 4, "Expected 4 inputs, got %s", inputs.length); //Order: x, seqLen, yLast, cLast
//lstmLayer c++ op expects: x, Wx, Wr, Wp, b, seqLen, yLast, cLast
return ArrayUtil.filterNull(inputs[0], weights, rWeights, bias, inputs[1], inputs[2], inputs[3], peepholeWeights);
}
@Override
public INDArray[] argsWithInputs(INDArray... inputs) {
Preconditions.checkArgument(inputs.length == 4, "Expected 4 inputs, got %s", inputs.length); //Order: x, seqLen, yLast, cLast
//lstmLayer c++ op expects: x, Wx, Wr, Wp, b, seqLen, yLast, cLast
return ArrayUtil.filterNull(inputs[0], iWeights, irWeights, iBias, inputs[1], inputs[2], inputs[3], iPeepholeWeights);
}
public boolean hasBias() {
return (bias!=null||iBias!=null);
}
public boolean hasPH() {
return (peepholeWeights!=null||iPeepholeWeights!=null);
}
}

View File

@ -98,6 +98,7 @@ public class Mmul extends DynamicCustomOp {
addIArgument(ArrayUtil.fromBoolean(transposeX), addIArgument(ArrayUtil.fromBoolean(transposeX),
ArrayUtil.fromBoolean(transposeY), ArrayUtil.fromBoolean(transposeY),
ArrayUtil.fromBoolean(transposeZ)); ArrayUtil.fromBoolean(transposeZ));
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
} }
public Mmul(INDArray x, INDArray y) { public Mmul(INDArray x, INDArray y) {
@ -110,6 +111,7 @@ public class Mmul extends DynamicCustomOp {
addIArgument(ArrayUtil.fromBoolean(transposeX), addIArgument(ArrayUtil.fromBoolean(transposeX),
ArrayUtil.fromBoolean(transposeY), ArrayUtil.fromBoolean(transposeY),
ArrayUtil.fromBoolean(transposeZ)); ArrayUtil.fromBoolean(transposeZ));
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
} }
public Mmul() {} public Mmul() {}

View File

@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -49,6 +50,9 @@ public class BatchMmul extends DynamicCustomOp {
protected int N; protected int N;
protected int K; protected int K;
public BatchMmul(SameDiff sameDiff, SDVariable[] matricesA, SDVariable[] matricesB, boolean transposeA, boolean transposeB) {
this(sameDiff, ArrayUtils.addAll(matricesA, matricesB), transposeA, transposeB);
}
public BatchMmul(SameDiff sameDiff, public BatchMmul(SameDiff sameDiff,
SDVariable[] matrices, SDVariable[] matrices,
@ -85,6 +89,22 @@ public class BatchMmul extends DynamicCustomOp {
addArgs(); addArgs();
} }
public BatchMmul(INDArray[] matricesA, INDArray[] matricesB, boolean transposeA, boolean transposeB){
super(ArrayUtils.addAll(matricesA, matricesB), null);
this.batchSize = matricesA.length;
this.transposeA = transposeA ? 1 : 0;
this.transposeB = transposeB ? 1 : 0;
long[] firstShape = matricesA[0].shape();
long[] lastShape = matricesB[0].shape();
this.M = transposeA ? (int) firstShape[1]: (int) firstShape[0];
this.N = transposeA ? (int) firstShape[0]: (int) firstShape[1];
this.K = transposeB ? (int) lastShape[0]: (int) lastShape[1];
addArgs();
}
@Override @Override
public int getNumOutputs(){ public int getNumOutputs(){
return batchSize; return batchSize;

View File

@ -34,17 +34,12 @@ import java.util.List;
@NoArgsConstructor @NoArgsConstructor
public class GatherNd extends DynamicCustomOp { public class GatherNd extends DynamicCustomOp {
public GatherNd(SameDiff sameDiff, SDVariable[] inputs, SDVariable[] indices) { public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices) {
super(null, sameDiff, ArrayUtils.addAll(inputs, indices), false); super(null, sameDiff, new SDVariable[] {input, indices});
} }
public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices, boolean inPlace) { public GatherNd(INDArray df, INDArray indices) {
super(null, sameDiff, new SDVariable[] {input, indices}, inPlace); super(new INDArray[]{df, indices}, null);
}
public GatherNd(INDArray[] df, INDArray[] indices) {
addInputArgument(df);
addInputArgument(indices);
} }
@Override @Override

View File

@ -16,13 +16,16 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NonNull;
import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.NotImplementedException;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -41,21 +44,27 @@ public class Linspace extends DynamicCustomOp {
private DataType dataType; private DataType dataType;
public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) { public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) {
super(sameDiff, new SDVariable[0]); this(sameDiff, sameDiff.constant(start), sameDiff.constant(stop), sameDiff.constant(number), dataType);
addTArgument(start,stop);
addIArgument(number);
addDArgument(dataType);
} }
public Linspace(SameDiff sameDiff, SDVariable from, SDVariable to, SDVariable length, DataType dataType){ public Linspace(SameDiff sameDiff, SDVariable from, SDVariable to, SDVariable length, DataType dataType){
super(sameDiff, new SDVariable[]{from, to, length}); super(sameDiff, new SDVariable[]{from, to, length});
this.dataType = dataType; this.dataType = dataType;
addDArgument(dataType);
} }
public Linspace(DataType dataType, double start, double stop, long number) { public Linspace(DataType dataType, double start, double stop, long number) {
this(dataType, Nd4j.scalar(start), Nd4j.scalar(stop), Nd4j.scalar(number));
}
public Linspace(DataType dataType, INDArray start, INDArray stop, INDArray number) {
this(start, stop, number, dataType);
}
public Linspace(@NonNull INDArray start, @NonNull INDArray stop, @NonNull INDArray number, @NonNull DataType dataType) {
super(new INDArray[]{start, stop, number}, null);
this.dataType = dataType;
addDArgument(dataType); addDArgument(dataType);
addTArgument(start, stop);
addIArgument(number);
} }
public Linspace(){ } public Linspace(){ }

View File

@ -16,9 +16,11 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.ArrayList; import java.util.ArrayList;
@ -41,6 +43,11 @@ public class MeshGrid extends DynamicCustomOp {
this(sd, cartesian, inputs); this(sd, cartesian, inputs);
} }
public MeshGrid(@NonNull INDArray[] inputs, boolean cartesian){
super(inputs, null);
addIArgument(cartesian ? 1 : 0);
}
public MeshGrid(){ } public MeshGrid(){ }
@Override @Override

View File

@ -44,7 +44,6 @@ import java.util.Map;
public class Reshape extends DynamicCustomOp { public class Reshape extends DynamicCustomOp {
private long[] shape; private long[] shape;
private String arrName;
public Reshape(SameDiff sameDiff, SDVariable i_v, long[] shape) { public Reshape(SameDiff sameDiff, SDVariable i_v, long[] shape) {
super(null, sameDiff, new SDVariable[]{i_v}); super(null, sameDiff, new SDVariable[]{i_v});
@ -56,6 +55,12 @@ public class Reshape extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{i_v, shape}); super(null, sameDiff, new SDVariable[]{i_v, shape});
} }
public Reshape(INDArray in, long... shape){
super(new INDArray[]{in}, null);
this.shape = shape;
addIArgument(shape);
}
public Reshape(INDArray in, INDArray shape){ public Reshape(INDArray in, INDArray shape){
this(in, shape, null); this(in, shape, null);
} }

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -64,15 +65,19 @@ public class SequenceMask extends DynamicCustomOp {
addDArgument(dataType); addDArgument(dataType);
} }
public SequenceMask(INDArray input, int maxLen, DataType dataType) { public SequenceMask(@NonNull INDArray input, int maxLen, DataType dataType) {
addInputArgument(input); addInputArgument(input);
addIArgument(maxLen); addIArgument(maxLen);
this.dataType = dataType; this.dataType = dataType;
addDArgument(dataType); addDArgument(dataType);
} }
public SequenceMask(INDArray input, DataType dataType) { public SequenceMask(@NonNull INDArray input, @NonNull DataType dataType) {
addInputArgument(input); this(input, null, dataType);
}
public SequenceMask(@NonNull INDArray input, INDArray maxLength, @NonNull DataType dataType) {
super(wrapFilterNull(input, maxLength), null);
this.dataType = dataType; this.dataType = dataType;
addDArgument(dataType); addDArgument(dataType);
} }

View File

@ -59,6 +59,10 @@ public class Slice extends DynamicCustomOp {
addIArgument(size); addIArgument(size);
} }
public Slice(@NonNull INDArray input, @NonNull INDArray begin, @NonNull INDArray end){
super(new INDArray[]{input, begin, end}, null);
}
@Override @Override
public String opName() { public String opName() {
return "slice"; return "slice";

View File

@ -50,7 +50,7 @@ public class Stack extends DynamicCustomOp {
addArgs(); addArgs();
} }
public Stack(INDArray input, int axis) { public Stack(INDArray[] input, int axis) {
addInputArgument(input); addInputArgument(input);
this.jaxis = axis; this.jaxis = axis;
addArgs(); addArgs();

View File

@ -98,10 +98,16 @@ public class StridedSlice extends DynamicCustomOp {
public StridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask, public StridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask,
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
this(in, ArrayUtil.toLongArray(begin), ArrayUtil.toLongArray(end), ArrayUtil.toLongArray(strides),
beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
}
public StridedSlice(INDArray in, long[] begin, long[] end, long[] strides, int beginMask,
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
addInputArgument(in); addInputArgument(in);
this.begin = ArrayUtil.toLongArray(begin); this.begin = begin;
this.end = ArrayUtil.toLongArray(end); this.end = end;
this.strides = ArrayUtil.toLongArray(strides); this.strides = strides;
this.beginMask = beginMask; this.beginMask = beginMask;
this.endMask = endMask; this.endMask = endMask;
this.ellipsisMask = ellipsisMask; this.ellipsisMask = ellipsisMask;

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NonNull;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -67,6 +68,13 @@ public class Unstack extends DynamicCustomOp {
addArgs(); addArgs();
} }
public Unstack(@NonNull INDArray value, int axis, int num){
super(new INDArray[]{value}, null);
this.jaxis = axis;
this.num = num;
addArgs();
}
public Unstack(INDArray in, INDArray[] out, int axis){ public Unstack(INDArray in, INDArray[] out, int axis){
super(null, new INDArray[]{in}, out, null, (int[])null); super(null, new INDArray[]{in}, out, null, (int[])null);
this.jaxis = axis; this.jaxis = axis;
@ -136,7 +144,8 @@ public class Unstack extends DynamicCustomOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> f1) { public List<SDVariable> doDiff(List<SDVariable> f1) {
return Collections.singletonList(sameDiff.stack(jaxis, f1.toArray(new SDVariable[f1.size()]))); return Collections.singletonList(sameDiff.stack(jaxis, f1.toArray(new SDVariable[0])));
} }
@Override @Override

View File

@ -58,6 +58,10 @@ public class Pad extends DynamicCustomOp {
this(sd, in, padding, Mode.CONSTANT, padValue); this(sd, in, padding, Mode.CONSTANT, padValue);
} }
public Pad(@NonNull INDArray in, @NonNull INDArray padding, double padValue){
this(in, padding, null, Mode.CONSTANT, padValue);
}
public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){ public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){
super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out}); super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out});
Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType());

View File

@ -66,11 +66,8 @@ public class DynamicPartition extends DynamicCustomOp {
addArgs(); addArgs();
} }
public DynamicPartition(INDArray input, INDArray[] partitions, int numPartitions) { public DynamicPartition(INDArray input, INDArray partitions, int numPartitions) {
addInputArgument(input); addInputArgument(input);
for (INDArray part : partitions)
addInputArgument(part);
addIArgument(numPartitions); addIArgument(numPartitions);
} }

View File

@ -16,9 +16,11 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom; package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays; import java.util.Arrays;
@ -30,10 +32,14 @@ public class ListDiff extends DynamicCustomOp {
// //
} }
public ListDiff(SameDiff sd, SDVariable x, SDVariable y){ public ListDiff(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){
super(sd, new SDVariable[]{x, y}); super(sd, new SDVariable[]{x, y});
} }
public ListDiff(@NonNull INDArray x, @NonNull INDArray y){
super(new INDArray[]{x, y}, null);
}
@Override @Override
public String tensorflowName() { public String tensorflowName() {
return "ListDiff"; //Note: Seems to be renamed to tf.setdiff1d in public API? return "ListDiff"; //Note: Seems to be renamed to tf.setdiff1d in public API?

View File

@ -73,12 +73,8 @@ public class XwPlusB extends DynamicCustomOp {
SDVariable dLdOut = gradient.get(0); SDVariable dLdOut = gradient.get(0);
SDVariable dLdb = dLdOut.sum(0); SDVariable dLdb = dLdOut.sum(0);
SDVariable dLdIn = sameDiff.mmul(dLdOut, w, MMulTranspose.builder() SDVariable dLdIn = sameDiff.mmul(dLdOut, w, false, true, false);
.transposeB(true) SDVariable dLdW = sameDiff.mmul(in, dLdOut, true, false, false);
.build());
SDVariable dLdW = sameDiff.mmul(in, dLdOut, MMulTranspose.builder()
.transposeA(true)
.build());
return Arrays.asList(dLdIn, dLdW, dLdb); return Arrays.asList(dLdIn, dLdW, dLdb);
} }

View File

@ -28,6 +28,7 @@ import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter; import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -55,25 +56,12 @@ public class Cast extends BaseDynamicTransformOp {
addArgs(); addArgs();
} }
/* public Cast(@NonNull INDArray arg, @NonNull DataType dataType){
@Override super(new INDArray[]{arg}, null);
public void setValueFor(Field target, Object value) { this.typeDst = dataType;
if(value == null) { addArgs();
throw new ND4JIllegalStateException("Unable to set field " + target + " using null value!");
} }
// FIXME!
if (!(value instanceof DataType))
return;
try {
target.set(this, (DataType) value);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
*/
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

View File

@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
@ -73,6 +74,12 @@ public class Range extends DynamicCustomOp {
addDArgument(dataType); addDArgument(dataType);
} }
public Range(INDArray from, INDArray to, INDArray step, DataType dataType){
super(new INDArray[]{from, to, step}, null);
this.dataType = dataType;
addDArgument(dataType);
}
@Override @Override
public int opNum() { public int opNum() {

View File

@ -149,6 +149,60 @@ public class NDBase {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, false, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, false, dimensions));
} }
/**
* Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same<br>
* length and each pair taken from these sets has to have dimensions (M, N) and (N, K),<br>
* respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.<br>
* Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).<br>
* <br>
* The result of this operation will be a batch of multiplied matrices. The<br>
* result has the same length as both input batches and each output matrix is of shape (M, K).<br>
*
* @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type)
* @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type)
* @param transposeA Whether to transpose A arrays or not
* @param transposeB Whether to transpose B arrays or not
*/
public INDArray[] batchMmul(INDArray[] inputsA, INDArray[] inputsB, boolean transposeA,
boolean transposeB) {
NDValidation.validateNumerical("batchMmul", "inputsA", inputsA);
Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length);
NDValidation.validateNumerical("batchMmul", "inputsB", inputsB);
Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, transposeA, transposeB));
}
/**
* Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same<br>
* length and each pair taken from these sets has to have dimensions (M, N) and (N, K),<br>
* respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.<br>
* Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).<br>
* <br>
* The result of this operation will be a batch of multiplied matrices. The<br>
* result has the same length as both input batches and each output matrix is of shape (M, K).<br>
*
* @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type)
* @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type)
*/
public INDArray[] batchMmul(INDArray[] inputsA, INDArray... inputsB) {
NDValidation.validateNumerical("batchMmul", "inputsA", inputsA);
Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length);
NDValidation.validateNumerical("batchMmul", "inputsB", inputsB);
Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, false, false));
}
/**
* Cast the array to a new datatype - for example, Integer -> Float<br>
*
* @param arg Input variable to cast (NDARRAY type)
* @param datatype Datatype to cast to
* @return output Output array (after casting) (NDARRAY type)
*/
public INDArray castTo(INDArray arg, DataType datatype) {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(arg, datatype))[0];
}
/** /**
* Concatenate a set of inputs along the specified dimension.<br> * Concatenate a set of inputs along the specified dimension.<br>
* Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.<br> * Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.<br>
@ -161,7 +215,7 @@ public class NDBase {
* @param dimension Dimension to concatenate on * @param dimension Dimension to concatenate on
* @return output (NUMERIC type) * @return output (NUMERIC type)
*/ */
public INDArray concat(INDArray[] inputs, int dimension) { public INDArray concat(int dimension, INDArray... inputs) {
NDValidation.validateNumerical("concat", "inputs", inputs); NDValidation.validateNumerical("concat", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype");
@ -274,28 +328,26 @@ public class NDBase {
* @param x Input variable (NUMERIC type) * @param x Input variable (NUMERIC type)
* @param partitions 1D input with values 0 to numPartitions-1 (INT type) * @param partitions 1D input with values 0 to numPartitions-1 (INT type)
* @param numPartitions Number of partitions, >= 1 * @param numPartitions Number of partitions, >= 1
* @return output Output variables (equal in number to numPartitions) (NUMERIC type)
*/ */
public INDArray dynamicPartition(INDArray x, INDArray[] partitions, int numPartitions) { public INDArray[] dynamicPartition(INDArray x, INDArray partitions, int numPartitions) {
NDValidation.validateNumerical("dynamicPartition", "x", x); NDValidation.validateNumerical("dynamicPartition", "x", x);
NDValidation.validateInteger("dynamicPartition", "partitions", partitions); NDValidation.validateInteger("dynamicPartition", "partitions", partitions);
Preconditions.checkArgument(partitions.length >= 1, "partitions has incorrect size/length. Expected: partitions.length >= 1, got %s", partitions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions));
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions))[0];
} }
/** /**
* Dynamically merge the specified input arrays into a single array, using the specified indices<br> * Dynamically merge the specified input arrays into a single array, using the specified indices<br>
* *
* @param x Input variables. (NUMERIC type)
* @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type)
* @param x Input variables. (NUMERIC type)
* @return output Merged output variable (NUMERIC type) * @return output Merged output variable (NUMERIC type)
*/ */
public INDArray dynamicStitch(INDArray[] x, INDArray[] indices) { public INDArray dynamicStitch(INDArray[] indices, INDArray... x) {
NDValidation.validateNumerical("dynamicStitch", "x", x);
Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length);
NDValidation.validateInteger("dynamicStitch", "indices", indices); NDValidation.validateInteger("dynamicStitch", "indices", indices);
Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(x, indices))[0]; NDValidation.validateNumerical("dynamicStitch", "x", x);
Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x))[0];
} }
/** /**
@ -395,11 +447,9 @@ public class NDBase {
* @param indices (NUMERIC type) * @param indices (NUMERIC type)
* @return output (NUMERIC type) * @return output (NUMERIC type)
*/ */
public INDArray gatherNd(INDArray[] df, INDArray[] indices) { public INDArray gatherNd(INDArray df, INDArray indices) {
NDValidation.validateNumerical("gatherNd", "df", df); NDValidation.validateNumerical("gatherNd", "df", df);
Preconditions.checkArgument(df.length >= 1, "df has incorrect size/length. Expected: df.length >= 1, got %s", df.length);
NDValidation.validateNumerical("gatherNd", "indices", indices); NDValidation.validateNumerical("gatherNd", "indices", indices);
Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices))[0];
} }
@ -516,6 +566,23 @@ public class NDBase {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0];
} }
/**
* Create a new 1d array with values evenly spaced between values 'start' and 'stop'<br>
* For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]<br>
*
* @param start Start value (NUMERIC type)
* @param stop Stop value (NUMERIC type)
* @param number Number of values to generate (LONG type)
* @param dataType Data type of the output array
* @return output INDArray with linearly spaced elements (NUMERIC type)
*/
public INDArray linspace(INDArray start, INDArray stop, INDArray number, DataType dataType) {
NDValidation.validateNumerical("linspace", "start", start);
NDValidation.validateNumerical("linspace", "stop", stop);
NDValidation.validateInteger("linspace", "number", number);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType))[0];
}
/** /**
* Less than operation: elementwise x < y<br> * Less than operation: elementwise x < y<br>
* *
@ -1071,6 +1138,20 @@ public class NDBase {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType))[0];
} }
/**
* Array permutation operation: permute the dimensions according to the specified permutation indices.<br>
* Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]<br>
*
* @param x Input variable (NUMERIC type)
* @param dimensions Permute dimensions (INT type)
* @return output Output variable (permuted input) (NUMERIC type)
*/
public INDArray permute(INDArray x, INDArray dimensions) {
NDValidation.validateNumerical("permute", "x", x);
NDValidation.validateInteger("permute", "dimensions", dimensions);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0];
}
/** /**
* Array permutation operation: permute the dimensions according to the specified permutation indices.<br> * Array permutation operation: permute the dimensions according to the specified permutation indices.<br>
* Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]<br> * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]<br>
@ -1141,6 +1222,24 @@ public class NDBase {
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0];
} }
/**
* Create a new variable with a 1d array, where the values start at from and increment by step<br>
* up to (but not including) limit.<br>
* For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]<br>
*
* @param from Initial/smallest value (NUMERIC type)
* @param to Largest value (exclusive) (NUMERIC type)
* @param step Step size (NUMERIC type)
* @param dataType
* @return output INDArray with the specified values (NUMERIC type)
*/
public INDArray range(INDArray from, INDArray to, INDArray step, DataType dataType) {
NDValidation.validateNumerical("range", "from", from);
NDValidation.validateNumerical("range", "to", to);
NDValidation.validateNumerical("range", "step", step);
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0];
}
/** /**
* Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable<br> * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable<br>
* *
@ -1168,6 +1267,21 @@ public class NDBase {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, condition)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, condition));
} }
/**
* Element-wise replace where condition:<br>
* out[i] = value if condition(update[i]) is satisfied, or<br>
* out[i] = update[i] if condition(update[i]) is NOT satisfied<br>
*
* @param update Source array (NUMERIC type)
* @param value Value to set at the output, if the condition is satisfied
* @param condition Condition to check on update array elements
* @return output New array with values replaced where condition is satisfied (NUMERIC type)
*/
public INDArray replaceWhere(INDArray update, double value, Condition condition) {
NDValidation.validateNumerical("replaceWhere", "update", update);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(update, value, condition));
}
/** /**
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the<br> * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the<br>
* input, but with the specified shape.<br> * input, but with the specified shape.<br>
@ -1183,6 +1297,21 @@ public class NDBase {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0];
} }
/**
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the<br>
* input, but with the specified shape.<br>
* Note that prod(shape) must match length(input) == prod(input.shape)<br>
*
* @param x Input variable (NUMERIC type)
* @param shape New shape for variable (Size: AtLeast(min=0))
* @return output Output variable (NUMERIC type)
*/
public INDArray reshape(INDArray x, long... shape) {
NDValidation.validateNumerical("reshape", "x", x);
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0];
}
/** /**
* Reverse the values of an array for the specified dimensions<br> * Reverse the values of an array for the specified dimensions<br>
* If input is:<br> * If input is:<br>
@ -1532,6 +1661,21 @@ public class NDBase {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0];
} }
/**
* Generate a sequence mask (with values 0 or 1) based on the specified lengths <br>
* Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)<br>
*
* @param lengths Lengths of the sequences (NUMERIC type)
* @param maxLen Maximum sequence length (INT type)
* @param dataType
* @return output Output variable (NUMERIC type)
*/
public INDArray sequenceMask(INDArray lengths, INDArray maxLen, DataType dataType) {
NDValidation.validateNumerical("sequenceMask", "lengths", lengths);
NDValidation.validateInteger("sequenceMask", "maxLen", maxLen);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0];
}
/** /**
* see sequenceMask(String, SDVariable, SDVariable, DataType)<br> * see sequenceMask(String, SDVariable, SDVariable, DataType)<br>
* *
@ -1601,6 +1745,28 @@ public class NDBase {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0];
} }
/**
* Get a subset of the specified input, by specifying the first element and the size of the array.<br>
* For example, if input is:<br>
* [a, b, c]<br>
* [d, e, f]<br>
* then slice(input, begin=[0,1], size=[2,1] will return:<br>
* [b]<br>
* [e]<br>
* Note that for each dimension i, begin[i] + size[i] <= input.size(i)<br>
*
* @param input input Variable to get subset of (NUMERIC type)
* @param begin Beginning index. Must be same length as rank of input array (INT type)
* @param size Size of the output array. Must be same length as rank of input array (INT type)
* @return output Subset of the input (NUMERIC type)
*/
public INDArray slice(INDArray input, INDArray begin, INDArray size) {
NDValidation.validateNumerical("slice", "input", input);
NDValidation.validateInteger("slice", "begin", begin);
NDValidation.validateInteger("slice", "size", size);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0];
}
/** /**
* Squared L2 norm: see norm2(String, SDVariable, boolean, int...)<br> * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)<br>
* *
@ -1668,7 +1834,8 @@ public class NDBase {
* @param axis Axis to stack on * @param axis Axis to stack on
* @return output Output variable (NDARRAY type) * @return output Output variable (NDARRAY type)
*/ */
public INDArray stack(INDArray values, int axis) { public INDArray stack(int axis, INDArray... values) {
Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0];
} }
@ -1737,7 +1904,7 @@ public class NDBase {
* @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions
* @return output A subset of the input array (NUMERIC type) * @return output A subset of the input array (NUMERIC type)
*/ */
public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask, public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long[] strides, int beginMask,
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
NDValidation.validateNumerical("stridedSlice", "in", in); NDValidation.validateNumerical("stridedSlice", "in", in);
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
@ -1762,7 +1929,7 @@ public class NDBase {
* @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1))
* @return output A subset of the input array (NUMERIC type) * @return output A subset of the input array (NUMERIC type)
*/ */
public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int... strides) { public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long... strides) {
NDValidation.validateNumerical("stridedSlice", "in", in); NDValidation.validateNumerical("stridedSlice", "in", in);
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length);
@ -1999,6 +2166,21 @@ public class NDBase {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0];
} }
/**
* Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis.<br>
* If input has shape [a,b,c] then output has shape:<br>
* axis = 0: [b,c]<br>
* axis = 1: [a,c]<br>
* axis = 2: [a,b]<br>
*
* @param value Input variable to unstack (NDARRAY type)
* @param axis Axis to unstack on
* @param num Number of output variables
*/
public INDArray[] unstack(INDArray value, int axis, int num) {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Unstack(value, axis, num));
}
/** /**
* Variance array reduction operation, optionally along specified dimensions<br> * Variance array reduction operation, optionally along specified dimensions<br>
* *

View File

@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType; import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.enums.DataFormat;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
@ -32,7 +33,6 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.enums.DataFormat;
public class NDCNN { public class NDCNN {
public NDCNN() { public NDCNN() {
@ -370,6 +370,18 @@ public class NDCNN {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(input, LocalResponseNormalizationConfig))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(input, LocalResponseNormalizationConfig))[0];
} }
/**
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
*
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
*/
public INDArray[] maxPoolWithArgmax(INDArray input, Pooling2DConfig Pooling2DConfig) {
NDValidation.validateNumerical("maxPoolWithArgmax", "input", input);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(input, Pooling2DConfig));
}
/** /**
* 2D Convolution layer operation - max pooling 2d <br> * 2D Convolution layer operation - max pooling 2d <br>
* *

View File

@ -222,15 +222,12 @@ public class NDLoss {
* *
* @param label Label array (NUMERIC type) * @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type) * @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param epsilon epsilon
* @return output Log loss (NUMERIC type) * @return output Log loss (NUMERIC type)
*/ */
public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights, double epsilon) { public INDArray logLoss(INDArray label, INDArray predictions) {
NDValidation.validateNumerical("logLoss", "label", label); NDValidation.validateNumerical("logLoss", "label", label);
NDValidation.validateNumerical("logLoss", "predictions", predictions); NDValidation.validateNumerical("logLoss", "predictions", predictions);
NDValidation.validateNumerical("logLoss", "weights", weights); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0];
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, epsilon))[0];
} }
/** /**

View File

@ -190,6 +190,58 @@ public class NDMath {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(x)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(x));
} }
/**
* Bit shift operation<br>
*
* @param x input (NUMERIC type)
* @param shift shift value (NUMERIC type)
* @return output shifted output (NUMERIC type)
*/
public INDArray bitShift(INDArray x, INDArray shift) {
NDValidation.validateNumerical("bitShift", "x", x);
NDValidation.validateNumerical("bitShift", "shift", shift);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, shift))[0];
}
/**
* Right bit shift operation<br>
*
* @param x Input tensor (NUMERIC type)
* @param shift shift argument (NUMERIC type)
* @return output shifted output (NUMERIC type)
*/
public INDArray bitShiftRight(INDArray x, INDArray shift) {
NDValidation.validateNumerical("bitShiftRight", "x", x);
NDValidation.validateNumerical("bitShiftRight", "shift", shift);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, shift))[0];
}
/**
* Cyclic bit shift operation<br>
*
* @param x Input tensor (NUMERIC type)
* @param shift shift argy=ument (NUMERIC type)
* @return output shifted output (NUMERIC type)
*/
public INDArray bitShiftRotl(INDArray x, INDArray shift) {
NDValidation.validateNumerical("bitShiftRotl", "x", x);
NDValidation.validateNumerical("bitShiftRotl", "shift", shift);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, shift))[0];
}
/**
* Cyclic right shift operation<br>
*
* @param x Input tensor (NUMERIC type)
* @param shift Shift argument (NUMERIC type)
* @return output Shifted output (NUMERIC type)
*/
public INDArray bitShiftRotr(INDArray x, INDArray shift) {
NDValidation.validateNumerical("bitShiftRotr", "x", x);
NDValidation.validateNumerical("bitShiftRotr", "shift", shift);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, shift))[0];
}
/** /**
* Element-wise ceiling function: out = ceil(x).<br> * Element-wise ceiling function: out = ceil(x).<br>
* Rounds each value up to the nearest integer value (if not already an integer)<br> * Rounds each value up to the nearest integer value (if not already an integer)<br>
@ -346,13 +398,13 @@ public class NDMath {
* *
* @param x Input variable x (NUMERIC type) * @param x Input variable x (NUMERIC type)
* @param y Input variable y (NUMERIC type) * @param y Input variable y (NUMERIC type)
* @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=1)) * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0))
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public INDArray cosineDistance(INDArray x, INDArray y, int... dimensions) { public INDArray cosineDistance(INDArray x, INDArray y, int... dimensions) {
NDValidation.validateNumerical("cosineDistance", "x", x); NDValidation.validateNumerical("cosineDistance", "x", x);
NDValidation.validateNumerical("cosineDistance", "y", y); NDValidation.validateNumerical("cosineDistance", "y", y);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(x, y, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(x, y, dimensions));
} }
@ -363,13 +415,13 @@ public class NDMath {
* *
* @param x Input variable x (NUMERIC type) * @param x Input variable x (NUMERIC type)
* @param y Input variable y (NUMERIC type) * @param y Input variable y (NUMERIC type)
* @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=1)) * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0))
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public INDArray cosineSimilarity(INDArray x, INDArray y, int... dimensions) { public INDArray cosineSimilarity(INDArray x, INDArray y, int... dimensions) {
NDValidation.validateNumerical("cosineSimilarity", "x", x); NDValidation.validateNumerical("cosineSimilarity", "x", x);
NDValidation.validateNumerical("cosineSimilarity", "y", y); NDValidation.validateNumerical("cosineSimilarity", "y", y);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(x, y, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(x, y, dimensions));
} }
@ -501,13 +553,13 @@ public class NDMath {
* *
* @param x Input variable x (NUMERIC type) * @param x Input variable x (NUMERIC type)
* @param y Input variable y (NUMERIC type) * @param y Input variable y (NUMERIC type)
* @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=1)) * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0))
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public INDArray euclideanDistance(INDArray x, INDArray y, int... dimensions) { public INDArray euclideanDistance(INDArray x, INDArray y, int... dimensions) {
NDValidation.validateNumerical("euclideanDistance", "x", x); NDValidation.validateNumerical("euclideanDistance", "x", x);
NDValidation.validateNumerical("euclideanDistance", "y", y); NDValidation.validateNumerical("euclideanDistance", "y", y);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(x, y, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(x, y, dimensions));
} }
@ -665,13 +717,13 @@ public class NDMath {
* *
* @param x Input variable x (NUMERIC type) * @param x Input variable x (NUMERIC type)
* @param y Input variable y (NUMERIC type) * @param y Input variable y (NUMERIC type)
* @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=1)) * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0))
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public INDArray hammingDistance(INDArray x, INDArray y, int... dimensions) { public INDArray hammingDistance(INDArray x, INDArray y, int... dimensions) {
NDValidation.validateNumerical("hammingDistance", "x", x); NDValidation.validateNumerical("hammingDistance", "x", x);
NDValidation.validateNumerical("hammingDistance", "y", y); NDValidation.validateNumerical("hammingDistance", "y", y);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(x, y, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(x, y, dimensions));
} }
@ -817,13 +869,13 @@ public class NDMath {
* *
* @param x Input variable x (NUMERIC type) * @param x Input variable x (NUMERIC type)
* @param y Input variable y (NUMERIC type) * @param y Input variable y (NUMERIC type)
* @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=1)) * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0))
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public INDArray jaccardDistance(INDArray x, INDArray y, int... dimensions) { public INDArray jaccardDistance(INDArray x, INDArray y, int... dimensions) {
NDValidation.validateNumerical("jaccardDistance", "x", x); NDValidation.validateNumerical("jaccardDistance", "x", x);
NDValidation.validateNumerical("jaccardDistance", "y", y); NDValidation.validateNumerical("jaccardDistance", "y", y);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(x, y, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(x, y, dimensions));
} }
@ -872,6 +924,18 @@ public class NDMath {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, keepDims, condition, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, keepDims, condition, dimensions));
} }
/**
* Calculates difference between inputs X and Y.<br>
*
* @param x Input variable X (NUMERIC type)
* @param y Input variable Y (NUMERIC type)
*/
public INDArray[] listDiff(INDArray x, INDArray y) {
NDValidation.validateNumerical("listDiff", "x", x);
NDValidation.validateNumerical("listDiff", "y", y);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(x, y));
}
/** /**
* Element-wise logarithm function (base e - natural logarithm): out = log(x)<br> * Element-wise logarithm function (base e - natural logarithm): out = log(x)<br>
* *
@ -940,13 +1004,13 @@ public class NDMath {
* *
* @param x Input variable x (NUMERIC type) * @param x Input variable x (NUMERIC type)
* @param y Input variable y (NUMERIC type) * @param y Input variable y (NUMERIC type)
* @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=1)) * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0))
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public INDArray manhattanDistance(INDArray x, INDArray y, int... dimensions) { public INDArray manhattanDistance(INDArray x, INDArray y, int... dimensions) {
NDValidation.validateNumerical("manhattanDistance", "x", x); NDValidation.validateNumerical("manhattanDistance", "x", x);
NDValidation.validateNumerical("manhattanDistance", "y", y); NDValidation.validateNumerical("manhattanDistance", "y", y);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(x, y, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(x, y, dimensions));
} }
@ -983,7 +1047,7 @@ public class NDMath {
* @param inputs Input variables (NUMERIC type) * @param inputs Input variables (NUMERIC type)
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public INDArray mergeAdd(INDArray[] inputs) { public INDArray mergeAdd(INDArray... inputs) {
NDValidation.validateNumerical("mergeAdd", "inputs", inputs); NDValidation.validateNumerical("mergeAdd", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(inputs))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(inputs))[0];
@ -996,7 +1060,7 @@ public class NDMath {
* @param inputs Input variables (NUMERIC type) * @param inputs Input variables (NUMERIC type)
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public INDArray mergeAvg(INDArray[] inputs) { public INDArray mergeAvg(INDArray... inputs) {
NDValidation.validateNumerical("mergeAvg", "inputs", inputs); NDValidation.validateNumerical("mergeAvg", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(inputs))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(inputs))[0];
@ -1009,12 +1073,24 @@ public class NDMath {
* @param inputs Input variables (NUMERIC type) * @param inputs Input variables (NUMERIC type)
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public INDArray mergeMax(INDArray[] inputs) { public INDArray mergeMax(INDArray... inputs) {
NDValidation.validateNumerical("mergeMax", "inputs", inputs); NDValidation.validateNumerical("mergeMax", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMax(inputs))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMax(inputs))[0];
} }
/**
* Broadcasts parameters for evaluation on an N-D grid.<br>
*
* @param inputs (NUMERIC type)
* @param cartesian
*/
public INDArray[] meshgrid(INDArray[] inputs, boolean cartesian) {
NDValidation.validateNumerical("meshgrid", "inputs", inputs);
Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(inputs, cartesian));
}
/** /**
* Calculate the mean and (population) variance for the input variable, for the specified axis<br> * Calculate the mean and (population) variance for the input variable, for the specified axis<br>
* *

View File

@ -237,12 +237,11 @@ public class NDNN {
* Alpha value is most commonly set to 0.01<br> * Alpha value is most commonly set to 0.01<br>
* *
* @param x Input variable (NUMERIC type) * @param x Input variable (NUMERIC type)
* @param alpha Cutoff - commonly 0.01 (NUMERIC type) * @param alpha Cutoff - commonly 0.01
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public INDArray leakyRelu(INDArray x, INDArray alpha) { public INDArray leakyRelu(INDArray x, double alpha) {
NDValidation.validateNumerical("leakyRelu", "x", x); NDValidation.validateNumerical("leakyRelu", "x", x);
NDValidation.validateNumerical("leakyRelu", "alpha", alpha);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(x, alpha)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(x, alpha));
} }
@ -250,12 +249,11 @@ public class NDNN {
* Leaky ReLU derivative: dOut/dIn given input.<br> * Leaky ReLU derivative: dOut/dIn given input.<br>
* *
* @param x Input variable (NUMERIC type) * @param x Input variable (NUMERIC type)
* @param alpha Cutoff - commonly 0.01 (NUMERIC type) * @param alpha Cutoff - commonly 0.01
* @return output Output variable (NUMERIC type) * @return output Output variable (NUMERIC type)
*/ */
public INDArray leakyReluDerivative(INDArray x, INDArray alpha) { public INDArray leakyReluDerivative(INDArray x, double alpha) {
NDValidation.validateNumerical("leakyReluDerivative", "x", x); NDValidation.validateNumerical("leakyReluDerivative", "x", x);
NDValidation.validateNumerical("leakyReluDerivative", "alpha", alpha);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(x, alpha)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(x, alpha));
} }
@ -346,6 +344,20 @@ public class NDNN {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0];
} }
/**
* Padding operation <br>
*
* @param input Input tensor (NUMERIC type)
* @param padding Padding value (NUMERIC type)
* @param constant Padding constant
* @return output Padded input (NUMERIC type)
*/
public INDArray pad(INDArray input, INDArray padding, double constant) {
NDValidation.validateNumerical("pad", "input", input);
NDValidation.validateNumerical("pad", "padding", padding);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, constant))[0];
}
/** /**
* PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:<br> * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:<br>
* out[i] = in[i] if in[i] >= 0<br> * out[i] = in[i] if in[i] >= 0<br>
@ -461,6 +473,17 @@ public class NDNN {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0];
} }
/**
* Softmax activation, along the specified dimension<br>
*
* @param x Input (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray softmax(INDArray x) {
NDValidation.validateNumerical("softmax", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, -1))[0];
}
/** /**
* Softmax derivative function<br> * Softmax derivative function<br>
* *
@ -519,4 +542,15 @@ public class NDNN {
NDValidation.validateNumerical("swish", "x", x); NDValidation.validateNumerical("swish", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(x)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(x));
} }
/**
* Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray tanh(INDArray x) {
NDValidation.validateNumerical("tanh", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(x));
}
} }

View File

@ -22,7 +22,9 @@ import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.NDValidation;
@ -38,12 +40,11 @@ public class NDRNN {
* @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
* @param GRUWeights Configuration Object * @param GRUWeights Configuration Object
* @return output The cell's outputs. (NUMERIC type)
*/ */
public INDArray gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) { public INDArray[] gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) {
NDValidation.validateNumerical("gru", "x", x); NDValidation.validateNumerical("gru", "x", x);
NDValidation.validateNumerical("gru", "hLast", hLast); NDValidation.validateNumerical("gru", "hLast", hLast);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights));
} }
/** /**
@ -54,18 +55,83 @@ public class NDRNN {
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
* @param LSTMWeights Configuration Object * @param LSTMWeights Configuration Object
* @param LSTMConfiguration Configuration Object * @param LSTMConfiguration Configuration Object
* @return output The cell's outputs (NUMERIC type)
*/ */
public INDArray lstmCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights, public INDArray[] lstmCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights,
LSTMConfiguration LSTMConfiguration) { LSTMConfiguration LSTMConfiguration) {
NDValidation.validateNumerical("lstmCell", "x", x); NDValidation.validateNumerical("lstmCell", "x", x);
NDValidation.validateNumerical("lstmCell", "cLast", cLast); NDValidation.validateNumerical("lstmCell", "cLast", cLast);
NDValidation.validateNumerical("lstmCell", "yLast", yLast); NDValidation.validateNumerical("lstmCell", "yLast", yLast);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(x, cLast, yLast, LSTMWeights, LSTMConfiguration));
} }
/** /**
* The LSTM layer. Does multiple time steps.<br> * Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type)
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type)
* @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type)
* @param LSTMLayerWeights Configuration Object
* @param LSTMLayerConfig Configuration Object
*/
public INDArray[] lstmLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength,
LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) {
NDValidation.validateNumerical("lstmLayer", "x", x);
NDValidation.validateNumerical("lstmLayer", "cLast", cLast);
NDValidation.validateNumerical("lstmLayer", "yLast", yLast);
NDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig));
}
/**
* Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
* @param LSTMLayerWeights Configuration Object
* @param LSTMLayerConfig Configuration Object
*/
public INDArray[] lstmLayer(INDArray x, LSTMLayerWeights LSTMLayerWeights,
LSTMLayerConfig LSTMLayerConfig) {
NDValidation.validateNumerical("lstmLayer", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(x, null, null, null, LSTMLayerWeights, LSTMLayerConfig));
}
/**
* The LSTM block<br>
* *
* @param maxTSLength (NUMERIC type) * @param maxTSLength (NUMERIC type)
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type) * @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
@ -75,13 +141,27 @@ public class NDRNN {
* @param LSTMConfiguration Configuration Object * @param LSTMConfiguration Configuration Object
* @return output The layer's outputs. (NUMERIC type) * @return output The layer's outputs. (NUMERIC type)
*/ */
public INDArray lstmLayer(INDArray maxTSLength, INDArray x, INDArray cLast, INDArray yLast, public INDArray lstmblock(INDArray maxTSLength, INDArray x, INDArray cLast, INDArray yLast,
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
NDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); NDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength);
NDValidation.validateNumerical("lstmLayer", "x", x); NDValidation.validateNumerical("lstmblock", "x", x);
NDValidation.validateNumerical("lstmLayer", "cLast", cLast); NDValidation.validateNumerical("lstmblock", "cLast", cLast);
NDValidation.validateNumerical("lstmLayer", "yLast", yLast); NDValidation.validateNumerical("lstmblock", "yLast", yLast);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0]; return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0];
}
/**
* The LSTM block<br>
*
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
* @param LSTMWeights Configuration Object
* @param LSTMConfiguration Configuration Object
* @return output The layer's outputs. (NUMERIC type)
*/
public INDArray lstmblock(INDArray x, LSTMWeights LSTMWeights,
LSTMConfiguration LSTMConfiguration) {
NDValidation.validateNumerical("lstmblock", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(null, x, null, null, LSTMWeights, LSTMConfiguration))[0];
} }
/** /**

View File

@ -199,7 +199,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage()); throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, null, st);
return op.z(); return op.z();
} }
@ -436,7 +436,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage()); throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, null, st);
return op.z(); return op.z();
} }
@ -524,7 +524,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
long st = profilingConfigurableHookIn(op); long st = profilingConfigurableHookIn(op);
naiveExec(op, dimension); naiveExec(op, dimension);
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, null, st);
return op.z(); return op.z();
} }
@ -607,7 +607,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage()); throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, null, st);
return op.z(); return op.z();
} }
@ -772,7 +772,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage()); throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, oc, st);
return null; return null;
} }
@ -863,7 +863,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage()); throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, oc, st);
return null; return null;
@ -1113,7 +1113,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage()); throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, oc, st);
Nd4j.getExecutioner().commit(); Nd4j.getExecutioner().commit();
@ -1200,7 +1200,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage()); throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, null, st);
return null; return null;
} }
@ -1296,7 +1296,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage()); throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, oc, st);
return null; return null;
} }
@ -1460,7 +1460,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
if (ret != null) if (ret != null)
ret.elementWiseStride(); ret.elementWiseStride();
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, oc, st);
return null; return null;
} }
@ -1579,7 +1579,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage()); throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, oc, st);
return z; return z;
} }
@ -2292,7 +2292,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public INDArray[] exec(CustomOp op, OpContext context) { public INDArray[] exec(CustomOp op, OpContext context) {
long st = profilingConfigurableHookIn(op); long st = profilingConfigurableHookIn(op, context);
val ctx = AtomicAllocator.getInstance().getDeviceContext(); val ctx = AtomicAllocator.getInstance().getDeviceContext();
((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation()); ((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation());
@ -2304,7 +2304,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
if (status != 0) if (status != 0)
throw new RuntimeException("Op [" + op.opName() + "] execution failed"); throw new RuntimeException("Op [" + op.opName() + "] execution failed");
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, context, st);
if (context.getOutputArrays().isEmpty()) if (context.getOutputArrays().isEmpty())
return new INDArray[0]; return new INDArray[0];

View File

@ -236,7 +236,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
if (loop.lastErrorCode() != 0) if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage()); throw new RuntimeException(loop.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, oc, st);
return getZ(op, oc); return getZ(op, oc);
} }
@ -690,7 +690,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
if (loop.lastErrorCode() != 0) if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage()); throw new RuntimeException(loop.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, oc, st);
return getZ(op, oc); return getZ(op, oc);
} }
@ -774,7 +774,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
if (z == null) if (z == null)
setZ(Nd4j.create(op.resultType(), x.shape()), op, oc); setZ(Nd4j.create(op.resultType(), x.shape()), op, oc);
// op.setZ(Nd4j.create(op.resultType(), op.x().shape()));
op.validateDataTypes(oc, experimentalMode.get()); op.validateDataTypes(oc, experimentalMode.get());
@ -884,7 +883,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
if (loop.lastErrorCode() != 0) if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage()); throw new RuntimeException(loop.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, oc, st);
} }
public INDArray exec(BroadcastOp op) { public INDArray exec(BroadcastOp op) {
@ -1306,7 +1305,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
if (loop.lastErrorCode() != 0) if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage()); throw new RuntimeException(loop.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, oc, st);
return z; return z;
} }
@ -2040,7 +2039,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
@Override @Override
public INDArray[] exec(CustomOp op, @NonNull OpContext context) { public INDArray[] exec(CustomOp op, @NonNull OpContext context) {
long st = profilingConfigurableHookIn(op); long st = profilingConfigurableHookIn(op, context);
boolean mklOverride = false; boolean mklOverride = false;
try { try {
if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) { if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) {
@ -2125,7 +2124,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
} finally { } finally {
if (mklOverride) if (mklOverride)
Nd4jCpu.Environment.getInstance().setUseMKLDNN(true); Nd4jCpu.Environment.getInstance().setUseMKLDNN(true);
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, context, st);
} }
} }

View File

@ -20,8 +20,10 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -36,6 +38,12 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm; import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -578,8 +586,6 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable dW = sd.var("dW", depthWeightArr); SDVariable dW = sd.var("dW", depthWeightArr);
SDVariable b = sd.var("b", bArr); SDVariable b = sd.var("b", bArr);
SDVariable[] vars = new SDVariable[]{in, dW, b};
Conv2DConfig c = Conv2DConfig.builder() Conv2DConfig c = Conv2DConfig.builder()
.kH(kH).kW(kW) .kH(kH).kW(kW)
.pH(0).pW(0) .pH(0).pW(0)
@ -588,8 +594,8 @@ public class LayerOpValidation extends BaseOpValidation {
.isSameMode(false) .isSameMode(false)
.build(); .build();
SDVariable out = sd.cnn().separableConv2d(in, dW, b, c); SDVariable out = sd.cnn().separableConv2d(in, dW, null, b, c);
out = sd.f().tanh(out); out = sd.nn().tanh("out", out);
INDArray outArr = out.eval(); INDArray outArr = out.eval();
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
@ -623,8 +629,6 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable pW = sd.var("pW", pointWeightArr); SDVariable pW = sd.var("pW", pointWeightArr);
SDVariable b = sd.var("b", bArr); SDVariable b = sd.var("b", bArr);
//SDVariable[] vars = new SDVariable[]{in, dW, pW, b};
Conv2DConfig c = Conv2DConfig.builder() Conv2DConfig c = Conv2DConfig.builder()
.kH(kH).kW(kW) .kH(kH).kW(kW)
.pH(0).pW(0) .pH(0).pW(0)
@ -635,7 +639,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable out = sd.cnn().separableConv2d(in, dW, pW, b, c); SDVariable out = sd.cnn().separableConv2d(in, dW, pW, b, c);
out = sd.nn().tanh(out); out = sd.nn().tanh("out", out);
INDArray outArr = out.eval(); INDArray outArr = out.eval();
//Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7 //Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7
@ -675,8 +679,6 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable w = sd.var("W", wArr); SDVariable w = sd.var("W", wArr);
SDVariable b = sd.var("b", bArr); SDVariable b = sd.var("b", bArr);
SDVariable[] vars = new SDVariable[]{in, w, b};
DeConv2DConfig deconv = DeConv2DConfig.builder() DeConv2DConfig deconv = DeConv2DConfig.builder()
.kH(kH).kW(kW) .kH(kH).kW(kW)
.pH(0).pW(0) .pH(0).pW(0)
@ -685,8 +687,8 @@ public class LayerOpValidation extends BaseOpValidation {
.isSameMode(false) .isSameMode(false)
.build(); .build();
SDVariable out = sd.f().deconv2d(vars, deconv); SDVariable out = sd.cnn().deconv2d(in, w, b, deconv);
out = sd.f().tanh(out); out = sd.nn().tanh("out", out);
INDArray outArr = out.eval(); INDArray outArr = out.eval();
//Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9 //Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9
@ -723,7 +725,6 @@ public class LayerOpValidation extends BaseOpValidation {
//Order: https://github.com/deeplearning4j/libnd4j/blob/6c41ea5528bb1f454e92a9da971de87b93ff521f/include/ops/declarable/generic/convo/conv2d.cpp#L20-L22 //Order: https://github.com/deeplearning4j/libnd4j/blob/6c41ea5528bb1f454e92a9da971de87b93ff521f/include/ops/declarable/generic/convo/conv2d.cpp#L20-L22
//in, w, b - bias is optional //in, w, b - bias is optional
SDVariable[] vars = new SDVariable[]{in, w, b};
Conv2DConfig c = Conv2DConfig.builder() Conv2DConfig c = Conv2DConfig.builder()
.kH(kH).kW(kW) .kH(kH).kW(kW)
@ -733,8 +734,8 @@ public class LayerOpValidation extends BaseOpValidation {
.isSameMode(false) .isSameMode(false)
.build(); .build();
SDVariable out = sd.f().conv2d(vars, c); SDVariable out = sd.cnn().conv2d("conv", in, w, b, c);
out = sd.f().tanh(out); out = sd.nn().tanh("out", out);
INDArray outArr = out.eval(); INDArray outArr = out.eval();
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
@ -767,7 +768,7 @@ public class LayerOpValidation extends BaseOpValidation {
.isSameMode(true) .isSameMode(true)
.build(); .build();
SDVariable[] results = sd.f().maxPoolWithArgmax(/*new String[]{"out","idx"},*/ in, pooling2DConfig); SDVariable[] results = sd.cnn().maxPoolWithArgmax(new String[]{"out", "idx"}, in, pooling2DConfig);
assertArrayEquals(inArr.shape(), results[0].eval().shape()); assertArrayEquals(inArr.shape(), results[0].eval().shape());
assertArrayEquals(inArr.shape(), results[1].eval().shape()); assertArrayEquals(inArr.shape(), results[1].eval().shape());
} }
@ -797,7 +798,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig); SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig);
SDVariable out = sd.f().tanh(/*"out",*/ outPool); SDVariable out = sd.nn().tanh("out", outPool);
INDArray outArr = out.eval(); INDArray outArr = out.eval();
val outShape = outArr.shape(); val outShape = outArr.shape();
@ -855,7 +856,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig); SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig);
SDVariable out = sd.f().tanh(/*"out",*/ outPool); SDVariable out = sd.nn().tanh("out", outPool);
INDArray outArr = out.eval(); INDArray outArr = out.eval();
val outShape = outArr.shape(); val outShape = outArr.shape();
@ -906,7 +907,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig); SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig);
out = sd.f().tanh(/*"loss", */out).shape().rename("out"); out = sd.nn().tanh("loss", out).shape().rename("out");
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
INDArray outArr = Nd4j.createFromArray(mb, nIn, 4, 4, 4L); INDArray outArr = Nd4j.createFromArray(mb, nIn, 4, 4, 4L);
@ -942,7 +943,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig); SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig);
out = sd.math().tanh("loss", out).shape().rename("out"); out = sd.nn().tanh("loss", out).shape().rename("out");
sd.setLossVariables("loss"); sd.setLossVariables("loss");
@ -976,8 +977,8 @@ public class LayerOpValidation extends BaseOpValidation {
.paddingMode(PaddingMode.VALID) .paddingMode(PaddingMode.VALID)
.build(); .build();
SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig); SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
out = sd.math().tanh("loss", out).shape().rename("out"); out = sd.nn().tanh("loss", out).shape().rename("out");
sd.setLossVariables("loss"); sd.setLossVariables("loss");
@ -996,7 +997,7 @@ public class LayerOpValidation extends BaseOpValidation {
int nOut = 4; int nOut = 4;
int mb = 2; int mb = 2;
for( int k : new int[]{2, 3}) { for (int k : new int[]{2, 3}) {
for (int sz : new int[]{3, 4, 5}) { for (int sz : new int[]{3, 4, 5}) {
for (int s : new int[]{1, 2}) { for (int s : new int[]{1, 2}) {
for (int d : new int[]{1, 2}) { for (int d : new int[]{1, 2}) {
@ -1018,7 +1019,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable out = sd.cnn().conv1d(in, w, b, conv1DConfig); SDVariable out = sd.cnn().conv1d(in, w, b, conv1DConfig);
SDVariable loss = sd.f().tanh(out).std(true).rename("loss"); SDVariable loss = sd.nn().tanh(out).std(true).rename("loss");
sd.setLossVariables("loss"); sd.setLossVariables("loss");
@ -1039,7 +1040,7 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testConv1dForward(){ public void testConv1dForward() {
int nIn = 2; int nIn = 2;
int nOut = 1; int nOut = 1;
int kernel = 3; int kernel = 3;
@ -1057,7 +1058,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
SDVariable w = sd.var("w", wArr); SDVariable w = sd.var("w", wArr);
SDVariable res = sd.cnn.conv1d(in, w, null, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build());
INDArray expected = Nd4j.createFromArray( INDArray expected = Nd4j.createFromArray(
new double[][][]{ new double[][][]{
@ -1113,7 +1114,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig); SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig);
out = sd.math().tanh("loss", out).shape().rename("out"); out = sd.nn().tanh("loss", out).shape().rename("out");
sd.setLossVariables("loss"); sd.setLossVariables("loss");
@ -1156,7 +1157,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable out = sd.cnn().deconv3d(in, w, conv3DConfig); SDVariable out = sd.cnn().deconv3d(in, w, conv3DConfig);
out = sd.math().tanh("loss", out).shape().rename("out"); out = sd.nn().tanh("loss", out).shape().rename("out");
sd.setLossVariables("loss"); sd.setLossVariables("loss");
@ -1201,13 +1202,13 @@ public class LayerOpValidation extends BaseOpValidation {
public void testLayerNorm4d() { public void testLayerNorm4d() {
int mb = 3; int mb = 3;
int ch = 4; int ch = 4;
for(boolean nchw : new boolean[]{true, false}) { for (boolean nchw : new boolean[]{true, false}) {
double eps = 0.0; double eps = 0.0;
INDArray x = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch}); INDArray x = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch});
INDArray gain4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); INDArray gain4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
INDArray bias4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); INDArray bias4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
INDArray mean = x.mean(true, 1, 2, 3); INDArray mean = x.mean(true, 1, 2, 3);
INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1); INDArray std = Transforms.sqrt(x.var(false, 1, 2, 3).addi(eps)).reshape(mb, 1, 1, 1);
INDArray standardized = x.sub(mean).div(std); INDArray standardized = x.sub(mean).div(std);
INDArray exp = standardized.mul(gain4d).add(bias4d); INDArray exp = standardized.mul(gain4d).add(bias4d);
@ -1274,7 +1275,7 @@ public class LayerOpValidation extends BaseOpValidation {
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
final INDArray gain = Nd4j.rand(DataType.DOUBLE,4); final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4);
final INDArray res = standardized.mulRowVector(gain); final INDArray res = standardized.mulRowVector(gain);
final INDArray output = Nd4j.zerosLike(res); final INDArray output = Nd4j.zerosLike(res);
@ -1287,7 +1288,7 @@ public class LayerOpValidation extends BaseOpValidation {
public void testLayerNormNoDeviation() { public void testLayerNormNoDeviation() {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
random.putScalar(1,i, 7); random.putScalar(1, i, 7);
} }
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
@ -1335,7 +1336,7 @@ public class LayerOpValidation extends BaseOpValidation {
.paddingMode(PaddingMode.VALID) .paddingMode(PaddingMode.VALID)
.build(); .build();
SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig); SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
} }
@ -1391,16 +1392,16 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNormMixedOrders(){ public void testLayerNormMixedOrders() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f'); INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
INDArray bias = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); INDArray bias = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f'); INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'f');
INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c'); INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'c');
INDArray outFC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c'); INDArray outFC = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'c');
INDArray outCF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f'); INDArray outCF = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'f');
//F in, F out case //F in, F out case
Nd4j.exec(DynamicCustomOp.builder("layer_norm") Nd4j.exec(DynamicCustomOp.builder("layer_norm")
@ -1441,11 +1442,11 @@ public class LayerOpValidation extends BaseOpValidation {
public void testBiasAdd_nchw_nhwc() { public void testBiasAdd_nchw_nhwc() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for(boolean nchw : new boolean[]{true, false}) { for (boolean nchw : new boolean[]{true, false}) {
log.info("Starting test: {}", nchw ? "nchw" : "nhwc"); log.info("Starting test: {}", nchw ? "nchw" : "nhwc");
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2,4,3,3} : new long[]{2,3,3,4})); SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2, 4, 3, 3} : new long[]{2, 3, 3, 4}));
SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4})); SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4}));
SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw); SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw);
@ -1453,10 +1454,10 @@ public class LayerOpValidation extends BaseOpValidation {
INDArray exp = in.getArr().dup(); INDArray exp = in.getArr().dup();
if(nchw){ if (nchw) {
exp.addi(b.getArr().reshape(1,4,1,1)); exp.addi(b.getArr().reshape(1, 4, 1, 1));
} else { } else {
exp.addi(b.getArr().reshape(1,1,1,4)); exp.addi(b.getArr().reshape(1, 1, 1, 4));
} }
TestCase tc = new TestCase(sameDiff) TestCase tc = new TestCase(sameDiff)
@ -1467,4 +1468,168 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
} }
@Test
public void LSTMLayerTestCase1() {
int bS = 5;
int nIn = 3;
int numUnits = 7;
int sL = 10; //small just for test
SameDiff sd = SameDiff.create();
// notations:
// bS - batch size, numExamples
// sL - sequence length, number of time steps, timeLength
// nIn - input size, inOutSize
// TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
// NST: shape [numExamples, inOutSize, timeLength]<br>
// NTS: shape [numExamples, timeLength, inOutSize]<br>
// for bidirectional:
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, nIn, sL));
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
LSTMLayerConfig c = LSTMLayerConfig.builder()
.lstmdataformat(LSTMDataFormat.NST)
.directionMode(LSTMDirectionMode.FWD)
.gateAct(LSTMActivations.SIGMOID)
.cellAct(LSTMActivations.TANH)
.outAct(LSTMActivations.TANH)
.retFullSequence(true)
.retLastC(true)
.retLastH(true)
.build();
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
in, cLast, yLast, null,
LSTMLayerWeights.builder()
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
.peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits)))
.bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits))).build(),
c), c);
long[] out = new long[]{bS, numUnits, sL};
long[] hL = new long[]{bS, numUnits};
long[] cL = new long[]{bS, numUnits};
assertArrayEquals(out, outputs.getOutput().eval().shape());
assertArrayEquals(hL, outputs.getLastTimeStepOutput().eval().shape());
assertArrayEquals(cL, outputs.getLastCellStateOutput().eval().shape());
}
@Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824
public void LSTMLayerTestCase2() {
int bS = 5;
int nIn = 3;
int numUnits = 7;
int sL = 10; //small just for test
SameDiff sd = SameDiff.create();
// notations:
// bS - batch size, numExamples
// sL - sequence length, number of time steps, timeLength
// nIn - input size, inOutSize
// TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
// NST: shape [numExamples, inOutSize, timeLength]<br>
// NTS: shape [numExamples, timeLength, inOutSize]<br>
// for bidirectional:
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, sL, bS, nIn));
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
LSTMLayerConfig c = LSTMLayerConfig.builder()
.lstmdataformat(LSTMDataFormat.TNS)
.directionMode(LSTMDirectionMode.FWD)
.gateAct(LSTMActivations.SIGMOID)
.cellAct(LSTMActivations.TANH)
.outAct(LSTMActivations.TANH)
.retFullSequence(true)
.retLastC(false)
.retLastH(false)
.build();
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
in, cLast, yLast, null,
LSTMLayerWeights.builder()
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
.build(),
c), c);
long[] out = new long[]{sL, bS, numUnits};
assertArrayEquals(out, outputs.getOutput().eval().shape());
}
@Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824
public void LSTMLayerTestCase3() {
int bS = 5;
int nIn = 3;
int numUnits = 7;
int sL = 10; //small just for test
SameDiff sd = SameDiff.create();
// notations:
// bS - batch size, numExamples
// sL - sequence length, number of time steps, timeLength
// nIn - input size, inOutSize
// TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
// NST: shape [numExamples, inOutSize, timeLength]<br>
// NTS: shape [numExamples, timeLength, inOutSize]<br>
// for bidirectional:
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, sL, nIn));
// when directionMode >= 2 (BIDIR_CONCAT=3)
// Wx, Wr [2, nIn, 4*nOut]
// hI, cI [2, bS, nOut]
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits));
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits));
LSTMLayerConfig c = LSTMLayerConfig.builder()
.lstmdataformat(LSTMDataFormat.NTS)
.directionMode(LSTMDirectionMode.BIDIR_CONCAT)
.gateAct(LSTMActivations.SIGMOID)
.cellAct(LSTMActivations.SOFTPLUS)
.outAct(LSTMActivations.SOFTPLUS)
.retFullSequence(true)
.retLastC(false)
.retLastH(false)
.build();
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(new String[]{"out"},
in, cLast, yLast, null,
LSTMLayerWeights.builder()
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, 2, nIn, 4 * numUnits)))
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, 2, numUnits, 4 * numUnits)))
.build(),
c), c);
long[] out = new long[]{bS, sL, 2 * numUnits};
assertArrayEquals(out, outputs.getOutput().eval().shape());
}
} }

View File

@ -548,7 +548,7 @@ public class MiscOpValidation extends BaseOpValidation {
INDArray arr2 = Nd4j.rand(new long[]{2, 2, 2}); INDArray arr2 = Nd4j.rand(new long[]{2, 2, 2});
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
SDVariable y = sameDiff.var("y", arr2); SDVariable y = sameDiff.var("y", arr2);
SDVariable result = sameDiff.tensorMmul(x, y, new int[][]{{0}, {1}}); SDVariable result = sameDiff.tensorMmul(x, y, new int[]{0}, new int[]{1});
assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}), assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}),
result.eval().shape()); result.eval().shape());
assertEquals(16, sameDiff.numElements()); assertEquals(16, sameDiff.numElements());
@ -689,13 +689,7 @@ public class MiscOpValidation extends BaseOpValidation {
SDVariable a = sd.var("a", aArr); SDVariable a = sd.var("a", aArr);
SDVariable b = sd.var("b", bArr); SDVariable b = sd.var("b", bArr);
MMulTranspose mt = MMulTranspose.builder() SDVariable mmul = sd.mmul(a, b, transposeA, transposeB, transposeResult);
.transposeA(transposeA)
.transposeB(transposeB)
.transposeResult(transposeResult)
.build();
SDVariable mmul = sd.mmul(a, b, mt);
INDArray exp = (transposeA ? aArr.transpose() : aArr); INDArray exp = (transposeA ? aArr.transpose() : aArr);
exp = exp.mmul(transposeB ? bArr.transpose() : bArr); exp = exp.mmul(transposeB ? bArr.transpose() : bArr);

View File

@ -70,7 +70,7 @@ public class RnnOpValidation extends BaseOpValidation {
LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b) LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
.inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build(); .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf)); //Output order: i, c, f, o, z, h, y
List<String> toExec = new ArrayList<>(); List<String> toExec = new ArrayList<>();
for(SDVariable sdv : v.getAllOutputs()){ for(SDVariable sdv : v.getAllOutputs()){
toExec.add(sdv.name()); toExec.add(sdv.name());
@ -173,7 +173,7 @@ public class RnnOpValidation extends BaseOpValidation {
LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b) LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
.inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build(); .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf)); //Output order: i, c, f, o, z, h, y
List<String> toExec = new ArrayList<>(); List<String> toExec = new ArrayList<>();
for(SDVariable sdv : v.getAllOutputs()){ for(SDVariable sdv : v.getAllOutputs()){
toExec.add(sdv.name()); toExec.add(sdv.name());
@ -227,7 +227,7 @@ public class RnnOpValidation extends BaseOpValidation {
.cBias(bc) .cBias(bc)
.build(); .build();
List<SDVariable> v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs(); SDVariable[] v = sd.rnn().gru(x, hLast, weights);
List<String> toExec = new ArrayList<>(); List<String> toExec = new ArrayList<>();
for(SDVariable sdv : v){ for(SDVariable sdv : v){
toExec.add(sdv.name()); toExec.add(sdv.name());

View File

@ -119,7 +119,7 @@ public class ShapeOpValidation extends BaseOpValidation {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
for (int[] toShape : new int[][]{{3, 4 * 5}, {3 * 4, 5}, {1, 3 * 4 * 5}, {3 * 4 * 5, 1}}) { for (long[] toShape : new long[][]{{3, 4 * 5}, {3 * 4, 5}, {1, 3 * 4 * 5}, {3 * 4 * 5, 1}}) {
for(char order : new char[]{'c','f'}){ for(char order : new char[]{'c','f'}){
INDArray inArr = Nd4j.rand(DataType.DOUBLE, origShape, order).muli(100); INDArray inArr = Nd4j.rand(DataType.DOUBLE, origShape, order).muli(100);
@ -388,10 +388,10 @@ public class ShapeOpValidation extends BaseOpValidation {
@Builder(builderClassName = "Builder") @Builder(builderClassName = "Builder")
@Data @Data
private static class SSCase { private static class SSCase {
private int[] shape; private long[] shape;
private int[] begin; private long[] begin;
private int[] end; private long[] end;
private int[] strides; private long[] strides;
private int beginMask; private int beginMask;
private int endMask; private int endMask;
private int ellipsisMask; private int ellipsisMask;
@ -400,22 +400,22 @@ public class ShapeOpValidation extends BaseOpValidation {
public static class Builder { public static class Builder {
public Builder shape(int... shape) { public Builder shape(long... shape) {
this.shape = shape; this.shape = shape;
return this; return this;
} }
public Builder begin(int... begin) { public Builder begin(long... begin) {
this.begin = begin; this.begin = begin;
return this; return this;
} }
public Builder end(int... end) { public Builder end(long... end) {
this.end = end; this.end = end;
return this; return this;
} }
public Builder strides(int... strides) { public Builder strides(long... strides) {
this.strides = strides; this.strides = strides;
return this; return this;
} }
@ -1571,7 +1571,7 @@ public class ShapeOpValidation extends BaseOpValidation {
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2); INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
SDVariable x1 = sameDiff.var("x1", arr1); SDVariable x1 = sameDiff.var("x1", arr1);
SDVariable x2 = sameDiff.var("x2", arr2); SDVariable x2 = sameDiff.var("x2", arr2);
SDVariable result = sameDiff.parallel_stack(new SDVariable[]{x1, x2}); SDVariable result = sameDiff.stack(0, new SDVariable[]{x1, x2});
assertArrayEquals(new long[]{2, 3, 2}, result.eval().shape()); assertArrayEquals(new long[]{2, 3, 2}, result.eval().shape());
assertEquals(Nd4j.concat(0, arr1, arr2).reshape(2, 3, 2), result.eval()); assertEquals(Nd4j.concat(0, arr1, arr2).reshape(2, 3, 2), result.eval());
} }
@ -1661,9 +1661,9 @@ public class ShapeOpValidation extends BaseOpValidation {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
SDVariable slice_full = sd.stridedSlice(in, new int[]{0, 0}, new int[]{3, 4}, new int[]{1, 1}); SDVariable slice_full = sd.stridedSlice(in,new long[]{0, 0},new long[]{3, 4},new long[]{1, 1});
SDVariable subPart = sd.stridedSlice(in, new int[]{1, 2}, new int[]{3, 4}, new int[]{1, 1}); SDVariable subPart = sd.stridedSlice(in,new long[]{1, 2},new long[]{3, 4},new long[]{1, 1});
// SDVariable subPart2 = sd.stridedSlice(in, new int[]{0, 0}, new int[]{4, 5}, new int[]{2, 2}); // SDVariable subPart2 = sd.stridedSlice(in,new long[]{0, 0},new long[]{4, 5},new long[]{2, 2});
sd.outputAll(null); sd.outputAll(null);
@ -1679,8 +1679,8 @@ public class ShapeOpValidation extends BaseOpValidation {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
SDVariable slice1 = sd.stridedSlice(in, new int[]{-999, 0}, new int[]{2, 4}, new int[]{1, 1}, 1 << 1, 0, 0, 0, 0); SDVariable slice1 = sd.stridedSlice(in,new long[]{-999, 0},new long[]{2, 4},new long[]{1, 1}, 1 << 1, 0, 0, 0, 0);
SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 0}, new int[]{-999, 4}, new int[]{1, 1}, 0, 1, 0, 0, 0); SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 0},new long[]{-999, 4},new long[]{1, 1}, 0, 1, 0, 0, 0);
sd.outputAll(null); sd.outputAll(null);
@ -1695,9 +1695,9 @@ public class ShapeOpValidation extends BaseOpValidation {
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
//[1:3,...] -> [1:3,:,:] //[1:3,...] -> [1:3,:,:]
SDVariable slice = sd.stridedSlice(in, new int[]{1}, new int[]{3}, new int[]{1}, 0, 0, 1 << 1, 0, 0); SDVariable slice = sd.stridedSlice(in,new long[]{1},new long[]{3},new long[]{1}, 0, 0, 1 << 1, 0, 0);
//[1:3,...,1:4] -> [1:3,:,1:4] //[1:3,...,1:4] -> [1:3,:,1:4]
SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 1}, new int[]{3, 4}, new int[]{1, 1}, 0, 0, 1 << 1, 0, 0); SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 1},new long[]{3, 4},new long[]{1, 1}, 0, 0, 1 << 1, 0, 0);
sd.outputAll(Collections.emptyMap()); sd.outputAll(Collections.emptyMap());
@ -1710,7 +1710,7 @@ public class ShapeOpValidation extends BaseOpValidation {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
SDVariable slice = sd.stridedSlice(in, new int[]{-999, 0, 0, 0}, new int[]{-999, 3, 4, 5}, new int[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0); SDVariable slice = sd.stridedSlice(in,new long[]{-999, 0, 0, 0},new long[]{-999, 3, 4, 5},new long[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0);
INDArray out = slice.eval(); INDArray out = slice.eval();
@ -1723,7 +1723,7 @@ public class ShapeOpValidation extends BaseOpValidation {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
SDVariable slice = sd.stridedSlice(in, new int[]{1, 1, -999, 1}, new int[]{3, 3, -999, 4}, new int[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0); SDVariable slice = sd.stridedSlice(in,new long[]{1, 1, -999, 1},new long[]{3, 3, -999, 4},new long[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0);
INDArray out = slice.eval(); INDArray out = slice.eval();
assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape()); assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape());
@ -1735,9 +1735,9 @@ public class ShapeOpValidation extends BaseOpValidation {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
SDVariable slice = sd.stridedSlice(in, new int[]{0, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1); SDVariable slice = sd.stridedSlice(in,new long[]{0, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1);
SDVariable slice2 = sd.stridedSlice(in, new int[]{2, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1); SDVariable slice2 = sd.stridedSlice(in,new long[]{2, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1);
SDVariable slice3 = sd.stridedSlice(in, new int[]{1, 2, 1}, new int[]{-999, -999, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1); SDVariable slice3 = sd.stridedSlice(in,new long[]{1, 2, 1},new long[]{-999, -999, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1);
sd.outputAll(null); sd.outputAll(null);

View File

@ -1920,7 +1920,7 @@ public class TransformOpValidation extends BaseOpValidation {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable sdA = sd.var("a", a); SDVariable sdA = sd.var("a", a);
SDVariable sdB = sd.var("b", b); SDVariable sdB = sd.var("b", b);
SDVariable t = sd.mmul(sdA, sdB, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).transposeResult(transposeResult).build()); SDVariable t = sd.mmul(sdA, sdB, transposeA, transposeB, transposeResult);
t.norm1("out"); t.norm1("out");
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)

View File

@ -759,8 +759,7 @@ public class SameDiffTests extends BaseNd4jTest {
val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1); val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1);
val input1 = sd.var("input", matrix); val input1 = sd.var("input", matrix);
val input2 = sd.var("input2", vector); val input2 = sd.var("input2", vector);
val output = sd val output = sd.mmul("output", input1, input2, true, false, false);
.mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build());
INDArray out = output.eval(); INDArray out = output.eval();
assertArrayEquals(new long[]{3, 1}, out.shape()); assertArrayEquals(new long[]{3, 1}, out.shape());
} }
@ -2675,7 +2674,7 @@ public class SameDiffTests extends BaseNd4jTest {
final long timeSteps = sdInput.getShape()[2]; final long timeSteps = sdInput.getShape()[2];
SDVariable[] outputSlices = new SDVariable[(int) timeSteps]; SDVariable[] outputSlices = new SDVariable[(int) timeSteps];
final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2); final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2, 2);
final val x_0 = inputSlices[0]; final val x_0 = inputSlices[0];
outputSlices[0] = x_0; outputSlices[0] = x_0;
@ -2702,7 +2701,7 @@ public class SameDiffTests extends BaseNd4jTest {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
final SDVariable sdInput = sd.var("input", input); final SDVariable sdInput = sd.var("input", input);
final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2); final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2, 2);
final val temp = inputSlices[0].add(inputSlices[1]).div(inputSlices[1]).mul(inputSlices[0]); final val temp = inputSlices[0].add(inputSlices[1]).div(inputSlices[1]).mul(inputSlices[0]);
final val out = temp.add(temp).add(inputSlices[1]); final val out = temp.add(temp).add(inputSlices[1]);
out.norm2("out"); out.norm2("out");
@ -3242,61 +3241,61 @@ public class SameDiffTests extends BaseNd4jTest {
@Test @Test
public void testNestedIf() throws IOException { public void testNestedIf() throws IOException {
SameDiff SD = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable a = SD.var("a", Nd4j.createFromArray(2.0)); SDVariable a = sd.var("a", Nd4j.createFromArray(2.0));
SDVariable b = SD.var("b", Nd4j.createFromArray(5.0)); SDVariable b = sd.var("b", Nd4j.createFromArray(5.0));
SDVariable c = SD.var("c", Nd4j.createFromArray(9.0)); SDVariable c = sd.var("c", Nd4j.createFromArray(9.0));
SDVariable d = SD.var("d", Nd4j.createFromArray(-7.0)); SDVariable d = sd.var("d", Nd4j.createFromArray(-7.0));
SDVariable output = SD.ifCond("out", null, SDVariable output = sd.ifCond("out", null,
(sd) -> a.lt(b), (s) -> a.lt(b),
(sd) -> sd.ifCond( (s) -> s.ifCond(
(sd2) -> d.lte(0), (sd2) -> d.lte(0),
(sd2) -> c.add(1), (sd2) -> c.add(1),
(sd2) -> d), (sd2) -> d),
(sd) -> c.add(5)); (s) -> c.add(5));
INDArray out = output.eval(); INDArray out = output.eval();
assertEquals(Nd4j.createFromArray(10.0), out); assertEquals(Nd4j.createFromArray(10.0), out);
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false));
assertEquals(Nd4j.createFromArray(10.0), SD.output(Collections.emptyMap(), "out").get("out")); assertEquals(Nd4j.createFromArray(10.0), sd.output(Collections.emptyMap(), "out").get("out"));
} }
@Test @Test
public void testWhile() throws IOException { public void testWhile() throws IOException {
SameDiff SD = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable countIn = SD.constant(5); SDVariable countIn = sd.constant(5);
SDVariable sumIn = SD.constant(0); SDVariable sumIn = sd.constant(0);
SDVariable[] sum = SD.whileLoop("while_1", new SDVariable[]{countIn, sumIn}, SDVariable[] sum = sd.whileLoop("while_1", new SDVariable[]{countIn, sumIn},
(sd, vars) -> vars[0].gt(0), (s, vars) -> vars[0].gt(0),
(sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(vars[0])}); (s, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(vars[0])});
INDArray out = sum[1].eval(); INDArray out = sum[1].eval();
assertEquals(15, out.getInt(0)); assertEquals(15, out.getInt(0));
String outName = sum[1].name(); String outName = sum[1].name();
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false));
assertEquals(15, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); assertEquals(15, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0));
} }
@Test @Test
@Ignore @Ignore
public void testNestedWhile() throws IOException { public void testNestedWhile() throws IOException {
SameDiff SD = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable countIn = SD.constant(5); SDVariable countIn = sd.constant(5);
SDVariable sumIn = SD.constant(0); SDVariable sumIn = sd.constant(0);
SDVariable sum2 = SD.constant(0); SDVariable sum2 = sd.constant(0);
//TODO creating constant instead of using sum2 causes errors //TODO creating constant instead of using sum2 causes errors
SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn}, SDVariable[] sum = sd.whileLoop(new SDVariable[]{countIn, sumIn},
(sd, vars) -> vars[0].gt(0), (s, vars) -> vars[0].gt(0),
(sd, vars) -> new SDVariable[]{vars[0].sub(1), (s, vars) -> new SDVariable[]{vars[0].sub(1),
vars[1].add(sd.whileLoop(new SDVariable[]{vars[0], sum2}, vars[1].add(s.whileLoop(new SDVariable[]{vars[0], sum2},
(sd2, vars2) -> vars2[0].gt(0), (sd2, vars2) -> vars2[0].gt(0),
(sd2, vars2) -> new SDVariable[]{vars2[0].sub(1), vars2[1].add(vars2[0])})[1])}); (sd2, vars2) -> new SDVariable[]{vars2[0].sub(1), vars2[1].add(vars2[0])})[1])});
@ -3305,23 +3304,23 @@ public class SameDiffTests extends BaseNd4jTest {
String outName = sum[1].name(); String outName = sum[1].name();
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false));
assertEquals(35, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); assertEquals(35, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0));
} }
@Test @Test
public void testNestedWhileIf() throws IOException { public void testNestedWhileIf() throws IOException {
SameDiff SD = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable countIn = SD.constant(5); SDVariable countIn = sd.constant(5);
SDVariable sumIn = SD.constant(0); SDVariable sumIn = sd.constant(0);
SDVariable hundred = SD.constant(100); SDVariable hundred = sd.constant(100);
SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn}, SDVariable[] sum = sd.whileLoop(new SDVariable[]{countIn, sumIn},
(sd, vars) -> vars[0].gte(0), (s, vars) -> vars[0].gte(0),
(sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add( (s, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(
sd.ifCond((sd2) -> vars[0].eq(0), s.ifCond((sd2) -> vars[0].eq(0),
(sd2) -> vars[0].add(100), //TODO replace with hundred and things break (sd2) -> vars[0].add(100), //TODO replace with hundred and things break
(sd2) -> vars[0]) (sd2) -> vars[0])
)}); )});
@ -3331,9 +3330,9 @@ public class SameDiffTests extends BaseNd4jTest {
String outName = sum[1].name(); String outName = sum[1].name();
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false));
assertEquals(115, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); assertEquals(115, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0));
} }
@Test @Test

View File

@ -61,7 +61,7 @@ public class OpsMappingTests extends BaseNd4jTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 180000L; //Can be slow on some CI machines such as PPC return 360000L; //Can be very slow on some CI machines (PPC)
} }
@Test @Test

View File

@ -29,7 +29,10 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.shape.Concat;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -473,6 +476,7 @@ public class OperationProfilerTests extends BaseNd4jTest {
Nd4j.exec(op); //Should trigger NaN panic Nd4j.exec(op); //Should trigger NaN panic
fail(); fail();
} catch (Exception e){ } catch (Exception e){
e.printStackTrace();
assertTrue(e.getMessage(), e.getMessage().contains("Inf")); assertTrue(e.getMessage(), e.getMessage().contains("Inf"));
} }
@ -488,4 +492,55 @@ public class OperationProfilerTests extends BaseNd4jTest {
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(false).build()); Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(false).build());
} }
} }
@Test
public void testOpProfilerOpContextLegacy(){
for(boolean nan : new boolean[]{true, false}) {
INDArray in = Nd4j.valueArrayOf(10, nan ? -1 : 0).castTo(DataType.FLOAT);
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(nan).checkForINF(!nan).build());
OpContext oc = Nd4j.getExecutioner().buildContext();
oc.setInputArray(0, in);
oc.setOutputArray(0, in.ulike());
try {
Nd4j.exec(new Log(), oc);
System.out.println(oc.getOutputArray(0));
fail("Expected op profiler exception");
} catch (Throwable t) {
//OK
assertTrue(t.getMessage(), t.getMessage().contains(nan ? "NaN" : "Inf"));
}
}
}
@Test
public void testOpProfilerOpContextCustomOp(){
for(boolean nan : new boolean[]{true, false}) {
INDArray in = Nd4j.create(DataType.DOUBLE, 10).assign(nan ? Double.NaN : Double.POSITIVE_INFINITY);
INDArray in2 = in.dup();
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(nan).checkForINF(!nan).build());
OpContext oc = Nd4j.getExecutioner().buildContext();
oc.setIArguments(0);
oc.setInputArray(0, in);
oc.setInputArray(1, in2);
oc.setOutputArray(0, Nd4j.create(DataType.DOUBLE, 20));
try {
Nd4j.exec(new Concat(), oc);
System.out.println(oc.getOutputArray(0));
fail("Expected op profiler exception");
} catch (Throwable t) {
//OK
assertTrue(t.getMessage(), t.getMessage().contains(nan ? "NaN" : "Inf"));
}
}
}
} }

View File

@ -3579,4 +3579,19 @@ public class ArrayUtil {
} }
return false; return false;
} }
public static <T> T[] filterNull(T... in){
int count = 0;
for( int i=0; i<in.length; i++ ) {
if (in[i] != null) count++;
}
T[] out = (T[]) Array.newInstance(in.getClass().getComponentType(), count);
int j=0;
for( int i=0; i<in.length; i++ ){
if(in[i] != null){
out[j++] = in[i];
}
}
return out;
}
} }