Fix a couple SameDiff training issues (#253)
* fix execBackwards training issue Signed-off-by: Ryan Nett <rnett@skymind.io> * fix validation not specifying outputs Signed-off-by: Ryan Nett <rnett@skymind.io> * another fix for validation listeners and history Signed-off-by: Ryan Nett <rnett@skymind.io> * tests Signed-off-by: Ryan Nett <rnett@skymind.io> * add single batch dataset output methods Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
4f7b35ac82
commit
8a05ec2a97
|
@ -16,21 +16,58 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff;
|
package org.nd4j.autodiff.samediff;
|
||||||
|
|
||||||
import org.nd4j.shade.guava.base.Predicates;
|
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
|
||||||
import org.nd4j.shade.guava.collect.HashBasedTable;
|
|
||||||
import org.nd4j.shade.guava.collect.Maps;
|
|
||||||
import org.nd4j.shade.guava.collect.Table;
|
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
import lombok.*;
|
import java.io.BufferedInputStream;
|
||||||
|
import java.io.BufferedOutputStream;
|
||||||
|
import java.io.DataOutputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.IdentityHashMap;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.LinkedHashSet;
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Queue;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.Stack;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
|
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
|
||||||
import org.nd4j.autodiff.execution.conf.OutputMode;
|
import org.nd4j.autodiff.execution.conf.OutputMode;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
|
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
|
||||||
import org.nd4j.autodiff.listeners.*;
|
import org.nd4j.autodiff.listeners.At;
|
||||||
|
import org.nd4j.autodiff.listeners.Listener;
|
||||||
|
import org.nd4j.autodiff.listeners.ListenerResponse;
|
||||||
|
import org.nd4j.autodiff.listeners.Loss;
|
||||||
|
import org.nd4j.autodiff.listeners.Operation;
|
||||||
import org.nd4j.autodiff.listeners.impl.HistoryListener;
|
import org.nd4j.autodiff.listeners.impl.HistoryListener;
|
||||||
import org.nd4j.autodiff.listeners.records.History;
|
import org.nd4j.autodiff.listeners.records.History;
|
||||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||||
|
@ -38,14 +75,34 @@ import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
|
||||||
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
|
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
|
||||||
import org.nd4j.autodiff.samediff.config.FitConfig;
|
import org.nd4j.autodiff.samediff.config.FitConfig;
|
||||||
import org.nd4j.autodiff.samediff.config.OutputConfig;
|
import org.nd4j.autodiff.samediff.config.OutputConfig;
|
||||||
import org.nd4j.autodiff.samediff.internal.*;
|
import org.nd4j.autodiff.samediff.internal.AbstractSession;
|
||||||
import org.nd4j.autodiff.samediff.ops.*;
|
import org.nd4j.autodiff.samediff.internal.DataTypesSession;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||||
|
import org.nd4j.autodiff.samediff.ops.SDBaseOps;
|
||||||
|
import org.nd4j.autodiff.samediff.ops.SDBitwise;
|
||||||
|
import org.nd4j.autodiff.samediff.ops.SDCNN;
|
||||||
|
import org.nd4j.autodiff.samediff.ops.SDImage;
|
||||||
|
import org.nd4j.autodiff.samediff.ops.SDLoss;
|
||||||
|
import org.nd4j.autodiff.samediff.ops.SDMath;
|
||||||
|
import org.nd4j.autodiff.samediff.ops.SDNN;
|
||||||
|
import org.nd4j.autodiff.samediff.ops.SDRNN;
|
||||||
|
import org.nd4j.autodiff.samediff.ops.SDRandom;
|
||||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.evaluation.classification.ROC;
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
import org.nd4j.graph.*;
|
import org.nd4j.graph.ExecutionMode;
|
||||||
|
import org.nd4j.graph.FlatArray;
|
||||||
|
import org.nd4j.graph.FlatConfiguration;
|
||||||
|
import org.nd4j.graph.FlatGraph;
|
||||||
|
import org.nd4j.graph.FlatNode;
|
||||||
|
import org.nd4j.graph.FlatVariable;
|
||||||
|
import org.nd4j.graph.IntPair;
|
||||||
|
import org.nd4j.graph.OpType;
|
||||||
|
import org.nd4j.graph.UpdaterState;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
@ -84,23 +141,17 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
||||||
import org.nd4j.linalg.util.ND4JFileUtils;
|
import org.nd4j.linalg.util.ND4JFileUtils;
|
||||||
|
import org.nd4j.shade.guava.base.Predicates;
|
||||||
|
import org.nd4j.shade.guava.collect.HashBasedTable;
|
||||||
|
import org.nd4j.shade.guava.collect.Maps;
|
||||||
|
import org.nd4j.shade.guava.collect.Table;
|
||||||
|
import org.nd4j.shade.guava.primitives.Ints;
|
||||||
import org.nd4j.weightinit.WeightInitScheme;
|
import org.nd4j.weightinit.WeightInitScheme;
|
||||||
import org.nd4j.weightinit.impl.ConstantInitScheme;
|
import org.nd4j.weightinit.impl.ConstantInitScheme;
|
||||||
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
|
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
|
||||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.lang.reflect.Method;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
import java.util.regex.Matcher;
|
|
||||||
import java.util.regex.Pattern;
|
|
||||||
|
|
||||||
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
|
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
|
||||||
* <p>
|
* <p>
|
||||||
|
@ -2064,7 +2115,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
List<Listener> activeListeners = new ArrayList<>();
|
List<Listener> activeListeners = new ArrayList<>();
|
||||||
|
|
||||||
if (!history.evaluations().isEmpty())
|
|
||||||
activeListeners.add(history);
|
activeListeners.add(history);
|
||||||
|
|
||||||
for (Listener l : this.listeners)
|
for (Listener l : this.listeners)
|
||||||
|
@ -2102,6 +2152,9 @@ public class SameDiff extends SDBaseOps {
|
||||||
requiredVars.addAll(l.requiredVariables(this).trainingVariables());
|
requiredVars.addAll(l.requiredVariables(this).trainingVariables());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ArrayList<Listener> listenersWitHistory = new ArrayList<>(listeners);
|
||||||
|
listenersWitHistory.add(history);
|
||||||
|
|
||||||
for (int i = 0; i < numEpochs; i++) {
|
for (int i = 0; i < numEpochs; i++) {
|
||||||
|
|
||||||
if (incrementEpochCount && hasListeners) {
|
if (incrementEpochCount && hasListeners) {
|
||||||
|
@ -2236,7 +2289,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hasListeners) {
|
|
||||||
double[] d = new double[lossVariables.size() + regScore.size()];
|
double[] d = new double[lossVariables.size() + regScore.size()];
|
||||||
List<String> lossVars;
|
List<String> lossVars;
|
||||||
if (regScore.size() > 0) {
|
if (regScore.size() > 0) {
|
||||||
|
@ -2252,7 +2304,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
lossVars = lossVariables;
|
lossVars = lossVariables;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//Collect the losses...
|
//Collect the losses...
|
||||||
SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY);
|
SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY);
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
@ -2283,6 +2334,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
lossCount++;
|
lossCount++;
|
||||||
|
|
||||||
|
if (hasListeners) {
|
||||||
for (Listener l : activeListeners) {
|
for (Listener l : activeListeners) {
|
||||||
l.iterationDone(this, at, ds, loss);
|
l.iterationDone(this, at, ds, loss);
|
||||||
}
|
}
|
||||||
|
@ -2294,7 +2346,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
long epochTime = System.currentTimeMillis() - epochStartTime;
|
long epochTime = System.currentTimeMillis() - epochStartTime;
|
||||||
|
|
||||||
if (incrementEpochCount && hasListeners) {
|
if (incrementEpochCount) {
|
||||||
for (int j = 0; j < lossSums.length; j++)
|
for (int j = 0; j < lossSums.length; j++)
|
||||||
lossSums[j] /= lossCount;
|
lossSums[j] /= lossCount;
|
||||||
|
|
||||||
|
@ -2341,7 +2393,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
long validationStart = System.currentTimeMillis();
|
long validationStart = System.currentTimeMillis();
|
||||||
outputHelper(validationData, new At(at.epoch(), 0, 0, 0, Operation.TRAINING_VALIDATION),
|
outputHelper(validationData, new At(at.epoch(), 0, 0, 0, Operation.TRAINING_VALIDATION),
|
||||||
listeners);
|
listenersWitHistory);
|
||||||
|
|
||||||
long validationTime = System.currentTimeMillis() - validationStart;
|
long validationTime = System.currentTimeMillis() - validationStart;
|
||||||
|
|
||||||
|
@ -2958,7 +3010,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
List<String> neededOutputs;
|
List<String> neededOutputs;
|
||||||
|
|
||||||
if (outputs != null) {
|
if (outputs != null && outputs.length != 0) {
|
||||||
neededOutputs = Arrays.asList(outputs);
|
neededOutputs = Arrays.asList(outputs);
|
||||||
} else {
|
} else {
|
||||||
neededOutputs = outputs();
|
neededOutputs = outputs();
|
||||||
|
@ -4618,8 +4670,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Also add loss values - we need these so we can report them to listeners...
|
//Also add loss values - we need these so we can report them to listeners or loss curves...
|
||||||
if (!listeners.isEmpty()) {
|
if (!activeListeners.isEmpty() || op == Operation.TRAINING) {
|
||||||
varGradNames.addAll(lossVariables);
|
varGradNames.addAll(lossVariables);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,9 @@ import org.nd4j.autodiff.util.TrainingUtils;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||||
|
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
|
||||||
|
@ -94,6 +97,20 @@ public class OutputConfig {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the data to use as input.
|
||||||
|
*/
|
||||||
|
public OutputConfig data(@NonNull DataSet data){
|
||||||
|
return data(new SingletonMultiDataSetIterator(data.toMultiDataSet()));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the data to use as input.
|
||||||
|
*/
|
||||||
|
public OutputConfig data(@NonNull MultiDataSet data){
|
||||||
|
return data(new SingletonMultiDataSetIterator(data));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add listeners for this operation
|
* Add listeners for this operation
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -16,9 +16,16 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff;
|
package org.nd4j.autodiff.samediff;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
||||||
|
import org.nd4j.autodiff.listeners.records.History;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
@ -29,16 +36,14 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.learning.config.*;
|
import org.nd4j.linalg.learning.config.AMSGrad;
|
||||||
|
import org.nd4j.linalg.learning.config.AdaMax;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.weightinit.impl.XavierInitScheme;
|
import org.nd4j.weightinit.impl.XavierInitScheme;
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SameDiffTrainingTest extends BaseNd4jTest {
|
public class SameDiffTrainingTest extends BaseNd4jTest {
|
||||||
|
|
||||||
|
@ -118,6 +123,110 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void irisTrainingEvalTest() {
|
||||||
|
|
||||||
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
NormalizerStandardize std = new NormalizerStandardize();
|
||||||
|
std.fit(iter);
|
||||||
|
iter.setPreProcessor(std);
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
|
||||||
|
|
||||||
|
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.FLOAT, 4, 10);
|
||||||
|
SDVariable b0 = sd.zero("b0", DataType.FLOAT, 1, 10);
|
||||||
|
|
||||||
|
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.FLOAT, 10, 3);
|
||||||
|
SDVariable b1 = sd.zero("b1", DataType.FLOAT, 1, 3);
|
||||||
|
|
||||||
|
SDVariable z0 = in.mmul(w0).add(b0);
|
||||||
|
SDVariable a0 = sd.math().tanh(z0);
|
||||||
|
SDVariable z1 = a0.mmul(w1).add("prediction", b1);
|
||||||
|
SDVariable a1 = sd.nn().softmax(z1);
|
||||||
|
|
||||||
|
SDVariable diff = sd.f().squaredDifference(a1, label);
|
||||||
|
SDVariable lossMse = diff.mul(diff).mean();
|
||||||
|
|
||||||
|
TrainingConfig conf = new TrainingConfig.Builder()
|
||||||
|
.l2(1e-4)
|
||||||
|
.updater(new Adam(1e-2))
|
||||||
|
.dataSetFeatureMapping("input")
|
||||||
|
.dataSetLabelMapping("label")
|
||||||
|
.trainEvaluation("prediction", 0, new Evaluation())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
sd.setTrainingConfig(conf);
|
||||||
|
|
||||||
|
History hist = sd.fit().train(iter, 50).exec();
|
||||||
|
|
||||||
|
Evaluation e = hist.finalTrainingEvaluations().evaluation("prediction");
|
||||||
|
|
||||||
|
System.out.println(e.stats());
|
||||||
|
|
||||||
|
double acc = e.accuracy();
|
||||||
|
|
||||||
|
assertTrue("Accuracy bad: " + acc, acc >= 0.75);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void irisTrainingValidationTest() {
|
||||||
|
|
||||||
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
NormalizerStandardize std = new NormalizerStandardize();
|
||||||
|
std.fit(iter);
|
||||||
|
iter.setPreProcessor(std);
|
||||||
|
|
||||||
|
DataSetIterator valIter = new IrisDataSetIterator(30, 60);
|
||||||
|
NormalizerStandardize valStd = new NormalizerStandardize();
|
||||||
|
valStd.fit(valIter);
|
||||||
|
valIter.setPreProcessor(std);
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
|
||||||
|
|
||||||
|
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.FLOAT, 4, 10);
|
||||||
|
SDVariable b0 = sd.zero("b0", DataType.FLOAT, 1, 10);
|
||||||
|
|
||||||
|
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.FLOAT, 10, 3);
|
||||||
|
SDVariable b1 = sd.zero("b1", DataType.FLOAT, 1, 3);
|
||||||
|
|
||||||
|
SDVariable z0 = in.mmul(w0).add(b0);
|
||||||
|
SDVariable a0 = sd.math().tanh(z0);
|
||||||
|
SDVariable z1 = a0.mmul(w1).add("prediction", b1);
|
||||||
|
SDVariable a1 = sd.nn().softmax(z1);
|
||||||
|
|
||||||
|
SDVariable diff = sd.f().squaredDifference(a1, label);
|
||||||
|
SDVariable lossMse = diff.mul(diff).mean();
|
||||||
|
|
||||||
|
TrainingConfig conf = new TrainingConfig.Builder()
|
||||||
|
.l2(1e-4)
|
||||||
|
.updater(new Adam(1e-2))
|
||||||
|
.dataSetFeatureMapping("input")
|
||||||
|
.dataSetLabelMapping("label")
|
||||||
|
.validationEvaluation("prediction", 0, new Evaluation())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
sd.setTrainingConfig(conf);
|
||||||
|
|
||||||
|
History hist = sd.fit().train(iter, 50).validate(valIter, 5).exec();
|
||||||
|
|
||||||
|
Evaluation e = hist.finalValidationEvaluations().evaluation("prediction");
|
||||||
|
|
||||||
|
System.out.println(e.stats());
|
||||||
|
|
||||||
|
double acc = e.accuracy();
|
||||||
|
|
||||||
|
assertTrue("Accuracy bad: " + acc, acc >= 0.75);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTrainingMixedDtypes(){
|
public void testTrainingMixedDtypes(){
|
||||||
|
|
Loading…
Reference in New Issue