Assorted fixes (#318)

* #8777 MultiLayerNetwork.evaluate(MultiDataSetIterator) overload

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8768 SameDiff.equals

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8750 shade freemarker library and switch to it in DL4J UI

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8704 DL4J UI redirect

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8776 RecordReaderDataSetIterator builder collectMetaData fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8718 Fix DL4J doEvaluation metadata

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8715 ArchiveUtils - Add option to not log every extracted file

Signed-off-by: Alex Black <blacka101@gmail.com>

* No exception for evaluations that don't support metadata

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8765 CompGraph+MDS fix for SharedTrainingMaster

Signed-off-by: Alex Black <blacka101@gmail.com>

* small fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Timeout

Signed-off-by: Alex Black <blacka101@gmail.com>

* Ignore

Signed-off-by: Alex Black <blacka101@gmail.com>

* Revert freemarker shading

Signed-off-by: Alex Black <blacka101@gmail.com>

* Ignore

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-27 00:33:13 +11:00 committed by GitHub
parent 9970aadc5a
commit 63c9223bc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 381 additions and 209 deletions

View File

@ -1381,4 +1381,17 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertNotNull(ds.getFeatures()); assertNotNull(ds.getFeatures());
assertNull(ds.getLabels()); assertNull(ds.getLabels());
} }
@Test
public void testCollectMetaData(){
RecordReaderDataSetIterator trainIter = new RecordReaderDataSetIterator.Builder(new CollectionRecordReader(Collections.<List<Writable>>emptyList()), 1)
.collectMetaData(true)
.build();
assertTrue(trainIter.isCollectMetaData());
trainIter.setCollectMetaData(false);
assertFalse(trainIter.isCollectMetaData());
trainIter.setCollectMetaData(true);
assertTrue(trainIter.isCollectMetaData());
}
} }

View File

