DL4J and SameDiff integration tests + LSTMLayer java op class (#353)
* init in this branch Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Lenetet Mnist workflow Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * small fix for calculations Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * for Alex to check placeholder null pointer issue Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * CNN3D workflow Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * state for launching on dxg to regenterate dl4j examples Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * SD RNN test case workflow Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * small fixes Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * checkpoint at lstmBlock: Input array 1 (x) rank must be got input with rank 2 issue Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Fix LSTMLayer inputs order Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * lstm mismatch with c++ op issue Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * LSTMLayer config draft Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * LSTMLayer config draft v2 Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * have doubt I had to do this Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * NDRNN generated by codegen Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * LSTMLayerTestCases draft Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * minor fixes again * added LSTMLayer testcases to nd4j-tests + setted Preconditions in LSTMLayer constructors Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * added lost SDCNNtestcases Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * overrided getNumOutputs from DynamicCustomOp in LSTMLayer and reorganized LSTMLayerOutputs according to cpp op Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * finished with LSTMLayerOutputs Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Fix MKLDNN platform checks (i.e., when MKLDNN can be used vs. not) Signed-off-by: Alex Black <blacka101@gmail.com> * Fix LSTMLayerWeights input order Signed-off-by: Alex Black <blacka101@gmail.com> * More fixes Signed-off-by: Alex Black <blacka101@gmail.com> * minor fixes Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * fixed LSTMLayer testcases Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * finished SameDiffRNNTestCase Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * finished all testcases + minor fixes Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Multiple generation-related fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fix multiple issues Signed-off-by: Alex Black <blacka101@gmail.com> * More fixes Signed-off-by: Alex Black <blacka101@gmail.com> * LSTM fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Regenerate ND4J namespaces and fix multiple issues Signed-off-by: Alex Black <blacka101@gmail.com> * changed SameDiffRNNTestCase Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Small fix Signed-off-by: Alex Black <blacka101@gmail.com> * added Nd4j.getRandom().setSeed(12345) where needed Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * #8828 Fix ND4J profiler NaN/Inf checks when using OpContext Signed-off-by: Alex Black <blacka101@gmail.com> * #8828 Fix ND4J profiler NaN/Inf checks when using OpContext Signed-off-by: Alex Black <blacka101@gmail.com> * Tweak to weight init for SameDiff CNN test case Signed-off-by: Alex Black <blacka101@gmail.com> * Tweaks for test cases Signed-off-by: Alex Black <blacka101@gmail.com> * Ignore failing tests until fixed Signed-off-by: Alex Black <blacka101@gmail.com> * Fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
ab083b9167
commit
d86dd5b131
|
@ -25,7 +25,7 @@ import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Getter
|
@Getter
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = true)
|
||||||
public class EvaluationCalibration extends org.nd4j.evaluation.classification.EvaluationCalibration implements org.deeplearning4j.eval.IEvaluation<org.nd4j.evaluation.classification.EvaluationCalibration> {
|
public class EvaluationCalibration extends org.nd4j.evaluation.classification.EvaluationCalibration implements org.deeplearning4j.eval.IEvaluation<org.nd4j.evaluation.classification.EvaluationCalibration> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -185,7 +185,9 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
|
||||||
final val R = paramTable.get(RECURRENT_WEIGHT_KEY);
|
final val R = paramTable.get(RECURRENT_WEIGHT_KEY);
|
||||||
final val b = paramTable.get(BIAS_KEY);
|
final val b = paramTable.get(BIAS_KEY);
|
||||||
|
|
||||||
SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2);
|
long[] shape = layerInput.getShape();
|
||||||
|
Preconditions.checkState(shape != null, "Null shape for input placeholder");
|
||||||
|
SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2, (int)shape[2]);
|
||||||
this.timeSteps = inputSlices.length;
|
this.timeSteps = inputSlices.length;
|
||||||
SDVariable[] outputSlices = new SDVariable[timeSteps];
|
SDVariable[] outputSlices = new SDVariable[timeSteps];
|
||||||
SDVariable prev = null;
|
SDVariable prev = null;
|
||||||
|
|
|
@ -20,7 +20,10 @@ package org.deeplearning4j.integration;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||||
|
import org.deeplearning4j.integration.testcases.dl4j.*;
|
||||||
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases;
|
||||||
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
||||||
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffRNNTestCases;
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
@ -66,14 +69,36 @@ public class IntegrationTestBaselineGenerator {
|
||||||
}
|
}
|
||||||
|
|
||||||
runGeneration(
|
runGeneration(
|
||||||
SameDiffMLPTestCases.getMLPMnist()
|
|
||||||
|
// DL4J integration test cases.
|
||||||
|
|
||||||
|
// CNN1DTestCases.getCnn1dTestCaseCharRNN(),
|
||||||
|
// CNN2DTestCases.testLenetTransferDropoutRepeatability(),
|
||||||
|
//// CNN2DTestCases.getCnn2DSynthetic(),
|
||||||
|
// CNN2DTestCases.getLenetMnist(),
|
||||||
|
// CNN2DTestCases.getVGG16TransferTinyImagenet(),
|
||||||
|
// CNN2DTestCases.getYoloHouseNumbers(),
|
||||||
|
// CNN3DTestCases.getCnn3dTestCaseSynthetic(),
|
||||||
|
// MLPTestCases.getMLPMnist(),
|
||||||
|
// MLPTestCases.getMLPMoon(),
|
||||||
|
// RNNTestCases.getRnnCharacterTestCase(),
|
||||||
|
// RNNTestCases.getRnnCsvSequenceClassificationTestCase1(),
|
||||||
|
// RNNTestCases.getRnnCsvSequenceClassificationTestCase2(),
|
||||||
|
// UnsupervisedTestCases.getVAEMnistAnomaly(),
|
||||||
|
|
||||||
|
// Samediff test cases done
|
||||||
|
SameDiffMLPTestCases.getMLPMnist(),
|
||||||
|
SameDiffMLPTestCases.getMLPMoon(),
|
||||||
|
SameDiffCNNCases.getLenetMnist(),
|
||||||
|
SameDiffCNNCases.getCnn3dSynthetic(),
|
||||||
|
SameDiffRNNTestCases.getRnnCsvSequenceClassificationTestCase1()
|
||||||
);
|
);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void runGeneration(TestCase... testCases) throws Exception {
|
private static void runGeneration(TestCase... testCases) throws Exception {
|
||||||
|
|
||||||
for( TestCase tc : testCases ) {
|
for (TestCase tc : testCases) {
|
||||||
final ModelType modelType = tc.modelType();
|
final ModelType modelType = tc.modelType();
|
||||||
|
|
||||||
//Basic validation:
|
//Basic validation:
|
||||||
|
@ -122,18 +147,18 @@ public class IntegrationTestBaselineGenerator {
|
||||||
mln = new MultiLayerNetwork(mlc);
|
mln = new MultiLayerNetwork(mlc);
|
||||||
mln.init();
|
mln.init();
|
||||||
m = mln;
|
m = mln;
|
||||||
} else if (config instanceof ComputationGraphConfiguration){
|
} else if (config instanceof ComputationGraphConfiguration) {
|
||||||
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
||||||
json = cgc.toJson();
|
json = cgc.toJson();
|
||||||
cg = new ComputationGraph(cgc);
|
cg = new ComputationGraph(cgc);
|
||||||
cg.init();
|
cg.init();
|
||||||
m = cg;
|
m = cg;
|
||||||
} else {
|
} else {
|
||||||
sd = (SameDiff)config;
|
sd = (SameDiff) config;
|
||||||
}
|
}
|
||||||
|
|
||||||
File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME);
|
File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME);
|
||||||
if(modelType != ModelType.SAMEDIFF) {
|
if (modelType != ModelType.SAMEDIFF) {
|
||||||
File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json"));
|
File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json"));
|
||||||
FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8);
|
FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8);
|
||||||
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
|
log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath());
|
||||||
|
@ -147,10 +172,10 @@ public class IntegrationTestBaselineGenerator {
|
||||||
m = tc.getPretrainedModel();
|
m = tc.getPretrainedModel();
|
||||||
if (m instanceof MultiLayerNetwork) {
|
if (m instanceof MultiLayerNetwork) {
|
||||||
mln = (MultiLayerNetwork) m;
|
mln = (MultiLayerNetwork) m;
|
||||||
} else if(m instanceof ComputationGraph){
|
} else if (m instanceof ComputationGraph) {
|
||||||
cg = (ComputationGraph) m;
|
cg = (ComputationGraph) m;
|
||||||
} else {
|
} else {
|
||||||
sd = (SameDiff)m;
|
sd = (SameDiff) m;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,7 +183,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
//Generate predictions to compare against
|
//Generate predictions to compare against
|
||||||
if (tc.isTestPredictions()) {
|
if (tc.isTestPredictions()) {
|
||||||
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
|
List<Pair<INDArray[], INDArray[]>> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null;
|
||||||
List<Map<String,INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
|
List<Map<String, INDArray>> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null;
|
||||||
// Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
// Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName());
|
||||||
|
|
||||||
|
|
||||||
|
@ -178,7 +203,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
Nd4j.write(out, dos);
|
Nd4j.write(out, dos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if(modelType == ModelType.CG) {
|
} else if (modelType == ModelType.CG) {
|
||||||
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
for (Pair<INDArray[], INDArray[]> p : inputs) {
|
||||||
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null);
|
||||||
|
|
||||||
|
@ -192,11 +217,11 @@ public class IntegrationTestBaselineGenerator {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
List<String> outNames = tc.getPredictionsNamesSameDiff();
|
List<String> outNames = tc.getPredictionsNamesSameDiff();
|
||||||
for( Map<String,INDArray> ph : inputsSd ){
|
for (Map<String, INDArray> ph : inputsSd) {
|
||||||
Map<String,INDArray> out = sd.output(ph, outNames);
|
Map<String, INDArray> out = sd.output(ph, outNames);
|
||||||
|
|
||||||
//Save the output...
|
//Save the output...
|
||||||
for(String s : outNames){
|
for (String s : outNames) {
|
||||||
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
|
File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin");
|
||||||
try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) {
|
try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) {
|
||||||
Nd4j.write(out.get(s), dos);
|
Nd4j.write(out.get(s), dos);
|
||||||
|
@ -211,7 +236,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
//Compute and save gradients:
|
//Compute and save gradients:
|
||||||
if (tc.isTestGradients()) {
|
if (tc.isTestGradients()) {
|
||||||
INDArray gradientFlat = null;
|
INDArray gradientFlat = null;
|
||||||
Map<String,INDArray> grad;
|
Map<String, INDArray> grad;
|
||||||
if (modelType == ModelType.MLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
MultiDataSet data = tc.getGradientsTestData();
|
MultiDataSet data = tc.getGradientsTestData();
|
||||||
mln.setInput(data.getFeatures(0));
|
mln.setInput(data.getFeatures(0));
|
||||||
|
@ -220,7 +245,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
gradientFlat = mln.getFlattenedGradients();
|
gradientFlat = mln.getFlattenedGradients();
|
||||||
grad = m.gradient().gradientForVariable();
|
grad = m.gradient().gradientForVariable();
|
||||||
} else if(modelType == ModelType.CG) {
|
} else if (modelType == ModelType.CG) {
|
||||||
MultiDataSet data = tc.getGradientsTestData();
|
MultiDataSet data = tc.getGradientsTestData();
|
||||||
cg.setInputs(data.getFeatures());
|
cg.setInputs(data.getFeatures());
|
||||||
cg.setLabels(data.getLabels());
|
cg.setLabels(data.getLabels());
|
||||||
|
@ -229,17 +254,17 @@ public class IntegrationTestBaselineGenerator {
|
||||||
gradientFlat = cg.getFlattenedGradients();
|
gradientFlat = cg.getFlattenedGradients();
|
||||||
grad = m.gradient().gradientForVariable();
|
grad = m.gradient().gradientForVariable();
|
||||||
} else {
|
} else {
|
||||||
Map<String,INDArray> ph = tc.getGradientsTestDataSameDiff();
|
Map<String, INDArray> ph = tc.getGradientsTestDataSameDiff();
|
||||||
List<String> allVars = new ArrayList<>();
|
List<String> allVars = new ArrayList<>();
|
||||||
for(SDVariable v : sd.variables()){
|
for (SDVariable v : sd.variables()) {
|
||||||
if(v.getVariableType() == VariableType.VARIABLE){
|
if (v.getVariableType() == VariableType.VARIABLE) {
|
||||||
allVars.add(v.name());
|
allVars.add(v.name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
grad = sd.calculateGradients(ph, allVars);
|
grad = sd.calculateGradients(ph, allVars);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(modelType != ModelType.SAMEDIFF) {
|
if (modelType != ModelType.SAMEDIFF) {
|
||||||
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME);
|
||||||
IntegrationTestRunner.write(gradientFlat, gFlatFile);
|
IntegrationTestRunner.write(gradientFlat, gFlatFile);
|
||||||
}
|
}
|
||||||
|
@ -254,25 +279,25 @@ public class IntegrationTestBaselineGenerator {
|
||||||
}
|
}
|
||||||
|
|
||||||
//Test pretraining
|
//Test pretraining
|
||||||
if(tc.isTestUnsupervisedTraining()){
|
if (tc.isTestUnsupervisedTraining()) {
|
||||||
log.info("Performing layerwise pretraining");
|
log.info("Performing layerwise pretraining");
|
||||||
MultiDataSetIterator iter = tc.getUnsupervisedTrainData();
|
MultiDataSetIterator iter = tc.getUnsupervisedTrainData();
|
||||||
|
|
||||||
INDArray paramsPostTraining;
|
INDArray paramsPostTraining;
|
||||||
if(modelType == ModelType.MLN){
|
if (modelType == ModelType.MLN) {
|
||||||
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
||||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||||
|
|
||||||
for( int i : layersToTrain){
|
for (int i : layersToTrain) {
|
||||||
mln.pretrainLayer(i, dsi);
|
mln.pretrainLayer(i, dsi);
|
||||||
}
|
}
|
||||||
paramsPostTraining = mln.params();
|
paramsPostTraining = mln.params();
|
||||||
} else if(modelType == ModelType.CG) {
|
} else if (modelType == ModelType.CG) {
|
||||||
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
||||||
|
|
||||||
for( String i : layersToTrain){
|
for (String i : layersToTrain) {
|
||||||
cg.pretrainLayer(i, iter);
|
cg.pretrainLayer(i, iter);
|
||||||
}
|
}
|
||||||
paramsPostTraining = cg.params();
|
paramsPostTraining = cg.params();
|
||||||
|
@ -290,20 +315,20 @@ public class IntegrationTestBaselineGenerator {
|
||||||
MultiDataSetIterator trainData = tc.getTrainingData();
|
MultiDataSetIterator trainData = tc.getTrainingData();
|
||||||
|
|
||||||
CollectScoresListener l = new CollectScoresListener(1);
|
CollectScoresListener l = new CollectScoresListener(1);
|
||||||
if(modelType != ModelType.SAMEDIFF)
|
if (modelType != ModelType.SAMEDIFF)
|
||||||
m.setListeners(l);
|
m.setListeners(l);
|
||||||
|
|
||||||
History h = null;
|
History h = null;
|
||||||
if (modelType == ModelType.MLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
mln.fit(trainData);
|
mln.fit(trainData);
|
||||||
} else if(modelType == ModelType.CG) {
|
} else if (modelType == ModelType.CG) {
|
||||||
cg.fit(trainData);
|
cg.fit(trainData);
|
||||||
} else {
|
} else {
|
||||||
h = sd.fit(trainData, 1);
|
h = sd.fit(trainData, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
double[] scores;
|
double[] scores;
|
||||||
if(modelType != ModelType.SAMEDIFF){
|
if (modelType != ModelType.SAMEDIFF) {
|
||||||
scores = l.getListScore().toDoubleArray();
|
scores = l.getListScore().toDoubleArray();
|
||||||
} else {
|
} else {
|
||||||
scores = h.lossCurve().getLossValues().toDoubleVector();
|
scores = h.lossCurve().getLossValues().toDoubleVector();
|
||||||
|
@ -314,11 +339,11 @@ public class IntegrationTestBaselineGenerator {
|
||||||
FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8);
|
FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8);
|
||||||
|
|
||||||
if (tc.isTestParamsPostTraining()) {
|
if (tc.isTestParamsPostTraining()) {
|
||||||
if(modelType == ModelType.SAMEDIFF){
|
if (modelType == ModelType.SAMEDIFF) {
|
||||||
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
|
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR);
|
||||||
p.mkdirs();
|
p.mkdirs();
|
||||||
for(SDVariable v : sd.variables()){
|
for (SDVariable v : sd.variables()) {
|
||||||
if(v.getVariableType() == VariableType.VARIABLE){
|
if (v.getVariableType() == VariableType.VARIABLE) {
|
||||||
INDArray arr = v.getArr();
|
INDArray arr = v.getArr();
|
||||||
File p2 = new File(p, v.name() + ".bin");
|
File p2 = new File(p, v.name() + ".bin");
|
||||||
IntegrationTestRunner.write(arr, p2);
|
IntegrationTestRunner.write(arr, p2);
|
||||||
|
@ -331,7 +356,6 @@ public class IntegrationTestBaselineGenerator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (tc.isTestEvaluation()) {
|
if (tc.isTestEvaluation()) {
|
||||||
IEvaluation[] evals = tc.getNewEvaluations();
|
IEvaluation[] evals = tc.getNewEvaluations();
|
||||||
MultiDataSetIterator iter = tc.getEvaluationTestData();
|
MultiDataSetIterator iter = tc.getEvaluationTestData();
|
||||||
|
@ -339,7 +363,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
if (modelType == ModelType.MLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||||
mln.doEvaluation(dsi, evals);
|
mln.doEvaluation(dsi, evals);
|
||||||
} else if(modelType == ModelType.CG){
|
} else if (modelType == ModelType.CG) {
|
||||||
cg.doEvaluation(iter, evals);
|
cg.doEvaluation(iter, evals);
|
||||||
} else {
|
} else {
|
||||||
evals = tc.doEvaluationSameDiff(sd, iter, evals);
|
evals = tc.doEvaluationSameDiff(sd, iter, evals);
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
package org.deeplearning4j.integration;
|
package org.deeplearning4j.integration;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases;
|
||||||
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -37,4 +38,20 @@ public class IntegrationTestsSameDiff extends BaseDL4JTest {
|
||||||
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir);
|
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMLPMoon() throws Exception {
|
||||||
|
IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMoon(), testDir);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLenetMnist() throws Exception {
|
||||||
|
IntegrationTestRunner.runTest(SameDiffCNNCases.getLenetMnist(), testDir);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn3dSynthetic() throws Exception {
|
||||||
|
IntegrationTestRunner.runTest(SameDiffCNNCases.getCnn3dSynthetic(), testDir);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,6 +194,8 @@ public class CNN2DTestCases {
|
||||||
testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb)
|
testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb)
|
||||||
testEvaluation = false;
|
testEvaluation = false;
|
||||||
testOverfitting = false;
|
testOverfitting = false;
|
||||||
|
maxRelativeErrorOutput = 0.2;
|
||||||
|
minAbsErrorOutput = 0.05; //Max value is around 0.22
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -314,6 +316,7 @@ public class CNN2DTestCases {
|
||||||
ComputationGraph model = new TransferLearning.GraphBuilder(pretrained)
|
ComputationGraph model = new TransferLearning.GraphBuilder(pretrained)
|
||||||
.fineTuneConfiguration(fineTuneConf)
|
.fineTuneConfiguration(fineTuneConf)
|
||||||
.removeVertexKeepConnections("conv2d_9")
|
.removeVertexKeepConnections("conv2d_9")
|
||||||
|
.removeVertexAndConnections("outputs")
|
||||||
.addLayer("convolution2d_9",
|
.addLayer("convolution2d_9",
|
||||||
new ConvolutionLayer.Builder(1,1)
|
new ConvolutionLayer.Builder(1,1)
|
||||||
.nIn(1024)
|
.nIn(1024)
|
||||||
|
@ -393,7 +396,7 @@ public class CNN2DTestCases {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ModelType modelType() {
|
public ModelType modelType() {
|
||||||
return ModelType.CG;
|
return ModelType.MLN;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -77,6 +77,10 @@ public class MLPTestCases {
|
||||||
testOverfitting = true;
|
testOverfitting = true;
|
||||||
maxRelativeErrorOverfit = 2e-2;
|
maxRelativeErrorOverfit = 2e-2;
|
||||||
minAbsErrorOverfit = 1e-2;
|
minAbsErrorOverfit = 1e-2;
|
||||||
|
maxRelativeErrorGradients = 0.01;
|
||||||
|
minAbsErrorGradients = 0.05;
|
||||||
|
maxRelativeErrorParamsPostTraining = 0.01;
|
||||||
|
minAbsErrorParamsPostTraining = 0.05;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -135,8 +139,7 @@ public class MLPTestCases {
|
||||||
public IEvaluation[] getNewEvaluations(){
|
public IEvaluation[] getNewEvaluations(){
|
||||||
return new IEvaluation[]{
|
return new IEvaluation[]{
|
||||||
new Evaluation(),
|
new Evaluation(),
|
||||||
new ROCMultiClass(),
|
new ROCMultiClass()
|
||||||
new EvaluationCalibration()
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.shade.guava.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.deeplearning4j.integration.TestCase;
|
import org.deeplearning4j.integration.TestCase;
|
||||||
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
|
import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator;
|
||||||
|
@ -91,7 +92,7 @@ public class RNNTestCases {
|
||||||
}
|
}
|
||||||
|
|
||||||
private int miniBatchSize = 32;
|
private int miniBatchSize = 32;
|
||||||
private int exampleLength = 1000;
|
private int exampleLength = 200;
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -101,6 +102,7 @@ public class RNNTestCases {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object getConfiguration() throws Exception {
|
public Object getConfiguration() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
|
CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
|
||||||
int nOut = iter.totalOutcomes();
|
int nOut = iter.totalOutcomes();
|
||||||
|
@ -113,7 +115,7 @@ public class RNNTestCases {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.l2(0.001)
|
.l2(0.001)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(new RmsProp(0.1))
|
.updater(new Adam(1e-3))
|
||||||
.list()
|
.list()
|
||||||
.layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
|
.layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
|
||||||
.activation(Activation.TANH).build())
|
.activation(Activation.TANH).build())
|
||||||
|
@ -140,7 +142,7 @@ public class RNNTestCases {
|
||||||
@Override
|
@Override
|
||||||
public MultiDataSetIterator getTrainingData() throws Exception {
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||||
DataSetIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
|
DataSetIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength);
|
||||||
iter = new EarlyTerminationDataSetIterator(iter, 2); //3 minibatches, 1000/200 = 5 updates per minibatch
|
iter = new EarlyTerminationDataSetIterator(iter, 2); //2 minibatches, 200/50 = 4 updates per minibatch
|
||||||
return new MultiDataSetIteratorAdapter(iter);
|
return new MultiDataSetIteratorAdapter(iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -72,12 +72,12 @@ public class UnsupervisedTestCases {
|
||||||
return new NeuralNetConfiguration.Builder()
|
return new NeuralNetConfiguration.Builder()
|
||||||
.dataType(DataType.FLOAT)
|
.dataType(DataType.FLOAT)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(0.05))
|
.updater(new Adam(1e-3))
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.l2(1e-4)
|
.l2(1e-4)
|
||||||
.list()
|
.list()
|
||||||
.layer(0, new VariationalAutoencoder.Builder()
|
.layer(0, new VariationalAutoencoder.Builder()
|
||||||
.activation(Activation.LEAKYRELU)
|
.activation(Activation.TANH)
|
||||||
.encoderLayerSizes(256, 256) //2 encoder layers, each of size 256
|
.encoderLayerSizes(256, 256) //2 encoder layers, each of size 256
|
||||||
.decoderLayerSizes(256, 256) //2 decoder layers, each of size 256
|
.decoderLayerSizes(256, 256) //2 decoder layers, each of size 256
|
||||||
.pzxActivationFunction(Activation.IDENTITY) //p(z|data) activation function
|
.pzxActivationFunction(Activation.IDENTITY) //p(z|data) activation function
|
||||||
|
|
|
@ -0,0 +1,398 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.integration.testcases.samediff;
|
||||||
|
|
||||||
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
||||||
|
import org.deeplearning4j.integration.ModelType;
|
||||||
|
import org.deeplearning4j.integration.TestCase;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
public class SameDiffCNNCases {
|
||||||
|
|
||||||
|
|
||||||
|
public static TestCase getLenetMnist() {
|
||||||
|
return new TestCase() {
|
||||||
|
{
|
||||||
|
testName = "LenetMnistSD";
|
||||||
|
testType = TestType.RANDOM_INIT;
|
||||||
|
testPredictions = true;
|
||||||
|
testTrainingCurves = true;
|
||||||
|
testGradients = true;
|
||||||
|
testParamsPostTraining = true;
|
||||||
|
testEvaluation = true;
|
||||||
|
testOverfitting = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.SAMEDIFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Object getConfiguration() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
int nChannels = 1; // Number of input channels
|
||||||
|
int outputNum = 10; // The number of possible outcomes
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, outputNum);
|
||||||
|
|
||||||
|
//input [minibatch, channels=1, Height = 28, Width = 28]
|
||||||
|
SDVariable in4d = in.reshape(-1, nChannels, 28, 28);
|
||||||
|
|
||||||
|
int kernelHeight = 5;
|
||||||
|
int kernelWidth = 5;
|
||||||
|
|
||||||
|
|
||||||
|
// w0 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 1, outputChannels = 20]
|
||||||
|
// b0 [20]
|
||||||
|
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, nChannels, 20).muli(0.01));
|
||||||
|
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 20).muli(0.01));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer0 = sd.nn.relu(sd.cnn.conv2d("layer0", in4d, w0, b0, Conv2DConfig.builder()
|
||||||
|
.kH(kernelHeight)
|
||||||
|
.kW(kernelWidth)
|
||||||
|
.sH(1)
|
||||||
|
.sW(1)
|
||||||
|
.dataFormat("NCHW")
|
||||||
|
.build()), 0);
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W) = ( 28 - 5 + 2*0 ) / 1 + 1 = 24
|
||||||
|
// [minibatch,20,24,24]
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer1 = sd.cnn.maxPooling2d("layer1", layer0, Pooling2DConfig.builder()
|
||||||
|
.kH(2).kW(2)
|
||||||
|
.sH(2).sW(2)
|
||||||
|
.isNHWC(false)
|
||||||
|
.build());
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W) = ( 24 - 2 + 2*0 ) / 2 + 1 = 12
|
||||||
|
// [minibatch,12,12,20]
|
||||||
|
|
||||||
|
|
||||||
|
// w2 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 20, outputChannels = 50]
|
||||||
|
// b0 [50]
|
||||||
|
SDVariable w2 = sd.var("w2", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, 20, 50).muli(0.01));
|
||||||
|
SDVariable b2 = sd.var("b2", Nd4j.rand(DataType.FLOAT, 50).muli(0.01));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer2 = sd.nn.relu(sd.cnn.conv2d("layer2", layer1, w2, b2, Conv2DConfig.builder()
|
||||||
|
.kH(kernelHeight)
|
||||||
|
.kW(kernelWidth)
|
||||||
|
.sH(1)
|
||||||
|
.sW(1)
|
||||||
|
.dataFormat("NCHW")
|
||||||
|
.build()), 0);
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W) = ( 12 - 5 + 2*0 ) / 1 + 1 = 8
|
||||||
|
// [minibatch,8,8,50]
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer3 = sd.cnn.maxPooling2d("layer3", layer2, Pooling2DConfig.builder()
|
||||||
|
.kH(2).kW(2)
|
||||||
|
.sH(2).sW(2)
|
||||||
|
.isNHWC(false)
|
||||||
|
.build());
|
||||||
|
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W) = ( 8 - 2 + 2*0 ) / 2 + 1 = 4
|
||||||
|
// [minibatch,4,4,50]
|
||||||
|
|
||||||
|
int channels_height_width = 4 * 4 * 50;
|
||||||
|
SDVariable layer3_reshaped = layer3.reshape(-1, channels_height_width);
|
||||||
|
|
||||||
|
SDVariable w4 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width, 500).muli(0.01));
|
||||||
|
SDVariable b4 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 500).muli(0.01));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer4 = sd.nn.relu("layer4", layer3_reshaped.mmul(w4).add(b4), 0);
|
||||||
|
|
||||||
|
SDVariable w5 = sd.var("w5", Nd4j.rand(DataType.FLOAT, 500, outputNum));
|
||||||
|
SDVariable b5 = sd.var("b5", Nd4j.rand(DataType.FLOAT, outputNum));
|
||||||
|
|
||||||
|
SDVariable out = sd.nn.softmax("out", layer4.mmul(w5).add(b5));
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||||
|
|
||||||
|
//Also set the training configuration:
|
||||||
|
sd.setTrainingConfig(TrainingConfig.builder()
|
||||||
|
.updater(new Adam(1e-3))
|
||||||
|
.l2(1e-3)
|
||||||
|
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
|
||||||
|
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
|
||||||
|
.build());
|
||||||
|
|
||||||
|
|
||||||
|
return sd;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||||
|
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
|
||||||
|
Map<String, INDArray> map = new HashMap<>();
|
||||||
|
map.put("in", ds.getFeatures());
|
||||||
|
map.put("label", ds.getLabels());
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||||
|
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
|
||||||
|
|
||||||
|
iter = new EarlyTerminationDataSetIterator(iter, 60);
|
||||||
|
return new MultiDataSetIteratorAdapter(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||||||
|
return new MultiDataSetIteratorAdapter(new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
|
DataSetIterator iter = new MnistDataSetIterator(8, true, 12345);
|
||||||
|
|
||||||
|
List<Map<String, INDArray>> list = new ArrayList<>();
|
||||||
|
|
||||||
|
org.nd4j.linalg.dataset.DataSet ds = iter.next();
|
||||||
|
ds = ds.asList().get(0);
|
||||||
|
|
||||||
|
list.add(Collections.singletonMap("in", ds.getFeatures()));
|
||||||
|
ds = iter.next();
|
||||||
|
list.add(Collections.singletonMap("in", ds.getFeatures()));
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getPredictionsNamesSameDiff() {
|
||||||
|
return Collections.singletonList("out");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] getNewEvaluations() {
|
||||||
|
return new IEvaluation[]{
|
||||||
|
new Evaluation(),
|
||||||
|
new ROCMultiClass(),
|
||||||
|
new EvaluationCalibration()};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
|
||||||
|
sd.evaluate(iter, "out", 0, evaluations);
|
||||||
|
return evaluations;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static TestCase getCnn3dSynthetic() {
|
||||||
|
return new TestCase() {
|
||||||
|
{
|
||||||
|
testName = "Cnn3dSynthetic";
|
||||||
|
testType = TestType.RANDOM_INIT;
|
||||||
|
testPredictions = true;
|
||||||
|
testTrainingCurves = true;
|
||||||
|
testGradients = true;
|
||||||
|
testParamsPostTraining = true;
|
||||||
|
testEvaluation = true;
|
||||||
|
testOverfitting = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.SAMEDIFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Object getConfiguration() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
int nChannels = 3; // Number of input channels
|
||||||
|
int outputNum = 10; // The number of possible outcomes
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
|
||||||
|
//input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8]
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, nChannels, 8, 8, 8);
|
||||||
|
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, nChannels, outputNum);
|
||||||
|
|
||||||
|
//input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8]
|
||||||
|
|
||||||
|
// Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]
|
||||||
|
// [kernelDepth = 3, kernelHeight = 3, kernelWidth = 3, inputChannels = 3, outputChannels = 8]
|
||||||
|
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 3, 3, 3, nChannels, 8));
|
||||||
|
// Optional 1D bias array with shape [outputChannels]. May be null.
|
||||||
|
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 8));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer0 = sd.nn.relu(sd.cnn.conv3d("layer0", in, w0, b0, Conv3DConfig.builder()
|
||||||
|
.kH(3)
|
||||||
|
.kW(3)
|
||||||
|
.kD(3)
|
||||||
|
.sH(2)
|
||||||
|
.sW(2)
|
||||||
|
.sD(2)
|
||||||
|
.dataFormat("NCDHW")
|
||||||
|
.build()), 0);
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W)(D) = (8 - 3 + 2*0 ) / 2 + 1 = 3
|
||||||
|
// [minibatch,8,3,3,3]
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable layer1 = sd.cnn.maxPooling3d("layer1", layer0, Pooling3DConfig.builder()
|
||||||
|
.kH(2).kW(2).kD(2)
|
||||||
|
.sH(2).sW(2).sD(2)
|
||||||
|
.isNCDHW(true)
|
||||||
|
.build());
|
||||||
|
|
||||||
|
// outputSize = (inputSize - kernelSize + 2*padding) / stride + 1
|
||||||
|
// outputsize_H(W)(D) = ( 3 - 2 + 2*0 ) / 2 + 1 = 1
|
||||||
|
// [minibatch,8,1,1,1]
|
||||||
|
|
||||||
|
|
||||||
|
int channels_height_width_depth = 8 * 1 * 1 * 1;
|
||||||
|
|
||||||
|
SDVariable layer1_reshaped = layer1.reshape(-1, channels_height_width_depth);
|
||||||
|
|
||||||
|
SDVariable w1 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width_depth, 10));
|
||||||
|
SDVariable b1 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 10));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable out = sd.nn.softmax("out", layer1_reshaped.mmul(w1).add(b1));
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||||
|
|
||||||
|
//Also set the training configuration:
|
||||||
|
sd.setTrainingConfig(TrainingConfig.builder()
|
||||||
|
.updater(new Nesterovs(0.01, 0.9))
|
||||||
|
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
|
||||||
|
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
|
||||||
|
.build());
|
||||||
|
|
||||||
|
return sd;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String,INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
//NCDHW format
|
||||||
|
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
|
||||||
|
INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
Map<String, INDArray> map = new HashMap<>();
|
||||||
|
map.put("in", arr);
|
||||||
|
map.put("label", labels);
|
||||||
|
return map;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getPredictionsNamesSameDiff() {
|
||||||
|
|
||||||
|
return Collections.singletonList("out");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
List<Map<String, INDArray>> list = new ArrayList<>();
|
||||||
|
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
|
||||||
|
|
||||||
|
list.add(Collections.singletonMap("in", arr));
|
||||||
|
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSet getGradientsTestData() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
//NCDHW format
|
||||||
|
INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8});
|
||||||
|
INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10);
|
||||||
|
return new org.nd4j.linalg.dataset.MultiDataSet(arr, labels);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||||
|
return new SingletonMultiDataSetIterator(getGradientsTestData());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||||||
|
return getTrainingData();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations){
|
||||||
|
sd.evaluate(iter, "out", 0, evaluations);
|
||||||
|
return evaluations;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] getNewEvaluations(){
|
||||||
|
return new IEvaluation[]{new Evaluation()};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,9 +15,14 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.deeplearning4j.integration.testcases.samediff;
|
package org.deeplearning4j.integration.testcases.samediff;
|
||||||
|
|
||||||
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
|
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
|
import org.datavec.api.split.FileSplit;
|
||||||
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
||||||
import org.deeplearning4j.integration.ModelType;
|
import org.deeplearning4j.integration.ModelType;
|
||||||
import org.deeplearning4j.integration.TestCase;
|
import org.deeplearning4j.integration.TestCase;
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
|
@ -26,21 +31,34 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig.*;
|
||||||
|
|
||||||
public class SameDiffMLPTestCases {
|
public class SameDiffMLPTestCases {
|
||||||
|
|
||||||
|
|
||||||
public static TestCase getMLPMnist(){
|
public static TestCase getMLPMnist() {
|
||||||
return new TestCase() {
|
return new TestCase() {
|
||||||
{
|
{
|
||||||
testName = "MLPMnistSD";
|
testName = "MLPMnistSD";
|
||||||
|
@ -69,10 +87,10 @@ public class SameDiffMLPTestCases {
|
||||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784);
|
||||||
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10);
|
||||||
|
|
||||||
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256));
|
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256).muli(0.1));
|
||||||
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256));
|
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256).muli(0.1));
|
||||||
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10));
|
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10).muli(0.1));
|
||||||
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10));
|
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10).muli(0.1));
|
||||||
|
|
||||||
SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0));
|
SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0));
|
||||||
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
|
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
|
||||||
|
@ -91,7 +109,7 @@ public class SameDiffMLPTestCases {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
List<Map<String,INDArray>> out = new ArrayList<>();
|
List<Map<String, INDArray>> out = new ArrayList<>();
|
||||||
|
|
||||||
DataSetIterator iter = new MnistDataSetIterator(1, true, 12345);
|
DataSetIterator iter = new MnistDataSetIterator(1, true, 12345);
|
||||||
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
|
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
|
||||||
|
@ -110,7 +128,7 @@ public class SameDiffMLPTestCases {
|
||||||
@Override
|
@Override
|
||||||
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
|
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||||
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
|
DataSet ds = new MnistDataSetIterator(8, true, 12345).next();
|
||||||
Map<String,INDArray> map = new HashMap<>();
|
Map<String, INDArray> map = new HashMap<>();
|
||||||
map.put("in", ds.getFeatures());
|
map.put("in", ds.getFeatures());
|
||||||
map.put("label", ds.getLabels());
|
map.put("label", ds.getLabels());
|
||||||
return map;
|
return map;
|
||||||
|
@ -153,4 +171,160 @@ public class SameDiffMLPTestCases {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static TestCase getMLPMoon() {
|
||||||
|
return new TestCase() {
|
||||||
|
{
|
||||||
|
testName = "MLPMoonSD";
|
||||||
|
testType = TestType.RANDOM_INIT;
|
||||||
|
testPredictions = true;
|
||||||
|
testTrainingCurves = true;
|
||||||
|
testGradients = true;
|
||||||
|
testParamsPostTraining = true;
|
||||||
|
testEvaluation = true;
|
||||||
|
testOverfitting = true;
|
||||||
|
maxRelativeErrorOverfit = 2e-2;
|
||||||
|
minAbsErrorOverfit = 1e-2;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.SAMEDIFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object getConfiguration() throws Exception {
|
||||||
|
|
||||||
|
int numInputs = 2;
|
||||||
|
int numOutputs = 2;
|
||||||
|
int numHiddenNodes = 20;
|
||||||
|
double learningRate = 0.005;
|
||||||
|
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
//Define the network structure:
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, numInputs);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, numOutputs);
|
||||||
|
|
||||||
|
SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, numInputs, numHiddenNodes));
|
||||||
|
SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, numHiddenNodes));
|
||||||
|
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numHiddenNodes, numOutputs));
|
||||||
|
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numOutputs));
|
||||||
|
|
||||||
|
SDVariable a0 = sd.nn.relu(in.mmul(w0).add(b0), 0);
|
||||||
|
SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1));
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||||
|
|
||||||
|
//Also set the training configuration:
|
||||||
|
sd.setTrainingConfig(TrainingConfig.builder()
|
||||||
|
.updater(new Nesterovs(learningRate, 0.9))
|
||||||
|
.weightDecay(1e-3, true)
|
||||||
|
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
|
||||||
|
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
|
||||||
|
.build());
|
||||||
|
|
||||||
|
return sd;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
|
List<Map<String, INDArray>> out = new ArrayList<>();
|
||||||
|
|
||||||
|
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
|
||||||
|
|
||||||
|
RecordReader rr = new CSVRecordReader();
|
||||||
|
rr.initialize(new FileSplit(f));
|
||||||
|
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 0, 2);
|
||||||
|
|
||||||
|
out.add(Collections.singletonMap("in", iter.next().getFeatures()));
|
||||||
|
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getPredictionsNamesSameDiff() throws Exception {
|
||||||
|
return Collections.singletonList("out");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, INDArray> getGradientsTestDataSameDiff() throws Exception {
|
||||||
|
|
||||||
|
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
|
||||||
|
RecordReader rr = new CSVRecordReader();
|
||||||
|
rr.initialize(new FileSplit(f));
|
||||||
|
org.nd4j.linalg.dataset.DataSet ds = new RecordReaderDataSetIterator(rr, 5, 0, 2).next();
|
||||||
|
|
||||||
|
Map<String, INDArray> map = new HashMap<>();
|
||||||
|
map.put("in", ds.getFeatures());
|
||||||
|
map.put("label", ds.getLabels());
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||||
|
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_train.csv");
|
||||||
|
RecordReader rr = new CSVRecordReader();
|
||||||
|
rr.initialize(new FileSplit(f));
|
||||||
|
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2);
|
||||||
|
|
||||||
|
iter = new EarlyTerminationDataSetIterator(iter, 32);
|
||||||
|
return new MultiDataSetIteratorAdapter(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] getNewEvaluations() {
|
||||||
|
return new IEvaluation[]{
|
||||||
|
new Evaluation(),
|
||||||
|
new ROCMultiClass(),
|
||||||
|
new EvaluationCalibration()};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||||||
|
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
|
||||||
|
RecordReader rr = new CSVRecordReader();
|
||||||
|
rr.initialize(new FileSplit(f));
|
||||||
|
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2);
|
||||||
|
return new MultiDataSetIteratorAdapter(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
|
||||||
|
sd.evaluate(iter, "out", 0, evaluations);
|
||||||
|
return evaluations;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSet getOverfittingData() throws Exception {
|
||||||
|
|
||||||
|
File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv");
|
||||||
|
RecordReader rr = new CSVRecordReader();
|
||||||
|
rr.initialize(new FileSplit(f));
|
||||||
|
return new RecordReaderDataSetIterator(rr, 1, 0, 2).next().toMultiDataSet();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getOverfitNumIterations() {
|
||||||
|
return 200;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,289 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.integration.testcases.samediff;
|
||||||
|
|
||||||
|
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||||
|
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||||
|
import org.datavec.api.split.NumberedFileInputSplit;
|
||||||
|
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||||
|
import org.deeplearning4j.integration.ModelType;
|
||||||
|
import org.deeplearning4j.integration.TestCase;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.resources.Resources;
|
||||||
|
import org.nd4j.shade.guava.io.Files;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class SameDiffRNNTestCases {
|
||||||
|
|
||||||
|
public static TestCase getRnnCsvSequenceClassificationTestCase1() {
|
||||||
|
return new SameDiffRNNTestCases.RnnCsvSequenceClassificationTestCase1();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static class RnnCsvSequenceClassificationTestCase1 extends TestCase {
|
||||||
|
protected RnnCsvSequenceClassificationTestCase1() {
|
||||||
|
testName = "RnnCsvSequenceClassification1";
|
||||||
|
testType = TestType.RANDOM_INIT;
|
||||||
|
testPredictions = true;
|
||||||
|
testTrainingCurves = false;
|
||||||
|
testGradients = false;
|
||||||
|
testParamsPostTraining = false;
|
||||||
|
testEvaluation = true;
|
||||||
|
testOverfitting = false; //Not much point on this one - it already fits very well...
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
protected MultiDataNormalization normalizer;
|
||||||
|
|
||||||
|
protected MultiDataNormalization getNormalizer() throws Exception {
|
||||||
|
if (normalizer != null) {
|
||||||
|
return normalizer;
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = new MultiNormalizerStandardize();
|
||||||
|
normalizer.fit(getTrainingDataUnnormalized());
|
||||||
|
|
||||||
|
return normalizer;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ModelType modelType() {
|
||||||
|
return ModelType.SAMEDIFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object getConfiguration() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
|
||||||
|
int miniBatchSize = 10;
|
||||||
|
int numLabelClasses = 6;
|
||||||
|
int nIn = 60;
|
||||||
|
int numUnits = 7;
|
||||||
|
int timeSteps = 3;
|
||||||
|
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, miniBatchSize, timeSteps, nIn);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, miniBatchSize, numLabelClasses);
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits));
|
||||||
|
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits));
|
||||||
|
|
||||||
|
LSTMLayerConfig c = LSTMLayerConfig.builder()
|
||||||
|
.lstmdataformat(LSTMDataFormat.NTS)
|
||||||
|
.directionMode(LSTMDirectionMode.FWD)
|
||||||
|
.gateAct(LSTMActivations.SIGMOID)
|
||||||
|
.cellAct(LSTMActivations.TANH)
|
||||||
|
.outAct(LSTMActivations.TANH)
|
||||||
|
.retFullSequence(true)
|
||||||
|
.retLastC(true)
|
||||||
|
.retLastH(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
|
||||||
|
in, cLast, yLast, null,
|
||||||
|
LSTMLayerWeights.builder()
|
||||||
|
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
|
||||||
|
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
|
||||||
|
.peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits)))
|
||||||
|
.bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits)))
|
||||||
|
.build(),
|
||||||
|
c), c);
|
||||||
|
|
||||||
|
|
||||||
|
// Behaviour with default settings: 3d (time series) input with shape
|
||||||
|
// [miniBatchSize, vectorSize, timeSeriesLength] -> 2d output [miniBatchSize, vectorSize]
|
||||||
|
SDVariable layer0 = outputs.getOutput();
|
||||||
|
|
||||||
|
SDVariable layer1 = layer0.mean(1);
|
||||||
|
|
||||||
|
SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numUnits, numLabelClasses));
|
||||||
|
SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numLabelClasses));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable out = sd.nn.softmax("out", layer1.mmul(w1).add(b1));
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||||
|
|
||||||
|
//Also set the training configuration:
|
||||||
|
sd.setTrainingConfig(TrainingConfig.builder()
|
||||||
|
.updater(new Adam(5e-2))
|
||||||
|
.l1(1e-3).l2(1e-3)
|
||||||
|
.dataSetFeatureMapping("in") //features[0] -> "in" placeholder
|
||||||
|
.dataSetLabelMapping("label") //labels[0] -> "label" placeholder
|
||||||
|
.build());
|
||||||
|
|
||||||
|
return sd;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Map<String, INDArray>> getPredictionsTestDataSameDiff() throws Exception {
|
||||||
|
|
||||||
|
MultiDataSet mds = getTrainingData().next();
|
||||||
|
|
||||||
|
List<Map<String, INDArray>> list = new ArrayList<>();
|
||||||
|
|
||||||
|
list.add(Collections.singletonMap("in", mds.getFeatures()[0].reshape(10, 1, 60)));
|
||||||
|
//[batchsize, insize]
|
||||||
|
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getPredictionsNamesSameDiff() throws Exception {
|
||||||
|
return Collections.singletonList("out");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||||||
|
MultiDataSetIterator iter = getTrainingDataUnnormalized();
|
||||||
|
MultiDataSetPreProcessor pp = multiDataSet -> {
|
||||||
|
INDArray l = multiDataSet.getLabels(0);
|
||||||
|
l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
|
||||||
|
multiDataSet.setLabels(0, l);
|
||||||
|
multiDataSet.setLabelsMaskArray(0, null);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));
|
||||||
|
|
||||||
|
return iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception {
|
||||||
|
int miniBatchSize = 10;
|
||||||
|
int numLabelClasses = 6;
|
||||||
|
|
||||||
|
File featuresDirTrain = Files.createTempDir();
|
||||||
|
File labelsDirTrain = Files.createTempDir();
|
||||||
|
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain);
|
||||||
|
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain);
|
||||||
|
|
||||||
|
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
|
||||||
|
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
|
||||||
|
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
|
||||||
|
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
|
||||||
|
|
||||||
|
DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
|
||||||
|
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
|
||||||
|
|
||||||
|
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData);
|
||||||
|
|
||||||
|
return iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] getNewEvaluations() {
|
||||||
|
return new IEvaluation[]{
|
||||||
|
new Evaluation(),
|
||||||
|
new ROCMultiClass(),
|
||||||
|
new EvaluationCalibration()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||||||
|
int miniBatchSize = 10;
|
||||||
|
int numLabelClasses = 6;
|
||||||
|
|
||||||
|
// File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
|
||||||
|
// File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
|
||||||
|
File featuresDirTest = Files.createTempDir();
|
||||||
|
File labelsDirTest = Files.createTempDir();
|
||||||
|
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest);
|
||||||
|
Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest);
|
||||||
|
|
||||||
|
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
|
||||||
|
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
|
||||||
|
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
|
||||||
|
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
|
||||||
|
|
||||||
|
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
|
||||||
|
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
|
||||||
|
|
||||||
|
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);
|
||||||
|
|
||||||
|
MultiDataSetPreProcessor pp = multiDataSet -> {
|
||||||
|
INDArray l = multiDataSet.getLabels(0);
|
||||||
|
l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
|
||||||
|
multiDataSet.setLabels(0, l);
|
||||||
|
multiDataSet.setLabelsMaskArray(0, null);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));
|
||||||
|
|
||||||
|
return iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) {
|
||||||
|
sd.evaluate(iter, "out", 0, evaluations);
|
||||||
|
return evaluations;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -368,7 +368,7 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
|
||||||
REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !");
|
REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !");
|
||||||
REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!");
|
REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!");
|
||||||
REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !");
|
REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !");
|
||||||
REQUIRE_TRUE((retLastH && retLastC) || (!retLastH && !retLastC), 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !");
|
REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !");
|
||||||
|
|
||||||
count = 0;
|
count = 0;
|
||||||
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
|
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
|
||||||
|
@ -464,13 +464,21 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
|
||||||
}
|
}
|
||||||
|
|
||||||
PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
|
PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
|
||||||
|
|
||||||
|
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX)
|
||||||
|
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
|
||||||
|
|
||||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||||
|
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
|
||||||
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||||
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||||
|
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||||
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
||||||
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||||
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||||
|
|
||||||
|
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||||||
|
|
||||||
const auto x = INPUT_VARIABLE(0); // input
|
const auto x = INPUT_VARIABLE(0); // input
|
||||||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||||
|
@ -495,7 +503,15 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
|
||||||
DataType hLType = hL != nullptr ? hL->dataType() : xType;
|
DataType hLType = hL != nullptr ? hL->dataType() : xType;
|
||||||
DataType cLType = cL != nullptr ? cL->dataType() : xType;
|
DataType cLType = cL != nullptr ? cL->dataType() : xType;
|
||||||
|
|
||||||
return block.isUseMKLDNN() && (
|
auto featuresSupported = (cellClip == 0) //Cell clipping not supported
|
||||||
|
&& retFullSeq //Always return full sequence in case of MKL DNN
|
||||||
|
&& !hasPH //Peephole connections not supported in MKL DNN
|
||||||
|
&& !hasSeqLen //Sequence length array not supported in MKL DNN
|
||||||
|
&& dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn]
|
||||||
|
&& directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat
|
||||||
|
&& retLastH == retLastC; //Return both lastH and lastC, or return neither (not just 1 or other)
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && featuresSupported && (
|
||||||
(xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) ||
|
(xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) ||
|
||||||
(xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) ||
|
(xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) ||
|
||||||
(xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8))
|
(xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8))
|
||||||
|
|
|
@ -2148,7 +2148,7 @@ public class DifferentialFunctionFactory {
|
||||||
|
|
||||||
public SDVariable gatherNd(SDVariable df, SDVariable indices) {
|
public SDVariable gatherNd(SDVariable df, SDVariable indices) {
|
||||||
validateDifferentialFunctionsameDiff(df);
|
validateDifferentialFunctionsameDiff(df);
|
||||||
return new GatherNd(sameDiff(), df, indices, false).outputVariable();
|
return new GatherNd(sameDiff(), df, indices).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable trace(SDVariable in){
|
public SDVariable trace(SDVariable in){
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.nd4j.weightinit.WeightInitScheme;
|
import org.nd4j.weightinit.WeightInitScheme;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
@ -244,7 +245,7 @@ public class SDVariable implements Serializable {
|
||||||
* @return new variable
|
* @return new variable
|
||||||
*/
|
*/
|
||||||
public SDVariable assign(Number value){
|
public SDVariable assign(Number value){
|
||||||
return sameDiff.scalarSet(this, value);
|
return sameDiff.scalarSet(this, value.doubleValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -538,7 +539,7 @@ public class SDVariable implements Serializable {
|
||||||
* @return Output variable (result of mmul)
|
* @return Output variable (result of mmul)
|
||||||
*/
|
*/
|
||||||
public SDVariable mmul(String name, SDVariable other, @NonNull MMulTranspose mMulTranspose) {
|
public SDVariable mmul(String name, SDVariable other, @NonNull MMulTranspose mMulTranspose) {
|
||||||
return sameDiff.mmul(name, this, other, mMulTranspose);
|
return sameDiff.mmul(name, this, other, mMulTranspose.isTransposeA(), mMulTranspose.isTransposeB(), mMulTranspose.isTransposeResult());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1403,7 +1404,7 @@ public class SDVariable implements Serializable {
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable reshape(int... newShape){
|
public SDVariable reshape(int... newShape){
|
||||||
return sameDiff.reshape(this, newShape);
|
return sameDiff.reshape(this, ArrayUtil.toLongArray(newShape));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -53,6 +53,7 @@ import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
|
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
|
||||||
|
@ -78,6 +79,7 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.nd4j.linalg.util.ND4JFileUtils;
|
import org.nd4j.linalg.util.ND4JFileUtils;
|
||||||
import org.nd4j.shade.guava.collect.HashBasedTable;
|
import org.nd4j.shade.guava.collect.HashBasedTable;
|
||||||
|
import org.nd4j.shade.guava.collect.Sets;
|
||||||
import org.nd4j.shade.guava.collect.Table;
|
import org.nd4j.shade.guava.collect.Table;
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
import org.nd4j.shade.guava.primitives.Ints;
|
||||||
import org.nd4j.weightinit.WeightInitScheme;
|
import org.nd4j.weightinit.WeightInitScheme;
|
||||||
|
@ -104,7 +106,6 @@ import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
|
||||||
* <p>
|
* <p>
|
||||||
* In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)}
|
* In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)}
|
||||||
*/
|
*/
|
||||||
@AllArgsConstructor
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SameDiff extends SDBaseOps {
|
public class SameDiff extends SDBaseOps {
|
||||||
protected static final String GRAD_FN_KEY = "grad";
|
protected static final String GRAD_FN_KEY = "grad";
|
||||||
|
@ -914,6 +915,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
private SameDiff() {
|
private SameDiff() {
|
||||||
|
super(null);
|
||||||
|
super.sd = this;
|
||||||
functionFactory = new DifferentialFunctionFactory(this);
|
functionFactory = new DifferentialFunctionFactory(this);
|
||||||
sameDiffFunctionInstances = new LinkedHashMap<>();
|
sameDiffFunctionInstances = new LinkedHashMap<>();
|
||||||
fieldVariableResolutionMapping = HashBasedTable.create();
|
fieldVariableResolutionMapping = HashBasedTable.create();
|
||||||
|
@ -4544,7 +4547,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
//Also exclude assert etc ops - doesn't make sense to return these "outputs" to user
|
//Also exclude assert etc ops - doesn't make sense to return these "outputs" to user
|
||||||
if (v.getOutputOfOp() != null) {
|
if (v.getOutputOfOp() != null && v.getVariable().dataType().isFPType()) {
|
||||||
String opName = v.getOutputOfOp();
|
String opName = v.getOutputOfOp();
|
||||||
SameDiffOp o = ops.get(opName);
|
SameDiffOp o = ops.get(opName);
|
||||||
if (o.getOp() instanceof Assert) {
|
if (o.getOp() instanceof Assert) {
|
||||||
|
@ -4621,12 +4624,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
return varToUpdate;
|
return varToUpdate;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
protected SameDiff sd() {
|
|
||||||
//Helper method for SDBaseOps etc
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable.
|
* Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable.
|
||||||
|
@ -5840,7 +5837,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
* See {@link #generateNewVarName(String, int, boolean)}
|
* See {@link #generateNewVarName(String, int, boolean)}
|
||||||
* existingOp is true.
|
* existingOp is true.
|
||||||
*/
|
*/
|
||||||
@Override
|
|
||||||
public String generateNewVarName(String base, int argIndex) {
|
public String generateNewVarName(String base, int argIndex) {
|
||||||
return generateNewVarName(base, argIndex, true);
|
return generateNewVarName(base, argIndex, true);
|
||||||
}
|
}
|
||||||
|
@ -5868,4 +5864,261 @@ public class SameDiff extends SDBaseOps {
|
||||||
public String toString(){
|
public String toString(){
|
||||||
return "SameDiff(nVars=" + variables.size() + ",nOps=" + ops.size() + ")";
|
return "SameDiff(nVars=" + variables.size() + ",nOps=" + ops.size() + ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
|
||||||
|
*/
|
||||||
|
public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda cond,
|
||||||
|
@NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
|
||||||
|
return ifCond(null, null, cond, trueBody, falseBody);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
|
||||||
|
*/
|
||||||
|
public SDVariable ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond,
|
||||||
|
@NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
|
||||||
|
return ifCond(null, ifName, cond, trueBody, falseBody);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a If statement using the tensorflow style control flow operations (Switch and Merge)
|
||||||
|
*
|
||||||
|
* If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody
|
||||||
|
*
|
||||||
|
* Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate.
|
||||||
|
*
|
||||||
|
* See <a href="http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf">Tensorflow Control Flow Implementation</a>
|
||||||
|
*
|
||||||
|
* @param outputName Name to give the output variable. If null, doesn't rename
|
||||||
|
* @param ifName The name of the if block. If null, uses "if"
|
||||||
|
* @param cond A lambda evaluating to the if condition
|
||||||
|
* @param trueBody A lambda to be executed if cond is true (the if block)
|
||||||
|
* @param falseBody A lambda to be executed if cond is false (the else block)
|
||||||
|
* @return The value of trueBody if cond is true, or falseBody if it isn't
|
||||||
|
*/
|
||||||
|
public SDVariable ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond,
|
||||||
|
@NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
|
||||||
|
|
||||||
|
ifName = newBlockName(ifName == null ? "if" : ifName);
|
||||||
|
|
||||||
|
NameScope ifScope = sd.withNameScope(ifName);
|
||||||
|
|
||||||
|
NameScope condScope = withNameScope("cond");
|
||||||
|
final SDVariable pred = cond.define(this);
|
||||||
|
condScope.close();
|
||||||
|
|
||||||
|
if (pred.dataType() != DataType.BOOL) {
|
||||||
|
//cleanup partially added block
|
||||||
|
|
||||||
|
for(SDVariable v : getVariablesInScope(ifScope))
|
||||||
|
this.getVariables().remove(v.name());
|
||||||
|
|
||||||
|
for(SameDiffOp op : this.getOpsInScope(ifScope)) {
|
||||||
|
for(String in : op.getInputsToOp()){
|
||||||
|
this.removeArgFromOp(in, op.getOp());
|
||||||
|
}
|
||||||
|
this.getOps().remove(op.getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
throw new IllegalStateException("Can not use " + pred.name()
|
||||||
|
+ " as the condition of an If statement, the condition must be a boolean.");
|
||||||
|
}
|
||||||
|
|
||||||
|
final Map<String, SDVariable[]> switches = new HashMap<>();
|
||||||
|
|
||||||
|
final Set<String> declared = Sets.newHashSet(this.variableMap().keySet());
|
||||||
|
|
||||||
|
this.addArgumentInterceptor(new ArgumentInterceptor() {
|
||||||
|
@Override
|
||||||
|
public SDVariable intercept(SDVariable argument) {
|
||||||
|
|
||||||
|
// if its declared in the if, we don't care acout it
|
||||||
|
if(!declared.contains(argument.name()))
|
||||||
|
return argument;
|
||||||
|
|
||||||
|
// if we've already added a switch, move on
|
||||||
|
if(switches.containsKey(argument.name()))
|
||||||
|
return switches.get(argument.name())[1];
|
||||||
|
|
||||||
|
SDVariable[] s = f().switchOp(argument, pred);
|
||||||
|
switches.put(argument.name(), s);
|
||||||
|
return s[1];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
NameScope trueScope = this.withNameScope("trueBody");
|
||||||
|
SDVariable trueOut = trueBody.define(this);
|
||||||
|
this.removeArgumentInterceptor();
|
||||||
|
|
||||||
|
if(declared.contains(trueOut.name())) {
|
||||||
|
SDVariable[] s = f().switchOp(trueOut, pred);
|
||||||
|
switches.put(trueOut.name(), s);
|
||||||
|
trueOut = s[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
trueScope.close();
|
||||||
|
|
||||||
|
final Set<String> declared2 = Sets.newHashSet(variableMap().keySet());
|
||||||
|
sd.addArgumentInterceptor(new ArgumentInterceptor() {
|
||||||
|
@Override
|
||||||
|
public SDVariable intercept(SDVariable argument) {
|
||||||
|
|
||||||
|
// if its declared in the if, we don't care acout it
|
||||||
|
if(!declared2.contains(argument.name()))
|
||||||
|
return argument;
|
||||||
|
|
||||||
|
// if we've already added a switch, move on
|
||||||
|
if(switches.containsKey(argument.name()))
|
||||||
|
return switches.get(argument.name())[0];
|
||||||
|
|
||||||
|
SDVariable[] s = f().switchOp(argument, pred);
|
||||||
|
switches.put(argument.name(), s);
|
||||||
|
return s[0];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
NameScope falseScope = this.withNameScope("falseBody");
|
||||||
|
SDVariable falseOut = falseBody.define(this);
|
||||||
|
this.removeArgumentInterceptor();
|
||||||
|
|
||||||
|
if(declared2.contains(falseOut.name())) {
|
||||||
|
SDVariable[] s = f().switchOp(falseOut, pred);
|
||||||
|
switches.put(falseOut.name(), s);
|
||||||
|
falseOut = s[0];
|
||||||
|
}
|
||||||
|
falseScope.close();
|
||||||
|
|
||||||
|
SDVariable output = f().merge(trueOut, falseOut);
|
||||||
|
|
||||||
|
ifScope.close();
|
||||||
|
|
||||||
|
return updateVariableNameAndReference(output, outputName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
|
||||||
|
*/
|
||||||
|
public SDVariable[] whileLoop(@NonNull SDVariable[] loopVars,
|
||||||
|
@NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
|
||||||
|
return whileLoop(null, null, loopVars, cond, body);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
|
||||||
|
*/
|
||||||
|
public SDVariable[] whileLoop(String loopName, @NonNull SDVariable[] loopVars,
|
||||||
|
@NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
|
||||||
|
return whileLoop(null, loopName, loopVars, cond, body);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration)
|
||||||
|
*
|
||||||
|
* Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false
|
||||||
|
*
|
||||||
|
* Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations.
|
||||||
|
*
|
||||||
|
* See <a href="http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf">Tensorflow Control Flow Implementation</a>
|
||||||
|
*
|
||||||
|
* @param outputNames Names to give the output variables. If null, doesn't rename
|
||||||
|
* @param loopName The name of the loop block and frame (must be unique). If null, uses "if"
|
||||||
|
* @param loopVars Loop variables' inputs
|
||||||
|
* @param cond A lambda evaluating to the loop condition
|
||||||
|
* @param body A lambda doing the loop operation and returning the new loop variable values
|
||||||
|
* @return The values of the loop variables once condition is false
|
||||||
|
*/
|
||||||
|
public SDVariable[] whileLoop(String[] outputNames, final String loopName, @NonNull SDVariable[] loopVars,
|
||||||
|
@NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
|
||||||
|
|
||||||
|
final String frameName = this.newBlockName(loopName == null ? "while" : loopName);
|
||||||
|
|
||||||
|
NameScope loopScope = this.withNameScope(frameName);
|
||||||
|
|
||||||
|
//SDVariable counter = SD.scalar(SD.generateNewVarName("counter", 0), 0);
|
||||||
|
|
||||||
|
SDVariable[] entered = new SDVariable[loopVars.length];
|
||||||
|
for(int i = 0 ; i < loopVars.length ; i++){
|
||||||
|
entered[i] = f().enter(loopVars[i], frameName);
|
||||||
|
}
|
||||||
|
|
||||||
|
//counter = SD.f().enter(counter, frameName);
|
||||||
|
|
||||||
|
SDVariable[] merged = new SDVariable[loopVars.length];
|
||||||
|
Merge[] mergeOps = new Merge[loopVars.length];
|
||||||
|
for(int i = 0 ; i < loopVars.length ; i++){
|
||||||
|
// the second arg will later be replaced with the output of NextIteration
|
||||||
|
// but that isn't available yet (and can't be, as it depends on this)
|
||||||
|
mergeOps[i] = new Merge(this, entered[i], entered[i]);
|
||||||
|
merged[i] = mergeOps[i].outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
//Merge counterMerge = new Merge(SD, counter, counter);
|
||||||
|
//counter = counterMerge.outputVariable();
|
||||||
|
|
||||||
|
NameScope condScope = this.withNameScope("cond");
|
||||||
|
SDVariable cond_result = cond.define(this, merged);
|
||||||
|
condScope.close();
|
||||||
|
|
||||||
|
|
||||||
|
if (cond_result.dataType() != DataType.BOOL)
|
||||||
|
throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean.");
|
||||||
|
|
||||||
|
|
||||||
|
final Set<String> alreadyEntered = Sets.newHashSet();
|
||||||
|
SDVariable[] trueSwitches = new SDVariable[loopVars.length];
|
||||||
|
SDVariable[] exits = new SDVariable[loopVars.length];
|
||||||
|
for(int i = 0 ; i < loopVars.length ; i++){
|
||||||
|
SDVariable[] s = f().switchOp(merged[i], cond_result);
|
||||||
|
trueSwitches[i] = s[1];
|
||||||
|
alreadyEntered.add(s[1].name());
|
||||||
|
exits[i] = f().exit(s[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
//SDVariable[] cs = SD.f().switchOp(counter, cond_result);
|
||||||
|
//SDVariable counterExit = SD.f().exit(cs[0]);
|
||||||
|
//counter = cs[1];
|
||||||
|
|
||||||
|
final Set<String> declared = Sets.newHashSet(this.variableMap().keySet());
|
||||||
|
final Map<String, SDVariable> done = new HashMap<>();
|
||||||
|
|
||||||
|
this.addArgumentInterceptor(new ArgumentInterceptor() {
|
||||||
|
@Override
|
||||||
|
public SDVariable intercept(SDVariable argument) {
|
||||||
|
|
||||||
|
if(!declared.contains(argument.name()))
|
||||||
|
return argument;
|
||||||
|
|
||||||
|
if(alreadyEntered.contains(argument.name()))
|
||||||
|
return argument;
|
||||||
|
|
||||||
|
if(done.containsKey(argument.name()))
|
||||||
|
return done.get(argument.name());
|
||||||
|
|
||||||
|
SDVariable e = f().enter(argument, frameName, true);
|
||||||
|
done.put(argument.name(), e);
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
NameScope bodyScope = this.withNameScope("body");
|
||||||
|
SDVariable[] outs = body.define(this, trueSwitches);
|
||||||
|
bodyScope.close();
|
||||||
|
this.removeArgumentInterceptor();
|
||||||
|
|
||||||
|
//counter.add(1);
|
||||||
|
|
||||||
|
for(int i = 0 ; i < loopVars.length ; i++){
|
||||||
|
SDVariable n = f().nextIteration(outs[i]);
|
||||||
|
mergeOps[i].replaceArg(1,n);
|
||||||
|
}
|
||||||
|
|
||||||
|
//counterMerge.replaceArg(1, counter);
|
||||||
|
|
||||||
|
loopScope.close();
|
||||||
|
return updateVariableNamesAndReferences(exits, outputNames);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -23,8 +23,8 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
import java.lang.String;
|
import java.lang.String;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.enums.DataFormat;
|
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.enums.DataFormat;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
|
@ -753,6 +753,33 @@ public class SDCNN extends SDOps {
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
||||||
|
*
|
||||||
|
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||||
|
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
|
* @param Pooling2DConfig Configuration Object
|
||||||
|
*/
|
||||||
|
public SDVariable[] maxPoolWithArgmax(SDVariable input, Pooling2DConfig Pooling2DConfig) {
|
||||||
|
SDValidation.validateNumerical("maxPoolWithArgmax", "input", input);
|
||||||
|
return new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(sd,input, Pooling2DConfig).outputVariables();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
||||||
|
*
|
||||||
|
* @param names names May be null. Arrays of names for the output variables.
|
||||||
|
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||||
|
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
|
* @param Pooling2DConfig Configuration Object
|
||||||
|
*/
|
||||||
|
public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable input,
|
||||||
|
Pooling2DConfig Pooling2DConfig) {
|
||||||
|
SDValidation.validateNumerical("maxPoolWithArgmax", "input", input);
|
||||||
|
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(sd,input, Pooling2DConfig).outputVariables();
|
||||||
|
return sd.updateVariableNamesAndReferences(out, names);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - max pooling 2d <br>
|
* 2D Convolution layer operation - max pooling 2d <br>
|
||||||
*
|
*
|
||||||
|
|
|
@ -2205,7 +2205,7 @@ public class SDMath extends SDOps {
|
||||||
* @param inputs Input variables (NUMERIC type)
|
* @param inputs Input variables (NUMERIC type)
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable mergeAdd(SDVariable[] inputs) {
|
public SDVariable mergeAdd(SDVariable... inputs) {
|
||||||
SDValidation.validateNumerical("mergeAdd", "inputs", inputs);
|
SDValidation.validateNumerical("mergeAdd", "inputs", inputs);
|
||||||
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
||||||
return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable();
|
||||||
|
@ -2219,7 +2219,7 @@ public class SDMath extends SDOps {
|
||||||
* @param inputs Input variables (NUMERIC type)
|
* @param inputs Input variables (NUMERIC type)
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable mergeAdd(String name, SDVariable[] inputs) {
|
public SDVariable mergeAdd(String name, SDVariable... inputs) {
|
||||||
SDValidation.validateNumerical("mergeAdd", "inputs", inputs);
|
SDValidation.validateNumerical("mergeAdd", "inputs", inputs);
|
||||||
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable();
|
||||||
|
@ -2233,7 +2233,7 @@ public class SDMath extends SDOps {
|
||||||
* @param inputs Input variables (NUMERIC type)
|
* @param inputs Input variables (NUMERIC type)
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable mergeAvg(SDVariable[] inputs) {
|
public SDVariable mergeAvg(SDVariable... inputs) {
|
||||||
SDValidation.validateNumerical("mergeAvg", "inputs", inputs);
|
SDValidation.validateNumerical("mergeAvg", "inputs", inputs);
|
||||||
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
||||||
return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable();
|
||||||
|
@ -2247,7 +2247,7 @@ public class SDMath extends SDOps {
|
||||||
* @param inputs Input variables (NUMERIC type)
|
* @param inputs Input variables (NUMERIC type)
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable mergeAvg(String name, SDVariable[] inputs) {
|
public SDVariable mergeAvg(String name, SDVariable... inputs) {
|
||||||
SDValidation.validateNumerical("mergeAvg", "inputs", inputs);
|
SDValidation.validateNumerical("mergeAvg", "inputs", inputs);
|
||||||
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable();
|
||||||
|
@ -2261,7 +2261,7 @@ public class SDMath extends SDOps {
|
||||||
* @param inputs Input variables (NUMERIC type)
|
* @param inputs Input variables (NUMERIC type)
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable mergeMax(SDVariable[] inputs) {
|
public SDVariable mergeMax(SDVariable... inputs) {
|
||||||
SDValidation.validateNumerical("mergeMax", "inputs", inputs);
|
SDValidation.validateNumerical("mergeMax", "inputs", inputs);
|
||||||
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
||||||
return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable();
|
||||||
|
@ -2275,7 +2275,7 @@ public class SDMath extends SDOps {
|
||||||
* @param inputs Input variables (NUMERIC type)
|
* @param inputs Input variables (NUMERIC type)
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable mergeMax(String name, SDVariable[] inputs) {
|
public SDVariable mergeMax(String name, SDVariable... inputs) {
|
||||||
SDValidation.validateNumerical("mergeMax", "inputs", inputs);
|
SDValidation.validateNumerical("mergeMax", "inputs", inputs);
|
||||||
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable();
|
||||||
|
|
|
@ -18,17 +18,15 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
import java.lang.String;
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import java.lang.String;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
|
||||||
|
|
||||||
|
@ -43,28 +41,26 @@ public class SDRNN extends SDOps {
|
||||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||||
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
|
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
* @param GRUWeights Configuration Object
|
* @param GRUWeights Configuration Object
|
||||||
* @return output The cell's outputs. (NUMERIC type)
|
|
||||||
*/
|
*/
|
||||||
public SDVariable gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) {
|
public SDVariable[] gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) {
|
||||||
SDValidation.validateNumerical("gru", "x", x);
|
SDValidation.validateNumerical("gru", "x", x);
|
||||||
SDValidation.validateNumerical("gru", "hLast", hLast);
|
SDValidation.validateNumerical("gru", "hLast", hLast);
|
||||||
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The GRU cell. Does a single time step operation<br>
|
* The GRU cell. Does a single time step operation<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param names names May be null. Arrays of names for the output variables.
|
||||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||||
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
|
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
* @param GRUWeights Configuration Object
|
* @param GRUWeights Configuration Object
|
||||||
* @return output The cell's outputs. (NUMERIC type)
|
|
||||||
*/
|
*/
|
||||||
public GRUCellOutputs gru(String name, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) {
|
public SDVariable[] gru(String[] names, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) {
|
||||||
SDValidation.validateNumerical("gru", "x", x);
|
SDValidation.validateNumerical("gru", "x", x);
|
||||||
SDValidation.validateNumerical("gru", "hLast", hLast);
|
SDValidation.validateNumerical("gru", "hLast", hLast);
|
||||||
GRUCell c = new GRUCell(sd,x, hLast, GRUWeights);
|
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables();
|
||||||
return new GRUCellOutputs(c.outputVariables(name));
|
return sd.updateVariableNamesAndReferences(out, names);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -75,39 +71,172 @@ public class SDRNN extends SDOps {
|
||||||
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
* @param LSTMWeights Configuration Object
|
* @param LSTMWeights Configuration Object
|
||||||
* @param LSTMConfiguration Configuration Object
|
* @param LSTMConfiguration Configuration Object
|
||||||
* @return output The cell's outputs (NUMERIC type)
|
|
||||||
*/
|
*/
|
||||||
public LSTMCellOutputs lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast,
|
public SDVariable[] lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast,
|
||||||
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
||||||
SDValidation.validateNumerical("lstmCell", "x", x);
|
SDValidation.validateNumerical("lstmCell", "x", x);
|
||||||
SDValidation.validateNumerical("lstmCell", "cLast", cLast);
|
SDValidation.validateNumerical("lstmCell", "cLast", cLast);
|
||||||
SDValidation.validateNumerical("lstmCell", "yLast", yLast);
|
SDValidation.validateNumerical("lstmCell", "yLast", yLast);
|
||||||
LSTMBlockCell c = new LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration);
|
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariables();
|
||||||
return new LSTMCellOutputs(c.outputVariables());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The LSTM cell. Does a single time step operation.<br>
|
* The LSTM cell. Does a single time step operation.<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param names names May be null. Arrays of names for the output variables.
|
||||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||||
* @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type)
|
* @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
* @param LSTMWeights Configuration Object
|
* @param LSTMWeights Configuration Object
|
||||||
* @param LSTMConfiguration Configuration Object
|
* @param LSTMConfiguration Configuration Object
|
||||||
* @return output The cell's outputs (NUMERIC type)
|
|
||||||
*/
|
*/
|
||||||
public LSTMCellOutputs lstmCell(String name, SDVariable x, SDVariable cLast, SDVariable yLast,
|
public SDVariable[] lstmCell(String[] names, SDVariable x, SDVariable cLast, SDVariable yLast,
|
||||||
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
||||||
SDValidation.validateNumerical("lstmCell", "x", x);
|
SDValidation.validateNumerical("lstmCell", "x", x);
|
||||||
SDValidation.validateNumerical("lstmCell", "cLast", cLast);
|
SDValidation.validateNumerical("lstmCell", "cLast", cLast);
|
||||||
SDValidation.validateNumerical("lstmCell", "yLast", yLast);
|
SDValidation.validateNumerical("lstmCell", "yLast", yLast);
|
||||||
LSTMBlockCell c = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration);
|
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariables();
|
||||||
return new LSTMCellOutputs(c.outputVariables(name));
|
return sd.updateVariableNamesAndReferences(out, names);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The LSTM layer. Does multiple time steps.<br>
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
|
* SUPPORTS following data formats:\n<br>
|
||||||
|
* for unidirectional: \n" +<br>
|
||||||
|
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||||
|
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||||
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
|
* for bidirectional:\n<br>
|
||||||
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||||
|
* SUPPORTS following direction modes:\n<br>
|
||||||
|
* FWD: forward<br>
|
||||||
|
* BWD: backward<br>
|
||||||
|
* BIDIR_SUM: bidirectional sum\n<br>
|
||||||
|
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||||
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||||
|
* You may use different gate configurations:<br>
|
||||||
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||||
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||||
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
|
*
|
||||||
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
|
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
|
* @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type)
|
||||||
|
* @param LSTMLayerWeights Configuration Object
|
||||||
|
* @param LSTMLayerConfig Configuration Object
|
||||||
|
*/
|
||||||
|
public SDVariable[] lstmLayer(SDVariable x, SDVariable cLast, SDVariable yLast,
|
||||||
|
SDVariable maxTSLength, LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) {
|
||||||
|
SDValidation.validateNumerical("lstmLayer", "x", x);
|
||||||
|
SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
|
||||||
|
SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
|
||||||
|
SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
|
||||||
|
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig).outputVariables();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
|
* SUPPORTS following data formats:\n<br>
|
||||||
|
* for unidirectional: \n" +<br>
|
||||||
|
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||||
|
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||||
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
|
* for bidirectional:\n<br>
|
||||||
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||||
|
* SUPPORTS following direction modes:\n<br>
|
||||||
|
* FWD: forward<br>
|
||||||
|
* BWD: backward<br>
|
||||||
|
* BIDIR_SUM: bidirectional sum\n<br>
|
||||||
|
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||||
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||||
|
* You may use different gate configurations:<br>
|
||||||
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||||
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||||
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
|
*
|
||||||
|
* @param names names May be null. Arrays of names for the output variables.
|
||||||
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
|
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
|
* @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type)
|
||||||
|
* @param LSTMLayerWeights Configuration Object
|
||||||
|
* @param LSTMLayerConfig Configuration Object
|
||||||
|
*/
|
||||||
|
public SDVariable[] lstmLayer(String[] names, SDVariable x, SDVariable cLast, SDVariable yLast,
|
||||||
|
SDVariable maxTSLength, LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) {
|
||||||
|
SDValidation.validateNumerical("lstmLayer", "x", x);
|
||||||
|
SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
|
||||||
|
SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
|
||||||
|
SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
|
||||||
|
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig).outputVariables();
|
||||||
|
return sd.updateVariableNamesAndReferences(out, names);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
|
* SUPPORTS following data formats:\n<br>
|
||||||
|
* for unidirectional: \n" +<br>
|
||||||
|
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||||
|
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||||
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
|
* for bidirectional:\n<br>
|
||||||
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||||
|
* SUPPORTS following direction modes:\n<br>
|
||||||
|
* FWD: forward<br>
|
||||||
|
* BWD: backward<br>
|
||||||
|
* BIDIR_SUM: bidirectional sum\n<br>
|
||||||
|
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||||
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||||
|
* You may use different gate configurations:<br>
|
||||||
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||||
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||||
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
|
*
|
||||||
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
* @param LSTMLayerWeights Configuration Object
|
||||||
|
* @param LSTMLayerConfig Configuration Object
|
||||||
|
*/
|
||||||
|
public SDVariable[] lstmLayer(SDVariable x, LSTMLayerWeights LSTMLayerWeights,
|
||||||
|
LSTMLayerConfig LSTMLayerConfig) {
|
||||||
|
SDValidation.validateNumerical("lstmLayer", "x", x);
|
||||||
|
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, null, null, null, LSTMLayerWeights, LSTMLayerConfig).outputVariables();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
|
* SUPPORTS following data formats:\n<br>
|
||||||
|
* for unidirectional: \n" +<br>
|
||||||
|
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||||
|
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||||
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
|
* for bidirectional:\n<br>
|
||||||
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||||
|
* SUPPORTS following direction modes:\n<br>
|
||||||
|
* FWD: forward<br>
|
||||||
|
* BWD: backward<br>
|
||||||
|
* BIDIR_SUM: bidirectional sum\n<br>
|
||||||
|
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||||
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||||
|
* You may use different gate configurations:<br>
|
||||||
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||||
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||||
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
|
*
|
||||||
|
* @param names names May be null. Arrays of names for the output variables.
|
||||||
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
* @param LSTMLayerWeights Configuration Object
|
||||||
|
* @param LSTMLayerConfig Configuration Object
|
||||||
|
*/
|
||||||
|
public SDVariable[] lstmLayer(String[] names, SDVariable x, LSTMLayerWeights LSTMLayerWeights,
|
||||||
|
LSTMLayerConfig LSTMLayerConfig) {
|
||||||
|
SDValidation.validateNumerical("lstmLayer", "x", x);
|
||||||
|
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, null, null, null, LSTMLayerWeights, LSTMLayerConfig).outputVariables();
|
||||||
|
return sd.updateVariableNamesAndReferences(out, names);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The LSTM block<br>
|
||||||
*
|
*
|
||||||
* @param maxTSLength (NUMERIC type)
|
* @param maxTSLength (NUMERIC type)
|
||||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
@ -117,17 +246,17 @@ public class SDRNN extends SDOps {
|
||||||
* @param LSTMConfiguration Configuration Object
|
* @param LSTMConfiguration Configuration Object
|
||||||
* @return output The layer's outputs. (NUMERIC type)
|
* @return output The layer's outputs. (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable lstmLayer(SDVariable maxTSLength, SDVariable x, SDVariable cLast,
|
public SDVariable lstmblock(SDVariable maxTSLength, SDVariable x, SDVariable cLast,
|
||||||
SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
||||||
SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
|
SDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength);
|
||||||
SDValidation.validateNumerical("lstmLayer", "x", x);
|
SDValidation.validateNumerical("lstmblock", "x", x);
|
||||||
SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
|
SDValidation.validateNumerical("lstmblock", "cLast", cLast);
|
||||||
SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
|
SDValidation.validateNumerical("lstmblock", "yLast", yLast);
|
||||||
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The LSTM layer. Does multiple time steps.<br>
|
* The LSTM block<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param maxTSLength (NUMERIC type)
|
* @param maxTSLength (NUMERIC type)
|
||||||
|
@ -138,13 +267,43 @@ public class SDRNN extends SDOps {
|
||||||
* @param LSTMConfiguration Configuration Object
|
* @param LSTMConfiguration Configuration Object
|
||||||
* @return output The layer's outputs. (NUMERIC type)
|
* @return output The layer's outputs. (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable lstmLayer(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast,
|
public SDVariable lstmblock(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast,
|
||||||
SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
||||||
SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
|
SDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength);
|
||||||
SDValidation.validateNumerical("lstmLayer", "x", x);
|
SDValidation.validateNumerical("lstmblock", "x", x);
|
||||||
SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
|
SDValidation.validateNumerical("lstmblock", "cLast", cLast);
|
||||||
SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
|
SDValidation.validateNumerical("lstmblock", "yLast", yLast);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable();
|
||||||
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The LSTM block<br>
|
||||||
|
*
|
||||||
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
* @param LSTMWeights Configuration Object
|
||||||
|
* @param LSTMConfiguration Configuration Object
|
||||||
|
* @return output The layer's outputs. (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public SDVariable lstmblock(SDVariable x, LSTMWeights LSTMWeights,
|
||||||
|
LSTMConfiguration LSTMConfiguration) {
|
||||||
|
SDValidation.validateNumerical("lstmblock", "x", x);
|
||||||
|
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,null, x, null, null, LSTMWeights, LSTMConfiguration).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The LSTM block<br>
|
||||||
|
*
|
||||||
|
* @param name name May be null. Name for the output variable
|
||||||
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
* @param LSTMWeights Configuration Object
|
||||||
|
* @param LSTMConfiguration Configuration Object
|
||||||
|
* @return output The layer's outputs. (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public SDVariable lstmblock(String name, SDVariable x, LSTMWeights LSTMWeights,
|
||||||
|
LSTMConfiguration LSTMConfiguration) {
|
||||||
|
SDValidation.validateNumerical("lstmblock", "x", x);
|
||||||
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,null, x, null, null, LSTMWeights, LSTMConfiguration).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
|
|
||||||
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Activations */
|
||||||
|
public enum CellAct {
|
||||||
|
TANH,
|
||||||
|
|
||||||
|
RELU,
|
||||||
|
|
||||||
|
SIGMOID,
|
||||||
|
|
||||||
|
AFFINE,
|
||||||
|
|
||||||
|
LEAKY_RELU,
|
||||||
|
|
||||||
|
THRESHHOLD_RELU,
|
||||||
|
|
||||||
|
SCALED_TAHN,
|
||||||
|
|
||||||
|
HARD_SIGMOID,
|
||||||
|
|
||||||
|
ELU,
|
||||||
|
|
||||||
|
SOFTSIGN,
|
||||||
|
|
||||||
|
SOFTPLUS
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
|
|
||||||
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Activations */
|
||||||
|
public enum GateAct {
|
||||||
|
TANH,
|
||||||
|
|
||||||
|
RELU,
|
||||||
|
|
||||||
|
SIGMOID,
|
||||||
|
|
||||||
|
AFFINE,
|
||||||
|
|
||||||
|
LEAKY_RELU,
|
||||||
|
|
||||||
|
THRESHHOLD_RELU,
|
||||||
|
|
||||||
|
SCALED_TAHN,
|
||||||
|
|
||||||
|
HARD_SIGMOID,
|
||||||
|
|
||||||
|
ELU,
|
||||||
|
|
||||||
|
SOFTSIGN,
|
||||||
|
|
||||||
|
SOFTPLUS
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
|
|
||||||
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* for unidirectional:
|
||||||
|
* TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
|
||||||
|
* NST: shape [numExamples, inOutSize, timeLength]<br>
|
||||||
|
* NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout<br>
|
||||||
|
* for bidirectional:
|
||||||
|
* T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) */
|
||||||
|
public enum LSTMDataFormat {
|
||||||
|
TNS,
|
||||||
|
|
||||||
|
NST,
|
||||||
|
|
||||||
|
NTS,
|
||||||
|
|
||||||
|
T2NS
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
|
|
||||||
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* direction <br>
|
||||||
|
* FWD: 0 = fwd
|
||||||
|
* BWD: 1 = bwd
|
||||||
|
* BIDIR_SUM: 2 = bidirectional sum
|
||||||
|
* BIDIR_CONCAT: 3 = bidirectional concat
|
||||||
|
* BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */
|
||||||
|
public enum LSTMDirectionMode {
|
||||||
|
FWD,
|
||||||
|
|
||||||
|
BWD,
|
||||||
|
|
||||||
|
BIDIR_SUM,
|
||||||
|
|
||||||
|
BIDIR_CONCAT,
|
||||||
|
|
||||||
|
BIDIR_EXTRA_DIM
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
|
|
||||||
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Activations */
|
||||||
|
public enum OutAct {
|
||||||
|
TANH,
|
||||||
|
|
||||||
|
RELU,
|
||||||
|
|
||||||
|
SIGMOID,
|
||||||
|
|
||||||
|
AFFINE,
|
||||||
|
|
||||||
|
LEAKY_RELU,
|
||||||
|
|
||||||
|
THRESHHOLD_RELU,
|
||||||
|
|
||||||
|
SCALED_TAHN,
|
||||||
|
|
||||||
|
HARD_SIGMOID,
|
||||||
|
|
||||||
|
ELU,
|
||||||
|
|
||||||
|
SOFTSIGN,
|
||||||
|
|
||||||
|
SOFTPLUS
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
|
|
||||||
|
package org.nd4j.enums;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The data format of the input. Input shape depends on data format (in config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, inSize]<br>
|
||||||
|
* NST -> [batchSize, inSize, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, inSize]<br> */
|
||||||
|
public enum RnnDataFormat {
|
||||||
|
TNS,
|
||||||
|
|
||||||
|
NST,
|
||||||
|
|
||||||
|
NTS
|
||||||
|
}
|
|
@ -146,6 +146,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class,
|
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class,
|
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer.class,
|
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU.class,
|
org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell.class,
|
org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell.class,
|
||||||
org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss.class,
|
org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss.class,
|
||||||
|
|
|
@ -301,24 +301,27 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void checkForWorkspaces(CustomOp op) {
|
protected void checkForWorkspaces(CustomOp op, OpContext oc) {
|
||||||
for (val input: op.inputArguments())
|
List<INDArray> inArgs = oc != null ? oc.getInputArrays() : op.inputArguments();
|
||||||
|
List<INDArray> outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments();
|
||||||
|
|
||||||
|
for (val input: inArgs)
|
||||||
checkWorkspace(op.opName(), input);
|
checkWorkspace(op.opName(), input);
|
||||||
|
|
||||||
for (val output: op.outputArguments())
|
for (val output: outArgs)
|
||||||
checkWorkspace(op.opName(), output);
|
checkWorkspace(op.opName(), output);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void checkForWorkspaces(Op op) {
|
protected void checkForWorkspaces(Op op, OpContext oc) {
|
||||||
val x = op.x();
|
val x = oc != null ? oc.getInputArray(0) : op.x();
|
||||||
if (x != null)
|
if (x != null)
|
||||||
checkWorkspace(op.opName(), x);
|
checkWorkspace(op.opName(), x);
|
||||||
|
|
||||||
val y = op.y();
|
val y = oc != null && oc.getInputArrays().size() > 1 ? oc.getInputArray(1) : op.y();
|
||||||
if (y != null)
|
if (y != null)
|
||||||
checkWorkspace(op.opName(), y);
|
checkWorkspace(op.opName(), y);
|
||||||
|
|
||||||
val z = op.z();
|
val z = oc != null ? oc.getOutputArray(0) : op.z();
|
||||||
if (z != null)
|
if (z != null)
|
||||||
checkWorkspace(op.opName(), z);
|
checkWorkspace(op.opName(), z);
|
||||||
}
|
}
|
||||||
|
@ -346,7 +349,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
OpProfiler.getInstance().processOpCall(op, tadBuffers);
|
OpProfiler.getInstance().processOpCall(op, tadBuffers);
|
||||||
break;
|
break;
|
||||||
case SCOPE_PANIC:
|
case SCOPE_PANIC:
|
||||||
checkForWorkspaces(op);
|
checkForWorkspaces(op, null);
|
||||||
return 0L;
|
return 0L;
|
||||||
case DISABLED:
|
case DISABLED:
|
||||||
default:
|
default:
|
||||||
|
@ -357,7 +360,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public long profilingHookIn(CustomOp op) {
|
public long profilingHookIn(CustomOp op, OpContext oc) {
|
||||||
switch (profilingMode) {
|
switch (profilingMode) {
|
||||||
case ALL:
|
case ALL:
|
||||||
OpProfiler.getInstance().processOpCall(op);
|
OpProfiler.getInstance().processOpCall(op);
|
||||||
|
@ -368,7 +371,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
OpProfiler.getInstance().processOpCall(op);
|
OpProfiler.getInstance().processOpCall(op);
|
||||||
break;
|
break;
|
||||||
case SCOPE_PANIC:
|
case SCOPE_PANIC:
|
||||||
checkForWorkspaces(op);
|
checkForWorkspaces(op, oc);
|
||||||
return 0L;
|
return 0L;
|
||||||
case DISABLED:
|
case DISABLED:
|
||||||
default:
|
default:
|
||||||
|
@ -379,7 +382,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public void profilingHookOut(Op op, long timeStart) {
|
public void profilingHookOut(Op op, OpContext oc, long timeStart) {
|
||||||
switch (profilingMode) {
|
switch (profilingMode) {
|
||||||
case ALL:
|
case ALL:
|
||||||
OpProfiler.getInstance().processStackCall(op, timeStart);
|
OpProfiler.getInstance().processStackCall(op, timeStart);
|
||||||
|
@ -392,14 +395,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
OpProfiler.getInstance().timeOpCall(op, timeStart);
|
OpProfiler.getInstance().timeOpCall(op, timeStart);
|
||||||
break;
|
break;
|
||||||
case NAN_PANIC:
|
case NAN_PANIC:
|
||||||
OpExecutionerUtil.checkForNaN(op);
|
OpExecutionerUtil.checkForNaN(op, oc);
|
||||||
break;
|
break;
|
||||||
case INF_PANIC:
|
case INF_PANIC:
|
||||||
OpExecutionerUtil.checkForInf(op);
|
OpExecutionerUtil.checkForInf(op, oc);
|
||||||
break;
|
break;
|
||||||
case ANY_PANIC:
|
case ANY_PANIC:
|
||||||
OpExecutionerUtil.checkForNaN(op);
|
OpExecutionerUtil.checkForNaN(op, oc);
|
||||||
OpExecutionerUtil.checkForInf(op);
|
OpExecutionerUtil.checkForInf(op, oc);
|
||||||
break;
|
break;
|
||||||
case DISABLED:
|
case DISABLED:
|
||||||
default:
|
default:
|
||||||
|
@ -413,7 +416,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public void profilingHookOut(CustomOp op, long timeStart) {
|
public void profilingHookOut(CustomOp op, OpContext oc, long timeStart) {
|
||||||
switch (profilingMode) {
|
switch (profilingMode) {
|
||||||
case ALL:
|
case ALL:
|
||||||
OpProfiler.getInstance().processStackCall(op, timeStart);
|
OpProfiler.getInstance().processStackCall(op, timeStart);
|
||||||
|
@ -426,14 +429,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
OpProfiler.getInstance().timeOpCall(op, timeStart);
|
OpProfiler.getInstance().timeOpCall(op, timeStart);
|
||||||
break;
|
break;
|
||||||
case NAN_PANIC:
|
case NAN_PANIC:
|
||||||
OpExecutionerUtil.checkForNaN(op);
|
OpExecutionerUtil.checkForNaN(op, oc);
|
||||||
break;
|
break;
|
||||||
case INF_PANIC:
|
case INF_PANIC:
|
||||||
OpExecutionerUtil.checkForInf(op);
|
OpExecutionerUtil.checkForInf(op, oc);
|
||||||
break;
|
break;
|
||||||
case ANY_PANIC:
|
case ANY_PANIC:
|
||||||
OpExecutionerUtil.checkForNaN(op);
|
OpExecutionerUtil.checkForNaN(op, oc);
|
||||||
OpExecutionerUtil.checkForInf(op);
|
OpExecutionerUtil.checkForInf(op, oc);
|
||||||
break;
|
break;
|
||||||
case DISABLED:
|
case DISABLED:
|
||||||
default:
|
default:
|
||||||
|
@ -442,12 +445,15 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public long profilingConfigurableHookIn(CustomOp op) {
|
public long profilingConfigurableHookIn(CustomOp op, OpContext oc) {
|
||||||
for (val arr: op.inputArguments())
|
List<INDArray> inArgs = oc != null ? oc.getInputArrays() : op.inputArguments();
|
||||||
|
List<INDArray> outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments();
|
||||||
|
|
||||||
|
for (val arr: inArgs)
|
||||||
if (arr.wasClosed())
|
if (arr.wasClosed())
|
||||||
throw new IllegalStateException("One of Input arguments was closed before call");
|
throw new IllegalStateException("One of Input arguments was closed before call");
|
||||||
|
|
||||||
for (val arr: op.outputArguments())
|
for (val arr: outArgs)
|
||||||
if (arr.wasClosed())
|
if (arr.wasClosed())
|
||||||
throw new IllegalStateException("One of Output arguments was closed before call");
|
throw new IllegalStateException("One of Output arguments was closed before call");
|
||||||
|
|
||||||
|
@ -460,7 +466,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) {
|
if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) {
|
||||||
checkForWorkspaces(op);
|
checkForWorkspaces(op, oc);
|
||||||
}
|
}
|
||||||
|
|
||||||
return System.nanoTime();
|
return System.nanoTime();
|
||||||
|
@ -491,14 +497,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
OpProfiler.getInstance().processOpCall(op, tadBuffers);
|
OpProfiler.getInstance().processOpCall(op, tadBuffers);
|
||||||
}
|
}
|
||||||
if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) {
|
if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) {
|
||||||
checkForWorkspaces(op);
|
checkForWorkspaces(op, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
return System.nanoTime();
|
return System.nanoTime();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public void profilingConfigurableHookOut(Op op, long timeStart) {
|
public void profilingConfigurableHookOut(Op op, OpContext oc, long timeStart) {
|
||||||
if (OpProfiler.getInstance().getConfig() == null)
|
if (OpProfiler.getInstance().getConfig() == null)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
@ -509,10 +515,10 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
OpProfiler.getInstance().timeOpCall(op, timeStart);
|
OpProfiler.getInstance().timeOpCall(op, timeStart);
|
||||||
}
|
}
|
||||||
if (OpProfiler.getInstance().getConfig().isCheckForNAN()) {
|
if (OpProfiler.getInstance().getConfig().isCheckForNAN()) {
|
||||||
OpExecutionerUtil.checkForNaN(op);
|
OpExecutionerUtil.checkForNaN(op, oc);
|
||||||
}
|
}
|
||||||
if (OpProfiler.getInstance().getConfig().isCheckForINF()) {
|
if (OpProfiler.getInstance().getConfig().isCheckForINF()) {
|
||||||
OpExecutionerUtil.checkForInf(op);
|
OpExecutionerUtil.checkForInf(op, oc);
|
||||||
}
|
}
|
||||||
if (OpProfiler.getInstance().getConfig().isNativeStatistics()) {
|
if (OpProfiler.getInstance().getConfig().isNativeStatistics()) {
|
||||||
if (op.z() != null) {
|
if (op.z() != null) {
|
||||||
|
@ -531,7 +537,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void profilingConfigurableHookOut(CustomOp op, long timeStart) {
|
public void profilingConfigurableHookOut(CustomOp op, OpContext oc, long timeStart) {
|
||||||
if (OpProfiler.getInstance().getConfig() == null)
|
if (OpProfiler.getInstance().getConfig() == null)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
@ -542,10 +548,10 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
OpProfiler.getInstance().timeOpCall(op, timeStart);
|
OpProfiler.getInstance().timeOpCall(op, timeStart);
|
||||||
}
|
}
|
||||||
if (OpProfiler.getInstance().getConfig().isCheckForNAN()) {
|
if (OpProfiler.getInstance().getConfig().isCheckForNAN()) {
|
||||||
OpExecutionerUtil.checkForNaN(op);
|
OpExecutionerUtil.checkForNaN(op, oc);
|
||||||
}
|
}
|
||||||
if (OpProfiler.getInstance().getConfig().isCheckForINF()) {
|
if (OpProfiler.getInstance().getConfig().isCheckForINF()) {
|
||||||
OpExecutionerUtil.checkForInf(op);
|
OpExecutionerUtil.checkForInf(op, oc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,12 +22,15 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||||
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
import org.nd4j.linalg.profiler.OpProfiler;
|
import org.nd4j.linalg.profiler.OpProfiler;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**Utility functions for the DefaultOpExecutioner
|
/**Utility functions for the DefaultOpExecutioner
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
|
@ -58,7 +61,7 @@ public class OpExecutionerUtil {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (match > 0)
|
if (match > 0)
|
||||||
throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): ");
|
throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s)");
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void checkForAny(INDArray z) {
|
public static void checkForAny(INDArray z) {
|
||||||
|
@ -92,44 +95,52 @@ public class OpExecutionerUtil {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void checkForNaN(Op op) {
|
public static void checkForNaN(Op op, OpContext oc) {
|
||||||
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
|
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (op.z() != null && !(op instanceof MatchCondition)) {
|
INDArray z = oc != null ? oc.getOutputArray(0) : op.z();
|
||||||
checkForNaN(op.z());
|
if (z != null && !(op instanceof MatchCondition)) {
|
||||||
|
checkForNaN(z);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void checkForInf(Op op) {
|
public static void checkForInf(Op op, OpContext oc) {
|
||||||
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
|
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (op.z() != null && !(op instanceof MatchCondition)) {
|
INDArray z = oc != null ? oc.getOutputArray(0) : op.z();
|
||||||
checkForInf(op.z());
|
if (z != null && !(op instanceof MatchCondition)) {
|
||||||
|
checkForInf(z);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void checkForInf(CustomOp op) {
|
public static void checkForInf(CustomOp op, OpContext oc) {
|
||||||
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
|
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
for (val input: op.inputArguments())
|
List<INDArray> inArgs = oc != null ? oc.getInputArrays() : op.inputArguments();
|
||||||
|
List<INDArray> outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments();
|
||||||
|
|
||||||
|
for (val input: inArgs)
|
||||||
checkForInf(input);
|
checkForInf(input);
|
||||||
|
|
||||||
for (val output: op.outputArguments())
|
for (val output: outArgs)
|
||||||
checkForInf(output);
|
checkForInf(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static void checkForNaN(CustomOp op) {
|
public static void checkForNaN(CustomOp op, OpContext oc) {
|
||||||
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
|
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
for (val input: op.inputArguments())
|
List<INDArray> inArgs = oc != null ? oc.getInputArrays() : op.inputArguments();
|
||||||
|
List<INDArray> outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments();
|
||||||
|
|
||||||
|
for (val input: inArgs)
|
||||||
checkForNaN(input);
|
checkForNaN(input);
|
||||||
|
|
||||||
for (val output: op.outputArguments())
|
for (val output: outArgs)
|
||||||
checkForNaN(output);
|
checkForNaN(output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,8 +57,12 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
public MaxPoolWithArgmax(INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){
|
public MaxPoolWithArgmax(@NonNull INDArray input, @NonNull Pooling2DConfig config){
|
||||||
super(null, new INDArray[]{input}, new INDArray[]{output, outArgMax});
|
this(input, null, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public MaxPoolWithArgmax(@NonNull INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){
|
||||||
|
super(null, new INDArray[]{input}, wrapFilterNull(output, outArgMax));
|
||||||
config.setType(Pooling2D.Pooling2DType.MAX);
|
config.setType(Pooling2D.Pooling2DType.MAX);
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
|
|
@ -45,7 +45,7 @@ public class SConv2D extends Conv2D {
|
||||||
}
|
}
|
||||||
|
|
||||||
public SConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights,
|
public SConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights,
|
||||||
@NonNull SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) {
|
SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) {
|
||||||
this(sameDiff, wrapFilterNull(layerInput, depthWeights, pointWeights, bias), conv2DConfig);
|
this(sameDiff, wrapFilterNull(layerInput, depthWeights, pointWeights, bias), conv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
|
||||||
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LSTM layer implemented as a single operation.
|
||||||
|
* Implementation of operation for LSTM layer with optional peep hole connections.<br>
|
||||||
|
* S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation and <a href="https://research.google.com/pubs/archive/43905.pdf">https://research.google.com/pubs/archive/43905.pdf</a><br>
|
||||||
|
* Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.<br>
|
||||||
|
* See also: <a href="https://arxiv.org/pdf/1503.04069.pdf">https://arxiv.org/pdf/1503.04069.pdf</a><br>
|
||||||
|
* <p>
|
||||||
|
* See also {@link LSTMBlockCell} - lstmBlockCell op is used internally at C++ level for computation.<br>
|
||||||
|
* <br>
|
||||||
|
* Input arrays:<br>
|
||||||
|
* 0: max sequence length; long/int64 scalar<br>
|
||||||
|
* 1: input [seqLength, bS, inSize] at time t<br>
|
||||||
|
* 2: previous/initial cell state [bS, numUnits]<br>
|
||||||
|
* 3: previous/initial output [bS, numUnits]<br>
|
||||||
|
* 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]<br>
|
||||||
|
* 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits]<br>
|
||||||
|
* 6: weights - cell peephole (t-1) connections to forget gate, [numUnits]<br>
|
||||||
|
* 7: weights - cell peephole (t) connections to output gate, [numUnits]<br>
|
||||||
|
* 8: biases, shape [4*numUnits]<br>
|
||||||
|
* <br>
|
||||||
|
* Input integer arguments: set via {@link LSTMConfiguration}<br>
|
||||||
|
* 0: if not zero, provide peephole connections<br>
|
||||||
|
* 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size]<br>
|
||||||
|
* <br>
|
||||||
|
* Input float arguments: set via {@link LSTMConfiguration}<br>
|
||||||
|
* 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training<br>
|
||||||
|
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped<br>
|
||||||
|
* <p>
|
||||||
|
* Output arrays:<br>
|
||||||
|
* 0: i - Input modulation gate activations, rank 3, shape as per dataFormat<br>
|
||||||
|
* 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat<br>
|
||||||
|
* 2: f - Output - forget gate activations, rank 3, shape as per dataFormat<br>
|
||||||
|
* 3: o - Output - output gate activations, rank 3, shape as per dataFormat<br>
|
||||||
|
* 4: z (ci) - Output - block input, rank 3, shape as per dataFormat<br>
|
||||||
|
* 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat<br>
|
||||||
|
* 6: y (h) - Current cell output, rank 3, shape as per dataFormat<br>
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class LSTMBlock extends DynamicCustomOp {
|
||||||
|
|
||||||
|
private LSTMConfiguration configuration;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private LSTMWeights weights;
|
||||||
|
|
||||||
|
public LSTMBlock() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public LSTMBlock(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
|
||||||
|
super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast));
|
||||||
|
this.configuration = configuration;
|
||||||
|
this.weights = weights;
|
||||||
|
addIArgument(configuration.iArgs(true));
|
||||||
|
addTArgument(configuration.tArgs());
|
||||||
|
}
|
||||||
|
|
||||||
|
public LSTMBlock(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) {
|
||||||
|
super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast));
|
||||||
|
this.configuration = lstmConfiguration;
|
||||||
|
this.weights = lstmWeights;
|
||||||
|
addIArgument(configuration.iArgs(true));
|
||||||
|
addTArgument(configuration.tArgs());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 9, "Expected exactly 9 inputs to LSTMBlock, got %s", inputDataTypes);
|
||||||
|
//7 outputs, all of same type as input. Note that input 0 is max sequence length (int64), input 1 is actual input
|
||||||
|
DataType dt = inputDataTypes.get(1);
|
||||||
|
Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt);
|
||||||
|
return Arrays.asList(dt, dt, dt, dt, dt, dt, dt);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> grads) {
|
||||||
|
throw new UnsupportedOperationException("Not yet implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
|
configuration = LSTMConfiguration.builder()
|
||||||
|
.forgetBias(attributesForNode.get("forget_bias").getF())
|
||||||
|
.clippingCellValue(attributesForNode.get("cell_clip").getF())
|
||||||
|
.peepHole(attributesForNode.get("use_peephole").getB())
|
||||||
|
.dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM
|
||||||
|
.build();
|
||||||
|
addIArgument(configuration.iArgs(true));
|
||||||
|
addTArgument(configuration.tArgs());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "lstmBlock";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Object> propertiesForFunction() {
|
||||||
|
return configuration.toProperties(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "BlockLSTM";
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -13,7 +13,6 @@
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
@ -24,89 +23,103 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
|
import org.nd4j.shade.guava.primitives.Booleans;
|
||||||
import org.tensorflow.framework.AttrValue;
|
|
||||||
import org.tensorflow.framework.GraphDef;
|
|
||||||
import org.tensorflow.framework.NodeDef;
|
|
||||||
|
|
||||||
|
import javax.xml.crypto.Data;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* LSTM layer implemented as a single operation.
|
* LSTM layer implemented as a single operation.
|
||||||
* Implementation of operation for LSTM layer with optional peep hole connections.<br>
|
* Implementation of operation for LSTM layer with optional peep hole connections.<br>
|
||||||
* S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation and <a href="https://research.google.com/pubs/archive/43905.pdf">https://research.google.com/pubs/archive/43905.pdf</a><br>
|
* S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation and <a href="https://research.google.com/pubs/archive/43905.pdf">https://research.google.com/pubs/archive/43905.pdf</a><br>
|
||||||
* Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.<br>
|
* Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.<br>
|
||||||
* See also: <a href="https://arxiv.org/pdf/1503.04069.pdf">https://arxiv.org/pdf/1503.04069.pdf</a><br>
|
* See also: <a href="https://arxiv.org/pdf/1503.04069.pdf">https://arxiv.org/pdf/1503.04069.pdf</a><br>
|
||||||
* <p>
|
|
||||||
* See also {@link LSTMBlockCell} - lstmBlockCell op is used internally at C++ level for computation.<br>
|
|
||||||
* <br>
|
|
||||||
* Input arrays:<br>
|
* Input arrays:<br>
|
||||||
* 0: max sequence length; long/int64 scalar<br>
|
* 0: input <br>
|
||||||
* 1: input [seqLength, bS, inSize] at time t<br>
|
* [sL, bS, nIn] when dataFormat - TNS <br>
|
||||||
* 2: previous/initial cell state [bS, numUnits]<br>
|
* [bS, sL, nIn] when dataFormat - NST <br>
|
||||||
* 3: previous/initial output [bS, numUnits]<br>
|
* [bS, nIn, sL] when dataFormat - NST <br>
|
||||||
* 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]<br>
|
* 1: previous/initial cell state<br>
|
||||||
* 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits]<br>
|
* shapes [nIn, 4*nOut] for FWD, BWD Direction Mode <br>
|
||||||
* 6: weights - cell peephole (t-1) connections to forget gate, [numUnits]<br>
|
* shapes [2, nIn, 4*nOut] BIDIR_SUM, BIDIR_CONCAT and BIDIR_EXTRA_DIM Direction Mode <br>
|
||||||
* 7: weights - cell peephole (t) connections to output gate, [numUnits]<br>
|
* 2: previous/initial output [bS, numUnits]<br>
|
||||||
* 8: biases, shape [4*numUnits]<br>
|
* * shapes [nIn, 4*nOut] for FWD, BWD Direction Mode <br>
|
||||||
* <br>
|
* * shapes [2, nIn, 4*nOut] BIDIR_SUM, BIDIR_CONCAT and BIDIR_EXTRA_DIM Direction Mode <br>
|
||||||
* Input integer arguments: set via {@link LSTMConfiguration}<br>
|
* 3 max sequence length [bS] <br>
|
||||||
* 0: if not zero, provide peephole connections<br>
|
* 4: LSTMLayerWeights - {@link LSTMLayerWeights} <br>
|
||||||
* 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size]<br>
|
* 5: LSTMLayerConfig - {@link LSTMLayerConfig}<br>
|
||||||
* <br>
|
|
||||||
* Input float arguments: set via {@link LSTMConfiguration}<br>
|
|
||||||
* 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training<br>
|
|
||||||
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped<br>
|
|
||||||
* <p>
|
* <p>
|
||||||
* Output arrays:<br>
|
* Output arrays:<br>
|
||||||
* 0: i - Input modulation gate activations, rank 3, shape as per dataFormat<br>
|
* 0: output h - rank 3 or 4, depends on DirectionMode and dataFormat<br>
|
||||||
* 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat<br>
|
* 1: output at last step hL - rank 3 or 4, depends on DirectionMode and dataFormat<<br>
|
||||||
* 2: f - Output - forget gate activations, rank 3, shape as per dataFormat<br>
|
* 2: cell state at last step cL - same shape as in hL<br>
|
||||||
* 3: o - Output - output gate activations, rank 3, shape as per dataFormat<br>
|
|
||||||
* 4: z (ci) - Output - block input, rank 3, shape as per dataFormat<br>
|
|
||||||
* 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat<br>
|
|
||||||
* 6: y (h) - Current cell output, rank 3, shape as per dataFormat<br>
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
*/
|
||||||
public class LSTMLayer extends DynamicCustomOp {
|
public class LSTMLayer extends DynamicCustomOp {
|
||||||
|
|
||||||
private LSTMConfiguration configuration;
|
@Getter
|
||||||
|
private LSTMLayerConfig configuration;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private LSTMWeights weights;
|
private LSTMLayerWeights weights;
|
||||||
|
|
||||||
|
|
||||||
public LSTMLayer() {
|
public LSTMLayer() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
|
public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights weights, LSTMLayerConfig configuration) {
|
||||||
super(null, sameDiff, weights.argsWithInputs(maxTSLength, x, cLast, yLast));
|
super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast));
|
||||||
this.configuration = configuration;
|
this.configuration = configuration;
|
||||||
this.weights = weights;
|
this.weights = weights;
|
||||||
addIArgument(configuration.iArgs(true));
|
addIArgument(iArgs());
|
||||||
addTArgument(configuration.tArgs());
|
addTArgument(tArgs());
|
||||||
|
addBArgument(bArgs(weights, maxTSLength, yLast, cLast));
|
||||||
|
|
||||||
|
Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(),
|
||||||
|
"You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them");
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) {
|
public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMLayerWeights lstmWeights, LSTMLayerConfig LSTMLayerConfig) {
|
||||||
super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast));
|
super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast));
|
||||||
this.configuration = lstmConfiguration;
|
this.configuration = LSTMLayerConfig;
|
||||||
this.weights = lstmWeights;
|
this.weights = lstmWeights;
|
||||||
addIArgument(configuration.iArgs(true));
|
addIArgument(iArgs());
|
||||||
addTArgument(configuration.tArgs());
|
addTArgument(tArgs());
|
||||||
|
addBArgument(bArgs(weights, maxTSLength, yLast, cLast));
|
||||||
|
|
||||||
|
Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(),
|
||||||
|
"You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 9, "Expected exactly 9 inputs to LSTMLayer, got %s", inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && 3 <= inputDataTypes.size() && inputDataTypes.size() <= 8, "Expected amount of inputs to LSTMLayer between 3 inputs minimum (input, Wx, Wr only) or 8 maximum, got %s", inputDataTypes);
|
||||||
//7 outputs, all of same type as input. Note that input 0 is max sequence length (int64), input 1 is actual input
|
//7 outputs, all of same type as input. Note that input 0 is max sequence length (int64), input 1 is actual input
|
||||||
DataType dt = inputDataTypes.get(1);
|
DataType dt = inputDataTypes.get(1);
|
||||||
|
ArrayList<DataType> list = new ArrayList<>();
|
||||||
|
if (configuration.isRetFullSequence()) {
|
||||||
|
|
||||||
|
list.add(dt);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (configuration.isRetLastC()) {
|
||||||
|
|
||||||
|
list.add(dt);
|
||||||
|
}
|
||||||
|
if (configuration.isRetLastH()){
|
||||||
|
|
||||||
|
list.add(dt);
|
||||||
|
}
|
||||||
|
|
||||||
Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt);
|
Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt);
|
||||||
return Arrays.asList(dt, dt, dt, dt, dt, dt, dt);
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -114,31 +127,61 @@ public class LSTMLayer extends DynamicCustomOp {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
throw new UnsupportedOperationException("Not yet implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
|
||||||
configuration = LSTMConfiguration.builder()
|
|
||||||
.forgetBias(attributesForNode.get("forget_bias").getF())
|
|
||||||
.clippingCellValue(attributesForNode.get("cell_clip").getF())
|
|
||||||
.peepHole(attributesForNode.get("use_peephole").getB())
|
|
||||||
.dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM
|
|
||||||
.build();
|
|
||||||
addIArgument(configuration.iArgs(true));
|
|
||||||
addTArgument(configuration.tArgs());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "lstmBlock";
|
return "lstmLayer";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Object> propertiesForFunction() {
|
public Map<String, Object> propertiesForFunction() {
|
||||||
return configuration.toProperties(true);
|
return configuration.toProperties(true, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public long[] iArgs() {
|
||||||
|
return new long[]{
|
||||||
|
configuration.getLstmdataformat().ordinal(),// INT_ARG(0)
|
||||||
|
configuration.getDirectionMode().ordinal(), // INT_ARG(1)
|
||||||
|
configuration.getGateAct().ordinal(), // INT_ARG(2)
|
||||||
|
configuration.getOutAct().ordinal(), // INT_ARG(3)
|
||||||
|
configuration.getCellAct().ordinal() // INT_ARG(4)
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] tArgs() {
|
||||||
|
return new double[]{this.configuration.getCellClip()}; // T_ARG(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public <T> boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) {
|
||||||
|
return new boolean[]{
|
||||||
|
weights.hasBias(), // hasBiases: B_ARG(0)
|
||||||
|
maxTSLength != null, // hasSeqLen: B_ARG(1)
|
||||||
|
yLast != null, // hasInitH: B_ARG(2)
|
||||||
|
cLast != null, // hasInitC: B_ARG(3)
|
||||||
|
weights.hasPH(), // hasPH: B_ARG(4)
|
||||||
|
configuration.isRetFullSequence(), //retFullSequence: B_ARG(5)
|
||||||
|
configuration.isRetLastH(), // retLastH: B_ARG(6)
|
||||||
|
configuration.isRetLastC() // retLastC: B_ARG(7)
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public int getNumOutputs(){
|
||||||
return "BlockLSTM";
|
|
||||||
|
return Booleans.countTrue(
|
||||||
|
configuration.isRetFullSequence(), //retFullSequence: B_ARG(5)
|
||||||
|
configuration.isRetLastH(), // retLastH: B_ARG(6)
|
||||||
|
configuration.isRetLastC() // retLastC: B_ARG(7)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* integer numbers corresponding to activations:
|
||||||
|
* 0=tanh,
|
||||||
|
* 1=relu,
|
||||||
|
* 2=sigmoid,
|
||||||
|
* 3=affine,
|
||||||
|
* 4=leaky relu,
|
||||||
|
* 5= thresholded relu,
|
||||||
|
* 6=scaled tanh,
|
||||||
|
* 7=hard sigmoid,
|
||||||
|
* 8=ELU,
|
||||||
|
* 9=softsign,
|
||||||
|
* 10=softplus
|
||||||
|
*/
|
||||||
|
public enum LSTMActivations {
|
||||||
|
//Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end
|
||||||
|
|
||||||
|
TANH,
|
||||||
|
RELU,
|
||||||
|
SIGMOID,
|
||||||
|
AFFINE,
|
||||||
|
LEAKY_RELU,
|
||||||
|
THRESHHOLD_RELU,
|
||||||
|
SCALED_TAHN,
|
||||||
|
HARD_SIGMOID,
|
||||||
|
ELU,
|
||||||
|
SOFTSIGN,
|
||||||
|
SOFTPLUS
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* notations <br>
|
||||||
|
* for unidirectional:
|
||||||
|
* TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
|
||||||
|
* NST: shape [numExamples, inOutSize, timeLength]<br>
|
||||||
|
* NTS: shape [numExamples, timeLength, inOutSize]<br>
|
||||||
|
* for bidirectional:
|
||||||
|
* T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
|
||||||
|
*/
|
||||||
|
|
||||||
|
public enum LSTMDataFormat {
|
||||||
|
//Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end
|
||||||
|
|
||||||
|
|
||||||
|
TNS,
|
||||||
|
NTS,
|
||||||
|
NST,
|
||||||
|
T2NS
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* direction <br>
|
||||||
|
* FWD: 0 = fwd
|
||||||
|
* BWD: 1 = bwd
|
||||||
|
* BIDIR_SUM: 2 = bidirectional sum
|
||||||
|
* BIDIR_CONCAT: 3 = bidirectional concat
|
||||||
|
* BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */
|
||||||
|
|
||||||
|
// const auto directionMode = INT_ARG(1); // direction:
|
||||||
|
|
||||||
|
public enum LSTMDirectionMode {
|
||||||
|
//Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end
|
||||||
|
|
||||||
|
|
||||||
|
FWD,
|
||||||
|
BWD,
|
||||||
|
BIDIR_SUM,
|
||||||
|
BIDIR_CONCAT,
|
||||||
|
BIDIR_EXTRA_DIM
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
|
||||||
|
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
@Data
|
||||||
|
public class LSTMLayerConfig {
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* notations <br>
|
||||||
|
* for unidirectional:
|
||||||
|
* TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
|
||||||
|
* NST: shape [numExamples, inOutSize, timeLength]<br>
|
||||||
|
* NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout<br>
|
||||||
|
* for bidirectional:
|
||||||
|
* T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private LSTMDataFormat lstmdataformat = LSTMDataFormat.TNS; //INT_ARG(0)
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* direction <br>
|
||||||
|
* FWD: 0 = fwd
|
||||||
|
* BWD: 1 = bwd
|
||||||
|
* BS: 2 = bidirectional sum
|
||||||
|
* BC: 3 = bidirectional concat
|
||||||
|
* BE: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private LSTMDirectionMode directionMode = LSTMDirectionMode.FWD; //INT_ARG(1)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Activation for input (i), forget (f) and output (o) gates
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private LSTMActivations gateAct = LSTMActivations.SIGMOID; // INT_ARG(2)
|
||||||
|
|
||||||
|
@Builder.Default
|
||||||
|
private LSTMActivations cellAct = LSTMActivations.TANH; // INT_ARG(3)
|
||||||
|
|
||||||
|
@Builder.Default
|
||||||
|
private LSTMActivations outAct = LSTMActivations.TANH; // INT_ARG(4)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private boolean retFullSequence = true; //B_ARG(5)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* indicates whether to return output at last time step only,
|
||||||
|
* in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||||
|
*/
|
||||||
|
private boolean retLastH; //B_ARG(6)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* indicates whether to return cells state at last time step only,
|
||||||
|
* in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||||
|
*/
|
||||||
|
private boolean retLastC; // B_ARG(7)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cell clipping value, if it = 0 then do not apply clipping
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private double cellClip; //T_ARG(0)
|
||||||
|
|
||||||
|
|
||||||
|
public Map<String, Object> toProperties(boolean includeLSTMDataFormat, boolean includeLSTMDirectionMode) {
|
||||||
|
Map<String, Object> ret = new LinkedHashMap<>();
|
||||||
|
ret.put("gateAct", gateAct.ordinal());
|
||||||
|
ret.put("outAct", outAct.ordinal());
|
||||||
|
ret.put("cellAct", cellAct.ordinal());
|
||||||
|
ret.put("retFullSequence", retFullSequence);
|
||||||
|
ret.put("retLastH", retLastH);
|
||||||
|
ret.put("retLastC", retLastC);
|
||||||
|
ret.put("cellClip", cellClip);
|
||||||
|
|
||||||
|
if (includeLSTMDataFormat)
|
||||||
|
ret.put("LSTMDataFormat", lstmdataformat.ordinal());
|
||||||
|
if (includeLSTMDirectionMode)
|
||||||
|
ret.put("LSTMDirectionMode", directionMode.ordinal());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,13 +2,18 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
import lombok.AccessLevel;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.autodiff.samediff.SDIndex;
|
import org.nd4j.autodiff.samediff.SDIndex;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
|
||||||
|
import org.nd4j.shade.guava.primitives.Booleans;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The outputs of a LSTM layer ({@link LSTMLayer}.
|
* The outputs of a LSTM layer ({@link LSTMLayer}.
|
||||||
|
@ -16,165 +21,78 @@ import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
|
||||||
@Getter
|
@Getter
|
||||||
public class LSTMLayerOutputs {
|
public class LSTMLayerOutputs {
|
||||||
|
|
||||||
private RnnDataFormat dataFormat;
|
/**
|
||||||
|
* The LSTM layer data format ({@link LSTMDataFormat}.
|
||||||
|
*/
|
||||||
|
private LSTMDataFormat dataFormat;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Output - input modulation gate activations.
|
* output h:
|
||||||
* Shape depends on data format (in layer config):<br>
|
* [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
|
||||||
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
* [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
|
||||||
* NST -> [batchSize, numUnits, timeSteps]<br>
|
* [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
|
||||||
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
* [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
|
||||||
|
* [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1
|
||||||
|
* [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2
|
||||||
|
* [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
|
||||||
|
* numbers mean index in corresponding enums {@link LSTMDataFormat} and {@link LSTMDirectionMode}
|
||||||
*/
|
*/
|
||||||
private SDVariable i;
|
private SDVariable timeSeriesOutput;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Activations, cell state (pre tanh).
|
* cell state at last step cL:
|
||||||
* Shape depends on data format (in layer config):<br>
|
* [bS, nOut] when directionMode FWD or BWD
|
||||||
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
* 2, bS, nOut] when directionMode BIDIR_SUM, BIDIR_CONCAT or BIDIR_EXTRA_DIM
|
||||||
* NST -> [batchSize, numUnits, timeSteps]<br>
|
|
||||||
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
|
||||||
*/
|
*/
|
||||||
private SDVariable c;
|
private SDVariable lastCellStateOutput;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Output - forget gate activations.
|
* output at last step hL:
|
||||||
* Shape depends on data format (in layer config):<br>
|
* [bS, nOut] when directionMode FWD or BWD
|
||||||
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
* 2, bS, nOut] when directionMode BIDIR_SUM, BIDIR_CONCAT or BIDIR_EXTRA_DIM
|
||||||
* NST -> [batchSize, numUnits, timeSteps]<br>
|
|
||||||
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
|
||||||
*/
|
*/
|
||||||
private SDVariable f;
|
private SDVariable lastTimeStepOutput;
|
||||||
|
|
||||||
/**
|
|
||||||
* Output - output gate activations.
|
|
||||||
* Shape depends on data format (in layer config):<br>
|
|
||||||
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
|
||||||
* NST -> [batchSize, numUnits, timeSteps]<br>
|
|
||||||
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
|
||||||
*/
|
|
||||||
private SDVariable o;
|
|
||||||
|
|
||||||
/**
|
public LSTMLayerOutputs(SDVariable[] outputs, LSTMLayerConfig lstmLayerConfig) {
|
||||||
* Output - input gate activations.
|
Preconditions.checkArgument(outputs.length > 0 && outputs.length <= 3,
|
||||||
* Shape depends on data format (in layer config):<br>
|
"Must have from 1 to 3 LSTM layer outputs, got %s", outputs.length);
|
||||||
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
|
||||||
* NST -> [batchSize, numUnits, timeSteps]<br>
|
|
||||||
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
|
||||||
*/
|
|
||||||
private SDVariable z;
|
|
||||||
|
|
||||||
/**
|
int i = 0;
|
||||||
* Cell state, post tanh.
|
timeSeriesOutput = lstmLayerConfig.isRetFullSequence() ? outputs[i++] : null;
|
||||||
* Shape depends on data format (in layer config):<br>
|
lastTimeStepOutput = lstmLayerConfig.isRetLastH() ? outputs[i++] : null;
|
||||||
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
lastCellStateOutput = lstmLayerConfig.isRetLastC() ? outputs[i++] : null;
|
||||||
* NST -> [batchSize, numUnits, timeSteps]<br>
|
|
||||||
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
|
||||||
*/
|
|
||||||
private SDVariable h;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Current cell output.
|
|
||||||
* Shape depends on data format (in layer config):<br>
|
|
||||||
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
|
||||||
* NST -> [batchSize, numUnits, timeSteps]<br>
|
|
||||||
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
|
||||||
*/
|
|
||||||
private SDVariable y;
|
|
||||||
|
|
||||||
public LSTMLayerOutputs(SDVariable[] outputs, RnnDataFormat dataFormat){
|
this.dataFormat = lstmLayerConfig.getLstmdataformat();
|
||||||
Preconditions.checkArgument(outputs.length == 7,
|
|
||||||
"Must have 7 LSTM layer outputs, got %s", outputs.length);
|
|
||||||
|
|
||||||
i = outputs[0];
|
|
||||||
c = outputs[1];
|
|
||||||
f = outputs[2];
|
|
||||||
o = outputs[3];
|
|
||||||
z = outputs[4];
|
|
||||||
h = outputs[5];
|
|
||||||
y = outputs[6];
|
|
||||||
this.dataFormat = dataFormat;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get all outputs returned by the cell.
|
|
||||||
*/
|
|
||||||
public List<SDVariable> getAllOutputs(){
|
|
||||||
return Arrays.asList(i, c, f, o, z, h, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get y, the output of the cell for all time steps.
|
* Get h, the output of the cell for all time steps.
|
||||||
*
|
* <p>
|
||||||
* Shape depends on data format (in layer config):<br>
|
* Shape depends on data format defined in {@link LSTMLayerConfig }:<br>
|
||||||
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
* for unidirectional:
|
||||||
* NST -> [batchSize, numUnits, timeSteps]<br>
|
* TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
|
||||||
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
* NST: shape [numExamples, inOutSize, timeLength]<br>
|
||||||
|
* NTS: shape [numExamples, timeLength, inOutSize] <br>
|
||||||
|
* for bidirectional:
|
||||||
|
* T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
|
||||||
*/
|
*/
|
||||||
public SDVariable getOutput(){
|
public SDVariable getOutput() {
|
||||||
return y;
|
Preconditions.checkArgument(timeSeriesOutput != null, "retFullSequence was setted as false in LSTMLayerConfig");
|
||||||
|
return timeSeriesOutput;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
public SDVariable getLastState() {
|
||||||
* Get c, the cell's state for all time steps.
|
Preconditions.checkArgument(lastCellStateOutput != null, "retLastC was setted as false in LSTMLayerConfig");
|
||||||
*
|
return lastCellStateOutput;
|
||||||
* Shape depends on data format (in layer config):<br>
|
|
||||||
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
|
||||||
* NST -> [batchSize, numUnits, timeSteps]<br>
|
|
||||||
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
|
||||||
*/
|
|
||||||
public SDVariable getState(){
|
|
||||||
return c;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private SDVariable lastOutput = null;
|
public SDVariable getLastOutput() {
|
||||||
|
Preconditions.checkArgument(lastTimeStepOutput != null, "retLastH was setted as false in LSTMLayerConfig");
|
||||||
/**
|
return lastTimeStepOutput;
|
||||||
* Get y, the output of the cell, for the last time step.
|
|
||||||
*
|
|
||||||
* Has shape [batchSize, numUnits].
|
|
||||||
*/
|
|
||||||
public SDVariable getLastOutput(){
|
|
||||||
if(lastOutput != null)
|
|
||||||
return lastOutput;
|
|
||||||
|
|
||||||
switch (dataFormat){
|
|
||||||
case TNS:
|
|
||||||
lastOutput = getOutput().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
|
|
||||||
break;
|
|
||||||
case NST:
|
|
||||||
lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
|
|
||||||
break;
|
|
||||||
case NTS:
|
|
||||||
lastOutput = getOutput().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return lastOutput;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private SDVariable lastState = null;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get c, the state of the cell, for the last time step.
|
|
||||||
*
|
|
||||||
* Has shape [batchSize, numUnits].
|
|
||||||
*/
|
|
||||||
public SDVariable getLastState(){
|
|
||||||
if(lastState != null)
|
|
||||||
return lastState;
|
|
||||||
|
|
||||||
switch (dataFormat){
|
|
||||||
case TNS:
|
|
||||||
lastState = getState().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
|
|
||||||
break;
|
|
||||||
case NST:
|
|
||||||
lastState = getState().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
|
|
||||||
break;
|
|
||||||
case NTS:
|
|
||||||
lastState = getState().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return lastState;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
|
||||||
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The weight configuration of a LSTMLayer. For {@link LSTMLayer}
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
public class LSTMLayerWeights extends RNNWeights {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Input to hidden weights with a shape of [inSize, 4*numUnits].
|
||||||
|
*
|
||||||
|
* Input to hidden and hidden to hidden are concatenated in dimension 0,
|
||||||
|
* so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :].
|
||||||
|
*/
|
||||||
|
private SDVariable weights;
|
||||||
|
private INDArray iWeights;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* hidden to hidden weights (aka "recurrent weights", with a shape of [numUnits, 4*numUnits].
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
private SDVariable rWeights;
|
||||||
|
private INDArray irWeights;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Peephole weights, with a shape of [3*numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable peepholeWeights;
|
||||||
|
private INDArray iPeepholeWeights;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Input to hidden and hidden to hidden biases, with shape [4*numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable bias;
|
||||||
|
private INDArray iBias;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SDVariable[] args() {
|
||||||
|
return filterNonNull(weights, rWeights, peepholeWeights, bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray[] arrayArgs() {
|
||||||
|
return filterNonNull(iWeights, irWeights, iPeepholeWeights, iBias);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SDVariable[] argsWithInputs(SDVariable... inputs){
|
||||||
|
Preconditions.checkArgument(inputs.length == 4, "Expected 4 inputs, got %s", inputs.length); //Order: x, seqLen, yLast, cLast
|
||||||
|
//lstmLayer c++ op expects: x, Wx, Wr, Wp, b, seqLen, yLast, cLast
|
||||||
|
return ArrayUtil.filterNull(inputs[0], weights, rWeights, bias, inputs[1], inputs[2], inputs[3], peepholeWeights);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray[] argsWithInputs(INDArray... inputs) {
|
||||||
|
Preconditions.checkArgument(inputs.length == 4, "Expected 4 inputs, got %s", inputs.length); //Order: x, seqLen, yLast, cLast
|
||||||
|
//lstmLayer c++ op expects: x, Wx, Wr, Wp, b, seqLen, yLast, cLast
|
||||||
|
return ArrayUtil.filterNull(inputs[0], iWeights, irWeights, iBias, inputs[1], inputs[2], inputs[3], iPeepholeWeights);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public boolean hasBias() {
|
||||||
|
return (bias!=null||iBias!=null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean hasPH() {
|
||||||
|
return (peepholeWeights!=null||iPeepholeWeights!=null);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -98,6 +98,7 @@ public class Mmul extends DynamicCustomOp {
|
||||||
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
||||||
ArrayUtil.fromBoolean(transposeY),
|
ArrayUtil.fromBoolean(transposeY),
|
||||||
ArrayUtil.fromBoolean(transposeZ));
|
ArrayUtil.fromBoolean(transposeZ));
|
||||||
|
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Mmul(INDArray x, INDArray y) {
|
public Mmul(INDArray x, INDArray y) {
|
||||||
|
@ -110,6 +111,7 @@ public class Mmul extends DynamicCustomOp {
|
||||||
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
||||||
ArrayUtil.fromBoolean(transposeY),
|
ArrayUtil.fromBoolean(transposeY),
|
||||||
ArrayUtil.fromBoolean(transposeZ));
|
ArrayUtil.fromBoolean(transposeZ));
|
||||||
|
mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Mmul() {}
|
public Mmul() {}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -49,6 +50,9 @@ public class BatchMmul extends DynamicCustomOp {
|
||||||
protected int N;
|
protected int N;
|
||||||
protected int K;
|
protected int K;
|
||||||
|
|
||||||
|
public BatchMmul(SameDiff sameDiff, SDVariable[] matricesA, SDVariable[] matricesB, boolean transposeA, boolean transposeB) {
|
||||||
|
this(sameDiff, ArrayUtils.addAll(matricesA, matricesB), transposeA, transposeB);
|
||||||
|
}
|
||||||
|
|
||||||
public BatchMmul(SameDiff sameDiff,
|
public BatchMmul(SameDiff sameDiff,
|
||||||
SDVariable[] matrices,
|
SDVariable[] matrices,
|
||||||
|
@ -85,6 +89,22 @@ public class BatchMmul extends DynamicCustomOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public BatchMmul(INDArray[] matricesA, INDArray[] matricesB, boolean transposeA, boolean transposeB){
|
||||||
|
super(ArrayUtils.addAll(matricesA, matricesB), null);
|
||||||
|
this.batchSize = matricesA.length;
|
||||||
|
|
||||||
|
this.transposeA = transposeA ? 1 : 0;
|
||||||
|
this.transposeB = transposeB ? 1 : 0;
|
||||||
|
|
||||||
|
long[] firstShape = matricesA[0].shape();
|
||||||
|
long[] lastShape = matricesB[0].shape();
|
||||||
|
|
||||||
|
this.M = transposeA ? (int) firstShape[1]: (int) firstShape[0];
|
||||||
|
this.N = transposeA ? (int) firstShape[0]: (int) firstShape[1];
|
||||||
|
this.K = transposeB ? (int) lastShape[0]: (int) lastShape[1];
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getNumOutputs(){
|
public int getNumOutputs(){
|
||||||
return batchSize;
|
return batchSize;
|
||||||
|
|
|
@ -34,17 +34,12 @@ import java.util.List;
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class GatherNd extends DynamicCustomOp {
|
public class GatherNd extends DynamicCustomOp {
|
||||||
|
|
||||||
public GatherNd(SameDiff sameDiff, SDVariable[] inputs, SDVariable[] indices) {
|
public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices) {
|
||||||
super(null, sameDiff, ArrayUtils.addAll(inputs, indices), false);
|
super(null, sameDiff, new SDVariable[] {input, indices});
|
||||||
}
|
}
|
||||||
|
|
||||||
public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices, boolean inPlace) {
|
public GatherNd(INDArray df, INDArray indices) {
|
||||||
super(null, sameDiff, new SDVariable[] {input, indices}, inPlace);
|
super(new INDArray[]{df, indices}, null);
|
||||||
}
|
|
||||||
|
|
||||||
public GatherNd(INDArray[] df, INDArray[] indices) {
|
|
||||||
addInputArgument(df);
|
|
||||||
addInputArgument(indices);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,13 +16,16 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.apache.commons.lang3.NotImplementedException;
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
@ -41,21 +44,27 @@ public class Linspace extends DynamicCustomOp {
|
||||||
private DataType dataType;
|
private DataType dataType;
|
||||||
|
|
||||||
public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) {
|
public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) {
|
||||||
super(sameDiff, new SDVariable[0]);
|
this(sameDiff, sameDiff.constant(start), sameDiff.constant(stop), sameDiff.constant(number), dataType);
|
||||||
addTArgument(start,stop);
|
|
||||||
addIArgument(number);
|
|
||||||
addDArgument(dataType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Linspace(SameDiff sameDiff, SDVariable from, SDVariable to, SDVariable length, DataType dataType){
|
public Linspace(SameDiff sameDiff, SDVariable from, SDVariable to, SDVariable length, DataType dataType){
|
||||||
super(sameDiff, new SDVariable[]{from, to, length});
|
super(sameDiff, new SDVariable[]{from, to, length});
|
||||||
this.dataType = dataType;
|
this.dataType = dataType;
|
||||||
|
addDArgument(dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Linspace(DataType dataType, double start, double stop, long number) {
|
public Linspace(DataType dataType, double start, double stop, long number) {
|
||||||
|
this(dataType, Nd4j.scalar(start), Nd4j.scalar(stop), Nd4j.scalar(number));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Linspace(DataType dataType, INDArray start, INDArray stop, INDArray number) {
|
||||||
|
this(start, stop, number, dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Linspace(@NonNull INDArray start, @NonNull INDArray stop, @NonNull INDArray number, @NonNull DataType dataType) {
|
||||||
|
super(new INDArray[]{start, stop, number}, null);
|
||||||
|
this.dataType = dataType;
|
||||||
addDArgument(dataType);
|
addDArgument(dataType);
|
||||||
addTArgument(start, stop);
|
|
||||||
addIArgument(number);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Linspace(){ }
|
public Linspace(){ }
|
||||||
|
|
|
@ -16,9 +16,11 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -41,6 +43,11 @@ public class MeshGrid extends DynamicCustomOp {
|
||||||
this(sd, cartesian, inputs);
|
this(sd, cartesian, inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MeshGrid(@NonNull INDArray[] inputs, boolean cartesian){
|
||||||
|
super(inputs, null);
|
||||||
|
addIArgument(cartesian ? 1 : 0);
|
||||||
|
}
|
||||||
|
|
||||||
public MeshGrid(){ }
|
public MeshGrid(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -44,7 +44,6 @@ import java.util.Map;
|
||||||
public class Reshape extends DynamicCustomOp {
|
public class Reshape extends DynamicCustomOp {
|
||||||
|
|
||||||
private long[] shape;
|
private long[] shape;
|
||||||
private String arrName;
|
|
||||||
|
|
||||||
public Reshape(SameDiff sameDiff, SDVariable i_v, long[] shape) {
|
public Reshape(SameDiff sameDiff, SDVariable i_v, long[] shape) {
|
||||||
super(null, sameDiff, new SDVariable[]{i_v});
|
super(null, sameDiff, new SDVariable[]{i_v});
|
||||||
|
@ -56,6 +55,12 @@ public class Reshape extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{i_v, shape});
|
super(null, sameDiff, new SDVariable[]{i_v, shape});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Reshape(INDArray in, long... shape){
|
||||||
|
super(new INDArray[]{in}, null);
|
||||||
|
this.shape = shape;
|
||||||
|
addIArgument(shape);
|
||||||
|
}
|
||||||
|
|
||||||
public Reshape(INDArray in, INDArray shape){
|
public Reshape(INDArray in, INDArray shape){
|
||||||
this(in, shape, null);
|
this(in, shape, null);
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -64,15 +65,19 @@ public class SequenceMask extends DynamicCustomOp {
|
||||||
addDArgument(dataType);
|
addDArgument(dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SequenceMask(INDArray input, int maxLen, DataType dataType) {
|
public SequenceMask(@NonNull INDArray input, int maxLen, DataType dataType) {
|
||||||
addInputArgument(input);
|
addInputArgument(input);
|
||||||
addIArgument(maxLen);
|
addIArgument(maxLen);
|
||||||
this.dataType = dataType;
|
this.dataType = dataType;
|
||||||
addDArgument(dataType);
|
addDArgument(dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SequenceMask(INDArray input, DataType dataType) {
|
public SequenceMask(@NonNull INDArray input, @NonNull DataType dataType) {
|
||||||
addInputArgument(input);
|
this(input, null, dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SequenceMask(@NonNull INDArray input, INDArray maxLength, @NonNull DataType dataType) {
|
||||||
|
super(wrapFilterNull(input, maxLength), null);
|
||||||
this.dataType = dataType;
|
this.dataType = dataType;
|
||||||
addDArgument(dataType);
|
addDArgument(dataType);
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,6 +59,10 @@ public class Slice extends DynamicCustomOp {
|
||||||
addIArgument(size);
|
addIArgument(size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Slice(@NonNull INDArray input, @NonNull INDArray begin, @NonNull INDArray end){
|
||||||
|
super(new INDArray[]{input, begin, end}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "slice";
|
return "slice";
|
||||||
|
|
|
@ -50,7 +50,7 @@ public class Stack extends DynamicCustomOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Stack(INDArray input, int axis) {
|
public Stack(INDArray[] input, int axis) {
|
||||||
addInputArgument(input);
|
addInputArgument(input);
|
||||||
this.jaxis = axis;
|
this.jaxis = axis;
|
||||||
addArgs();
|
addArgs();
|
||||||
|
|
|
@ -98,10 +98,16 @@ public class StridedSlice extends DynamicCustomOp {
|
||||||
|
|
||||||
public StridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask,
|
public StridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask,
|
||||||
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
|
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
|
||||||
|
this(in, ArrayUtil.toLongArray(begin), ArrayUtil.toLongArray(end), ArrayUtil.toLongArray(strides),
|
||||||
|
beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
|
||||||
|
}
|
||||||
|
|
||||||
|
public StridedSlice(INDArray in, long[] begin, long[] end, long[] strides, int beginMask,
|
||||||
|
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
|
||||||
addInputArgument(in);
|
addInputArgument(in);
|
||||||
this.begin = ArrayUtil.toLongArray(begin);
|
this.begin = begin;
|
||||||
this.end = ArrayUtil.toLongArray(end);
|
this.end = end;
|
||||||
this.strides = ArrayUtil.toLongArray(strides);
|
this.strides = strides;
|
||||||
this.beginMask = beginMask;
|
this.beginMask = beginMask;
|
||||||
this.endMask = endMask;
|
this.endMask = endMask;
|
||||||
this.ellipsisMask = ellipsisMask;
|
this.ellipsisMask = ellipsisMask;
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -67,6 +68,13 @@ public class Unstack extends DynamicCustomOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Unstack(@NonNull INDArray value, int axis, int num){
|
||||||
|
super(new INDArray[]{value}, null);
|
||||||
|
this.jaxis = axis;
|
||||||
|
this.num = num;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
public Unstack(INDArray in, INDArray[] out, int axis){
|
public Unstack(INDArray in, INDArray[] out, int axis){
|
||||||
super(null, new INDArray[]{in}, out, null, (int[])null);
|
super(null, new INDArray[]{in}, out, null, (int[])null);
|
||||||
this.jaxis = axis;
|
this.jaxis = axis;
|
||||||
|
@ -136,7 +144,8 @@ public class Unstack extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return Collections.singletonList(sameDiff.stack(jaxis, f1.toArray(new SDVariable[f1.size()])));
|
return Collections.singletonList(sameDiff.stack(jaxis, f1.toArray(new SDVariable[0])));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -58,6 +58,10 @@ public class Pad extends DynamicCustomOp {
|
||||||
this(sd, in, padding, Mode.CONSTANT, padValue);
|
this(sd, in, padding, Mode.CONSTANT, padValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Pad(@NonNull INDArray in, @NonNull INDArray padding, double padValue){
|
||||||
|
this(in, padding, null, Mode.CONSTANT, padValue);
|
||||||
|
}
|
||||||
|
|
||||||
public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){
|
public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){
|
||||||
super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out});
|
super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out});
|
||||||
Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType());
|
Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType());
|
||||||
|
|
|
@ -66,11 +66,8 @@ public class DynamicPartition extends DynamicCustomOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
public DynamicPartition(INDArray input, INDArray[] partitions, int numPartitions) {
|
public DynamicPartition(INDArray input, INDArray partitions, int numPartitions) {
|
||||||
addInputArgument(input);
|
addInputArgument(input);
|
||||||
for (INDArray part : partitions)
|
|
||||||
addInputArgument(part);
|
|
||||||
|
|
||||||
addIArgument(numPartitions);
|
addIArgument(numPartitions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,9 +16,11 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -30,10 +32,14 @@ public class ListDiff extends DynamicCustomOp {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
||||||
public ListDiff(SameDiff sd, SDVariable x, SDVariable y){
|
public ListDiff(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){
|
||||||
super(sd, new SDVariable[]{x, y});
|
super(sd, new SDVariable[]{x, y});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ListDiff(@NonNull INDArray x, @NonNull INDArray y){
|
||||||
|
super(new INDArray[]{x, y}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "ListDiff"; //Note: Seems to be renamed to tf.setdiff1d in public API?
|
return "ListDiff"; //Note: Seems to be renamed to tf.setdiff1d in public API?
|
||||||
|
|
|
@ -73,12 +73,8 @@ public class XwPlusB extends DynamicCustomOp {
|
||||||
SDVariable dLdOut = gradient.get(0);
|
SDVariable dLdOut = gradient.get(0);
|
||||||
|
|
||||||
SDVariable dLdb = dLdOut.sum(0);
|
SDVariable dLdb = dLdOut.sum(0);
|
||||||
SDVariable dLdIn = sameDiff.mmul(dLdOut, w, MMulTranspose.builder()
|
SDVariable dLdIn = sameDiff.mmul(dLdOut, w, false, true, false);
|
||||||
.transposeB(true)
|
SDVariable dLdW = sameDiff.mmul(in, dLdOut, true, false, false);
|
||||||
.build());
|
|
||||||
SDVariable dLdW = sameDiff.mmul(in, dLdOut, MMulTranspose.builder()
|
|
||||||
.transposeA(true)
|
|
||||||
.build());
|
|
||||||
|
|
||||||
return Arrays.asList(dLdIn, dLdW, dLdb);
|
return Arrays.asList(dLdIn, dLdW, dLdb);
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter;
|
import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -55,24 +56,11 @@ public class Cast extends BaseDynamicTransformOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
public Cast(@NonNull INDArray arg, @NonNull DataType dataType){
|
||||||
@Override
|
super(new INDArray[]{arg}, null);
|
||||||
public void setValueFor(Field target, Object value) {
|
this.typeDst = dataType;
|
||||||
if(value == null) {
|
addArgs();
|
||||||
throw new ND4JIllegalStateException("Unable to set field " + target + " using null value!");
|
|
||||||
}
|
|
||||||
|
|
||||||
// FIXME!
|
|
||||||
if (!(value instanceof DataType))
|
|
||||||
return;
|
|
||||||
|
|
||||||
try {
|
|
||||||
target.set(this, (DataType) value);
|
|
||||||
} catch (IllegalAccessException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -73,6 +74,12 @@ public class Range extends DynamicCustomOp {
|
||||||
addDArgument(dataType);
|
addDArgument(dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Range(INDArray from, INDArray to, INDArray step, DataType dataType){
|
||||||
|
super(new INDArray[]{from, to, step}, null);
|
||||||
|
this.dataType = dataType;
|
||||||
|
addDArgument(dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int opNum() {
|
public int opNum() {
|
||||||
|
|
|
@ -149,6 +149,60 @@ public class NDBase {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, false, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, false, dimensions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same<br>
|
||||||
|
* length and each pair taken from these sets has to have dimensions (M, N) and (N, K),<br>
|
||||||
|
* respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.<br>
|
||||||
|
* Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).<br>
|
||||||
|
* <br>
|
||||||
|
* The result of this operation will be a batch of multiplied matrices. The<br>
|
||||||
|
* result has the same length as both input batches and each output matrix is of shape (M, K).<br>
|
||||||
|
*
|
||||||
|
* @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type)
|
||||||
|
* @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type)
|
||||||
|
* @param transposeA Whether to transpose A arrays or not
|
||||||
|
* @param transposeB Whether to transpose B arrays or not
|
||||||
|
*/
|
||||||
|
public INDArray[] batchMmul(INDArray[] inputsA, INDArray[] inputsB, boolean transposeA,
|
||||||
|
boolean transposeB) {
|
||||||
|
NDValidation.validateNumerical("batchMmul", "inputsA", inputsA);
|
||||||
|
Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length);
|
||||||
|
NDValidation.validateNumerical("batchMmul", "inputsB", inputsB);
|
||||||
|
Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, transposeA, transposeB));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same<br>
|
||||||
|
* length and each pair taken from these sets has to have dimensions (M, N) and (N, K),<br>
|
||||||
|
* respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.<br>
|
||||||
|
* Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).<br>
|
||||||
|
* <br>
|
||||||
|
* The result of this operation will be a batch of multiplied matrices. The<br>
|
||||||
|
* result has the same length as both input batches and each output matrix is of shape (M, K).<br>
|
||||||
|
*
|
||||||
|
* @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type)
|
||||||
|
* @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray[] batchMmul(INDArray[] inputsA, INDArray... inputsB) {
|
||||||
|
NDValidation.validateNumerical("batchMmul", "inputsA", inputsA);
|
||||||
|
Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length);
|
||||||
|
NDValidation.validateNumerical("batchMmul", "inputsB", inputsB);
|
||||||
|
Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, false, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cast the array to a new datatype - for example, Integer -> Float<br>
|
||||||
|
*
|
||||||
|
* @param arg Input variable to cast (NDARRAY type)
|
||||||
|
* @param datatype Datatype to cast to
|
||||||
|
* @return output Output array (after casting) (NDARRAY type)
|
||||||
|
*/
|
||||||
|
public INDArray castTo(INDArray arg, DataType datatype) {
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(arg, datatype))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Concatenate a set of inputs along the specified dimension.<br>
|
* Concatenate a set of inputs along the specified dimension.<br>
|
||||||
* Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.<br>
|
* Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.<br>
|
||||||
|
@ -161,7 +215,7 @@ public class NDBase {
|
||||||
* @param dimension Dimension to concatenate on
|
* @param dimension Dimension to concatenate on
|
||||||
* @return output (NUMERIC type)
|
* @return output (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray concat(INDArray[] inputs, int dimension) {
|
public INDArray concat(int dimension, INDArray... inputs) {
|
||||||
NDValidation.validateNumerical("concat", "inputs", inputs);
|
NDValidation.validateNumerical("concat", "inputs", inputs);
|
||||||
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
||||||
Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype");
|
Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype");
|
||||||
|
@ -274,28 +328,26 @@ public class NDBase {
|
||||||
* @param x Input variable (NUMERIC type)
|
* @param x Input variable (NUMERIC type)
|
||||||
* @param partitions 1D input with values 0 to numPartitions-1 (INT type)
|
* @param partitions 1D input with values 0 to numPartitions-1 (INT type)
|
||||||
* @param numPartitions Number of partitions, >= 1
|
* @param numPartitions Number of partitions, >= 1
|
||||||
* @return output Output variables (equal in number to numPartitions) (NUMERIC type)
|
|
||||||
*/
|
*/
|
||||||
public INDArray dynamicPartition(INDArray x, INDArray[] partitions, int numPartitions) {
|
public INDArray[] dynamicPartition(INDArray x, INDArray partitions, int numPartitions) {
|
||||||
NDValidation.validateNumerical("dynamicPartition", "x", x);
|
NDValidation.validateNumerical("dynamicPartition", "x", x);
|
||||||
NDValidation.validateInteger("dynamicPartition", "partitions", partitions);
|
NDValidation.validateInteger("dynamicPartition", "partitions", partitions);
|
||||||
Preconditions.checkArgument(partitions.length >= 1, "partitions has incorrect size/length. Expected: partitions.length >= 1, got %s", partitions.length);
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions));
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions))[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dynamically merge the specified input arrays into a single array, using the specified indices<br>
|
* Dynamically merge the specified input arrays into a single array, using the specified indices<br>
|
||||||
*
|
*
|
||||||
* @param x Input variables. (NUMERIC type)
|
|
||||||
* @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type)
|
* @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type)
|
||||||
|
* @param x Input variables. (NUMERIC type)
|
||||||
* @return output Merged output variable (NUMERIC type)
|
* @return output Merged output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray dynamicStitch(INDArray[] x, INDArray[] indices) {
|
public INDArray dynamicStitch(INDArray[] indices, INDArray... x) {
|
||||||
NDValidation.validateNumerical("dynamicStitch", "x", x);
|
|
||||||
Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length);
|
|
||||||
NDValidation.validateInteger("dynamicStitch", "indices", indices);
|
NDValidation.validateInteger("dynamicStitch", "indices", indices);
|
||||||
Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length);
|
Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(x, indices))[0];
|
NDValidation.validateNumerical("dynamicStitch", "x", x);
|
||||||
|
Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -395,11 +447,9 @@ public class NDBase {
|
||||||
* @param indices (NUMERIC type)
|
* @param indices (NUMERIC type)
|
||||||
* @return output (NUMERIC type)
|
* @return output (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray gatherNd(INDArray[] df, INDArray[] indices) {
|
public INDArray gatherNd(INDArray df, INDArray indices) {
|
||||||
NDValidation.validateNumerical("gatherNd", "df", df);
|
NDValidation.validateNumerical("gatherNd", "df", df);
|
||||||
Preconditions.checkArgument(df.length >= 1, "df has incorrect size/length. Expected: df.length >= 1, got %s", df.length);
|
|
||||||
NDValidation.validateNumerical("gatherNd", "indices", indices);
|
NDValidation.validateNumerical("gatherNd", "indices", indices);
|
||||||
Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length);
|
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -516,6 +566,23 @@ public class NDBase {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new 1d array with values evenly spaced between values 'start' and 'stop'<br>
|
||||||
|
* For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]<br>
|
||||||
|
*
|
||||||
|
* @param start Start value (NUMERIC type)
|
||||||
|
* @param stop Stop value (NUMERIC type)
|
||||||
|
* @param number Number of values to generate (LONG type)
|
||||||
|
* @param dataType Data type of the output array
|
||||||
|
* @return output INDArray with linearly spaced elements (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray linspace(INDArray start, INDArray stop, INDArray number, DataType dataType) {
|
||||||
|
NDValidation.validateNumerical("linspace", "start", start);
|
||||||
|
NDValidation.validateNumerical("linspace", "stop", stop);
|
||||||
|
NDValidation.validateInteger("linspace", "number", number);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Less than operation: elementwise x < y<br>
|
* Less than operation: elementwise x < y<br>
|
||||||
*
|
*
|
||||||
|
@ -1071,6 +1138,20 @@ public class NDBase {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Array permutation operation: permute the dimensions according to the specified permutation indices.<br>
|
||||||
|
* Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @param dimensions Permute dimensions (INT type)
|
||||||
|
* @return output Output variable (permuted input) (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray permute(INDArray x, INDArray dimensions) {
|
||||||
|
NDValidation.validateNumerical("permute", "x", x);
|
||||||
|
NDValidation.validateInteger("permute", "dimensions", dimensions);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Array permutation operation: permute the dimensions according to the specified permutation indices.<br>
|
* Array permutation operation: permute the dimensions according to the specified permutation indices.<br>
|
||||||
* Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]<br>
|
* Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]<br>
|
||||||
|
@ -1141,6 +1222,24 @@ public class NDBase {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new variable with a 1d array, where the values start at from and increment by step<br>
|
||||||
|
* up to (but not including) limit.<br>
|
||||||
|
* For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]<br>
|
||||||
|
*
|
||||||
|
* @param from Initial/smallest value (NUMERIC type)
|
||||||
|
* @param to Largest value (exclusive) (NUMERIC type)
|
||||||
|
* @param step Step size (NUMERIC type)
|
||||||
|
* @param dataType
|
||||||
|
* @return output INDArray with the specified values (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray range(INDArray from, INDArray to, INDArray step, DataType dataType) {
|
||||||
|
NDValidation.validateNumerical("range", "from", from);
|
||||||
|
NDValidation.validateNumerical("range", "to", to);
|
||||||
|
NDValidation.validateNumerical("range", "step", step);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable<br>
|
* Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable<br>
|
||||||
*
|
*
|
||||||
|
@ -1168,6 +1267,21 @@ public class NDBase {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, condition));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, condition));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-wise replace where condition:<br>
|
||||||
|
* out[i] = value if condition(update[i]) is satisfied, or<br>
|
||||||
|
* out[i] = update[i] if condition(update[i]) is NOT satisfied<br>
|
||||||
|
*
|
||||||
|
* @param update Source array (NUMERIC type)
|
||||||
|
* @param value Value to set at the output, if the condition is satisfied
|
||||||
|
* @param condition Condition to check on update array elements
|
||||||
|
* @return output New array with values replaced where condition is satisfied (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray replaceWhere(INDArray update, double value, Condition condition) {
|
||||||
|
NDValidation.validateNumerical("replaceWhere", "update", update);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(update, value, condition));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the<br>
|
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the<br>
|
||||||
* input, but with the specified shape.<br>
|
* input, but with the specified shape.<br>
|
||||||
|
@ -1183,6 +1297,21 @@ public class NDBase {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the<br>
|
||||||
|
* input, but with the specified shape.<br>
|
||||||
|
* Note that prod(shape) must match length(input) == prod(input.shape)<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @param shape New shape for variable (Size: AtLeast(min=0))
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray reshape(INDArray x, long... shape) {
|
||||||
|
NDValidation.validateNumerical("reshape", "x", x);
|
||||||
|
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reverse the values of an array for the specified dimensions<br>
|
* Reverse the values of an array for the specified dimensions<br>
|
||||||
* If input is:<br>
|
* If input is:<br>
|
||||||
|
@ -1532,6 +1661,21 @@ public class NDBase {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a sequence mask (with values 0 or 1) based on the specified lengths <br>
|
||||||
|
* Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)<br>
|
||||||
|
*
|
||||||
|
* @param lengths Lengths of the sequences (NUMERIC type)
|
||||||
|
* @param maxLen Maximum sequence length (INT type)
|
||||||
|
* @param dataType
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray sequenceMask(INDArray lengths, INDArray maxLen, DataType dataType) {
|
||||||
|
NDValidation.validateNumerical("sequenceMask", "lengths", lengths);
|
||||||
|
NDValidation.validateInteger("sequenceMask", "maxLen", maxLen);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* see sequenceMask(String, SDVariable, SDVariable, DataType)<br>
|
* see sequenceMask(String, SDVariable, SDVariable, DataType)<br>
|
||||||
*
|
*
|
||||||
|
@ -1601,6 +1745,28 @@ public class NDBase {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a subset of the specified input, by specifying the first element and the size of the array.<br>
|
||||||
|
* For example, if input is:<br>
|
||||||
|
* [a, b, c]<br>
|
||||||
|
* [d, e, f]<br>
|
||||||
|
* then slice(input, begin=[0,1], size=[2,1] will return:<br>
|
||||||
|
* [b]<br>
|
||||||
|
* [e]<br>
|
||||||
|
* Note that for each dimension i, begin[i] + size[i] <= input.size(i)<br>
|
||||||
|
*
|
||||||
|
* @param input input Variable to get subset of (NUMERIC type)
|
||||||
|
* @param begin Beginning index. Must be same length as rank of input array (INT type)
|
||||||
|
* @param size Size of the output array. Must be same length as rank of input array (INT type)
|
||||||
|
* @return output Subset of the input (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray slice(INDArray input, INDArray begin, INDArray size) {
|
||||||
|
NDValidation.validateNumerical("slice", "input", input);
|
||||||
|
NDValidation.validateInteger("slice", "begin", begin);
|
||||||
|
NDValidation.validateInteger("slice", "size", size);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Squared L2 norm: see norm2(String, SDVariable, boolean, int...)<br>
|
* Squared L2 norm: see norm2(String, SDVariable, boolean, int...)<br>
|
||||||
*
|
*
|
||||||
|
@ -1668,7 +1834,8 @@ public class NDBase {
|
||||||
* @param axis Axis to stack on
|
* @param axis Axis to stack on
|
||||||
* @return output Output variable (NDARRAY type)
|
* @return output Output variable (NDARRAY type)
|
||||||
*/
|
*/
|
||||||
public INDArray stack(INDArray values, int axis) {
|
public INDArray stack(int axis, INDArray... values) {
|
||||||
|
Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1737,7 +1904,7 @@ public class NDBase {
|
||||||
* @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions
|
* @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions
|
||||||
* @return output A subset of the input array (NUMERIC type)
|
* @return output A subset of the input array (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask,
|
public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long[] strides, int beginMask,
|
||||||
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
|
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
|
||||||
NDValidation.validateNumerical("stridedSlice", "in", in);
|
NDValidation.validateNumerical("stridedSlice", "in", in);
|
||||||
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
|
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
|
||||||
|
@ -1762,7 +1929,7 @@ public class NDBase {
|
||||||
* @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1))
|
* @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1))
|
||||||
* @return output A subset of the input array (NUMERIC type)
|
* @return output A subset of the input array (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int... strides) {
|
public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long... strides) {
|
||||||
NDValidation.validateNumerical("stridedSlice", "in", in);
|
NDValidation.validateNumerical("stridedSlice", "in", in);
|
||||||
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
|
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
|
||||||
Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length);
|
Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length);
|
||||||
|
@ -1999,6 +2166,21 @@ public class NDBase {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis.<br>
|
||||||
|
* If input has shape [a,b,c] then output has shape:<br>
|
||||||
|
* axis = 0: [b,c]<br>
|
||||||
|
* axis = 1: [a,c]<br>
|
||||||
|
* axis = 2: [a,b]<br>
|
||||||
|
*
|
||||||
|
* @param value Input variable to unstack (NDARRAY type)
|
||||||
|
* @param axis Axis to unstack on
|
||||||
|
* @param num Number of output variables
|
||||||
|
*/
|
||||||
|
public INDArray[] unstack(INDArray value, int axis, int num) {
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Unstack(value, axis, num));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Variance array reduction operation, optionally along specified dimensions<br>
|
* Variance array reduction operation, optionally along specified dimensions<br>
|
||||||
*
|
*
|
||||||
|
|
|
@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops;
|
||||||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.enums.DataFormat;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
@ -32,7 +33,6 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||||
import org.nd4j.linalg.factory.NDValidation;
|
import org.nd4j.linalg.factory.NDValidation;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.enums.DataFormat;
|
|
||||||
|
|
||||||
public class NDCNN {
|
public class NDCNN {
|
||||||
public NDCNN() {
|
public NDCNN() {
|
||||||
|
@ -370,6 +370,18 @@ public class NDCNN {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(input, LocalResponseNormalizationConfig))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(input, LocalResponseNormalizationConfig))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
||||||
|
*
|
||||||
|
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||||
|
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||||
|
* @param Pooling2DConfig Configuration Object
|
||||||
|
*/
|
||||||
|
public INDArray[] maxPoolWithArgmax(INDArray input, Pooling2DConfig Pooling2DConfig) {
|
||||||
|
NDValidation.validateNumerical("maxPoolWithArgmax", "input", input);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(input, Pooling2DConfig));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - max pooling 2d <br>
|
* 2D Convolution layer operation - max pooling 2d <br>
|
||||||
*
|
*
|
||||||
|
|
|
@ -222,15 +222,12 @@ public class NDLoss {
|
||||||
*
|
*
|
||||||
* @param label Label array (NUMERIC type)
|
* @param label Label array (NUMERIC type)
|
||||||
* @param predictions Predictions array (NUMERIC type)
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
|
||||||
* @param epsilon epsilon
|
|
||||||
* @return output Log loss (NUMERIC type)
|
* @return output Log loss (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights, double epsilon) {
|
public INDArray logLoss(INDArray label, INDArray predictions) {
|
||||||
NDValidation.validateNumerical("logLoss", "label", label);
|
NDValidation.validateNumerical("logLoss", "label", label);
|
||||||
NDValidation.validateNumerical("logLoss", "predictions", predictions);
|
NDValidation.validateNumerical("logLoss", "predictions", predictions);
|
||||||
NDValidation.validateNumerical("logLoss", "weights", weights);
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0];
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, epsilon))[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -190,6 +190,58 @@ public class NDMath {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(x));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bit shift operation<br>
|
||||||
|
*
|
||||||
|
* @param x input (NUMERIC type)
|
||||||
|
* @param shift shift value (NUMERIC type)
|
||||||
|
* @return output shifted output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray bitShift(INDArray x, INDArray shift) {
|
||||||
|
NDValidation.validateNumerical("bitShift", "x", x);
|
||||||
|
NDValidation.validateNumerical("bitShift", "shift", shift);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, shift))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Right bit shift operation<br>
|
||||||
|
*
|
||||||
|
* @param x Input tensor (NUMERIC type)
|
||||||
|
* @param shift shift argument (NUMERIC type)
|
||||||
|
* @return output shifted output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray bitShiftRight(INDArray x, INDArray shift) {
|
||||||
|
NDValidation.validateNumerical("bitShiftRight", "x", x);
|
||||||
|
NDValidation.validateNumerical("bitShiftRight", "shift", shift);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, shift))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cyclic bit shift operation<br>
|
||||||
|
*
|
||||||
|
* @param x Input tensor (NUMERIC type)
|
||||||
|
* @param shift shift argy=ument (NUMERIC type)
|
||||||
|
* @return output shifted output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray bitShiftRotl(INDArray x, INDArray shift) {
|
||||||
|
NDValidation.validateNumerical("bitShiftRotl", "x", x);
|
||||||
|
NDValidation.validateNumerical("bitShiftRotl", "shift", shift);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, shift))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cyclic right shift operation<br>
|
||||||
|
*
|
||||||
|
* @param x Input tensor (NUMERIC type)
|
||||||
|
* @param shift Shift argument (NUMERIC type)
|
||||||
|
* @return output Shifted output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray bitShiftRotr(INDArray x, INDArray shift) {
|
||||||
|
NDValidation.validateNumerical("bitShiftRotr", "x", x);
|
||||||
|
NDValidation.validateNumerical("bitShiftRotr", "shift", shift);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, shift))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element-wise ceiling function: out = ceil(x).<br>
|
* Element-wise ceiling function: out = ceil(x).<br>
|
||||||
* Rounds each value up to the nearest integer value (if not already an integer)<br>
|
* Rounds each value up to the nearest integer value (if not already an integer)<br>
|
||||||
|
@ -346,13 +398,13 @@ public class NDMath {
|
||||||
*
|
*
|
||||||
* @param x Input variable x (NUMERIC type)
|
* @param x Input variable x (NUMERIC type)
|
||||||
* @param y Input variable y (NUMERIC type)
|
* @param y Input variable y (NUMERIC type)
|
||||||
* @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=1))
|
* @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0))
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray cosineDistance(INDArray x, INDArray y, int... dimensions) {
|
public INDArray cosineDistance(INDArray x, INDArray y, int... dimensions) {
|
||||||
NDValidation.validateNumerical("cosineDistance", "x", x);
|
NDValidation.validateNumerical("cosineDistance", "x", x);
|
||||||
NDValidation.validateNumerical("cosineDistance", "y", y);
|
NDValidation.validateNumerical("cosineDistance", "y", y);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(x, y, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(x, y, dimensions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -363,13 +415,13 @@ public class NDMath {
|
||||||
*
|
*
|
||||||
* @param x Input variable x (NUMERIC type)
|
* @param x Input variable x (NUMERIC type)
|
||||||
* @param y Input variable y (NUMERIC type)
|
* @param y Input variable y (NUMERIC type)
|
||||||
* @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=1))
|
* @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0))
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray cosineSimilarity(INDArray x, INDArray y, int... dimensions) {
|
public INDArray cosineSimilarity(INDArray x, INDArray y, int... dimensions) {
|
||||||
NDValidation.validateNumerical("cosineSimilarity", "x", x);
|
NDValidation.validateNumerical("cosineSimilarity", "x", x);
|
||||||
NDValidation.validateNumerical("cosineSimilarity", "y", y);
|
NDValidation.validateNumerical("cosineSimilarity", "y", y);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(x, y, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(x, y, dimensions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -501,13 +553,13 @@ public class NDMath {
|
||||||
*
|
*
|
||||||
* @param x Input variable x (NUMERIC type)
|
* @param x Input variable x (NUMERIC type)
|
||||||
* @param y Input variable y (NUMERIC type)
|
* @param y Input variable y (NUMERIC type)
|
||||||
* @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=1))
|
* @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0))
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray euclideanDistance(INDArray x, INDArray y, int... dimensions) {
|
public INDArray euclideanDistance(INDArray x, INDArray y, int... dimensions) {
|
||||||
NDValidation.validateNumerical("euclideanDistance", "x", x);
|
NDValidation.validateNumerical("euclideanDistance", "x", x);
|
||||||
NDValidation.validateNumerical("euclideanDistance", "y", y);
|
NDValidation.validateNumerical("euclideanDistance", "y", y);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(x, y, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(x, y, dimensions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -665,13 +717,13 @@ public class NDMath {
|
||||||
*
|
*
|
||||||
* @param x Input variable x (NUMERIC type)
|
* @param x Input variable x (NUMERIC type)
|
||||||
* @param y Input variable y (NUMERIC type)
|
* @param y Input variable y (NUMERIC type)
|
||||||
* @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=1))
|
* @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0))
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray hammingDistance(INDArray x, INDArray y, int... dimensions) {
|
public INDArray hammingDistance(INDArray x, INDArray y, int... dimensions) {
|
||||||
NDValidation.validateNumerical("hammingDistance", "x", x);
|
NDValidation.validateNumerical("hammingDistance", "x", x);
|
||||||
NDValidation.validateNumerical("hammingDistance", "y", y);
|
NDValidation.validateNumerical("hammingDistance", "y", y);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(x, y, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(x, y, dimensions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -817,13 +869,13 @@ public class NDMath {
|
||||||
*
|
*
|
||||||
* @param x Input variable x (NUMERIC type)
|
* @param x Input variable x (NUMERIC type)
|
||||||
* @param y Input variable y (NUMERIC type)
|
* @param y Input variable y (NUMERIC type)
|
||||||
* @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=1))
|
* @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0))
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray jaccardDistance(INDArray x, INDArray y, int... dimensions) {
|
public INDArray jaccardDistance(INDArray x, INDArray y, int... dimensions) {
|
||||||
NDValidation.validateNumerical("jaccardDistance", "x", x);
|
NDValidation.validateNumerical("jaccardDistance", "x", x);
|
||||||
NDValidation.validateNumerical("jaccardDistance", "y", y);
|
NDValidation.validateNumerical("jaccardDistance", "y", y);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(x, y, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(x, y, dimensions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -872,6 +924,18 @@ public class NDMath {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, keepDims, condition, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, keepDims, condition, dimensions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates difference between inputs X and Y.<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable X (NUMERIC type)
|
||||||
|
* @param y Input variable Y (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray[] listDiff(INDArray x, INDArray y) {
|
||||||
|
NDValidation.validateNumerical("listDiff", "x", x);
|
||||||
|
NDValidation.validateNumerical("listDiff", "y", y);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(x, y));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element-wise logarithm function (base e - natural logarithm): out = log(x)<br>
|
* Element-wise logarithm function (base e - natural logarithm): out = log(x)<br>
|
||||||
*
|
*
|
||||||
|
@ -940,13 +1004,13 @@ public class NDMath {
|
||||||
*
|
*
|
||||||
* @param x Input variable x (NUMERIC type)
|
* @param x Input variable x (NUMERIC type)
|
||||||
* @param y Input variable y (NUMERIC type)
|
* @param y Input variable y (NUMERIC type)
|
||||||
* @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=1))
|
* @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0))
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray manhattanDistance(INDArray x, INDArray y, int... dimensions) {
|
public INDArray manhattanDistance(INDArray x, INDArray y, int... dimensions) {
|
||||||
NDValidation.validateNumerical("manhattanDistance", "x", x);
|
NDValidation.validateNumerical("manhattanDistance", "x", x);
|
||||||
NDValidation.validateNumerical("manhattanDistance", "y", y);
|
NDValidation.validateNumerical("manhattanDistance", "y", y);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(x, y, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(x, y, dimensions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -983,7 +1047,7 @@ public class NDMath {
|
||||||
* @param inputs Input variables (NUMERIC type)
|
* @param inputs Input variables (NUMERIC type)
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray mergeAdd(INDArray[] inputs) {
|
public INDArray mergeAdd(INDArray... inputs) {
|
||||||
NDValidation.validateNumerical("mergeAdd", "inputs", inputs);
|
NDValidation.validateNumerical("mergeAdd", "inputs", inputs);
|
||||||
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(inputs))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(inputs))[0];
|
||||||
|
@ -996,7 +1060,7 @@ public class NDMath {
|
||||||
* @param inputs Input variables (NUMERIC type)
|
* @param inputs Input variables (NUMERIC type)
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray mergeAvg(INDArray[] inputs) {
|
public INDArray mergeAvg(INDArray... inputs) {
|
||||||
NDValidation.validateNumerical("mergeAvg", "inputs", inputs);
|
NDValidation.validateNumerical("mergeAvg", "inputs", inputs);
|
||||||
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(inputs))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(inputs))[0];
|
||||||
|
@ -1009,12 +1073,24 @@ public class NDMath {
|
||||||
* @param inputs Input variables (NUMERIC type)
|
* @param inputs Input variables (NUMERIC type)
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray mergeMax(INDArray[] inputs) {
|
public INDArray mergeMax(INDArray... inputs) {
|
||||||
NDValidation.validateNumerical("mergeMax", "inputs", inputs);
|
NDValidation.validateNumerical("mergeMax", "inputs", inputs);
|
||||||
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMax(inputs))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMax(inputs))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Broadcasts parameters for evaluation on an N-D grid.<br>
|
||||||
|
*
|
||||||
|
* @param inputs (NUMERIC type)
|
||||||
|
* @param cartesian
|
||||||
|
*/
|
||||||
|
public INDArray[] meshgrid(INDArray[] inputs, boolean cartesian) {
|
||||||
|
NDValidation.validateNumerical("meshgrid", "inputs", inputs);
|
||||||
|
Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(inputs, cartesian));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the mean and (population) variance for the input variable, for the specified axis<br>
|
* Calculate the mean and (population) variance for the input variable, for the specified axis<br>
|
||||||
*
|
*
|
||||||
|
|
|
@ -237,12 +237,11 @@ public class NDNN {
|
||||||
* Alpha value is most commonly set to 0.01<br>
|
* Alpha value is most commonly set to 0.01<br>
|
||||||
*
|
*
|
||||||
* @param x Input variable (NUMERIC type)
|
* @param x Input variable (NUMERIC type)
|
||||||
* @param alpha Cutoff - commonly 0.01 (NUMERIC type)
|
* @param alpha Cutoff - commonly 0.01
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray leakyRelu(INDArray x, INDArray alpha) {
|
public INDArray leakyRelu(INDArray x, double alpha) {
|
||||||
NDValidation.validateNumerical("leakyRelu", "x", x);
|
NDValidation.validateNumerical("leakyRelu", "x", x);
|
||||||
NDValidation.validateNumerical("leakyRelu", "alpha", alpha);
|
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(x, alpha));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(x, alpha));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,12 +249,11 @@ public class NDNN {
|
||||||
* Leaky ReLU derivative: dOut/dIn given input.<br>
|
* Leaky ReLU derivative: dOut/dIn given input.<br>
|
||||||
*
|
*
|
||||||
* @param x Input variable (NUMERIC type)
|
* @param x Input variable (NUMERIC type)
|
||||||
* @param alpha Cutoff - commonly 0.01 (NUMERIC type)
|
* @param alpha Cutoff - commonly 0.01
|
||||||
* @return output Output variable (NUMERIC type)
|
* @return output Output variable (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray leakyReluDerivative(INDArray x, INDArray alpha) {
|
public INDArray leakyReluDerivative(INDArray x, double alpha) {
|
||||||
NDValidation.validateNumerical("leakyReluDerivative", "x", x);
|
NDValidation.validateNumerical("leakyReluDerivative", "x", x);
|
||||||
NDValidation.validateNumerical("leakyReluDerivative", "alpha", alpha);
|
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(x, alpha));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(x, alpha));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -346,6 +344,20 @@ public class NDNN {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Padding operation <br>
|
||||||
|
*
|
||||||
|
* @param input Input tensor (NUMERIC type)
|
||||||
|
* @param padding Padding value (NUMERIC type)
|
||||||
|
* @param constant Padding constant
|
||||||
|
* @return output Padded input (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray pad(INDArray input, INDArray padding, double constant) {
|
||||||
|
NDValidation.validateNumerical("pad", "input", input);
|
||||||
|
NDValidation.validateNumerical("pad", "padding", padding);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, constant))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:<br>
|
* PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:<br>
|
||||||
* out[i] = in[i] if in[i] >= 0<br>
|
* out[i] = in[i] if in[i] >= 0<br>
|
||||||
|
@ -461,6 +473,17 @@ public class NDNN {
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Softmax activation, along the specified dimension<br>
|
||||||
|
*
|
||||||
|
* @param x Input (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray softmax(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("softmax", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, -1))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Softmax derivative function<br>
|
* Softmax derivative function<br>
|
||||||
*
|
*
|
||||||
|
@ -519,4 +542,15 @@ public class NDNN {
|
||||||
NDValidation.validateNumerical("swish", "x", x);
|
NDValidation.validateNumerical("swish", "x", x);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(x));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)<br>
|
||||||
|
*
|
||||||
|
* @param x Input variable (NUMERIC type)
|
||||||
|
* @return output Output variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray tanh(INDArray x) {
|
||||||
|
NDValidation.validateNumerical("tanh", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(x));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,9 @@ import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
|
||||||
import org.nd4j.linalg.factory.NDValidation;
|
import org.nd4j.linalg.factory.NDValidation;
|
||||||
|
@ -38,12 +40,11 @@ public class NDRNN {
|
||||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||||
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
|
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
* @param GRUWeights Configuration Object
|
* @param GRUWeights Configuration Object
|
||||||
* @return output The cell's outputs. (NUMERIC type)
|
|
||||||
*/
|
*/
|
||||||
public INDArray gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) {
|
public INDArray[] gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) {
|
||||||
NDValidation.validateNumerical("gru", "x", x);
|
NDValidation.validateNumerical("gru", "x", x);
|
||||||
NDValidation.validateNumerical("gru", "hLast", hLast);
|
NDValidation.validateNumerical("gru", "hLast", hLast);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -54,18 +55,83 @@ public class NDRNN {
|
||||||
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
* @param LSTMWeights Configuration Object
|
* @param LSTMWeights Configuration Object
|
||||||
* @param LSTMConfiguration Configuration Object
|
* @param LSTMConfiguration Configuration Object
|
||||||
* @return output The cell's outputs (NUMERIC type)
|
|
||||||
*/
|
*/
|
||||||
public INDArray lstmCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights,
|
public INDArray[] lstmCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights,
|
||||||
LSTMConfiguration LSTMConfiguration) {
|
LSTMConfiguration LSTMConfiguration) {
|
||||||
NDValidation.validateNumerical("lstmCell", "x", x);
|
NDValidation.validateNumerical("lstmCell", "x", x);
|
||||||
NDValidation.validateNumerical("lstmCell", "cLast", cLast);
|
NDValidation.validateNumerical("lstmCell", "cLast", cLast);
|
||||||
NDValidation.validateNumerical("lstmCell", "yLast", yLast);
|
NDValidation.validateNumerical("lstmCell", "yLast", yLast);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(x, cLast, yLast, LSTMWeights, LSTMConfiguration));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The LSTM layer. Does multiple time steps.<br>
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
|
* SUPPORTS following data formats:\n<br>
|
||||||
|
* for unidirectional: \n" +<br>
|
||||||
|
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||||
|
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||||
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
|
* for bidirectional:\n<br>
|
||||||
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||||
|
* SUPPORTS following direction modes:\n<br>
|
||||||
|
* FWD: forward<br>
|
||||||
|
* BWD: backward<br>
|
||||||
|
* BIDIR_SUM: bidirectional sum\n<br>
|
||||||
|
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||||
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||||
|
* You may use different gate configurations:<br>
|
||||||
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||||
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||||
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
|
*
|
||||||
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
|
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
||||||
|
* @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type)
|
||||||
|
* @param LSTMLayerWeights Configuration Object
|
||||||
|
* @param LSTMLayerConfig Configuration Object
|
||||||
|
*/
|
||||||
|
public INDArray[] lstmLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength,
|
||||||
|
LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) {
|
||||||
|
NDValidation.validateNumerical("lstmLayer", "x", x);
|
||||||
|
NDValidation.validateNumerical("lstmLayer", "cLast", cLast);
|
||||||
|
NDValidation.validateNumerical("lstmLayer", "yLast", yLast);
|
||||||
|
NDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||||
|
* SUPPORTS following data formats:\n<br>
|
||||||
|
* for unidirectional: \n" +<br>
|
||||||
|
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||||
|
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||||
|
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||||
|
* for bidirectional:\n<br>
|
||||||
|
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||||
|
* SUPPORTS following direction modes:\n<br>
|
||||||
|
* FWD: forward<br>
|
||||||
|
* BWD: backward<br>
|
||||||
|
* BIDIR_SUM: bidirectional sum\n<br>
|
||||||
|
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||||
|
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||||
|
* You may use different gate configurations:<br>
|
||||||
|
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||||
|
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||||
|
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||||
|
*
|
||||||
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
* @param LSTMLayerWeights Configuration Object
|
||||||
|
* @param LSTMLayerConfig Configuration Object
|
||||||
|
*/
|
||||||
|
public INDArray[] lstmLayer(INDArray x, LSTMLayerWeights LSTMLayerWeights,
|
||||||
|
LSTMLayerConfig LSTMLayerConfig) {
|
||||||
|
NDValidation.validateNumerical("lstmLayer", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(x, null, null, null, LSTMLayerWeights, LSTMLayerConfig));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The LSTM block<br>
|
||||||
*
|
*
|
||||||
* @param maxTSLength (NUMERIC type)
|
* @param maxTSLength (NUMERIC type)
|
||||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
@ -75,13 +141,27 @@ public class NDRNN {
|
||||||
* @param LSTMConfiguration Configuration Object
|
* @param LSTMConfiguration Configuration Object
|
||||||
* @return output The layer's outputs. (NUMERIC type)
|
* @return output The layer's outputs. (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray lstmLayer(INDArray maxTSLength, INDArray x, INDArray cLast, INDArray yLast,
|
public INDArray lstmblock(INDArray maxTSLength, INDArray x, INDArray cLast, INDArray yLast,
|
||||||
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
||||||
NDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
|
NDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength);
|
||||||
NDValidation.validateNumerical("lstmLayer", "x", x);
|
NDValidation.validateNumerical("lstmblock", "x", x);
|
||||||
NDValidation.validateNumerical("lstmLayer", "cLast", cLast);
|
NDValidation.validateNumerical("lstmblock", "cLast", cLast);
|
||||||
NDValidation.validateNumerical("lstmLayer", "yLast", yLast);
|
NDValidation.validateNumerical("lstmblock", "yLast", yLast);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The LSTM block<br>
|
||||||
|
*
|
||||||
|
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||||
|
* @param LSTMWeights Configuration Object
|
||||||
|
* @param LSTMConfiguration Configuration Object
|
||||||
|
* @return output The layer's outputs. (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray lstmblock(INDArray x, LSTMWeights LSTMWeights,
|
||||||
|
LSTMConfiguration LSTMConfiguration) {
|
||||||
|
NDValidation.validateNumerical("lstmblock", "x", x);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(null, x, null, null, LSTMWeights, LSTMConfiguration))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -199,7 +199,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, null, st);
|
||||||
|
|
||||||
return op.z();
|
return op.z();
|
||||||
}
|
}
|
||||||
|
@ -436,7 +436,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, null, st);
|
||||||
|
|
||||||
return op.z();
|
return op.z();
|
||||||
}
|
}
|
||||||
|
@ -524,7 +524,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
long st = profilingConfigurableHookIn(op);
|
long st = profilingConfigurableHookIn(op);
|
||||||
naiveExec(op, dimension);
|
naiveExec(op, dimension);
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, null, st);
|
||||||
|
|
||||||
return op.z();
|
return op.z();
|
||||||
}
|
}
|
||||||
|
@ -607,7 +607,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, null, st);
|
||||||
|
|
||||||
return op.z();
|
return op.z();
|
||||||
}
|
}
|
||||||
|
@ -772,7 +772,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, oc, st);
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -863,7 +863,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, oc, st);
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
|
||||||
|
@ -1113,7 +1113,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, oc, st);
|
||||||
|
|
||||||
Nd4j.getExecutioner().commit();
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
|
@ -1200,7 +1200,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, null, st);
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -1296,7 +1296,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, oc, st);
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -1460,7 +1460,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (ret != null)
|
if (ret != null)
|
||||||
ret.elementWiseStride();
|
ret.elementWiseStride();
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, oc, st);
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -1579,7 +1579,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, oc, st);
|
||||||
|
|
||||||
return z;
|
return z;
|
||||||
}
|
}
|
||||||
|
@ -2292,7 +2292,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] exec(CustomOp op, OpContext context) {
|
public INDArray[] exec(CustomOp op, OpContext context) {
|
||||||
long st = profilingConfigurableHookIn(op);
|
long st = profilingConfigurableHookIn(op, context);
|
||||||
|
|
||||||
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
||||||
((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation());
|
((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation());
|
||||||
|
@ -2304,7 +2304,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
if (status != 0)
|
if (status != 0)
|
||||||
throw new RuntimeException("Op [" + op.opName() + "] execution failed");
|
throw new RuntimeException("Op [" + op.opName() + "] execution failed");
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, context, st);
|
||||||
|
|
||||||
if (context.getOutputArrays().isEmpty())
|
if (context.getOutputArrays().isEmpty())
|
||||||
return new INDArray[0];
|
return new INDArray[0];
|
||||||
|
|
|
@ -236,7 +236,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
if (loop.lastErrorCode() != 0)
|
if (loop.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(loop.lastErrorMessage());
|
throw new RuntimeException(loop.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, oc, st);
|
||||||
return getZ(op, oc);
|
return getZ(op, oc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -690,7 +690,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
if (loop.lastErrorCode() != 0)
|
if (loop.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(loop.lastErrorMessage());
|
throw new RuntimeException(loop.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, oc, st);
|
||||||
|
|
||||||
return getZ(op, oc);
|
return getZ(op, oc);
|
||||||
}
|
}
|
||||||
|
@ -774,7 +774,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
if (z == null)
|
if (z == null)
|
||||||
setZ(Nd4j.create(op.resultType(), x.shape()), op, oc);
|
setZ(Nd4j.create(op.resultType(), x.shape()), op, oc);
|
||||||
// op.setZ(Nd4j.create(op.resultType(), op.x().shape()));
|
|
||||||
|
|
||||||
|
|
||||||
op.validateDataTypes(oc, experimentalMode.get());
|
op.validateDataTypes(oc, experimentalMode.get());
|
||||||
|
@ -884,7 +883,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
if (loop.lastErrorCode() != 0)
|
if (loop.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(loop.lastErrorMessage());
|
throw new RuntimeException(loop.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, oc, st);
|
||||||
}
|
}
|
||||||
|
|
||||||
public INDArray exec(BroadcastOp op) {
|
public INDArray exec(BroadcastOp op) {
|
||||||
|
@ -1306,7 +1305,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
if (loop.lastErrorCode() != 0)
|
if (loop.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(loop.lastErrorMessage());
|
throw new RuntimeException(loop.lastErrorMessage());
|
||||||
|
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, oc, st);
|
||||||
|
|
||||||
return z;
|
return z;
|
||||||
}
|
}
|
||||||
|
@ -2040,7 +2039,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] exec(CustomOp op, @NonNull OpContext context) {
|
public INDArray[] exec(CustomOp op, @NonNull OpContext context) {
|
||||||
long st = profilingConfigurableHookIn(op);
|
long st = profilingConfigurableHookIn(op, context);
|
||||||
boolean mklOverride = false;
|
boolean mklOverride = false;
|
||||||
try {
|
try {
|
||||||
if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) {
|
if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) {
|
||||||
|
@ -2125,7 +2124,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
} finally {
|
} finally {
|
||||||
if (mklOverride)
|
if (mklOverride)
|
||||||
Nd4jCpu.Environment.getInstance().setUseMKLDNN(true);
|
Nd4jCpu.Environment.getInstance().setUseMKLDNN(true);
|
||||||
profilingConfigurableHookOut(op, st);
|
profilingConfigurableHookOut(op, context, st);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,10 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.OpValidationSuite;
|
import org.nd4j.OpValidationSuite;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -36,6 +38,12 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -257,7 +265,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
msg = "7 - upsampling2d, NCHW, 2x2 - " + Arrays.toString(inSizeNCHW);
|
msg = "7 - upsampling2d, NCHW, 2x2 - " + Arrays.toString(inSizeNCHW);
|
||||||
inSize = inSizeNCHW;
|
inSize = inSizeNCHW;
|
||||||
in = sd.var("in", inSize);
|
in = sd.var("in", inSize);
|
||||||
out = sd.cnn().upsampling2d(in, 2, 2, true);
|
out = sd.cnn().upsampling2d(in, 2, 2, true);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
|
@ -578,8 +586,6 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
SDVariable dW = sd.var("dW", depthWeightArr);
|
SDVariable dW = sd.var("dW", depthWeightArr);
|
||||||
SDVariable b = sd.var("b", bArr);
|
SDVariable b = sd.var("b", bArr);
|
||||||
|
|
||||||
SDVariable[] vars = new SDVariable[]{in, dW, b};
|
|
||||||
|
|
||||||
Conv2DConfig c = Conv2DConfig.builder()
|
Conv2DConfig c = Conv2DConfig.builder()
|
||||||
.kH(kH).kW(kW)
|
.kH(kH).kW(kW)
|
||||||
.pH(0).pW(0)
|
.pH(0).pW(0)
|
||||||
|
@ -588,8 +594,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.isSameMode(false)
|
.isSameMode(false)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().separableConv2d(in, dW, b, c);
|
SDVariable out = sd.cnn().separableConv2d(in, dW, null, b, c);
|
||||||
out = sd.f().tanh(out);
|
out = sd.nn().tanh("out", out);
|
||||||
|
|
||||||
INDArray outArr = out.eval();
|
INDArray outArr = out.eval();
|
||||||
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
|
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
|
||||||
|
@ -623,8 +629,6 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
SDVariable pW = sd.var("pW", pointWeightArr);
|
SDVariable pW = sd.var("pW", pointWeightArr);
|
||||||
SDVariable b = sd.var("b", bArr);
|
SDVariable b = sd.var("b", bArr);
|
||||||
|
|
||||||
//SDVariable[] vars = new SDVariable[]{in, dW, pW, b};
|
|
||||||
|
|
||||||
Conv2DConfig c = Conv2DConfig.builder()
|
Conv2DConfig c = Conv2DConfig.builder()
|
||||||
.kH(kH).kW(kW)
|
.kH(kH).kW(kW)
|
||||||
.pH(0).pW(0)
|
.pH(0).pW(0)
|
||||||
|
@ -635,7 +639,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().separableConv2d(in, dW, pW, b, c);
|
SDVariable out = sd.cnn().separableConv2d(in, dW, pW, b, c);
|
||||||
out = sd.nn().tanh(out);
|
out = sd.nn().tanh("out", out);
|
||||||
|
|
||||||
INDArray outArr = out.eval();
|
INDArray outArr = out.eval();
|
||||||
//Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7
|
//Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7
|
||||||
|
@ -675,8 +679,6 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
SDVariable w = sd.var("W", wArr);
|
SDVariable w = sd.var("W", wArr);
|
||||||
SDVariable b = sd.var("b", bArr);
|
SDVariable b = sd.var("b", bArr);
|
||||||
|
|
||||||
SDVariable[] vars = new SDVariable[]{in, w, b};
|
|
||||||
|
|
||||||
DeConv2DConfig deconv = DeConv2DConfig.builder()
|
DeConv2DConfig deconv = DeConv2DConfig.builder()
|
||||||
.kH(kH).kW(kW)
|
.kH(kH).kW(kW)
|
||||||
.pH(0).pW(0)
|
.pH(0).pW(0)
|
||||||
|
@ -685,8 +687,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.isSameMode(false)
|
.isSameMode(false)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.f().deconv2d(vars, deconv);
|
SDVariable out = sd.cnn().deconv2d(in, w, b, deconv);
|
||||||
out = sd.f().tanh(out);
|
out = sd.nn().tanh("out", out);
|
||||||
|
|
||||||
INDArray outArr = out.eval();
|
INDArray outArr = out.eval();
|
||||||
//Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9
|
//Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9
|
||||||
|
@ -723,7 +725,6 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
//Order: https://github.com/deeplearning4j/libnd4j/blob/6c41ea5528bb1f454e92a9da971de87b93ff521f/include/ops/declarable/generic/convo/conv2d.cpp#L20-L22
|
//Order: https://github.com/deeplearning4j/libnd4j/blob/6c41ea5528bb1f454e92a9da971de87b93ff521f/include/ops/declarable/generic/convo/conv2d.cpp#L20-L22
|
||||||
//in, w, b - bias is optional
|
//in, w, b - bias is optional
|
||||||
SDVariable[] vars = new SDVariable[]{in, w, b};
|
|
||||||
|
|
||||||
Conv2DConfig c = Conv2DConfig.builder()
|
Conv2DConfig c = Conv2DConfig.builder()
|
||||||
.kH(kH).kW(kW)
|
.kH(kH).kW(kW)
|
||||||
|
@ -733,8 +734,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.isSameMode(false)
|
.isSameMode(false)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.f().conv2d(vars, c);
|
SDVariable out = sd.cnn().conv2d("conv", in, w, b, c);
|
||||||
out = sd.f().tanh(out);
|
out = sd.nn().tanh("out", out);
|
||||||
|
|
||||||
INDArray outArr = out.eval();
|
INDArray outArr = out.eval();
|
||||||
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
|
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
|
||||||
|
@ -767,7 +768,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.isSameMode(true)
|
.isSameMode(true)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable[] results = sd.f().maxPoolWithArgmax(/*new String[]{"out","idx"},*/ in, pooling2DConfig);
|
SDVariable[] results = sd.cnn().maxPoolWithArgmax(new String[]{"out", "idx"}, in, pooling2DConfig);
|
||||||
assertArrayEquals(inArr.shape(), results[0].eval().shape());
|
assertArrayEquals(inArr.shape(), results[0].eval().shape());
|
||||||
assertArrayEquals(inArr.shape(), results[1].eval().shape());
|
assertArrayEquals(inArr.shape(), results[1].eval().shape());
|
||||||
}
|
}
|
||||||
|
@ -797,7 +798,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig);
|
SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig);
|
||||||
SDVariable out = sd.f().tanh(/*"out",*/ outPool);
|
SDVariable out = sd.nn().tanh("out", outPool);
|
||||||
|
|
||||||
INDArray outArr = out.eval();
|
INDArray outArr = out.eval();
|
||||||
val outShape = outArr.shape();
|
val outShape = outArr.shape();
|
||||||
|
@ -855,7 +856,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig);
|
SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig);
|
||||||
SDVariable out = sd.f().tanh(/*"out",*/ outPool);
|
SDVariable out = sd.nn().tanh("out", outPool);
|
||||||
|
|
||||||
INDArray outArr = out.eval();
|
INDArray outArr = out.eval();
|
||||||
val outShape = outArr.shape();
|
val outShape = outArr.shape();
|
||||||
|
@ -906,7 +907,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig);
|
SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig);
|
||||||
out = sd.f().tanh(/*"loss", */out).shape().rename("out");
|
out = sd.nn().tanh("loss", out).shape().rename("out");
|
||||||
|
|
||||||
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
|
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
|
||||||
INDArray outArr = Nd4j.createFromArray(mb, nIn, 4, 4, 4L);
|
INDArray outArr = Nd4j.createFromArray(mb, nIn, 4, 4, 4L);
|
||||||
|
@ -942,7 +943,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig);
|
SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig);
|
||||||
out = sd.math().tanh("loss", out).shape().rename("out");
|
out = sd.nn().tanh("loss", out).shape().rename("out");
|
||||||
|
|
||||||
sd.setLossVariables("loss");
|
sd.setLossVariables("loss");
|
||||||
|
|
||||||
|
@ -976,8 +977,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.paddingMode(PaddingMode.VALID)
|
.paddingMode(PaddingMode.VALID)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig);
|
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
|
||||||
out = sd.math().tanh("loss", out).shape().rename("out");
|
out = sd.nn().tanh("loss", out).shape().rename("out");
|
||||||
|
|
||||||
sd.setLossVariables("loss");
|
sd.setLossVariables("loss");
|
||||||
|
|
||||||
|
@ -996,7 +997,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
int nOut = 4;
|
int nOut = 4;
|
||||||
int mb = 2;
|
int mb = 2;
|
||||||
|
|
||||||
for( int k : new int[]{2, 3}) {
|
for (int k : new int[]{2, 3}) {
|
||||||
for (int sz : new int[]{3, 4, 5}) {
|
for (int sz : new int[]{3, 4, 5}) {
|
||||||
for (int s : new int[]{1, 2}) {
|
for (int s : new int[]{1, 2}) {
|
||||||
for (int d : new int[]{1, 2}) {
|
for (int d : new int[]{1, 2}) {
|
||||||
|
@ -1018,7 +1019,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().conv1d(in, w, b, conv1DConfig);
|
SDVariable out = sd.cnn().conv1d(in, w, b, conv1DConfig);
|
||||||
SDVariable loss = sd.f().tanh(out).std(true).rename("loss");
|
SDVariable loss = sd.nn().tanh(out).std(true).rename("loss");
|
||||||
|
|
||||||
sd.setLossVariables("loss");
|
sd.setLossVariables("loss");
|
||||||
|
|
||||||
|
@ -1039,7 +1040,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConv1dForward(){
|
public void testConv1dForward() {
|
||||||
int nIn = 2;
|
int nIn = 2;
|
||||||
int nOut = 1;
|
int nOut = 1;
|
||||||
int kernel = 3;
|
int kernel = 3;
|
||||||
|
@ -1057,7 +1058,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
SDVariable in = sd.var("in", inArr);
|
SDVariable in = sd.var("in", inArr);
|
||||||
SDVariable w = sd.var("w", wArr);
|
SDVariable w = sd.var("w", wArr);
|
||||||
|
|
||||||
SDVariable res = sd.cnn.conv1d(in, w, null, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build());
|
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build());
|
||||||
|
|
||||||
INDArray expected = Nd4j.createFromArray(
|
INDArray expected = Nd4j.createFromArray(
|
||||||
new double[][][]{
|
new double[][][]{
|
||||||
|
@ -1113,7 +1114,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig);
|
SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig);
|
||||||
out = sd.math().tanh("loss", out).shape().rename("out");
|
out = sd.nn().tanh("loss", out).shape().rename("out");
|
||||||
|
|
||||||
sd.setLossVariables("loss");
|
sd.setLossVariables("loss");
|
||||||
|
|
||||||
|
@ -1156,7 +1157,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().deconv3d(in, w, conv3DConfig);
|
SDVariable out = sd.cnn().deconv3d(in, w, conv3DConfig);
|
||||||
out = sd.math().tanh("loss", out).shape().rename("out");
|
out = sd.nn().tanh("loss", out).shape().rename("out");
|
||||||
|
|
||||||
sd.setLossVariables("loss");
|
sd.setLossVariables("loss");
|
||||||
|
|
||||||
|
@ -1201,13 +1202,13 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
public void testLayerNorm4d() {
|
public void testLayerNorm4d() {
|
||||||
int mb = 3;
|
int mb = 3;
|
||||||
int ch = 4;
|
int ch = 4;
|
||||||
for(boolean nchw : new boolean[]{true, false}) {
|
for (boolean nchw : new boolean[]{true, false}) {
|
||||||
double eps = 0.0;
|
double eps = 0.0;
|
||||||
INDArray x = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch});
|
INDArray x = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch});
|
||||||
INDArray gain4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
INDArray gain4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
||||||
INDArray bias4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
INDArray bias4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
||||||
INDArray mean = x.mean(true, 1, 2, 3);
|
INDArray mean = x.mean(true, 1, 2, 3);
|
||||||
INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1);
|
INDArray std = Transforms.sqrt(x.var(false, 1, 2, 3).addi(eps)).reshape(mb, 1, 1, 1);
|
||||||
|
|
||||||
INDArray standardized = x.sub(mean).div(std);
|
INDArray standardized = x.sub(mean).div(std);
|
||||||
INDArray exp = standardized.mul(gain4d).add(bias4d);
|
INDArray exp = standardized.mul(gain4d).add(bias4d);
|
||||||
|
@ -1274,7 +1275,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
final INDArray standardized = random.ulike();
|
final INDArray standardized = random.ulike();
|
||||||
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
||||||
|
|
||||||
final INDArray gain = Nd4j.rand(DataType.DOUBLE,4);
|
final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4);
|
||||||
final INDArray res = standardized.mulRowVector(gain);
|
final INDArray res = standardized.mulRowVector(gain);
|
||||||
|
|
||||||
final INDArray output = Nd4j.zerosLike(res);
|
final INDArray output = Nd4j.zerosLike(res);
|
||||||
|
@ -1287,7 +1288,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
public void testLayerNormNoDeviation() {
|
public void testLayerNormNoDeviation() {
|
||||||
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
random.putScalar(1,i, 7);
|
random.putScalar(1, i, 7);
|
||||||
}
|
}
|
||||||
|
|
||||||
final INDArray standardized = random.ulike();
|
final INDArray standardized = random.ulike();
|
||||||
|
@ -1335,7 +1336,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.paddingMode(PaddingMode.VALID)
|
.paddingMode(PaddingMode.VALID)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig);
|
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1391,16 +1392,16 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLayerNormMixedOrders(){
|
public void testLayerNormMixedOrders() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
|
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
|
||||||
INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
|
INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
|
||||||
INDArray bias = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
|
INDArray bias = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
|
||||||
|
|
||||||
INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f');
|
INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'f');
|
||||||
INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c');
|
INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'c');
|
||||||
INDArray outFC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c');
|
INDArray outFC = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'c');
|
||||||
INDArray outCF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f');
|
INDArray outCF = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'f');
|
||||||
|
|
||||||
//F in, F out case
|
//F in, F out case
|
||||||
Nd4j.exec(DynamicCustomOp.builder("layer_norm")
|
Nd4j.exec(DynamicCustomOp.builder("layer_norm")
|
||||||
|
@ -1441,11 +1442,11 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
public void testBiasAdd_nchw_nhwc() {
|
public void testBiasAdd_nchw_nhwc() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
for(boolean nchw : new boolean[]{true, false}) {
|
for (boolean nchw : new boolean[]{true, false}) {
|
||||||
log.info("Starting test: {}", nchw ? "nchw" : "nhwc");
|
log.info("Starting test: {}", nchw ? "nchw" : "nhwc");
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
|
||||||
SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2,4,3,3} : new long[]{2,3,3,4}));
|
SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2, 4, 3, 3} : new long[]{2, 3, 3, 4}));
|
||||||
SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4}));
|
SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4}));
|
||||||
|
|
||||||
SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw);
|
SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw);
|
||||||
|
@ -1453,10 +1454,10 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
INDArray exp = in.getArr().dup();
|
INDArray exp = in.getArr().dup();
|
||||||
if(nchw){
|
if (nchw) {
|
||||||
exp.addi(b.getArr().reshape(1,4,1,1));
|
exp.addi(b.getArr().reshape(1, 4, 1, 1));
|
||||||
} else {
|
} else {
|
||||||
exp.addi(b.getArr().reshape(1,1,1,4));
|
exp.addi(b.getArr().reshape(1, 1, 1, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
TestCase tc = new TestCase(sameDiff)
|
TestCase tc = new TestCase(sameDiff)
|
||||||
|
@ -1467,4 +1468,168 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void LSTMLayerTestCase1() {
|
||||||
|
|
||||||
|
int bS = 5;
|
||||||
|
int nIn = 3;
|
||||||
|
int numUnits = 7;
|
||||||
|
int sL = 10; //small just for test
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
// notations:
|
||||||
|
// bS - batch size, numExamples
|
||||||
|
// sL - sequence length, number of time steps, timeLength
|
||||||
|
// nIn - input size, inOutSize
|
||||||
|
|
||||||
|
// TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
|
||||||
|
// NST: shape [numExamples, inOutSize, timeLength]<br>
|
||||||
|
// NTS: shape [numExamples, timeLength, inOutSize]<br>
|
||||||
|
// for bidirectional:
|
||||||
|
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, nIn, sL));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
|
||||||
|
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
|
||||||
|
|
||||||
|
LSTMLayerConfig c = LSTMLayerConfig.builder()
|
||||||
|
.lstmdataformat(LSTMDataFormat.NST)
|
||||||
|
.directionMode(LSTMDirectionMode.FWD)
|
||||||
|
.gateAct(LSTMActivations.SIGMOID)
|
||||||
|
.cellAct(LSTMActivations.TANH)
|
||||||
|
.outAct(LSTMActivations.TANH)
|
||||||
|
.retFullSequence(true)
|
||||||
|
.retLastC(true)
|
||||||
|
.retLastH(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
|
||||||
|
in, cLast, yLast, null,
|
||||||
|
LSTMLayerWeights.builder()
|
||||||
|
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
|
||||||
|
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
|
||||||
|
.peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits)))
|
||||||
|
.bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits))).build(),
|
||||||
|
c), c);
|
||||||
|
|
||||||
|
long[] out = new long[]{bS, numUnits, sL};
|
||||||
|
long[] hL = new long[]{bS, numUnits};
|
||||||
|
long[] cL = new long[]{bS, numUnits};
|
||||||
|
|
||||||
|
assertArrayEquals(out, outputs.getOutput().eval().shape());
|
||||||
|
assertArrayEquals(hL, outputs.getLastTimeStepOutput().eval().shape());
|
||||||
|
assertArrayEquals(cL, outputs.getLastCellStateOutput().eval().shape());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824
|
||||||
|
public void LSTMLayerTestCase2() {
|
||||||
|
int bS = 5;
|
||||||
|
int nIn = 3;
|
||||||
|
int numUnits = 7;
|
||||||
|
int sL = 10; //small just for test
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
// notations:
|
||||||
|
// bS - batch size, numExamples
|
||||||
|
// sL - sequence length, number of time steps, timeLength
|
||||||
|
// nIn - input size, inOutSize
|
||||||
|
|
||||||
|
// TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
|
||||||
|
// NST: shape [numExamples, inOutSize, timeLength]<br>
|
||||||
|
// NTS: shape [numExamples, timeLength, inOutSize]<br>
|
||||||
|
// for bidirectional:
|
||||||
|
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
|
||||||
|
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, sL, bS, nIn));
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
|
||||||
|
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
|
||||||
|
|
||||||
|
LSTMLayerConfig c = LSTMLayerConfig.builder()
|
||||||
|
.lstmdataformat(LSTMDataFormat.TNS)
|
||||||
|
.directionMode(LSTMDirectionMode.FWD)
|
||||||
|
.gateAct(LSTMActivations.SIGMOID)
|
||||||
|
.cellAct(LSTMActivations.TANH)
|
||||||
|
.outAct(LSTMActivations.TANH)
|
||||||
|
.retFullSequence(true)
|
||||||
|
.retLastC(false)
|
||||||
|
.retLastH(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
|
||||||
|
in, cLast, yLast, null,
|
||||||
|
LSTMLayerWeights.builder()
|
||||||
|
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
|
||||||
|
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
|
||||||
|
.build(),
|
||||||
|
c), c);
|
||||||
|
|
||||||
|
|
||||||
|
long[] out = new long[]{sL, bS, numUnits};
|
||||||
|
assertArrayEquals(out, outputs.getOutput().eval().shape());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824
|
||||||
|
public void LSTMLayerTestCase3() {
|
||||||
|
int bS = 5;
|
||||||
|
int nIn = 3;
|
||||||
|
int numUnits = 7;
|
||||||
|
int sL = 10; //small just for test
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
// notations:
|
||||||
|
// bS - batch size, numExamples
|
||||||
|
// sL - sequence length, number of time steps, timeLength
|
||||||
|
// nIn - input size, inOutSize
|
||||||
|
|
||||||
|
// TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
|
||||||
|
// NST: shape [numExamples, inOutSize, timeLength]<br>
|
||||||
|
// NTS: shape [numExamples, timeLength, inOutSize]<br>
|
||||||
|
// for bidirectional:
|
||||||
|
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
|
||||||
|
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, sL, nIn));
|
||||||
|
|
||||||
|
|
||||||
|
// when directionMode >= 2 (BIDIR_CONCAT=3)
|
||||||
|
// Wx, Wr [2, nIn, 4*nOut]
|
||||||
|
// hI, cI [2, bS, nOut]
|
||||||
|
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits));
|
||||||
|
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits));
|
||||||
|
|
||||||
|
LSTMLayerConfig c = LSTMLayerConfig.builder()
|
||||||
|
.lstmdataformat(LSTMDataFormat.NTS)
|
||||||
|
.directionMode(LSTMDirectionMode.BIDIR_CONCAT)
|
||||||
|
.gateAct(LSTMActivations.SIGMOID)
|
||||||
|
.cellAct(LSTMActivations.SOFTPLUS)
|
||||||
|
.outAct(LSTMActivations.SOFTPLUS)
|
||||||
|
.retFullSequence(true)
|
||||||
|
.retLastC(false)
|
||||||
|
.retLastH(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(new String[]{"out"},
|
||||||
|
in, cLast, yLast, null,
|
||||||
|
LSTMLayerWeights.builder()
|
||||||
|
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, 2, nIn, 4 * numUnits)))
|
||||||
|
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, 2, numUnits, 4 * numUnits)))
|
||||||
|
.build(),
|
||||||
|
c), c);
|
||||||
|
|
||||||
|
|
||||||
|
long[] out = new long[]{bS, sL, 2 * numUnits};
|
||||||
|
|
||||||
|
assertArrayEquals(out, outputs.getOutput().eval().shape());
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -548,7 +548,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
INDArray arr2 = Nd4j.rand(new long[]{2, 2, 2});
|
INDArray arr2 = Nd4j.rand(new long[]{2, 2, 2});
|
||||||
SDVariable x = sameDiff.var("x", arr);
|
SDVariable x = sameDiff.var("x", arr);
|
||||||
SDVariable y = sameDiff.var("y", arr2);
|
SDVariable y = sameDiff.var("y", arr2);
|
||||||
SDVariable result = sameDiff.tensorMmul(x, y, new int[][]{{0}, {1}});
|
SDVariable result = sameDiff.tensorMmul(x, y, new int[]{0}, new int[]{1});
|
||||||
assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}),
|
assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}),
|
||||||
result.eval().shape());
|
result.eval().shape());
|
||||||
assertEquals(16, sameDiff.numElements());
|
assertEquals(16, sameDiff.numElements());
|
||||||
|
@ -689,13 +689,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
SDVariable a = sd.var("a", aArr);
|
SDVariable a = sd.var("a", aArr);
|
||||||
SDVariable b = sd.var("b", bArr);
|
SDVariable b = sd.var("b", bArr);
|
||||||
|
|
||||||
MMulTranspose mt = MMulTranspose.builder()
|
SDVariable mmul = sd.mmul(a, b, transposeA, transposeB, transposeResult);
|
||||||
.transposeA(transposeA)
|
|
||||||
.transposeB(transposeB)
|
|
||||||
.transposeResult(transposeResult)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
SDVariable mmul = sd.mmul(a, b, mt);
|
|
||||||
|
|
||||||
INDArray exp = (transposeA ? aArr.transpose() : aArr);
|
INDArray exp = (transposeA ? aArr.transpose() : aArr);
|
||||||
exp = exp.mmul(transposeB ? bArr.transpose() : bArr);
|
exp = exp.mmul(transposeB ? bArr.transpose() : bArr);
|
||||||
|
|
|
@ -70,7 +70,7 @@ public class RnnOpValidation extends BaseOpValidation {
|
||||||
LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
|
LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
|
||||||
.inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
|
.inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
|
||||||
|
|
||||||
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
|
LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf)); //Output order: i, c, f, o, z, h, y
|
||||||
List<String> toExec = new ArrayList<>();
|
List<String> toExec = new ArrayList<>();
|
||||||
for(SDVariable sdv : v.getAllOutputs()){
|
for(SDVariable sdv : v.getAllOutputs()){
|
||||||
toExec.add(sdv.name());
|
toExec.add(sdv.name());
|
||||||
|
@ -173,7 +173,7 @@ public class RnnOpValidation extends BaseOpValidation {
|
||||||
LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
|
LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
|
||||||
.inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
|
.inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
|
||||||
|
|
||||||
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
|
LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf)); //Output order: i, c, f, o, z, h, y
|
||||||
List<String> toExec = new ArrayList<>();
|
List<String> toExec = new ArrayList<>();
|
||||||
for(SDVariable sdv : v.getAllOutputs()){
|
for(SDVariable sdv : v.getAllOutputs()){
|
||||||
toExec.add(sdv.name());
|
toExec.add(sdv.name());
|
||||||
|
@ -227,7 +227,7 @@ public class RnnOpValidation extends BaseOpValidation {
|
||||||
.cBias(bc)
|
.cBias(bc)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
List<SDVariable> v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs();
|
SDVariable[] v = sd.rnn().gru(x, hLast, weights);
|
||||||
List<String> toExec = new ArrayList<>();
|
List<String> toExec = new ArrayList<>();
|
||||||
for(SDVariable sdv : v){
|
for(SDVariable sdv : v){
|
||||||
toExec.add(sdv.name());
|
toExec.add(sdv.name());
|
||||||
|
|
|
@ -119,7 +119,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
|
||||||
for (int[] toShape : new int[][]{{3, 4 * 5}, {3 * 4, 5}, {1, 3 * 4 * 5}, {3 * 4 * 5, 1}}) {
|
for (long[] toShape : new long[][]{{3, 4 * 5}, {3 * 4, 5}, {1, 3 * 4 * 5}, {3 * 4 * 5, 1}}) {
|
||||||
for(char order : new char[]{'c','f'}){
|
for(char order : new char[]{'c','f'}){
|
||||||
INDArray inArr = Nd4j.rand(DataType.DOUBLE, origShape, order).muli(100);
|
INDArray inArr = Nd4j.rand(DataType.DOUBLE, origShape, order).muli(100);
|
||||||
|
|
||||||
|
@ -388,10 +388,10 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
@Builder(builderClassName = "Builder")
|
@Builder(builderClassName = "Builder")
|
||||||
@Data
|
@Data
|
||||||
private static class SSCase {
|
private static class SSCase {
|
||||||
private int[] shape;
|
private long[] shape;
|
||||||
private int[] begin;
|
private long[] begin;
|
||||||
private int[] end;
|
private long[] end;
|
||||||
private int[] strides;
|
private long[] strides;
|
||||||
private int beginMask;
|
private int beginMask;
|
||||||
private int endMask;
|
private int endMask;
|
||||||
private int ellipsisMask;
|
private int ellipsisMask;
|
||||||
|
@ -400,22 +400,22 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
public static class Builder {
|
public static class Builder {
|
||||||
|
|
||||||
public Builder shape(int... shape) {
|
public Builder shape(long... shape) {
|
||||||
this.shape = shape;
|
this.shape = shape;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder begin(int... begin) {
|
public Builder begin(long... begin) {
|
||||||
this.begin = begin;
|
this.begin = begin;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder end(int... end) {
|
public Builder end(long... end) {
|
||||||
this.end = end;
|
this.end = end;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder strides(int... strides) {
|
public Builder strides(long... strides) {
|
||||||
this.strides = strides;
|
this.strides = strides;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -1571,7 +1571,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
|
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
|
||||||
SDVariable x1 = sameDiff.var("x1", arr1);
|
SDVariable x1 = sameDiff.var("x1", arr1);
|
||||||
SDVariable x2 = sameDiff.var("x2", arr2);
|
SDVariable x2 = sameDiff.var("x2", arr2);
|
||||||
SDVariable result = sameDiff.parallel_stack(new SDVariable[]{x1, x2});
|
SDVariable result = sameDiff.stack(0, new SDVariable[]{x1, x2});
|
||||||
assertArrayEquals(new long[]{2, 3, 2}, result.eval().shape());
|
assertArrayEquals(new long[]{2, 3, 2}, result.eval().shape());
|
||||||
assertEquals(Nd4j.concat(0, arr1, arr2).reshape(2, 3, 2), result.eval());
|
assertEquals(Nd4j.concat(0, arr1, arr2).reshape(2, 3, 2), result.eval());
|
||||||
}
|
}
|
||||||
|
@ -1661,9 +1661,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.var("in", inArr);
|
SDVariable in = sd.var("in", inArr);
|
||||||
SDVariable slice_full = sd.stridedSlice(in, new int[]{0, 0}, new int[]{3, 4}, new int[]{1, 1});
|
SDVariable slice_full = sd.stridedSlice(in,new long[]{0, 0},new long[]{3, 4},new long[]{1, 1});
|
||||||
SDVariable subPart = sd.stridedSlice(in, new int[]{1, 2}, new int[]{3, 4}, new int[]{1, 1});
|
SDVariable subPart = sd.stridedSlice(in,new long[]{1, 2},new long[]{3, 4},new long[]{1, 1});
|
||||||
// SDVariable subPart2 = sd.stridedSlice(in, new int[]{0, 0}, new int[]{4, 5}, new int[]{2, 2});
|
// SDVariable subPart2 = sd.stridedSlice(in,new long[]{0, 0},new long[]{4, 5},new long[]{2, 2});
|
||||||
|
|
||||||
sd.outputAll(null);
|
sd.outputAll(null);
|
||||||
|
|
||||||
|
@ -1679,8 +1679,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.var("in", inArr);
|
SDVariable in = sd.var("in", inArr);
|
||||||
SDVariable slice1 = sd.stridedSlice(in, new int[]{-999, 0}, new int[]{2, 4}, new int[]{1, 1}, 1 << 1, 0, 0, 0, 0);
|
SDVariable slice1 = sd.stridedSlice(in,new long[]{-999, 0},new long[]{2, 4},new long[]{1, 1}, 1 << 1, 0, 0, 0, 0);
|
||||||
SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 0}, new int[]{-999, 4}, new int[]{1, 1}, 0, 1, 0, 0, 0);
|
SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 0},new long[]{-999, 4},new long[]{1, 1}, 0, 1, 0, 0, 0);
|
||||||
|
|
||||||
sd.outputAll(null);
|
sd.outputAll(null);
|
||||||
|
|
||||||
|
@ -1695,9 +1695,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
SDVariable in = sd.var("in", inArr);
|
SDVariable in = sd.var("in", inArr);
|
||||||
|
|
||||||
//[1:3,...] -> [1:3,:,:]
|
//[1:3,...] -> [1:3,:,:]
|
||||||
SDVariable slice = sd.stridedSlice(in, new int[]{1}, new int[]{3}, new int[]{1}, 0, 0, 1 << 1, 0, 0);
|
SDVariable slice = sd.stridedSlice(in,new long[]{1},new long[]{3},new long[]{1}, 0, 0, 1 << 1, 0, 0);
|
||||||
//[1:3,...,1:4] -> [1:3,:,1:4]
|
//[1:3,...,1:4] -> [1:3,:,1:4]
|
||||||
SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 1}, new int[]{3, 4}, new int[]{1, 1}, 0, 0, 1 << 1, 0, 0);
|
SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 1},new long[]{3, 4},new long[]{1, 1}, 0, 0, 1 << 1, 0, 0);
|
||||||
|
|
||||||
sd.outputAll(Collections.emptyMap());
|
sd.outputAll(Collections.emptyMap());
|
||||||
|
|
||||||
|
@ -1710,7 +1710,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.var("in", inArr);
|
SDVariable in = sd.var("in", inArr);
|
||||||
SDVariable slice = sd.stridedSlice(in, new int[]{-999, 0, 0, 0}, new int[]{-999, 3, 4, 5}, new int[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0);
|
SDVariable slice = sd.stridedSlice(in,new long[]{-999, 0, 0, 0},new long[]{-999, 3, 4, 5},new long[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0);
|
||||||
|
|
||||||
INDArray out = slice.eval();
|
INDArray out = slice.eval();
|
||||||
|
|
||||||
|
@ -1723,7 +1723,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.var("in", inArr);
|
SDVariable in = sd.var("in", inArr);
|
||||||
SDVariable slice = sd.stridedSlice(in, new int[]{1, 1, -999, 1}, new int[]{3, 3, -999, 4}, new int[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0);
|
SDVariable slice = sd.stridedSlice(in,new long[]{1, 1, -999, 1},new long[]{3, 3, -999, 4},new long[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0);
|
||||||
INDArray out = slice.eval();
|
INDArray out = slice.eval();
|
||||||
|
|
||||||
assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape());
|
assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape());
|
||||||
|
@ -1735,9 +1735,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.var("in", inArr);
|
SDVariable in = sd.var("in", inArr);
|
||||||
SDVariable slice = sd.stridedSlice(in, new int[]{0, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1);
|
SDVariable slice = sd.stridedSlice(in,new long[]{0, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1);
|
||||||
SDVariable slice2 = sd.stridedSlice(in, new int[]{2, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1);
|
SDVariable slice2 = sd.stridedSlice(in,new long[]{2, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1);
|
||||||
SDVariable slice3 = sd.stridedSlice(in, new int[]{1, 2, 1}, new int[]{-999, -999, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1);
|
SDVariable slice3 = sd.stridedSlice(in,new long[]{1, 2, 1},new long[]{-999, -999, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1);
|
||||||
|
|
||||||
sd.outputAll(null);
|
sd.outputAll(null);
|
||||||
|
|
||||||
|
|
|
@ -1920,7 +1920,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable sdA = sd.var("a", a);
|
SDVariable sdA = sd.var("a", a);
|
||||||
SDVariable sdB = sd.var("b", b);
|
SDVariable sdB = sd.var("b", b);
|
||||||
SDVariable t = sd.mmul(sdA, sdB, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).transposeResult(transposeResult).build());
|
SDVariable t = sd.mmul(sdA, sdB, transposeA, transposeB, transposeResult);
|
||||||
t.norm1("out");
|
t.norm1("out");
|
||||||
|
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
|
|
|
@ -759,8 +759,7 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1);
|
val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1);
|
||||||
val input1 = sd.var("input", matrix);
|
val input1 = sd.var("input", matrix);
|
||||||
val input2 = sd.var("input2", vector);
|
val input2 = sd.var("input2", vector);
|
||||||
val output = sd
|
val output = sd.mmul("output", input1, input2, true, false, false);
|
||||||
.mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build());
|
|
||||||
INDArray out = output.eval();
|
INDArray out = output.eval();
|
||||||
assertArrayEquals(new long[]{3, 1}, out.shape());
|
assertArrayEquals(new long[]{3, 1}, out.shape());
|
||||||
}
|
}
|
||||||
|
@ -2675,7 +2674,7 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
|
|
||||||
final long timeSteps = sdInput.getShape()[2];
|
final long timeSteps = sdInput.getShape()[2];
|
||||||
SDVariable[] outputSlices = new SDVariable[(int) timeSteps];
|
SDVariable[] outputSlices = new SDVariable[(int) timeSteps];
|
||||||
final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2);
|
final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2, 2);
|
||||||
|
|
||||||
final val x_0 = inputSlices[0];
|
final val x_0 = inputSlices[0];
|
||||||
outputSlices[0] = x_0;
|
outputSlices[0] = x_0;
|
||||||
|
@ -2702,7 +2701,7 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
final SDVariable sdInput = sd.var("input", input);
|
final SDVariable sdInput = sd.var("input", input);
|
||||||
|
|
||||||
final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2);
|
final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2, 2);
|
||||||
final val temp = inputSlices[0].add(inputSlices[1]).div(inputSlices[1]).mul(inputSlices[0]);
|
final val temp = inputSlices[0].add(inputSlices[1]).div(inputSlices[1]).mul(inputSlices[0]);
|
||||||
final val out = temp.add(temp).add(inputSlices[1]);
|
final val out = temp.add(temp).add(inputSlices[1]);
|
||||||
out.norm2("out");
|
out.norm2("out");
|
||||||
|
@ -3242,61 +3241,61 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNestedIf() throws IOException {
|
public void testNestedIf() throws IOException {
|
||||||
SameDiff SD = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable a = SD.var("a", Nd4j.createFromArray(2.0));
|
SDVariable a = sd.var("a", Nd4j.createFromArray(2.0));
|
||||||
SDVariable b = SD.var("b", Nd4j.createFromArray(5.0));
|
SDVariable b = sd.var("b", Nd4j.createFromArray(5.0));
|
||||||
SDVariable c = SD.var("c", Nd4j.createFromArray(9.0));
|
SDVariable c = sd.var("c", Nd4j.createFromArray(9.0));
|
||||||
SDVariable d = SD.var("d", Nd4j.createFromArray(-7.0));
|
SDVariable d = sd.var("d", Nd4j.createFromArray(-7.0));
|
||||||
|
|
||||||
SDVariable output = SD.ifCond("out", null,
|
SDVariable output = sd.ifCond("out", null,
|
||||||
(sd) -> a.lt(b),
|
(s) -> a.lt(b),
|
||||||
(sd) -> sd.ifCond(
|
(s) -> s.ifCond(
|
||||||
(sd2) -> d.lte(0),
|
(sd2) -> d.lte(0),
|
||||||
(sd2) -> c.add(1),
|
(sd2) -> c.add(1),
|
||||||
(sd2) -> d),
|
(sd2) -> d),
|
||||||
(sd) -> c.add(5));
|
(s) -> c.add(5));
|
||||||
INDArray out = output.eval();
|
INDArray out = output.eval();
|
||||||
assertEquals(Nd4j.createFromArray(10.0), out);
|
assertEquals(Nd4j.createFromArray(10.0), out);
|
||||||
|
|
||||||
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
|
sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false));
|
||||||
|
|
||||||
assertEquals(Nd4j.createFromArray(10.0), SD.output(Collections.emptyMap(), "out").get("out"));
|
assertEquals(Nd4j.createFromArray(10.0), sd.output(Collections.emptyMap(), "out").get("out"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWhile() throws IOException {
|
public void testWhile() throws IOException {
|
||||||
|
|
||||||
SameDiff SD = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable countIn = SD.constant(5);
|
SDVariable countIn = sd.constant(5);
|
||||||
SDVariable sumIn = SD.constant(0);
|
SDVariable sumIn = sd.constant(0);
|
||||||
|
|
||||||
SDVariable[] sum = SD.whileLoop("while_1", new SDVariable[]{countIn, sumIn},
|
SDVariable[] sum = sd.whileLoop("while_1", new SDVariable[]{countIn, sumIn},
|
||||||
(sd, vars) -> vars[0].gt(0),
|
(s, vars) -> vars[0].gt(0),
|
||||||
(sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(vars[0])});
|
(s, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(vars[0])});
|
||||||
|
|
||||||
INDArray out = sum[1].eval();
|
INDArray out = sum[1].eval();
|
||||||
assertEquals(15, out.getInt(0));
|
assertEquals(15, out.getInt(0));
|
||||||
|
|
||||||
String outName = sum[1].name();
|
String outName = sum[1].name();
|
||||||
|
|
||||||
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
|
sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false));
|
||||||
|
|
||||||
assertEquals(15, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0));
|
assertEquals(15, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Ignore
|
||||||
public void testNestedWhile() throws IOException {
|
public void testNestedWhile() throws IOException {
|
||||||
SameDiff SD = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable countIn = SD.constant(5);
|
SDVariable countIn = sd.constant(5);
|
||||||
SDVariable sumIn = SD.constant(0);
|
SDVariable sumIn = sd.constant(0);
|
||||||
SDVariable sum2 = SD.constant(0);
|
SDVariable sum2 = sd.constant(0);
|
||||||
//TODO creating constant instead of using sum2 causes errors
|
//TODO creating constant instead of using sum2 causes errors
|
||||||
|
|
||||||
SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn},
|
SDVariable[] sum = sd.whileLoop(new SDVariable[]{countIn, sumIn},
|
||||||
(sd, vars) -> vars[0].gt(0),
|
(s, vars) -> vars[0].gt(0),
|
||||||
(sd, vars) -> new SDVariable[]{vars[0].sub(1),
|
(s, vars) -> new SDVariable[]{vars[0].sub(1),
|
||||||
vars[1].add(sd.whileLoop(new SDVariable[]{vars[0], sum2},
|
vars[1].add(s.whileLoop(new SDVariable[]{vars[0], sum2},
|
||||||
(sd2, vars2) -> vars2[0].gt(0),
|
(sd2, vars2) -> vars2[0].gt(0),
|
||||||
(sd2, vars2) -> new SDVariable[]{vars2[0].sub(1), vars2[1].add(vars2[0])})[1])});
|
(sd2, vars2) -> new SDVariable[]{vars2[0].sub(1), vars2[1].add(vars2[0])})[1])});
|
||||||
|
|
||||||
|
@ -3305,23 +3304,23 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
|
|
||||||
String outName = sum[1].name();
|
String outName = sum[1].name();
|
||||||
|
|
||||||
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
|
sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false));
|
||||||
|
|
||||||
assertEquals(35, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0));
|
assertEquals(35, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNestedWhileIf() throws IOException {
|
public void testNestedWhileIf() throws IOException {
|
||||||
SameDiff SD = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable countIn = SD.constant(5);
|
SDVariable countIn = sd.constant(5);
|
||||||
SDVariable sumIn = SD.constant(0);
|
SDVariable sumIn = sd.constant(0);
|
||||||
SDVariable hundred = SD.constant(100);
|
SDVariable hundred = sd.constant(100);
|
||||||
|
|
||||||
SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn},
|
SDVariable[] sum = sd.whileLoop(new SDVariable[]{countIn, sumIn},
|
||||||
(sd, vars) -> vars[0].gte(0),
|
(s, vars) -> vars[0].gte(0),
|
||||||
(sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(
|
(s, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(
|
||||||
sd.ifCond((sd2) -> vars[0].eq(0),
|
s.ifCond((sd2) -> vars[0].eq(0),
|
||||||
(sd2) -> vars[0].add(100), //TODO replace with hundred and things break
|
(sd2) -> vars[0].add(100), //TODO replace with hundred and things break
|
||||||
(sd2) -> vars[0])
|
(sd2) -> vars[0])
|
||||||
)});
|
)});
|
||||||
|
@ -3331,9 +3330,9 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
|
|
||||||
String outName = sum[1].name();
|
String outName = sum[1].name();
|
||||||
|
|
||||||
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
|
sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false));
|
||||||
|
|
||||||
assertEquals(115, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0));
|
assertEquals(115, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class OpsMappingTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 180000L; //Can be slow on some CI machines such as PPC
|
return 360000L; //Can be very slow on some CI machines (PPC)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -29,7 +29,10 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.Concat;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
@ -473,6 +476,7 @@ public class OperationProfilerTests extends BaseNd4jTest {
|
||||||
Nd4j.exec(op); //Should trigger NaN panic
|
Nd4j.exec(op); //Should trigger NaN panic
|
||||||
fail();
|
fail();
|
||||||
} catch (Exception e){
|
} catch (Exception e){
|
||||||
|
e.printStackTrace();
|
||||||
assertTrue(e.getMessage(), e.getMessage().contains("Inf"));
|
assertTrue(e.getMessage(), e.getMessage().contains("Inf"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -488,4 +492,55 @@ public class OperationProfilerTests extends BaseNd4jTest {
|
||||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(false).build());
|
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(false).build());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOpProfilerOpContextLegacy(){
|
||||||
|
|
||||||
|
for(boolean nan : new boolean[]{true, false}) {
|
||||||
|
|
||||||
|
INDArray in = Nd4j.valueArrayOf(10, nan ? -1 : 0).castTo(DataType.FLOAT);
|
||||||
|
|
||||||
|
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(nan).checkForINF(!nan).build());
|
||||||
|
|
||||||
|
OpContext oc = Nd4j.getExecutioner().buildContext();
|
||||||
|
oc.setInputArray(0, in);
|
||||||
|
oc.setOutputArray(0, in.ulike());
|
||||||
|
try {
|
||||||
|
Nd4j.exec(new Log(), oc);
|
||||||
|
System.out.println(oc.getOutputArray(0));
|
||||||
|
fail("Expected op profiler exception");
|
||||||
|
} catch (Throwable t) {
|
||||||
|
//OK
|
||||||
|
assertTrue(t.getMessage(), t.getMessage().contains(nan ? "NaN" : "Inf"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOpProfilerOpContextCustomOp(){
|
||||||
|
|
||||||
|
for(boolean nan : new boolean[]{true, false}) {
|
||||||
|
|
||||||
|
INDArray in = Nd4j.create(DataType.DOUBLE, 10).assign(nan ? Double.NaN : Double.POSITIVE_INFINITY);
|
||||||
|
INDArray in2 = in.dup();
|
||||||
|
|
||||||
|
|
||||||
|
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(nan).checkForINF(!nan).build());
|
||||||
|
|
||||||
|
OpContext oc = Nd4j.getExecutioner().buildContext();
|
||||||
|
oc.setIArguments(0);
|
||||||
|
oc.setInputArray(0, in);
|
||||||
|
oc.setInputArray(1, in2);
|
||||||
|
oc.setOutputArray(0, Nd4j.create(DataType.DOUBLE, 20));
|
||||||
|
try {
|
||||||
|
Nd4j.exec(new Concat(), oc);
|
||||||
|
System.out.println(oc.getOutputArray(0));
|
||||||
|
fail("Expected op profiler exception");
|
||||||
|
} catch (Throwable t) {
|
||||||
|
//OK
|
||||||
|
assertTrue(t.getMessage(), t.getMessage().contains(nan ? "NaN" : "Inf"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3579,4 +3579,19 @@ public class ArrayUtil {
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static <T> T[] filterNull(T... in){
|
||||||
|
int count = 0;
|
||||||
|
for( int i=0; i<in.length; i++ ) {
|
||||||
|
if (in[i] != null) count++;
|
||||||
|
}
|
||||||
|
T[] out = (T[]) Array.newInstance(in.getClass().getComponentType(), count);
|
||||||
|
int j=0;
|
||||||
|
for( int i=0; i<in.length; i++ ){
|
||||||
|
if(in[i] != null){
|
||||||
|
out[j++] = in[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue