diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 452077238..625912f50 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -16,21 +16,58 @@ package org.nd4j.autodiff.samediff; -import org.nd4j.shade.guava.base.Predicates; -import org.nd4j.shade.guava.collect.HashBasedTable; -import org.nd4j.shade.guava.collect.Maps; -import org.nd4j.shade.guava.collect.Table; -import org.nd4j.shade.guava.primitives.Ints; +import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; + import com.google.flatbuffers.FlatBufferBuilder; -import lombok.*; +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.IdentityHashMap; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.Stack; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunctionFactory; -import org.nd4j.autodiff.listeners.*; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.Listener; +import org.nd4j.autodiff.listeners.ListenerResponse; +import org.nd4j.autodiff.listeners.Loss; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.listeners.impl.HistoryListener; import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.listeners.records.LossCurve; @@ -38,14 +75,34 @@ import org.nd4j.autodiff.samediff.config.BatchOutputConfig; import org.nd4j.autodiff.samediff.config.EvaluationConfig; import org.nd4j.autodiff.samediff.config.FitConfig; import org.nd4j.autodiff.samediff.config.OutputConfig; -import org.nd4j.autodiff.samediff.internal.*; -import org.nd4j.autodiff.samediff.ops.*; +import org.nd4j.autodiff.samediff.internal.AbstractSession; +import org.nd4j.autodiff.samediff.internal.DataTypesSession; +import org.nd4j.autodiff.samediff.internal.InferenceSession; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; +import org.nd4j.autodiff.samediff.ops.SDBaseOps; +import org.nd4j.autodiff.samediff.ops.SDBitwise; +import org.nd4j.autodiff.samediff.ops.SDCNN; +import org.nd4j.autodiff.samediff.ops.SDImage; +import org.nd4j.autodiff.samediff.ops.SDLoss; +import org.nd4j.autodiff.samediff.ops.SDMath; +import org.nd4j.autodiff.samediff.ops.SDNN; +import org.nd4j.autodiff.samediff.ops.SDRNN; +import org.nd4j.autodiff.samediff.ops.SDRandom; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROC; -import org.nd4j.graph.*; +import org.nd4j.graph.ExecutionMode; +import org.nd4j.graph.FlatArray; +import org.nd4j.graph.FlatConfiguration; +import org.nd4j.graph.FlatGraph; +import org.nd4j.graph.FlatNode; +import org.nd4j.graph.FlatVariable; +import org.nd4j.graph.IntPair; +import org.nd4j.graph.OpType; +import org.nd4j.graph.UpdaterState; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -84,23 +141,17 @@ import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.DeviceLocalNDArray; import org.nd4j.linalg.util.ND4JFileUtils; +import org.nd4j.shade.guava.base.Predicates; +import org.nd4j.shade.guava.collect.HashBasedTable; +import org.nd4j.shade.guava.collect.Maps; +import org.nd4j.shade.guava.collect.Table; +import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.weightinit.WeightInitScheme; import org.nd4j.weightinit.impl.ConstantInitScheme; import org.nd4j.weightinit.impl.NDArraySupplierInitScheme; import org.nd4j.weightinit.impl.ZeroInitScheme; import org.tensorflow.framework.GraphDef; -import java.io.*; -import java.lang.reflect.Method; -import java.nio.ByteBuffer; -import java.util.*; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; - /** * SameDiff is the entrypoint for ND4J's automatic differentiation functionality. *

@@ -2064,8 +2115,7 @@ public class SameDiff extends SDBaseOps { List activeListeners = new ArrayList<>(); - if (!history.evaluations().isEmpty()) - activeListeners.add(history); + activeListeners.add(history); for (Listener l : this.listeners) if (l.isActive(Operation.TRAINING)) @@ -2102,6 +2152,9 @@ public class SameDiff extends SDBaseOps { requiredVars.addAll(l.requiredVariables(this).trainingVariables()); } + ArrayList listenersWitHistory = new ArrayList<>(listeners); + listenersWitHistory.add(history); + for (int i = 0; i < numEpochs; i++) { if (incrementEpochCount && hasListeners) { @@ -2236,53 +2289,52 @@ public class SameDiff extends SDBaseOps { } } + double[] d = new double[lossVariables.size() + regScore.size()]; + List lossVars; + if (regScore.size() > 0) { + lossVars = new ArrayList<>(lossVariables.size() + regScore.size()); + lossVars.addAll(lossVariables); + int s = regScore.size(); + //Collect regularization losses + for (Map.Entry, AtomicDouble> entry : regScore.entrySet()) { + lossVars.add(entry.getKey().getSimpleName()); + d[s] = entry.getValue().get(); + } + } else { + lossVars = lossVariables; + } + + //Collect the losses... + SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY); + int count = 0; + for (String s : lossVariables) { + INDArray arr = gradFn.getArrForVarName(s); + double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue(); + d[count++] = l; + } + + Loss loss = new Loss(lossVars, d); + + if (lossNames == null) { + lossNames = lossVars; + } else { + Preconditions.checkState(lossNames.equals(lossVars), + "Loss names mismatch, expected: %s, got: %s", lossNames, lossVars); + } + + if (lossSums == null) { + lossSums = d; + } else { + Preconditions.checkState(lossNames.equals(lossVars), + "Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length); + + for (int j = 0; j < lossSums.length; j++) { + lossSums[j] += d[j]; + } + } + lossCount++; + if (hasListeners) { - double[] d = new double[lossVariables.size() + regScore.size()]; - List lossVars; - if (regScore.size() > 0) { - lossVars = new ArrayList<>(lossVariables.size() + regScore.size()); - lossVars.addAll(lossVariables); - int s = regScore.size(); - //Collect regularization losses - for (Map.Entry, AtomicDouble> entry : regScore.entrySet()) { - lossVars.add(entry.getKey().getSimpleName()); - d[s] = entry.getValue().get(); - } - } else { - lossVars = lossVariables; - } - - - //Collect the losses... - SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY); - int count = 0; - for (String s : lossVariables) { - INDArray arr = gradFn.getArrForVarName(s); - double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue(); - d[count++] = l; - } - - Loss loss = new Loss(lossVars, d); - - if (lossNames == null) { - lossNames = lossVars; - } else { - Preconditions.checkState(lossNames.equals(lossVars), - "Loss names mismatch, expected: %s, got: %s", lossNames, lossVars); - } - - if (lossSums == null) { - lossSums = d; - } else { - Preconditions.checkState(lossNames.equals(lossVars), - "Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length); - - for (int j = 0; j < lossSums.length; j++) { - lossSums[j] += d[j]; - } - } - lossCount++; - for (Listener l : activeListeners) { l.iterationDone(this, at, ds, loss); } @@ -2294,7 +2346,7 @@ public class SameDiff extends SDBaseOps { long epochTime = System.currentTimeMillis() - epochStartTime; - if (incrementEpochCount && hasListeners) { + if (incrementEpochCount) { for (int j = 0; j < lossSums.length; j++) lossSums[j] /= lossCount; @@ -2341,7 +2393,7 @@ public class SameDiff extends SDBaseOps { long validationStart = System.currentTimeMillis(); outputHelper(validationData, new At(at.epoch(), 0, 0, 0, Operation.TRAINING_VALIDATION), - listeners); + listenersWitHistory); long validationTime = System.currentTimeMillis() - validationStart; @@ -2958,7 +3010,7 @@ public class SameDiff extends SDBaseOps { List neededOutputs; - if (outputs != null) { + if (outputs != null && outputs.length != 0) { neededOutputs = Arrays.asList(outputs); } else { neededOutputs = outputs(); @@ -4618,8 +4670,8 @@ public class SameDiff extends SDBaseOps { } } - //Also add loss values - we need these so we can report them to listeners... - if (!listeners.isEmpty()) { + //Also add loss values - we need these so we can report them to listeners or loss curves... + if (!activeListeners.isEmpty() || op == Operation.TRAINING) { varGradNames.addAll(lossVariables); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java index e56ab5988..376430f54 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java @@ -31,6 +31,9 @@ import org.nd4j.autodiff.util.TrainingUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; +import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; @@ -94,6 +97,20 @@ public class OutputConfig { return this; } + /** + * Set the data to use as input. + */ + public OutputConfig data(@NonNull DataSet data){ + return data(new SingletonMultiDataSetIterator(data.toMultiDataSet())); + } + + /** + * Set the data to use as input. + */ + public OutputConfig data(@NonNull MultiDataSet data){ + return data(new SingletonMultiDataSetIterator(data)); + } + /** * Add listeners for this operation */ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java index 04d6e9551..9361c47f1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java @@ -16,9 +16,16 @@ package org.nd4j.autodiff.samediff; +import static org.junit.Assert.assertTrue; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import lombok.extern.slf4j.Slf4j; import org.junit.Test; import org.nd4j.autodiff.listeners.impl.ScoreListener; +import org.nd4j.autodiff.listeners.records.History; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.BaseNd4jTest; @@ -29,16 +36,14 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import org.nd4j.linalg.learning.config.*; +import org.nd4j.linalg.learning.config.AMSGrad; +import org.nd4j.linalg.learning.config.AdaMax; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.learning.config.Nesterovs; +import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.weightinit.impl.XavierInitScheme; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.assertTrue; - @Slf4j public class SameDiffTrainingTest extends BaseNd4jTest { @@ -118,6 +123,110 @@ public class SameDiffTrainingTest extends BaseNd4jTest { } + @Test + public void irisTrainingEvalTest() { + + DataSetIterator iter = new IrisDataSetIterator(150, 150); + NormalizerStandardize std = new NormalizerStandardize(); + std.fit(iter); + iter.setPreProcessor(std); + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + + SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); + + SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.FLOAT, 4, 10); + SDVariable b0 = sd.zero("b0", DataType.FLOAT, 1, 10); + + SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.FLOAT, 10, 3); + SDVariable b1 = sd.zero("b1", DataType.FLOAT, 1, 3); + + SDVariable z0 = in.mmul(w0).add(b0); + SDVariable a0 = sd.math().tanh(z0); + SDVariable z1 = a0.mmul(w1).add("prediction", b1); + SDVariable a1 = sd.nn().softmax(z1); + + SDVariable diff = sd.f().squaredDifference(a1, label); + SDVariable lossMse = diff.mul(diff).mean(); + + TrainingConfig conf = new TrainingConfig.Builder() + .l2(1e-4) + .updater(new Adam(1e-2)) + .dataSetFeatureMapping("input") + .dataSetLabelMapping("label") + .trainEvaluation("prediction", 0, new Evaluation()) + .build(); + + sd.setTrainingConfig(conf); + + History hist = sd.fit().train(iter, 50).exec(); + + Evaluation e = hist.finalTrainingEvaluations().evaluation("prediction"); + + System.out.println(e.stats()); + + double acc = e.accuracy(); + + assertTrue("Accuracy bad: " + acc, acc >= 0.75); + } + + + @Test + public void irisTrainingValidationTest() { + + DataSetIterator iter = new IrisDataSetIterator(150, 150); + NormalizerStandardize std = new NormalizerStandardize(); + std.fit(iter); + iter.setPreProcessor(std); + + DataSetIterator valIter = new IrisDataSetIterator(30, 60); + NormalizerStandardize valStd = new NormalizerStandardize(); + valStd.fit(valIter); + valIter.setPreProcessor(std); + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + + SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); + + SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.FLOAT, 4, 10); + SDVariable b0 = sd.zero("b0", DataType.FLOAT, 1, 10); + + SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.FLOAT, 10, 3); + SDVariable b1 = sd.zero("b1", DataType.FLOAT, 1, 3); + + SDVariable z0 = in.mmul(w0).add(b0); + SDVariable a0 = sd.math().tanh(z0); + SDVariable z1 = a0.mmul(w1).add("prediction", b1); + SDVariable a1 = sd.nn().softmax(z1); + + SDVariable diff = sd.f().squaredDifference(a1, label); + SDVariable lossMse = diff.mul(diff).mean(); + + TrainingConfig conf = new TrainingConfig.Builder() + .l2(1e-4) + .updater(new Adam(1e-2)) + .dataSetFeatureMapping("input") + .dataSetLabelMapping("label") + .validationEvaluation("prediction", 0, new Evaluation()) + .build(); + + sd.setTrainingConfig(conf); + + History hist = sd.fit().train(iter, 50).validate(valIter, 5).exec(); + + Evaluation e = hist.finalValidationEvaluations().evaluation("prediction"); + + System.out.println(e.stats()); + + double acc = e.accuracy(); + + assertTrue("Accuracy bad: " + acc, acc >= 0.75); + } + @Test public void testTrainingMixedDtypes(){