@ -33,7 +33,6 @@ import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.eval.meta.Prediction;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
@ -52,19 +51,13 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.util.*; import java.util.*;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
/** /**
* Created by agibsonccc on 12/22/14. * Created by agibsonccc on 12/22/14.
@ -165,7 +158,7 @@ public class EvalTest extends BaseDL4JTest {
assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix()); assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix());
} }
@Test(timeout = 300000) @Test
public void testEvaluationWithMetaData() throws Exception { public void testEvaluationWithMetaData() throws Exception {
RecordReader csv = new CSVRecordReader(); RecordReader csv = new CSVRecordReader();
@ -256,6 +249,30 @@ public class EvalTest extends BaseDL4JTest {
assertEquals(actualCounts[i], actualClassI.size()); assertEquals(actualCounts[i], actualClassI.size());
assertEquals(predictedCounts[i], predictedClassI.size()); assertEquals(predictedCounts[i], predictedClassI.size());
} }
//Finally: test doEvaluation methods
rrdsi.reset();
org.nd4j.evaluation.classification.Evaluation e2 = new org.nd4j.evaluation.classification.Evaluation();
net.doEvaluation(rrdsi, e2);
for (int i = 0; i < 3; i++) {
List<org.nd4j.evaluation.meta.Prediction> actualClassI = e2.getPredictionsByActualClass(i);
List<org.nd4j.evaluation.meta.Prediction> predictedClassI = e2.getPredictionByPredictedClass(i);
assertEquals(actualCounts[i], actualClassI.size());
assertEquals(predictedCounts[i], predictedClassI.size());
}
ComputationGraph cg = net.toComputationGraph();
rrdsi.reset();
e2 = new org.nd4j.evaluation.classification.Evaluation();
cg.doEvaluation(rrdsi, e2);
for (int i = 0; i < 3; i++) {
List<org.nd4j.evaluation.meta.Prediction> actualClassI = e2.getPredictionsByActualClass(i);
List<org.nd4j.evaluation.meta.Prediction> predictedClassI = e2.getPredictionByPredictedClass(i);
assertEquals(actualCounts[i], actualClassI.size());
assertEquals(predictedCounts[i], predictedClassI.size());
}
} }
private static void apply(org.nd4j.evaluation.classification.Evaluation e, int nTimes, INDArray predicted, INDArray actual) { private static void apply(org.nd4j.evaluation.classification.Evaluation e, int nTimes, INDArray predicted, INDArray actual) {
@ -504,11 +521,11 @@ public class EvalTest extends BaseDL4JTest {
list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{ds.getFeatures()}, new INDArray[]{ds.getLabels(), ds.getLabels()})); list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{ds.getFeatures()}, new INDArray[]{ds.getLabels(), ds.getLabels()}));
} }
Evaluation e = new Evaluation(); org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation();
RegressionEvaluation e2 = new RegressionEvaluation(); org.nd4j.evaluation.regression.RegressionEvaluation e2 = new org.nd4j.evaluation.regression.RegressionEvaluation();
Map<Integer,IEvaluation[]> evals = new HashMap<>(); Map<Integer,org.nd4j.evaluation.IEvaluation[]> evals = new HashMap<>();
evals.put(0, new IEvaluation[]{(IEvaluation) e}); evals.put(0, new org.nd4j.evaluation.IEvaluation[]{e});
evals.put(1, new IEvaluation[]{(IEvaluation) e2}); evals.put(1, new org.nd4j.evaluation.IEvaluation[]{e2});
cg.evaluate(new IteratorMultiDataSetIterator(list.iterator(), 30), evals); cg.evaluate(new IteratorMultiDataSetIterator(list.iterator(), 30), evals);
@ -567,14 +584,14 @@ public class EvalTest extends BaseDL4JTest {
} }
try { try {
net.evaluateROC(iter); net.evaluateROC(iter, 0);
fail("Expected exception"); fail("Expected exception");
} catch (IllegalStateException e){ } catch (IllegalStateException e){
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC")); assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
} }
try { try {
net.evaluateROCMultiClass(iter); net.evaluateROCMultiClass(iter, 0);
fail("Expected exception"); fail("Expected exception");
} catch (IllegalStateException e){ } catch (IllegalStateException e){
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass")); assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
@ -589,14 +606,14 @@ public class EvalTest extends BaseDL4JTest {
} }
try { try {
cg.evaluateROC(iter); cg.evaluateROC(iter, 0);
fail("Expected exception"); fail("Expected exception");
} catch (IllegalStateException e){ } catch (IllegalStateException e){
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC")); assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
} }
try { try {
cg.evaluateROCMultiClass(iter); cg.evaluateROCMultiClass(iter, 0);
fail("Expected exception"); fail("Expected exception");
} catch (IllegalStateException e){ } catch (IllegalStateException e){
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass")); assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
@ -606,10 +623,10 @@ public class EvalTest extends BaseDL4JTest {
//Disable validation, and check same thing: //Disable validation, and check same thing:
net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false); net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false);
net.evaluate(iter); net.evaluate(iter);
net.evaluateROCMultiClass(iter); net.evaluateROCMultiClass(iter, 0);
cg.getConfiguration().setValidateOutputLayerConfig(false); cg.getConfiguration().setValidateOutputLayerConfig(false);
cg.evaluate(iter); cg.evaluate(iter);
cg.evaluateROCMultiClass(iter); cg.evaluateROCMultiClass(iter, 0);
} }
} }

View File

@ -61,7 +61,7 @@ public class RegressionEvalTest extends BaseDL4JTest {
DataSet ds = new DataSet(f, l); DataSet ds = new DataSet(f, l);
DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds)); DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds));
RegressionEvaluation re = net.evaluateRegression(iter); org.nd4j.evaluation.regression.RegressionEvaluation re = net.evaluateRegression(iter);
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
assertEquals(1.0, re.meanSquaredError(i), 1e-6); assertEquals(1.0, re.meanSquaredError(i), 1e-6);

View File

@ -86,7 +86,7 @@ public abstract class CacheableExtractableDataSetFetcher implements CacheableDat
} }
try { try {
ArchiveUtils.unzipFileTo(tmpFile.getAbsolutePath(), localCacheDir.getAbsolutePath()); ArchiveUtils.unzipFileTo(tmpFile.getAbsolutePath(), localCacheDir.getAbsolutePath(), false);
} catch (Throwable t){ } catch (Throwable t){
//Catch any errors during extraction, and delete the directory to avoid leaving the dir in an invalid state //Catch any errors during extraction, and delete the directory to avoid leaving the dir in an invalid state
if(localCacheDir.exists()) if(localCacheDir.exists())

View File

@ -205,6 +205,7 @@ public class RecordReaderDataSetIterator implements DataSetIterator {
this.numPossibleLabels = b.numPossibleLabels; this.numPossibleLabels = b.numPossibleLabels;
this.regression = b.regression; this.regression = b.regression;
this.preProcessor = b.preProcessor; this.preProcessor = b.preProcessor;
this.collectMetaData = b.collectMetaData;
} }
/** /**

View File

@ -67,7 +67,7 @@ public class KuromojiBinFilesFetcher {
new URL("https://dl4jdata.blob.core.windows.net/kuromoji/kuromoji_bin_files.tar.gz"), new URL("https://dl4jdata.blob.core.windows.net/kuromoji/kuromoji_bin_files.tar.gz"),
tarFile); tarFile);
} }
ArchiveUtils.unzipFileTo(tarFile.getAbsolutePath(), rootDir.getAbsolutePath()); ArchiveUtils.unzipFileTo(tarFile.getAbsolutePath(), rootDir.getAbsolutePath(), false);
return rootDir.getAbsoluteFile(); return rootDir.getAbsoluteFile();
} }

View File

@ -4170,6 +4170,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
INDArray[] featuresMasks = next.getFeaturesMaskArrays(); INDArray[] featuresMasks = next.getFeaturesMaskArrays();
INDArray[] labels = next.getLabels(); INDArray[] labels = next.getLabels();
INDArray[] labelMasks = next.getLabelsMaskArrays(); INDArray[] labelMasks = next.getLabelsMaskArrays();
List<Serializable> meta = next.getExampleMetaData();
try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) { try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) {
INDArray[] out = outputOfLayersDetached(false, FwdPassType.STANDARD, getOutputLayerIndices(), features, featuresMasks, labelMasks, true, false, ws); INDArray[] out = outputOfLayersDetached(false, FwdPassType.STANDARD, getOutputLayerIndices(), features, featuresMasks, labelMasks, true, false, ws);
@ -4188,7 +4189,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
for (IEvaluation evaluation : evalsThisOutput) for (IEvaluation evaluation : evalsThisOutput)
evaluation.eval(currLabel, currOut, next.getLabelsMaskArray(i)); evaluation.eval(currLabel, currOut, next.getLabelsMaskArray(i), meta);
} }
} }
} }

View File

@ -23,6 +23,9 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import java.io.Serializable;
import java.util.List;
public class ComputationGraphUtil { public class ComputationGraphUtil {
private ComputationGraphUtil() {} private ComputationGraphUtil() {}
@ -33,13 +36,16 @@ public class ComputationGraphUtil {
INDArray l = dataSet.getLabels(); INDArray l = dataSet.getLabels();
INDArray fMask = dataSet.getFeaturesMaskArray(); INDArray fMask = dataSet.getFeaturesMaskArray();
INDArray lMask = dataSet.getLabelsMaskArray(); INDArray lMask = dataSet.getLabelsMaskArray();
List<Serializable> meta = dataSet.getExampleMetaData();
INDArray[] fNew = f == null ? null : new INDArray[] {f}; INDArray[] fNew = f == null ? null : new INDArray[] {f};
INDArray[] lNew = l == null ? null : new INDArray[] {l}; INDArray[] lNew = l == null ? null : new INDArray[] {l};
INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null); INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null);
INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null); INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null);
return new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew); org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
mds.setExampleMetaData(meta);
return mds;
} }
/** Convert a DataSetIterator to a MultiDataSetIterator, via an adaptor class */ /** Convert a DataSetIterator to a MultiDataSetIterator, via an adaptor class */

