Fix a couple SameDiff training issues (#253)

* fix execBackwards training issue

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fix validation not specifying outputs

Signed-off-by: Ryan Nett <rnett@skymind.io>

* another fix for validation listeners and history

Signed-off-by: Ryan Nett <rnett@skymind.io>

* tests

Signed-off-by: Ryan Nett <rnett@skymind.io>

* add single batch dataset output methods

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-09-10 20:38:23 -07:00 committed by GitHub
parent 4f7b35ac82
commit 8a05ec2a97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 260 additions and 82 deletions

View File

@ -16,21 +16,58 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import org.nd4j.shade.guava.base.Predicates; import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Maps;
import org.nd4j.shade.guava.collect.Table;
import org.nd4j.shade.guava.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import lombok.*; import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.Stack;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.listeners.*; import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.impl.HistoryListener; import org.nd4j.autodiff.listeners.impl.HistoryListener;
import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve; import org.nd4j.autodiff.listeners.records.LossCurve;
@ -38,14 +75,34 @@ import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
import org.nd4j.autodiff.samediff.config.EvaluationConfig; import org.nd4j.autodiff.samediff.config.EvaluationConfig;
import org.nd4j.autodiff.samediff.config.FitConfig; import org.nd4j.autodiff.samediff.config.FitConfig;
import org.nd4j.autodiff.samediff.config.OutputConfig; import org.nd4j.autodiff.samediff.config.OutputConfig;
import org.nd4j.autodiff.samediff.internal.*; import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.ops.*; import org.nd4j.autodiff.samediff.internal.DataTypesSession;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.ops.SDBaseOps;
import org.nd4j.autodiff.samediff.ops.SDBitwise;
import org.nd4j.autodiff.samediff.ops.SDCNN;
import org.nd4j.autodiff.samediff.ops.SDImage;
import org.nd4j.autodiff.samediff.ops.SDLoss;
import org.nd4j.autodiff.samediff.ops.SDMath;
import org.nd4j.autodiff.samediff.ops.SDNN;
import org.nd4j.autodiff.samediff.ops.SDRNN;
import org.nd4j.autodiff.samediff.ops.SDRandom;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROC;
import org.nd4j.graph.*; import org.nd4j.graph.ExecutionMode;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatConfiguration;
import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair;
import org.nd4j.graph.OpType;
import org.nd4j.graph.UpdaterState;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -84,23 +141,17 @@ import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.DeviceLocalNDArray; import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.nd4j.linalg.util.ND4JFileUtils; import org.nd4j.linalg.util.ND4JFileUtils;
import org.nd4j.shade.guava.base.Predicates;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Maps;
import org.nd4j.shade.guava.collect.Table;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.weightinit.WeightInitScheme; import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme; import org.nd4j.weightinit.impl.ConstantInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme; import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme; import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import java.io.*;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
/** /**
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality. * SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
* <p> * <p>
@ -2064,8 +2115,7 @@ public class SameDiff extends SDBaseOps {
List<Listener> activeListeners = new ArrayList<>(); List<Listener> activeListeners = new ArrayList<>();
if (!history.evaluations().isEmpty()) activeListeners.add(history);
activeListeners.add(history);
for (Listener l : this.listeners) for (Listener l : this.listeners)
if (l.isActive(Operation.TRAINING)) if (l.isActive(Operation.TRAINING))
@ -2102,6 +2152,9 @@ public class SameDiff extends SDBaseOps {
requiredVars.addAll(l.requiredVariables(this).trainingVariables()); requiredVars.addAll(l.requiredVariables(this).trainingVariables());
} }
ArrayList<Listener> listenersWitHistory = new ArrayList<>(listeners);
listenersWitHistory.add(history);
for (int i = 0; i < numEpochs; i++) { for (int i = 0; i < numEpochs; i++) {
if (incrementEpochCount && hasListeners) { if (incrementEpochCount && hasListeners) {
@ -2236,53 +2289,52 @@ public class SameDiff extends SDBaseOps {
} }
} }
double[] d = new double[lossVariables.size() + regScore.size()];
List<String> lossVars;
if (regScore.size() > 0) {
lossVars = new ArrayList<>(lossVariables.size() + regScore.size());
lossVars.addAll(lossVariables);
int s = regScore.size();
//Collect regularization losses
for (Map.Entry<Class<?>, AtomicDouble> entry : regScore.entrySet()) {
lossVars.add(entry.getKey().getSimpleName());
d[s] = entry.getValue().get();
}
} else {
lossVars = lossVariables;
}
//Collect the losses...
SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY);
int count = 0;
for (String s : lossVariables) {
INDArray arr = gradFn.getArrForVarName(s);
double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue();
d[count++] = l;
}
Loss loss = new Loss(lossVars, d);
if (lossNames == null) {
lossNames = lossVars;
} else {
Preconditions.checkState(lossNames.equals(lossVars),
"Loss names mismatch, expected: %s, got: %s", lossNames, lossVars);
}
if (lossSums == null) {
lossSums = d;
} else {
Preconditions.checkState(lossNames.equals(lossVars),
"Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length);
for (int j = 0; j < lossSums.length; j++) {
lossSums[j] += d[j];
}
}
lossCount++;
if (hasListeners) { if (hasListeners) {
double[] d = new double[lossVariables.size() + regScore.size()];
List<String> lossVars;
if (regScore.size() > 0) {
lossVars = new ArrayList<>(lossVariables.size() + regScore.size());
lossVars.addAll(lossVariables);
int s = regScore.size();
//Collect regularization losses
for (Map.Entry<Class<?>, AtomicDouble> entry : regScore.entrySet()) {
lossVars.add(entry.getKey().getSimpleName());
d[s] = entry.getValue().get();
}
} else {
lossVars = lossVariables;
}
//Collect the losses...
SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY);
int count = 0;
for (String s : lossVariables) {
INDArray arr = gradFn.getArrForVarName(s);
double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue();
d[count++] = l;
}
Loss loss = new Loss(lossVars, d);
if (lossNames == null) {
lossNames = lossVars;
} else {
Preconditions.checkState(lossNames.equals(lossVars),
"Loss names mismatch, expected: %s, got: %s", lossNames, lossVars);
}
if (lossSums == null) {
lossSums = d;
} else {
Preconditions.checkState(lossNames.equals(lossVars),
"Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length);
for (int j = 0; j < lossSums.length; j++) {
lossSums[j] += d[j];
}
}
lossCount++;
for (Listener l : activeListeners) { for (Listener l : activeListeners) {
l.iterationDone(this, at, ds, loss); l.iterationDone(this, at, ds, loss);
} }
@ -2294,7 +2346,7 @@ public class SameDiff extends SDBaseOps {
long epochTime = System.currentTimeMillis() - epochStartTime; long epochTime = System.currentTimeMillis() - epochStartTime;
if (incrementEpochCount && hasListeners) { if (incrementEpochCount) {
for (int j = 0; j < lossSums.length; j++) for (int j = 0; j < lossSums.length; j++)
lossSums[j] /= lossCount; lossSums[j] /= lossCount;
@ -2341,7 +2393,7 @@ public class SameDiff extends SDBaseOps {
long validationStart = System.currentTimeMillis(); long validationStart = System.currentTimeMillis();
outputHelper(validationData, new At(at.epoch(), 0, 0, 0, Operation.TRAINING_VALIDATION), outputHelper(validationData, new At(at.epoch(), 0, 0, 0, Operation.TRAINING_VALIDATION),
listeners); listenersWitHistory);
long validationTime = System.currentTimeMillis() - validationStart; long validationTime = System.currentTimeMillis() - validationStart;
@ -2958,7 +3010,7 @@ public class SameDiff extends SDBaseOps {
List<String> neededOutputs; List<String> neededOutputs;
if (outputs != null) { if (outputs != null && outputs.length != 0) {
neededOutputs = Arrays.asList(outputs); neededOutputs = Arrays.asList(outputs);
} else { } else {
neededOutputs = outputs(); neededOutputs = outputs();
@ -4618,8 +4670,8 @@ public class SameDiff extends SDBaseOps {
} }
} }
//Also add loss values - we need these so we can report them to listeners... //Also add loss values - we need these so we can report them to listeners or loss curves...
if (!listeners.isEmpty()) { if (!activeListeners.isEmpty() || op == Operation.TRAINING) {
varGradNames.addAll(lossVariables); varGradNames.addAll(lossVariables);
} }

View File

@ -31,6 +31,9 @@ import org.nd4j.autodiff.util.TrainingUtils;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
@ -94,6 +97,20 @@ public class OutputConfig {
return this; return this;
} }
/**
* Set the data to use as input.
*/
public OutputConfig data(@NonNull DataSet data){
return data(new SingletonMultiDataSetIterator(data.toMultiDataSet()));
}
/**
* Set the data to use as input.
*/
public OutputConfig data(@NonNull MultiDataSet data){
return data(new SingletonMultiDataSetIterator(data));
}
/** /**
* Add listeners for this operation * Add listeners for this operation
*/ */

View File

@ -16,9 +16,16 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import static org.junit.Assert.assertTrue;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.Test; import org.junit.Test;
import org.nd4j.autodiff.listeners.impl.ScoreListener; import org.nd4j.autodiff.listeners.impl.ScoreListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTest;
@ -29,16 +36,14 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.*; import org.nd4j.linalg.learning.config.AMSGrad;
import org.nd4j.linalg.learning.config.AdaMax;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.weightinit.impl.XavierInitScheme; import org.nd4j.weightinit.impl.XavierInitScheme;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertTrue;
@Slf4j @Slf4j
public class SameDiffTrainingTest extends BaseNd4jTest { public class SameDiffTrainingTest extends BaseNd4jTest {
@ -118,6 +123,110 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
} }
@Test
public void irisTrainingEvalTest() {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize();
std.fit(iter);
iter.setPreProcessor(std);
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.FLOAT, 4, 10);
SDVariable b0 = sd.zero("b0", DataType.FLOAT, 1, 10);
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.FLOAT, 10, 3);
SDVariable b1 = sd.zero("b1", DataType.FLOAT, 1, 3);
SDVariable z0 = in.mmul(w0).add(b0);
SDVariable a0 = sd.math().tanh(z0);
SDVariable z1 = a0.mmul(w1).add("prediction", b1);
SDVariable a1 = sd.nn().softmax(z1);
SDVariable diff = sd.f().squaredDifference(a1, label);
SDVariable lossMse = diff.mul(diff).mean();
TrainingConfig conf = new TrainingConfig.Builder()
.l2(1e-4)
.updater(new Adam(1e-2))
.dataSetFeatureMapping("input")
.dataSetLabelMapping("label")
.trainEvaluation("prediction", 0, new Evaluation())
.build();
sd.setTrainingConfig(conf);
History hist = sd.fit().train(iter, 50).exec();
Evaluation e = hist.finalTrainingEvaluations().evaluation("prediction");
System.out.println(e.stats());
double acc = e.accuracy();
assertTrue("Accuracy bad: " + acc, acc >= 0.75);
}
@Test
public void irisTrainingValidationTest() {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize();
std.fit(iter);
iter.setPreProcessor(std);
DataSetIterator valIter = new IrisDataSetIterator(30, 60);
NormalizerStandardize valStd = new NormalizerStandardize();
valStd.fit(valIter);
valIter.setPreProcessor(std);
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.FLOAT, 4, 10);
SDVariable b0 = sd.zero("b0", DataType.FLOAT, 1, 10);
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.FLOAT, 10, 3);
SDVariable b1 = sd.zero("b1", DataType.FLOAT, 1, 3);
SDVariable z0 = in.mmul(w0).add(b0);
SDVariable a0 = sd.math().tanh(z0);
SDVariable z1 = a0.mmul(w1).add("prediction", b1);
SDVariable a1 = sd.nn().softmax(z1);
SDVariable diff = sd.f().squaredDifference(a1, label);
SDVariable lossMse = diff.mul(diff).mean();
TrainingConfig conf = new TrainingConfig.Builder()
.l2(1e-4)
.updater(new Adam(1e-2))
.dataSetFeatureMapping("input")
.dataSetLabelMapping("label")
.validationEvaluation("prediction", 0, new Evaluation())
.build();
sd.setTrainingConfig(conf);
History hist = sd.fit().train(iter, 50).validate(valIter, 5).exec();
Evaluation e = hist.finalValidationEvaluations().evaluation("prediction");
System.out.println(e.stats());
double acc = e.accuracy();
assertTrue("Accuracy bad: " + acc, acc >= 0.75);
}
@Test @Test
public void testTrainingMixedDtypes(){ public void testTrainingMixedDtypes(){