diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
index 452077238..625912f50 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
@@ -16,21 +16,58 @@
package org.nd4j.autodiff.samediff;
-import org.nd4j.shade.guava.base.Predicates;
-import org.nd4j.shade.guava.collect.HashBasedTable;
-import org.nd4j.shade.guava.collect.Maps;
-import org.nd4j.shade.guava.collect.Table;
-import org.nd4j.shade.guava.primitives.Ints;
+import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
+
import com.google.flatbuffers.FlatBufferBuilder;
-import lombok.*;
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.lang.reflect.Method;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.IdentityHashMap;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Set;
+import java.util.Stack;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Getter;
+import lombok.NonNull;
+import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
+import lombok.val;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
-import org.nd4j.autodiff.listeners.*;
+import org.nd4j.autodiff.listeners.At;
+import org.nd4j.autodiff.listeners.Listener;
+import org.nd4j.autodiff.listeners.ListenerResponse;
+import org.nd4j.autodiff.listeners.Loss;
+import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.impl.HistoryListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve;
@@ -38,14 +75,34 @@ import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
import org.nd4j.autodiff.samediff.config.FitConfig;
import org.nd4j.autodiff.samediff.config.OutputConfig;
-import org.nd4j.autodiff.samediff.internal.*;
-import org.nd4j.autodiff.samediff.ops.*;
+import org.nd4j.autodiff.samediff.internal.AbstractSession;
+import org.nd4j.autodiff.samediff.internal.DataTypesSession;
+import org.nd4j.autodiff.samediff.internal.InferenceSession;
+import org.nd4j.autodiff.samediff.internal.SameDiffOp;
+import org.nd4j.autodiff.samediff.internal.Variable;
+import org.nd4j.autodiff.samediff.ops.SDBaseOps;
+import org.nd4j.autodiff.samediff.ops.SDBitwise;
+import org.nd4j.autodiff.samediff.ops.SDCNN;
+import org.nd4j.autodiff.samediff.ops.SDImage;
+import org.nd4j.autodiff.samediff.ops.SDLoss;
+import org.nd4j.autodiff.samediff.ops.SDMath;
+import org.nd4j.autodiff.samediff.ops.SDNN;
+import org.nd4j.autodiff.samediff.ops.SDRNN;
+import org.nd4j.autodiff.samediff.ops.SDRandom;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC;
-import org.nd4j.graph.*;
+import org.nd4j.graph.ExecutionMode;
+import org.nd4j.graph.FlatArray;
+import org.nd4j.graph.FlatConfiguration;
+import org.nd4j.graph.FlatGraph;
+import org.nd4j.graph.FlatNode;
+import org.nd4j.graph.FlatVariable;
+import org.nd4j.graph.IntPair;
+import org.nd4j.graph.OpType;
+import org.nd4j.graph.UpdaterState;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
@@ -84,23 +141,17 @@ import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.nd4j.linalg.util.ND4JFileUtils;
+import org.nd4j.shade.guava.base.Predicates;
+import org.nd4j.shade.guava.collect.HashBasedTable;
+import org.nd4j.shade.guava.collect.Maps;
+import org.nd4j.shade.guava.collect.Table;
+import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.GraphDef;
-import java.io.*;
-import java.lang.reflect.Method;
-import java.nio.ByteBuffer;
-import java.util.*;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.regex.Matcher;
-import java.util.regex.Pattern;
-
-import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
-
/**
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
*
@@ -2064,8 +2115,7 @@ public class SameDiff extends SDBaseOps {
List activeListeners = new ArrayList<>();
- if (!history.evaluations().isEmpty())
- activeListeners.add(history);
+ activeListeners.add(history);
for (Listener l : this.listeners)
if (l.isActive(Operation.TRAINING))
@@ -2102,6 +2152,9 @@ public class SameDiff extends SDBaseOps {
requiredVars.addAll(l.requiredVariables(this).trainingVariables());
}
+ ArrayList listenersWitHistory = new ArrayList<>(listeners);
+ listenersWitHistory.add(history);
+
for (int i = 0; i < numEpochs; i++) {
if (incrementEpochCount && hasListeners) {
@@ -2236,53 +2289,52 @@ public class SameDiff extends SDBaseOps {
}
}
+ double[] d = new double[lossVariables.size() + regScore.size()];
+ List lossVars;
+ if (regScore.size() > 0) {
+ lossVars = new ArrayList<>(lossVariables.size() + regScore.size());
+ lossVars.addAll(lossVariables);
+ int s = regScore.size();
+ //Collect regularization losses
+ for (Map.Entry, AtomicDouble> entry : regScore.entrySet()) {
+ lossVars.add(entry.getKey().getSimpleName());
+ d[s] = entry.getValue().get();
+ }
+ } else {
+ lossVars = lossVariables;
+ }
+
+ //Collect the losses...
+ SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY);
+ int count = 0;
+ for (String s : lossVariables) {
+ INDArray arr = gradFn.getArrForVarName(s);
+ double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue();
+ d[count++] = l;
+ }
+
+ Loss loss = new Loss(lossVars, d);
+
+ if (lossNames == null) {
+ lossNames = lossVars;
+ } else {
+ Preconditions.checkState(lossNames.equals(lossVars),
+ "Loss names mismatch, expected: %s, got: %s", lossNames, lossVars);
+ }
+
+ if (lossSums == null) {
+ lossSums = d;
+ } else {
+ Preconditions.checkState(lossNames.equals(lossVars),
+ "Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length);
+
+ for (int j = 0; j < lossSums.length; j++) {
+ lossSums[j] += d[j];
+ }
+ }
+ lossCount++;
+
if (hasListeners) {
- double[] d = new double[lossVariables.size() + regScore.size()];
- List lossVars;
- if (regScore.size() > 0) {
- lossVars = new ArrayList<>(lossVariables.size() + regScore.size());
- lossVars.addAll(lossVariables);
- int s = regScore.size();
- //Collect regularization losses
- for (Map.Entry, AtomicDouble> entry : regScore.entrySet()) {
- lossVars.add(entry.getKey().getSimpleName());
- d[s] = entry.getValue().get();
- }
- } else {
- lossVars = lossVariables;
- }
-
-
- //Collect the losses...
- SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY);
- int count = 0;
- for (String s : lossVariables) {
- INDArray arr = gradFn.getArrForVarName(s);
- double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue();
- d[count++] = l;
- }
-
- Loss loss = new Loss(lossVars, d);
-
- if (lossNames == null) {
- lossNames = lossVars;
- } else {
- Preconditions.checkState(lossNames.equals(lossVars),
- "Loss names mismatch, expected: %s, got: %s", lossNames, lossVars);
- }
-
- if (lossSums == null) {
- lossSums = d;
- } else {
- Preconditions.checkState(lossNames.equals(lossVars),
- "Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length);
-
- for (int j = 0; j < lossSums.length; j++) {
- lossSums[j] += d[j];
- }
- }
- lossCount++;
-
for (Listener l : activeListeners) {
l.iterationDone(this, at, ds, loss);
}
@@ -2294,7 +2346,7 @@ public class SameDiff extends SDBaseOps {
long epochTime = System.currentTimeMillis() - epochStartTime;
- if (incrementEpochCount && hasListeners) {
+ if (incrementEpochCount) {
for (int j = 0; j < lossSums.length; j++)
lossSums[j] /= lossCount;
@@ -2341,7 +2393,7 @@ public class SameDiff extends SDBaseOps {
long validationStart = System.currentTimeMillis();
outputHelper(validationData, new At(at.epoch(), 0, 0, 0, Operation.TRAINING_VALIDATION),
- listeners);
+ listenersWitHistory);
long validationTime = System.currentTimeMillis() - validationStart;
@@ -2958,7 +3010,7 @@ public class SameDiff extends SDBaseOps {
List neededOutputs;
- if (outputs != null) {
+ if (outputs != null && outputs.length != 0) {
neededOutputs = Arrays.asList(outputs);
} else {
neededOutputs = outputs();
@@ -4618,8 +4670,8 @@ public class SameDiff extends SDBaseOps {
}
}
- //Also add loss values - we need these so we can report them to listeners...
- if (!listeners.isEmpty()) {
+ //Also add loss values - we need these so we can report them to listeners or loss curves...
+ if (!activeListeners.isEmpty() || op == Operation.TRAINING) {
varGradNames.addAll(lossVariables);
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java
index e56ab5988..376430f54 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java
@@ -31,6 +31,9 @@ import org.nd4j.autodiff.util.TrainingUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
+import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
+import org.nd4j.linalg.dataset.api.DataSet;
+import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
@@ -94,6 +97,20 @@ public class OutputConfig {
return this;
}
+ /**
+ * Set the data to use as input.
+ */
+ public OutputConfig data(@NonNull DataSet data){
+ return data(new SingletonMultiDataSetIterator(data.toMultiDataSet()));
+ }
+
+ /**
+ * Set the data to use as input.
+ */
+ public OutputConfig data(@NonNull MultiDataSet data){
+ return data(new SingletonMultiDataSetIterator(data));
+ }
+
/**
* Add listeners for this operation
*/
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java
index 04d6e9551..9361c47f1 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java
@@ -16,9 +16,16 @@
package org.nd4j.autodiff.samediff;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.nd4j.autodiff.listeners.impl.ScoreListener;
+import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.BaseNd4jTest;
@@ -29,16 +36,14 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
-import org.nd4j.linalg.learning.config.*;
+import org.nd4j.linalg.learning.config.AMSGrad;
+import org.nd4j.linalg.learning.config.AdaMax;
+import org.nd4j.linalg.learning.config.Adam;
+import org.nd4j.linalg.learning.config.IUpdater;
+import org.nd4j.linalg.learning.config.Nesterovs;
+import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.weightinit.impl.XavierInitScheme;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
-import static org.junit.Assert.assertTrue;
-
@Slf4j
public class SameDiffTrainingTest extends BaseNd4jTest {
@@ -118,6 +123,110 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
}
+ @Test
+ public void irisTrainingEvalTest() {
+
+ DataSetIterator iter = new IrisDataSetIterator(150, 150);
+ NormalizerStandardize std = new NormalizerStandardize();
+ std.fit(iter);
+ iter.setPreProcessor(std);
+
+ Nd4j.getRandom().setSeed(12345);
+ SameDiff sd = SameDiff.create();
+
+ SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
+ SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
+
+ SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.FLOAT, 4, 10);
+ SDVariable b0 = sd.zero("b0", DataType.FLOAT, 1, 10);
+
+ SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.FLOAT, 10, 3);
+ SDVariable b1 = sd.zero("b1", DataType.FLOAT, 1, 3);
+
+ SDVariable z0 = in.mmul(w0).add(b0);
+ SDVariable a0 = sd.math().tanh(z0);
+ SDVariable z1 = a0.mmul(w1).add("prediction", b1);
+ SDVariable a1 = sd.nn().softmax(z1);
+
+ SDVariable diff = sd.f().squaredDifference(a1, label);
+ SDVariable lossMse = diff.mul(diff).mean();
+
+ TrainingConfig conf = new TrainingConfig.Builder()
+ .l2(1e-4)
+ .updater(new Adam(1e-2))
+ .dataSetFeatureMapping("input")
+ .dataSetLabelMapping("label")
+ .trainEvaluation("prediction", 0, new Evaluation())
+ .build();
+
+ sd.setTrainingConfig(conf);
+
+ History hist = sd.fit().train(iter, 50).exec();
+
+ Evaluation e = hist.finalTrainingEvaluations().evaluation("prediction");
+
+ System.out.println(e.stats());
+
+ double acc = e.accuracy();
+
+ assertTrue("Accuracy bad: " + acc, acc >= 0.75);
+ }
+
+
+ @Test
+ public void irisTrainingValidationTest() {
+
+ DataSetIterator iter = new IrisDataSetIterator(150, 150);
+ NormalizerStandardize std = new NormalizerStandardize();
+ std.fit(iter);
+ iter.setPreProcessor(std);
+
+ DataSetIterator valIter = new IrisDataSetIterator(30, 60);
+ NormalizerStandardize valStd = new NormalizerStandardize();
+ valStd.fit(valIter);
+ valIter.setPreProcessor(std);
+
+ Nd4j.getRandom().setSeed(12345);
+ SameDiff sd = SameDiff.create();
+
+ SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
+ SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
+
+ SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.FLOAT, 4, 10);
+ SDVariable b0 = sd.zero("b0", DataType.FLOAT, 1, 10);
+
+ SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.FLOAT, 10, 3);
+ SDVariable b1 = sd.zero("b1", DataType.FLOAT, 1, 3);
+
+ SDVariable z0 = in.mmul(w0).add(b0);
+ SDVariable a0 = sd.math().tanh(z0);
+ SDVariable z1 = a0.mmul(w1).add("prediction", b1);
+ SDVariable a1 = sd.nn().softmax(z1);
+
+ SDVariable diff = sd.f().squaredDifference(a1, label);
+ SDVariable lossMse = diff.mul(diff).mean();
+
+ TrainingConfig conf = new TrainingConfig.Builder()
+ .l2(1e-4)
+ .updater(new Adam(1e-2))
+ .dataSetFeatureMapping("input")
+ .dataSetLabelMapping("label")
+ .validationEvaluation("prediction", 0, new Evaluation())
+ .build();
+
+ sd.setTrainingConfig(conf);
+
+ History hist = sd.fit().train(iter, 50).validate(valIter, 5).exec();
+
+ Evaluation e = hist.finalValidationEvaluations().evaluation("prediction");
+
+ System.out.println(e.stats());
+
+ double acc = e.accuracy();
+
+ assertTrue("Accuracy bad: " + acc, acc >= 0.75);
+ }
+
@Test
public void testTrainingMixedDtypes(){