View File

@ -25,14 +25,11 @@ import lombok.val;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.*;
import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.api.*;
import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer; import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.*;
@ -44,8 +41,8 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.FrozenLayer; import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop; import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.updater.UpdaterCreator; import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
@ -58,19 +55,23 @@ import org.deeplearning4j.util.CrashReportingUtil;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.NetworkUtils; import org.deeplearning4j.util.NetworkUtils;
import org.deeplearning4j.util.OutputLayerUtil; import org.deeplearning4j.util.OutputLayerUtil;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy; import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy; import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy; import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy; import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -84,7 +85,6 @@ import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils; import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils; import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple; import org.nd4j.linalg.primitives.Triple;
import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.schedule.ISchedule;
@ -96,6 +96,8 @@ import org.nd4j.util.OneTimeLogger;
import java.io.*; import java.io.*;
import java.util.*; import java.util.*;
;
/** /**
* MultiLayerNetwork is a neural network with multiple layers in a stack, and usually an output layer.<br> * MultiLayerNetwork is a neural network with multiple layers in a stack, and usually an output layer.<br>
@ -3315,19 +3317,39 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
* @param iterator Iterator to evaluate on * @param iterator Iterator to evaluate on
* @return Evaluation object; results of evaluation on all examples in the data set * @return Evaluation object; results of evaluation on all examples in the data set
*/ */
public <T extends Evaluation> T evaluate(DataSetIterator iterator) { public <T extends Evaluation> T evaluate(@NonNull DataSetIterator iterator) {
return (T)evaluate(iterator, null); return (T)evaluate(iterator, null);
} }
/**
* Evaluate the network (classification performance).
* Can only be used with MultiDataSetIterator instances with a single input/output array
*
* @param iterator Iterator to evaluate on
* @return Evaluation object; results of evaluation on all examples in the data set
*/
public Evaluation evaluate(@NonNull MultiDataSetIterator iterator) {
return evaluate(new MultiDataSetWrapperIterator(iterator));
}
/** /**
* Evaluate the network for regression performance * Evaluate the network for regression performance
* @param iterator Data to evaluate on * @param iterator Data to evaluate on
* @return * @return Regression evaluation
*/ */
public <T extends RegressionEvaluation> T evaluateRegression(DataSetIterator iterator) { public <T extends RegressionEvaluation> T evaluateRegression(DataSetIterator iterator) {
return (T)doEvaluation(iterator, new RegressionEvaluation(iterator.totalOutcomes()))[0]; return (T)doEvaluation(iterator, new RegressionEvaluation(iterator.totalOutcomes()))[0];
} }
/**
* Evaluate the network for regression performance
* Can only be used with MultiDataSetIterator instances with a single input/output array
* @param iterator Data to evaluate on
*/
public org.nd4j.evaluation.regression.RegressionEvaluation evaluateRegression(MultiDataSetIterator iterator) {
return evaluateRegression(new MultiDataSetWrapperIterator(iterator));
}
/** /**
* @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration * @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
*/ */
@ -3424,6 +3446,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
INDArray labels = next.getLabels(); INDArray labels = next.getLabels();
INDArray fMask = next.getFeaturesMaskArray(); INDArray fMask = next.getFeaturesMaskArray();
INDArray lMask = next.getLabelsMaskArray(); INDArray lMask = next.getLabelsMaskArray();
List<Serializable> meta = next.getExampleMetaData();
if (!useRnnSegments) { if (!useRnnSegments) {
@ -3433,7 +3456,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
for (T evaluation : evaluations) for (T evaluation : evaluations)
evaluation.eval(labels, out, lMask); evaluation.eval(labels, out, lMask, meta);
} }
} }
} else { } else {

View File

@ -222,8 +222,11 @@ public class AdaptiveThresholdAlgorithm implements ThresholdAlgorithm {
if(a == null || Double.isNaN(a.lastThreshold)) if(a == null || Double.isNaN(a.lastThreshold))
return; return;
lastThresholdSum += a.lastThreshold; lastThresholdSum += a.lastThreshold;
lastSparsitySum += a.lastSparsity; if (!Double.isNaN(a.lastSparsity)) {
lastSparsitySum += a.lastSparsity;
}
count++; count++;
} }

View File

@ -38,16 +38,22 @@
<artifactId>nd4j-aeron</artifactId> <artifactId>nd4j-aeron</artifactId>
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark_2.11</artifactId> <artifactId>dl4j-spark_2.11</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
<version>${nd4j.version}</version>
<exclusions>
<exclusion>
<groupId>net.jpountz.lz4</groupId>
<artifactId>lz4</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency> <dependency>
<groupId>org.projectlombok</groupId> <groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId> <artifactId>lombok</artifactId>

View File

@ -23,6 +23,7 @@ import org.nd4j.linalg.dataset.api.iterator.ParallelMultiDataSetIterator;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
/** /**
* This MultiDataSetIterator implementation does accumulation of MultiDataSets from different Spark executors, wrt Thread/Device Affinity * This MultiDataSetIterator implementation does accumulation of MultiDataSets from different Spark executors, wrt Thread/Device Affinity
@ -32,14 +33,16 @@ import java.util.List;
public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator { public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator {
protected final List<Iterator<MultiDataSet>> iterators; protected final List<Iterator<MultiDataSet>> iterators;
protected final AtomicInteger position;
public VirtualMultiDataSetIterator(@NonNull List<Iterator<MultiDataSet>> iterators) { public VirtualMultiDataSetIterator(@NonNull List<Iterator<MultiDataSet>> iterators) {
this.iterators = iterators; this.iterators = iterators;
this.position = new AtomicInteger(0);
} }
@Override @Override
public MultiDataSet next(int num) { public MultiDataSet next(int num) {
return null; return next();
} }
@Override @Override
@ -59,27 +62,34 @@ public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator
@Override @Override
public boolean asyncSupported() { public boolean asyncSupported() {
return false; return true;
} }
@Override @Override
public void reset() { public void reset() {
throw new UnsupportedOperationException();
} }
@Override @Override
public boolean hasNext() { public boolean hasNext() {
return false; // just checking if that's not the last iterator, or if that's the last one - check if it has something
boolean ret = position.get() < iterators.size() - 1
|| (position.get() < iterators.size() && iterators.get(position.get()).hasNext());
return ret;
} }
@Override @Override
public MultiDataSet next() { public MultiDataSet next() {
return null; // TODO: this solution isn't ideal, it assumes non-empty iterators all the time. Would be nice to do something here
if (!iterators.get(position.get()).hasNext())
position.getAndIncrement();
return iterators.get(position.get()).next();
} }
@Override @Override
public void remove() { public void remove() {
// no-op
} }
@Override @Override

View File

@ -109,6 +109,7 @@ public class SharedTrainingWrapper {
// now we're creating DataSetIterators, to feed ParallelWrapper // now we're creating DataSetIterators, to feed ParallelWrapper
iteratorDS = new VirtualDataSetIterator(iteratorsDS); iteratorDS = new VirtualDataSetIterator(iteratorsDS);
iteratorMDS = new VirtualMultiDataSetIterator(iteratorsMDS);
} }
public static synchronized SharedTrainingWrapper getInstance(long id) { public static synchronized SharedTrainingWrapper getInstance(long id) {
@ -447,17 +448,19 @@ public class SharedTrainingWrapper {
throw new DL4JInvalidConfigException("No iterators were defined for training"); throw new DL4JInvalidConfigException("No iterators were defined for training");
try { try {
while((iteratorDS != null && iteratorDS.hasNext()) || (iteratorMDS != null && iteratorMDS.hasNext())) { boolean dsNext;
boolean mdsNext;
while((dsNext = iteratorDS != null && iteratorDS.hasNext()) || (mdsNext = iteratorMDS != null && iteratorMDS.hasNext())) {
//Loop as a guard against concurrent modifications and RCs //Loop as a guard against concurrent modifications and RCs
if (wrapper != null) { if (wrapper != null) {
if (iteratorDS != null) if (dsNext)
wrapper.fit(iteratorDS); wrapper.fit(iteratorDS);
else else
wrapper.fit(iteratorMDS); wrapper.fit(iteratorMDS);
} else { } else {
// if wrapper is null, we're fitting standalone model then // if wrapper is null, we're fitting standalone model then
if (iteratorDS != null) { if (dsNext) {
if (model instanceof ComputationGraph) { if (model instanceof ComputationGraph) {
((ComputationGraph) originalModel).fit(iteratorDS); ((ComputationGraph) originalModel).fit(iteratorDS);
} else if (model instanceof MultiLayerNetwork) { } else if (model instanceof MultiLayerNetwork) {
@ -472,7 +475,8 @@ public class SharedTrainingWrapper {
} }
} }
consumer.getUpdatesQueue().purge(); if(consumer != null)
consumer.getUpdatesQueue().purge();
} }
} catch (Throwable t){ } catch (Throwable t){
log.warn("Exception encountered during fit operation", t); log.warn("Exception encountered during fit operation", t);

View File

@ -116,8 +116,7 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
} }
protected int numExecutors() { protected int numExecutors() {
int numProc = Runtime.getRuntime().availableProcessors(); return 4;
return Math.min(4, numProc);
} }
protected MultiLayerConfiguration getBasicConf() { protected MultiLayerConfiguration getBasicConf() {

View File

@ -49,6 +49,7 @@ import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.AMSGrad; import org.nd4j.linalg.learning.config.AMSGrad;
@ -66,137 +67,170 @@ import java.util.concurrent.ConcurrentHashMap;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@Slf4j @Slf4j
@Ignore("AB 2019/05/21 - Failing - Issue #7657") //@Ignore("AB 2019/05/21 - Failing - Issue #7657")
public class GradientSharingTrainingTest extends BaseSparkTest { public class GradientSharingTrainingTest extends BaseSparkTest {
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void trainSanityCheck() throws Exception { public void trainSanityCheck() throws Exception {
INDArray last = null; for(boolean mds : new boolean[]{false, true}) {
INDArray lastDup = null; INDArray last = null;
for (String s : new String[]{"paths", "direct", "export"}) { INDArray lastDup = null;
System.out.println("--------------------------------------------------------------------------------------------------------------"); for (String s : new String[]{"paths", "direct", "export"}) {
log.info("Starting: {}", s); System.out.println("--------------------------------------------------------------------------------------------------------------");
boolean isPaths = "paths".equals(s); log.info("Starting: {} - {}", s, (mds ? "MultiDataSet" : "DataSet"));
boolean isPaths = "paths".equals(s);
RDDTrainingApproach rddTrainingApproach; RDDTrainingApproach rddTrainingApproach;
switch (s) {
case "direct":
rddTrainingApproach = RDDTrainingApproach.Direct;
break;
case "export":
rddTrainingApproach = RDDTrainingApproach.Export;
break;
case "paths":
rddTrainingApproach = RDDTrainingApproach.Direct; //Actualy not used for fitPaths
break;
default:
throw new RuntimeException();
}
File temp = testDir.newFolder();
//TODO this probably won't work everywhere...
String controller = Inet4Address.getLocalHost().getHostAddress();
String networkMask = controller.substring(0, controller.lastIndexOf('.')) + ".0" + "/16";
VoidConfiguration voidConfiguration = VoidConfiguration.builder()
.unicastPort(40123) // Should be open for IN/OUT communications on all Spark nodes
.networkMask(networkMask) // Local network mask
.controllerAddress(controller)
.meshBuildMode(MeshBuildMode.PLAIN) // everyone is connected to the master
.build();
TrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, 2, new AdaptiveThresholdAlgorithm(1e-3), 16)
.rngSeed(12345)
.collectTrainingStats(false)
.batchSizePerWorker(16) // Minibatch size for each worker
.workersPerNode(2) // Workers per node
.rddTrainingApproach(rddTrainingApproach)
.exportDirectory("file:///" + temp.getAbsolutePath().replaceAll("\\\\", "/"))
.build();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.updater(new AMSGrad(0.1))
.graphBuilder()
.addInputs("in")
.layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("out")
.build();
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats());
System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
File f = testDir.newFolder();
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
int count = 0;
List<String> paths = new ArrayList<>();
List<DataSet> ds = new ArrayList<>();
while (iter.hasNext() && count++ < 8) {
DataSet d = iter.next();
if (isPaths) {
File out = new File(f, count + ".bin");
d.save(out);
String path = "file:///" + out.getAbsolutePath().replaceAll("\\\\", "/");
paths.add(path);
}
ds.add(d);
}
int numIter = 1;
double[] acc = new double[numIter + 1];
for (int i = 0; i < numIter; i++) {
//Check accuracy before:
DataSetIterator testIter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10);
Evaluation eBefore = sparkNet.getNetwork().evaluate(testIter);
INDArray paramsBefore = sparkNet.getNetwork().params().dup();
ComputationGraph after;
switch (s) { switch (s) {
case "direct": case "direct":
rddTrainingApproach = RDDTrainingApproach.Direct;
break;
case "export": case "export":
JavaRDD<DataSet> dsRDD = sc.parallelize(ds); rddTrainingApproach = RDDTrainingApproach.Export;
after = sparkNet.fit(dsRDD);
break; break;
case "paths": case "paths":
JavaRDD<String> pathRdd = sc.parallelize(paths); rddTrainingApproach = RDDTrainingApproach.Direct; //Actualy not used for fitPaths
after = sparkNet.fitPaths(pathRdd);
break; break;
default: default:
throw new RuntimeException(); throw new RuntimeException();
} }
INDArray paramsAfter = after.params(); File temp = testDir.newFolder();
System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
System.out.println(Arrays.toString(
Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
assertNotEquals(paramsBefore, paramsAfter);
testIter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10); //TODO this probably won't work everywhere...
Evaluation eAfter = after.evaluate(testIter); String controller = Inet4Address.getLocalHost().getHostAddress();
String networkMask = controller.substring(0, controller.lastIndexOf('.')) + ".0" + "/16";
double accAfter = eAfter.accuracy(); VoidConfiguration voidConfiguration = VoidConfiguration.builder()
double accBefore = eBefore.accuracy(); .unicastPort(40123) // Should be open for IN/OUT communications on all Spark nodes
assertTrue("after: " + accAfter + ", before=" + accBefore, accAfter >= accBefore + 0.005); .networkMask(networkMask) // Local network mask
.controllerAddress(controller)
.meshBuildMode(MeshBuildMode.PLAIN) // everyone is connected to the master
.build();
TrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, 2, new AdaptiveThresholdAlgorithm(1e-3), 16)
.rngSeed(12345)
.collectTrainingStats(false)
.batchSizePerWorker(16) // Minibatch size for each worker
.workersPerNode(2) // Workers per node
.rddTrainingApproach(rddTrainingApproach)
.exportDirectory("file:///" + temp.getAbsolutePath().replaceAll("\\\\", "/"))
.build();
if (i == 0) {
acc[0] = eBefore.accuracy(); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.updater(new AMSGrad(0.1))
.graphBuilder()
.addInputs("in")
.layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("out")
.build();
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats());
System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
File f = testDir.newFolder();
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
int count = 0;
List<String> paths = new ArrayList<>();
List<DataSet> ds = new ArrayList<>();
while (iter.hasNext() && count++ < 8) {
DataSet d = iter.next();
if (isPaths) {
File out = new File(f, count + ".bin");
if(mds){
d.toMultiDataSet().save(out);
} else {
d.save(out);
}
String path = "file:///" + out.getAbsolutePath().replaceAll("\\\\", "/");
paths.add(path);
}
ds.add(d);
} }
acc[i + 1] = eAfter.accuracy();
int numIter = 1;
double[] acc = new double[numIter + 1];
for (int i = 0; i < numIter; i++) {
//Check accuracy before:
DataSetIterator testIter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10);
Evaluation eBefore = sparkNet.getNetwork().evaluate(testIter);
INDArray paramsBefore = sparkNet.getNetwork().params().dup();
ComputationGraph after;
if(mds) {
//Fitting from MultiDataSet
List<MultiDataSet> mdsList = new ArrayList<>();
for(DataSet d : ds){
mdsList.add(d.toMultiDataSet());
}
switch (s) {
case "direct":
case "export":
JavaRDD<MultiDataSet> dsRDD = sc.parallelize(mdsList);
after = sparkNet.fitMultiDataSet(dsRDD);
break;
case "paths":
JavaRDD<String> pathRdd = sc.parallelize(paths);
after = sparkNet.fitPathsMultiDataSet(pathRdd);
break;
default:
throw new RuntimeException();
}
} else {
//Fitting from DataSet
switch (s) {
case "direct":
case "export":
JavaRDD<DataSet> dsRDD = sc.parallelize(ds);
after = sparkNet.fit(dsRDD);
break;
case "paths":
JavaRDD<String> pathRdd = sc.parallelize(paths);
after = sparkNet.fitPaths(pathRdd);
break;
default:
throw new RuntimeException();
}
}
INDArray paramsAfter = after.params();
System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
System.out.println(Arrays.toString(
Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
assertNotEquals(paramsBefore, paramsAfter);
testIter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10);
Evaluation eAfter = after.evaluate(testIter);
double accAfter = eAfter.accuracy();
double accBefore = eBefore.accuracy();
assertTrue("after: " + accAfter + ", before=" + accBefore, accAfter >= accBefore + 0.005);
if (i == 0) {
acc[0] = eBefore.accuracy();
}
acc[i + 1] = eAfter.accuracy();
}
log.info("Accuracies: {}", Arrays.toString(acc));
last = sparkNet.getNetwork().params();
lastDup = last.dup();
} }
log.info("Accuracies: {}", Arrays.toString(acc));
last = sparkNet.getNetwork().params();
lastDup = last.dup();
} }
} }
@ -289,7 +323,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
} }
@Test @Test @Ignore
public void testEpochUpdating() throws Exception { public void testEpochUpdating() throws Exception {
//Ensure that epoch counter is incremented properly on the workers //Ensure that epoch counter is incremented properly on the workers
@ -316,7 +350,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
.updater(new AMSGrad(0.1)) .updater(new AMSGrad(0.001))
.graphBuilder() .graphBuilder()
.addInputs("in") .addInputs("in")
.layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX) .layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX)

View File

@ -20,12 +20,12 @@ log4j.appender.Console.layout=org.apache.log4j.PatternLayout
log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
log4j.appender.org.springframework=DEBUG log4j.appender.org.springframework=DEBUG
log4j.appender.org.deeplearning4j=DEBUG log4j.appender.org.deeplearning4j=INFO
log4j.appender.org.nd4j=DEBUG log4j.appender.org.nd4j=INFO
log4j.logger.org.springframework=INFO log4j.logger.org.springframework=INFO
log4j.logger.org.deeplearning4j=DEBUG log4j.logger.org.deeplearning4j=INFO
log4j.logger.org.nd4j=DEBUG log4j.logger.org.nd4j=INFO
log4j.logger.org.apache.spark=WARN log4j.logger.org.apache.spark=WARN

View File

@ -35,7 +35,7 @@
<logger name="org.apache.catalina.core" level="DEBUG" /> <logger name="org.apache.catalina.core" level="DEBUG" />
<logger name="org.springframework" level="DEBUG" /> <logger name="org.springframework" level="DEBUG" />
<logger name="org.deeplearning4j" level="DEBUG" /> <logger name="org.deeplearning4j" level="INFO" />
<logger name="org.datavec" level="INFO" /> <logger name="org.datavec" level="INFO" />
<logger name="org.nd4j" level="INFO" /> <logger name="org.nd4j" level="INFO" />
<logger name="opennlp.uima.util" level="OFF" /> <logger name="opennlp.uima.util" level="OFF" />

View File

@ -25,10 +25,6 @@
<artifactId>deeplearning4j-ui-components</artifactId> <artifactId>deeplearning4j-ui-components</artifactId>
<properties>
<freemarker.version>2.3.23</freemarker.version>
</properties>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.projectlombok</groupId> <groupId>org.projectlombok</groupId>

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.ui.components.chart.style.StyleChart;
import org.deeplearning4j.ui.components.table.ComponentTable; import org.deeplearning4j.ui.components.table.ComponentTable;
import org.deeplearning4j.ui.components.table.style.StyleTable; import org.deeplearning4j.ui.components.table.style.StyleTable;
import org.deeplearning4j.ui.standalone.StaticPageUtil; import org.deeplearning4j.ui.standalone.StaticPageUtil;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import java.awt.*; import java.awt.*;

View File

@ -60,7 +60,7 @@
<dependency> <dependency>
<groupId>org.freemarker</groupId> <groupId>org.freemarker</groupId>
<artifactId>freemarker</artifactId> <artifactId>freemarker</artifactId>
<version>2.3.29</version> <version>${freemarker.version}</version>
</dependency> </dependency>
<dependency> <dependency>

View File

@ -200,6 +200,7 @@ public class TrainModule implements UIModule {
})); }));
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc))); r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
} else { } else {
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview")));
r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID))); r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
r.add(new Route("/train/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc))); r.add(new Route("/train/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)));
r.add(new Route("/train/overview", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingOverview.html.ftl", rc))); r.add(new Route("/train/overview", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingOverview.html.ftl", rc)));

View File

@ -1654,29 +1654,6 @@ public class SDVariable implements Serializable {
return x; return x;
} }
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof SDVariable)) {
return false;
}
SDVariable that = (SDVariable) o;
if (!Objects.equals(varName, that.varName)) {
return false;
}
if (variableType != that.variableType) {
return false;
}
if(sameDiff != that.sameDiff){
return false;
}
return dataType == that.dataType;
}
@Override @Override
public int hashCode() { public int hashCode() {
int result = super.hashCode(); int result = super.hashCode();
@ -1695,4 +1672,26 @@ public class SDVariable implements Serializable {
v.sameDiff = sd; v.sameDiff = sd;
return v; return v;
} }
@Override
public boolean equals(Object o){
if(o == this) return true;
if(!(o instanceof SDVariable))
return false;
SDVariable s = (SDVariable)o;
if(!varName.equals(s.varName))
return false;
if(variableType != s.variableType)
return false;
if(dataType != s.dataType)
return false;
if(variableType == VariableType.VARIABLE || variableType == VariableType.CONSTANT){
INDArray a1 = getArr();
INDArray a2 = s.getArr();
return a1.equals(a2);
}
return true;
}
} }

View File

@ -1234,13 +1234,14 @@ public class SameDiff extends SDBaseOps {
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass())
return false;
SameDiff sameDiff = (SameDiff) o; SameDiff sameDiff = (SameDiff) o;
if (variables != null ? !variables.equals(sameDiff.variables) : sameDiff.variables != null) boolean eqVars = variables.equals(sameDiff.variables);
return false; boolean eqOps = ops.equals(sameDiff.ops);
return sameDiffFunctionInstances != null ? sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null; return eqVars && eqOps;
} }
/** /**
@ -5843,4 +5844,10 @@ public class SameDiff extends SDBaseOps {
return base + "_" + inc; return base + "_" + inc;
} }
@Override
public String toString(){
return "SameDiff(nVars=" + variables.size() + ",nOps=" + ops.size() + ")";
}
} }

View File

@ -16,10 +16,7 @@
package org.nd4j.autodiff.samediff.internal; package org.nd4j.autodiff.samediff.internal;
import lombok.AllArgsConstructor; import lombok.*;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import java.util.List; import java.util.List;
@ -28,6 +25,7 @@ import java.util.List;
@NoArgsConstructor @NoArgsConstructor
@Data //TODO immutable? @Data //TODO immutable?
@Builder @Builder
@EqualsAndHashCode(exclude = {"gradient", "variableIndex"})
public class Variable { public class Variable {
protected String name; protected String name;
protected SDVariable variable; protected SDVariable variable;

View File

@ -173,9 +173,6 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
@Override @Override
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) { public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
if(recordMetaData != null){
throw new UnsupportedOperationException("Evaluation with record metadata not yet implemented for EvaluationBinary");
}
eval(labels, networkPredictions, maskArray); eval(labels, networkPredictions, maskArray);
} }

View File

@ -325,7 +325,7 @@ public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
@Override @Override
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) { public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
throw new UnsupportedOperationException("Not yet implemented"); eval(labels, networkPredictions, maskArray);
} }
@Override @Override

View File

@ -229,7 +229,7 @@ public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
@Override @Override
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) { public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
throw new UnsupportedOperationException("Not yet implemented"); eval(labels, networkPredictions, maskArray);
} }
@Override @Override

View File

@ -3556,4 +3556,52 @@ public class SameDiffTests extends BaseNd4jTest {
assertTrue(msg, msg.contains("\"labels\"") && msg.contains("No array was provided")); assertTrue(msg, msg.contains("\"labels\"") && msg.contains("No array was provided"));
} }
} }
@Test
public void testEquals1(){
SameDiff sd1 = SameDiff.create();
SameDiff sd2 = SameDiff.create();
assertEquals(sd1, sd2);
SDVariable p1 = sd1.placeHolder("ph", DataType.FLOAT, -1, 10);
SDVariable p2 = sd2.placeHolder("ph", DataType.FLOAT, -1, 10);
assertEquals(sd1, sd2);
SDVariable w1 = sd1.constant("c1",1.0f);
SDVariable w2 = sd2.constant("c1",1.0f);
assertEquals(sd1, sd2);
SDVariable a1 = p1.add("add", w1);
SDVariable a2 = p2.add("add", w2);
assertEquals(sd1, sd2);
SDVariable w1a = sd1.constant("c2", 2.0f);
SDVariable w2a = sd2.constant("cX", 2.0f);
assertNotEquals(sd1, sd2);
w2a.rename("c2");
assertEquals(sd1, sd2);
sd2.createGradFunction("ph");
assertEquals(sd1, sd2);
w2a.getArr().assign(3.0f);
assertNotEquals(sd1, sd2);
w1a.getArr().assign(3.0f);
assertEquals(sd1, sd2);
SDVariable s1 = p1.sub("op", w1);
SDVariable s2 = p2.add("op", w1);
assertNotEquals(sd1, sd2);
}
} }

View File

@ -61,7 +61,7 @@ public class OpsMappingTests extends BaseNd4jTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 90000L; return 180000L; //Can be slow on some CI machines such as PPC
} }
@Test @Test

View File

@ -95,7 +95,7 @@ public class Downloader {
} }
// try extracting // try extracting
try{ try{
ArchiveUtils.unzipFileTo(f.getAbsolutePath(), extractToDir.getAbsolutePath()); ArchiveUtils.unzipFileTo(f.getAbsolutePath(), extractToDir.getAbsolutePath(), false);
} catch (Throwable t){ } catch (Throwable t){
log.warn("Error extracting {} files from file {} - retrying...", name, f.getAbsolutePath(), t); log.warn("Error extracting {} files from file {} - retrying...", name, f.getAbsolutePath(), t);
f.delete(); f.delete();

View File

@ -51,6 +51,10 @@ public class ArchiveUtils {
* @throws IOException * @throws IOException
*/ */
public static void unzipFileTo(String file, String dest) throws IOException { public static void unzipFileTo(String file, String dest) throws IOException {
unzipFileTo(file, dest, true);
}
public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException {
File target = new File(file); File target = new File(file);
if (!target.exists()) if (!target.exists())
throw new IllegalArgumentException("Archive doesnt exist"); throw new IllegalArgumentException("Archive doesnt exist");
@ -93,7 +97,9 @@ public class ArchiveUtils {
fos.close(); fos.close();
ze = zis.getNextEntry(); ze = zis.getNextEntry();
log.debug("File extracted: " + newFile.getAbsoluteFile()); if(logFiles) {
log.info("File extracted: " + newFile.getAbsoluteFile());
}
} }
zis.closeEntry(); zis.closeEntry();
@ -112,7 +118,9 @@ public class ArchiveUtils {
TarArchiveEntry entry; TarArchiveEntry entry;
/* Read the tar entries using the getNextEntry method **/ /* Read the tar entries using the getNextEntry method **/
while ((entry = (TarArchiveEntry) tarIn.getNextEntry()) != null) { while ((entry = (TarArchiveEntry) tarIn.getNextEntry()) != null) {
log.info("Extracting: " + entry.getName()); if(logFiles) {
log.info("Extracting: " + entry.getName());
}
/* If the entry is a directory, create the directory. */ /* If the entry is a directory, create the directory. */
if (entry.isDirectory()) { if (entry.isDirectory()) {