Assorted fixes (#318)
* #8777 MultiLayerNetwork.evaluate(MultiDataSetIterator) overload Signed-off-by: Alex Black <blacka101@gmail.com> * #8768 SameDiff.equals Signed-off-by: Alex Black <blacka101@gmail.com> * #8750 shade freemarker library and switch to it in DL4J UI Signed-off-by: Alex Black <blacka101@gmail.com> * #8704 DL4J UI redirect Signed-off-by: Alex Black <blacka101@gmail.com> * #8776 RecordReaderDataSetIterator builder collectMetaData fix Signed-off-by: Alex Black <blacka101@gmail.com> * #8718 Fix DL4J doEvaluation metadata Signed-off-by: Alex Black <blacka101@gmail.com> * #8715 ArchiveUtils - Add option to not log every extracted file Signed-off-by: Alex Black <blacka101@gmail.com> * No exception for evaluations that don't support metadata Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * #8765 CompGraph+MDS fix for SharedTrainingMaster Signed-off-by: Alex Black <blacka101@gmail.com> * small fix Signed-off-by: Alex Black <blacka101@gmail.com> * Timeout Signed-off-by: Alex Black <blacka101@gmail.com> * Ignore Signed-off-by: Alex Black <blacka101@gmail.com> * Revert freemarker shading Signed-off-by: Alex Black <blacka101@gmail.com> * Ignore Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
9970aadc5a
commit
63c9223bc2
|
@ -1381,4 +1381,17 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||||
assertNotNull(ds.getFeatures());
|
assertNotNull(ds.getFeatures());
|
||||||
assertNull(ds.getLabels());
|
assertNull(ds.getLabels());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCollectMetaData(){
|
||||||
|
RecordReaderDataSetIterator trainIter = new RecordReaderDataSetIterator.Builder(new CollectionRecordReader(Collections.<List<Writable>>emptyList()), 1)
|
||||||
|
.collectMetaData(true)
|
||||||
|
.build();
|
||||||
|
assertTrue(trainIter.isCollectMetaData());
|
||||||
|
trainIter.setCollectMetaData(false);
|
||||||
|
assertFalse(trainIter.isCollectMetaData());
|
||||||
|
trainIter.setCollectMetaData(true);
|
||||||
|
assertTrue(trainIter.isCollectMetaData());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,6 @@ import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
||||||
import org.deeplearning4j.eval.meta.Prediction;
|
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.*;
|
import org.deeplearning4j.nn.conf.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
@ -52,19 +51,13 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
|
||||||
import org.nd4j.linalg.learning.config.Sgd;
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import org.nd4j.linalg.util.FeatureUtil;
|
|
||||||
import org.nd4j.resources.Resources;
|
import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
|
|
||||||
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by agibsonccc on 12/22/14.
|
* Created by agibsonccc on 12/22/14.
|
||||||
|
@ -165,7 +158,7 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix());
|
assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
@Test
|
||||||
public void testEvaluationWithMetaData() throws Exception {
|
public void testEvaluationWithMetaData() throws Exception {
|
||||||
|
|
||||||
RecordReader csv = new CSVRecordReader();
|
RecordReader csv = new CSVRecordReader();
|
||||||
|
@ -256,6 +249,30 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
assertEquals(actualCounts[i], actualClassI.size());
|
assertEquals(actualCounts[i], actualClassI.size());
|
||||||
assertEquals(predictedCounts[i], predictedClassI.size());
|
assertEquals(predictedCounts[i], predictedClassI.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//Finally: test doEvaluation methods
|
||||||
|
rrdsi.reset();
|
||||||
|
org.nd4j.evaluation.classification.Evaluation e2 = new org.nd4j.evaluation.classification.Evaluation();
|
||||||
|
net.doEvaluation(rrdsi, e2);
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
List<org.nd4j.evaluation.meta.Prediction> actualClassI = e2.getPredictionsByActualClass(i);
|
||||||
|
List<org.nd4j.evaluation.meta.Prediction> predictedClassI = e2.getPredictionByPredictedClass(i);
|
||||||
|
assertEquals(actualCounts[i], actualClassI.size());
|
||||||
|
assertEquals(predictedCounts[i], predictedClassI.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputationGraph cg = net.toComputationGraph();
|
||||||
|
rrdsi.reset();
|
||||||
|
e2 = new org.nd4j.evaluation.classification.Evaluation();
|
||||||
|
cg.doEvaluation(rrdsi, e2);
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
List<org.nd4j.evaluation.meta.Prediction> actualClassI = e2.getPredictionsByActualClass(i);
|
||||||
|
List<org.nd4j.evaluation.meta.Prediction> predictedClassI = e2.getPredictionByPredictedClass(i);
|
||||||
|
assertEquals(actualCounts[i], actualClassI.size());
|
||||||
|
assertEquals(predictedCounts[i], predictedClassI.size());
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void apply(org.nd4j.evaluation.classification.Evaluation e, int nTimes, INDArray predicted, INDArray actual) {
|
private static void apply(org.nd4j.evaluation.classification.Evaluation e, int nTimes, INDArray predicted, INDArray actual) {
|
||||||
|
@ -504,11 +521,11 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{ds.getFeatures()}, new INDArray[]{ds.getLabels(), ds.getLabels()}));
|
list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{ds.getFeatures()}, new INDArray[]{ds.getLabels(), ds.getLabels()}));
|
||||||
}
|
}
|
||||||
|
|
||||||
Evaluation e = new Evaluation();
|
org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation();
|
||||||
RegressionEvaluation e2 = new RegressionEvaluation();
|
org.nd4j.evaluation.regression.RegressionEvaluation e2 = new org.nd4j.evaluation.regression.RegressionEvaluation();
|
||||||
Map<Integer,IEvaluation[]> evals = new HashMap<>();
|
Map<Integer,org.nd4j.evaluation.IEvaluation[]> evals = new HashMap<>();
|
||||||
evals.put(0, new IEvaluation[]{(IEvaluation) e});
|
evals.put(0, new org.nd4j.evaluation.IEvaluation[]{e});
|
||||||
evals.put(1, new IEvaluation[]{(IEvaluation) e2});
|
evals.put(1, new org.nd4j.evaluation.IEvaluation[]{e2});
|
||||||
|
|
||||||
cg.evaluate(new IteratorMultiDataSetIterator(list.iterator(), 30), evals);
|
cg.evaluate(new IteratorMultiDataSetIterator(list.iterator(), 30), evals);
|
||||||
|
|
||||||
|
@ -567,14 +584,14 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
net.evaluateROC(iter);
|
net.evaluateROC(iter, 0);
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IllegalStateException e){
|
} catch (IllegalStateException e){
|
||||||
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
|
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
net.evaluateROCMultiClass(iter);
|
net.evaluateROCMultiClass(iter, 0);
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IllegalStateException e){
|
} catch (IllegalStateException e){
|
||||||
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
|
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
|
||||||
|
@ -589,14 +606,14 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
cg.evaluateROC(iter);
|
cg.evaluateROC(iter, 0);
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IllegalStateException e){
|
} catch (IllegalStateException e){
|
||||||
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
|
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
cg.evaluateROCMultiClass(iter);
|
cg.evaluateROCMultiClass(iter, 0);
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IllegalStateException e){
|
} catch (IllegalStateException e){
|
||||||
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
|
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
|
||||||
|
@ -606,10 +623,10 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
//Disable validation, and check same thing:
|
//Disable validation, and check same thing:
|
||||||
net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false);
|
net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false);
|
||||||
net.evaluate(iter);
|
net.evaluate(iter);
|
||||||
net.evaluateROCMultiClass(iter);
|
net.evaluateROCMultiClass(iter, 0);
|
||||||
|
|
||||||
cg.getConfiguration().setValidateOutputLayerConfig(false);
|
cg.getConfiguration().setValidateOutputLayerConfig(false);
|
||||||
cg.evaluate(iter);
|
cg.evaluate(iter);
|
||||||
cg.evaluateROCMultiClass(iter);
|
cg.evaluateROCMultiClass(iter, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class RegressionEvalTest extends BaseDL4JTest {
|
||||||
|
|
||||||
DataSet ds = new DataSet(f, l);
|
DataSet ds = new DataSet(f, l);
|
||||||
DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds));
|
DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds));
|
||||||
RegressionEvaluation re = net.evaluateRegression(iter);
|
org.nd4j.evaluation.regression.RegressionEvaluation re = net.evaluateRegression(iter);
|
||||||
|
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < 5; i++) {
|
||||||
assertEquals(1.0, re.meanSquaredError(i), 1e-6);
|
assertEquals(1.0, re.meanSquaredError(i), 1e-6);
|
||||||
|
|
|
@ -86,7 +86,7 @@ public abstract class CacheableExtractableDataSetFetcher implements CacheableDat
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
ArchiveUtils.unzipFileTo(tmpFile.getAbsolutePath(), localCacheDir.getAbsolutePath());
|
ArchiveUtils.unzipFileTo(tmpFile.getAbsolutePath(), localCacheDir.getAbsolutePath(), false);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
//Catch any errors during extraction, and delete the directory to avoid leaving the dir in an invalid state
|
//Catch any errors during extraction, and delete the directory to avoid leaving the dir in an invalid state
|
||||||
if(localCacheDir.exists())
|
if(localCacheDir.exists())
|
||||||
|
|
|
@ -205,6 +205,7 @@ public class RecordReaderDataSetIterator implements DataSetIterator {
|
||||||
this.numPossibleLabels = b.numPossibleLabels;
|
this.numPossibleLabels = b.numPossibleLabels;
|
||||||
this.regression = b.regression;
|
this.regression = b.regression;
|
||||||
this.preProcessor = b.preProcessor;
|
this.preProcessor = b.preProcessor;
|
||||||
|
this.collectMetaData = b.collectMetaData;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -67,7 +67,7 @@ public class KuromojiBinFilesFetcher {
|
||||||
new URL("https://dl4jdata.blob.core.windows.net/kuromoji/kuromoji_bin_files.tar.gz"),
|
new URL("https://dl4jdata.blob.core.windows.net/kuromoji/kuromoji_bin_files.tar.gz"),
|
||||||
tarFile);
|
tarFile);
|
||||||
}
|
}
|
||||||
ArchiveUtils.unzipFileTo(tarFile.getAbsolutePath(), rootDir.getAbsolutePath());
|
ArchiveUtils.unzipFileTo(tarFile.getAbsolutePath(), rootDir.getAbsolutePath(), false);
|
||||||
|
|
||||||
return rootDir.getAbsoluteFile();
|
return rootDir.getAbsoluteFile();
|
||||||
}
|
}
|
||||||
|
|
|
@ -4170,6 +4170,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
INDArray[] featuresMasks = next.getFeaturesMaskArrays();
|
INDArray[] featuresMasks = next.getFeaturesMaskArrays();
|
||||||
INDArray[] labels = next.getLabels();
|
INDArray[] labels = next.getLabels();
|
||||||
INDArray[] labelMasks = next.getLabelsMaskArrays();
|
INDArray[] labelMasks = next.getLabelsMaskArrays();
|
||||||
|
List<Serializable> meta = next.getExampleMetaData();
|
||||||
|
|
||||||
try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) {
|
try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) {
|
||||||
INDArray[] out = outputOfLayersDetached(false, FwdPassType.STANDARD, getOutputLayerIndices(), features, featuresMasks, labelMasks, true, false, ws);
|
INDArray[] out = outputOfLayersDetached(false, FwdPassType.STANDARD, getOutputLayerIndices(), features, featuresMasks, labelMasks, true, false, ws);
|
||||||
|
@ -4188,7 +4189,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
|
|
||||||
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
for (IEvaluation evaluation : evalsThisOutput)
|
for (IEvaluation evaluation : evalsThisOutput)
|
||||||
evaluation.eval(currLabel, currOut, next.getLabelsMaskArray(i));
|
evaluation.eval(currLabel, currOut, next.getLabelsMaskArray(i), meta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,9 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class ComputationGraphUtil {
|
public class ComputationGraphUtil {
|
||||||
|
|
||||||
private ComputationGraphUtil() {}
|
private ComputationGraphUtil() {}
|
||||||
|
@ -33,13 +36,16 @@ public class ComputationGraphUtil {
|
||||||
INDArray l = dataSet.getLabels();
|
INDArray l = dataSet.getLabels();
|
||||||
INDArray fMask = dataSet.getFeaturesMaskArray();
|
INDArray fMask = dataSet.getFeaturesMaskArray();
|
||||||
INDArray lMask = dataSet.getLabelsMaskArray();
|
INDArray lMask = dataSet.getLabelsMaskArray();
|
||||||
|
List<Serializable> meta = dataSet.getExampleMetaData();
|
||||||
|
|
||||||
INDArray[] fNew = f == null ? null : new INDArray[] {f};
|
INDArray[] fNew = f == null ? null : new INDArray[] {f};
|
||||||
INDArray[] lNew = l == null ? null : new INDArray[] {l};
|
INDArray[] lNew = l == null ? null : new INDArray[] {l};
|
||||||
INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null);
|
INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null);
|
||||||
INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null);
|
INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null);
|
||||||
|
|
||||||
return new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
|
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
|
||||||
|
mds.setExampleMetaData(meta);
|
||||||
|
return mds;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Convert a DataSetIterator to a MultiDataSetIterator, via an adaptor class */
|
/** Convert a DataSetIterator to a MultiDataSetIterator, via an adaptor class */
|
||||||
|
|
|
@ -25,14 +25,11 @@ import lombok.val;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.adapters.OutputAdapter;
|
|
||||||
import org.nd4j.linalg.dataset.AsyncDataSetIterator;;
|
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||||
import org.deeplearning4j.eval.RegressionEvaluation;
|
|
||||||
import org.deeplearning4j.exception.DL4JException;
|
import org.deeplearning4j.exception.DL4JException;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.nn.api.*;
|
|
||||||
import org.deeplearning4j.nn.api.Updater;
|
import org.deeplearning4j.nn.api.Updater;
|
||||||
|
import org.deeplearning4j.nn.api.*;
|
||||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||||
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
|
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
|
||||||
import org.deeplearning4j.nn.conf.*;
|
import org.deeplearning4j.nn.conf.*;
|
||||||
|
@ -44,8 +41,8 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.FrozenLayer;
|
import org.deeplearning4j.nn.layers.FrozenLayer;
|
||||||
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
|
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
|
||||||
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
|
|
||||||
import org.deeplearning4j.nn.layers.LayerHelper;
|
import org.deeplearning4j.nn.layers.LayerHelper;
|
||||||
|
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
|
||||||
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
|
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
|
||||||
import org.deeplearning4j.nn.updater.UpdaterCreator;
|
import org.deeplearning4j.nn.updater.UpdaterCreator;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
@ -58,19 +55,23 @@ import org.deeplearning4j.util.CrashReportingUtil;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
import org.deeplearning4j.util.NetworkUtils;
|
import org.deeplearning4j.util.NetworkUtils;
|
||||||
import org.deeplearning4j.util.OutputLayerUtil;
|
import org.deeplearning4j.util.OutputLayerUtil;
|
||||||
|
import org.nd4j.adapters.OutputAdapter;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.evaluation.classification.ROC;
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
||||||
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
||||||
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
@ -84,7 +85,6 @@ import org.nd4j.linalg.heartbeat.reports.Task;
|
||||||
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
|
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
|
||||||
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
|
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.primitives.Triple;
|
import org.nd4j.linalg.primitives.Triple;
|
||||||
import org.nd4j.linalg.schedule.ISchedule;
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
|
@ -96,6 +96,8 @@ import org.nd4j.util.OneTimeLogger;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* MultiLayerNetwork is a neural network with multiple layers in a stack, and usually an output layer.<br>
|
* MultiLayerNetwork is a neural network with multiple layers in a stack, and usually an output layer.<br>
|
||||||
|
@ -3315,19 +3317,39 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
* @param iterator Iterator to evaluate on
|
* @param iterator Iterator to evaluate on
|
||||||
* @return Evaluation object; results of evaluation on all examples in the data set
|
* @return Evaluation object; results of evaluation on all examples in the data set
|
||||||
*/
|
*/
|
||||||
public <T extends Evaluation> T evaluate(DataSetIterator iterator) {
|
public <T extends Evaluation> T evaluate(@NonNull DataSetIterator iterator) {
|
||||||
return (T)evaluate(iterator, null);
|
return (T)evaluate(iterator, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluate the network (classification performance).
|
||||||
|
* Can only be used with MultiDataSetIterator instances with a single input/output array
|
||||||
|
*
|
||||||
|
* @param iterator Iterator to evaluate on
|
||||||
|
* @return Evaluation object; results of evaluation on all examples in the data set
|
||||||
|
*/
|
||||||
|
public Evaluation evaluate(@NonNull MultiDataSetIterator iterator) {
|
||||||
|
return evaluate(new MultiDataSetWrapperIterator(iterator));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Evaluate the network for regression performance
|
* Evaluate the network for regression performance
|
||||||
* @param iterator Data to evaluate on
|
* @param iterator Data to evaluate on
|
||||||
* @return
|
* @return Regression evaluation
|
||||||
*/
|
*/
|
||||||
public <T extends RegressionEvaluation> T evaluateRegression(DataSetIterator iterator) {
|
public <T extends RegressionEvaluation> T evaluateRegression(DataSetIterator iterator) {
|
||||||
return (T)doEvaluation(iterator, new RegressionEvaluation(iterator.totalOutcomes()))[0];
|
return (T)doEvaluation(iterator, new RegressionEvaluation(iterator.totalOutcomes()))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluate the network for regression performance
|
||||||
|
* Can only be used with MultiDataSetIterator instances with a single input/output array
|
||||||
|
* @param iterator Data to evaluate on
|
||||||
|
*/
|
||||||
|
public org.nd4j.evaluation.regression.RegressionEvaluation evaluateRegression(MultiDataSetIterator iterator) {
|
||||||
|
return evaluateRegression(new MultiDataSetWrapperIterator(iterator));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
|
* @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
|
||||||
*/
|
*/
|
||||||
|
@ -3424,6 +3446,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
INDArray labels = next.getLabels();
|
INDArray labels = next.getLabels();
|
||||||
INDArray fMask = next.getFeaturesMaskArray();
|
INDArray fMask = next.getFeaturesMaskArray();
|
||||||
INDArray lMask = next.getLabelsMaskArray();
|
INDArray lMask = next.getLabelsMaskArray();
|
||||||
|
List<Serializable> meta = next.getExampleMetaData();
|
||||||
|
|
||||||
|
|
||||||
if (!useRnnSegments) {
|
if (!useRnnSegments) {
|
||||||
|
@ -3433,7 +3456,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
for (T evaluation : evaluations)
|
for (T evaluation : evaluations)
|
||||||
evaluation.eval(labels, out, lMask);
|
evaluation.eval(labels, out, lMask, meta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -222,8 +222,11 @@ public class AdaptiveThresholdAlgorithm implements ThresholdAlgorithm {
|
||||||
if(a == null || Double.isNaN(a.lastThreshold))
|
if(a == null || Double.isNaN(a.lastThreshold))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
|
||||||
lastThresholdSum += a.lastThreshold;
|
lastThresholdSum += a.lastThreshold;
|
||||||
|
if (!Double.isNaN(a.lastSparsity)) {
|
||||||
lastSparsitySum += a.lastSparsity;
|
lastSparsitySum += a.lastSparsity;
|
||||||
|
}
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,16 +38,22 @@
|
||||||
<artifactId>nd4j-aeron</artifactId>
|
<artifactId>nd4j-aeron</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
|
|
||||||
<version>${nd4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<groupId>org.deeplearning4j</groupId>
|
||||||
<artifactId>dl4j-spark_2.11</artifactId>
|
<artifactId>dl4j-spark_2.11</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
|
||||||
|
<version>${nd4j.version}</version>
|
||||||
|
<exclusions>
|
||||||
|
<exclusion>
|
||||||
|
<groupId>net.jpountz.lz4</groupId>
|
||||||
|
<artifactId>lz4</artifactId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.projectlombok</groupId>
|
<groupId>org.projectlombok</groupId>
|
||||||
<artifactId>lombok</artifactId>
|
<artifactId>lombok</artifactId>
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.nd4j.linalg.dataset.api.iterator.ParallelMultiDataSetIterator;
|
||||||
|
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This MultiDataSetIterator implementation does accumulation of MultiDataSets from different Spark executors, wrt Thread/Device Affinity
|
* This MultiDataSetIterator implementation does accumulation of MultiDataSets from different Spark executors, wrt Thread/Device Affinity
|
||||||
|
@ -32,14 +33,16 @@ import java.util.List;
|
||||||
public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator {
|
public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator {
|
||||||
|
|
||||||
protected final List<Iterator<MultiDataSet>> iterators;
|
protected final List<Iterator<MultiDataSet>> iterators;
|
||||||
|
protected final AtomicInteger position;
|
||||||
|
|
||||||
public VirtualMultiDataSetIterator(@NonNull List<Iterator<MultiDataSet>> iterators) {
|
public VirtualMultiDataSetIterator(@NonNull List<Iterator<MultiDataSet>> iterators) {
|
||||||
this.iterators = iterators;
|
this.iterators = iterators;
|
||||||
|
this.position = new AtomicInteger(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MultiDataSet next(int num) {
|
public MultiDataSet next(int num) {
|
||||||
return null;
|
return next();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -59,27 +62,34 @@ public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean asyncSupported() {
|
public boolean asyncSupported() {
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void reset() {
|
public void reset() {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean hasNext() {
|
public boolean hasNext() {
|
||||||
return false;
|
// just checking if that's not the last iterator, or if that's the last one - check if it has something
|
||||||
|
boolean ret = position.get() < iterators.size() - 1
|
||||||
|
|| (position.get() < iterators.size() && iterators.get(position.get()).hasNext());
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MultiDataSet next() {
|
public MultiDataSet next() {
|
||||||
return null;
|
// TODO: this solution isn't ideal, it assumes non-empty iterators all the time. Would be nice to do something here
|
||||||
|
if (!iterators.get(position.get()).hasNext())
|
||||||
|
position.getAndIncrement();
|
||||||
|
|
||||||
|
return iterators.get(position.get()).next();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void remove() {
|
public void remove() {
|
||||||
|
// no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -109,6 +109,7 @@ public class SharedTrainingWrapper {
|
||||||
|
|
||||||
// now we're creating DataSetIterators, to feed ParallelWrapper
|
// now we're creating DataSetIterators, to feed ParallelWrapper
|
||||||
iteratorDS = new VirtualDataSetIterator(iteratorsDS);
|
iteratorDS = new VirtualDataSetIterator(iteratorsDS);
|
||||||
|
iteratorMDS = new VirtualMultiDataSetIterator(iteratorsMDS);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static synchronized SharedTrainingWrapper getInstance(long id) {
|
public static synchronized SharedTrainingWrapper getInstance(long id) {
|
||||||
|
@ -447,17 +448,19 @@ public class SharedTrainingWrapper {
|
||||||
throw new DL4JInvalidConfigException("No iterators were defined for training");
|
throw new DL4JInvalidConfigException("No iterators were defined for training");
|
||||||
|
|
||||||
try {
|
try {
|
||||||
while((iteratorDS != null && iteratorDS.hasNext()) || (iteratorMDS != null && iteratorMDS.hasNext())) {
|
boolean dsNext;
|
||||||
|
boolean mdsNext;
|
||||||
|
while((dsNext = iteratorDS != null && iteratorDS.hasNext()) || (mdsNext = iteratorMDS != null && iteratorMDS.hasNext())) {
|
||||||
//Loop as a guard against concurrent modifications and RCs
|
//Loop as a guard against concurrent modifications and RCs
|
||||||
|
|
||||||
if (wrapper != null) {
|
if (wrapper != null) {
|
||||||
if (iteratorDS != null)
|
if (dsNext)
|
||||||
wrapper.fit(iteratorDS);
|
wrapper.fit(iteratorDS);
|
||||||
else
|
else
|
||||||
wrapper.fit(iteratorMDS);
|
wrapper.fit(iteratorMDS);
|
||||||
} else {
|
} else {
|
||||||
// if wrapper is null, we're fitting standalone model then
|
// if wrapper is null, we're fitting standalone model then
|
||||||
if (iteratorDS != null) {
|
if (dsNext) {
|
||||||
if (model instanceof ComputationGraph) {
|
if (model instanceof ComputationGraph) {
|
||||||
((ComputationGraph) originalModel).fit(iteratorDS);
|
((ComputationGraph) originalModel).fit(iteratorDS);
|
||||||
} else if (model instanceof MultiLayerNetwork) {
|
} else if (model instanceof MultiLayerNetwork) {
|
||||||
|
@ -472,6 +475,7 @@ public class SharedTrainingWrapper {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(consumer != null)
|
||||||
consumer.getUpdatesQueue().purge();
|
consumer.getUpdatesQueue().purge();
|
||||||
}
|
}
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
|
|
|
@ -116,8 +116,7 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
|
||||||
}
|
}
|
||||||
|
|
||||||
protected int numExecutors() {
|
protected int numExecutors() {
|
||||||
int numProc = Runtime.getRuntime().availableProcessors();
|
return 4;
|
||||||
return Math.min(4, numProc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected MultiLayerConfiguration getBasicConf() {
|
protected MultiLayerConfiguration getBasicConf() {
|
||||||
|
|
|
@ -49,6 +49,7 @@ import org.junit.rules.TemporaryFolder;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.learning.config.AMSGrad;
|
import org.nd4j.linalg.learning.config.AMSGrad;
|
||||||
|
@ -66,20 +67,26 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Ignore("AB 2019/05/21 - Failing - Issue #7657")
|
//@Ignore("AB 2019/05/21 - Failing - Issue #7657")
|
||||||
public class GradientSharingTrainingTest extends BaseSparkTest {
|
public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void trainSanityCheck() throws Exception {
|
public void trainSanityCheck() throws Exception {
|
||||||
|
|
||||||
|
for(boolean mds : new boolean[]{false, true}) {
|
||||||
INDArray last = null;
|
INDArray last = null;
|
||||||
INDArray lastDup = null;
|
INDArray lastDup = null;
|
||||||
for (String s : new String[]{"paths", "direct", "export"}) {
|
for (String s : new String[]{"paths", "direct", "export"}) {
|
||||||
System.out.println("--------------------------------------------------------------------------------------------------------------");
|
System.out.println("--------------------------------------------------------------------------------------------------------------");
|
||||||
log.info("Starting: {}", s);
|
log.info("Starting: {} - {}", s, (mds ? "MultiDataSet" : "DataSet"));
|
||||||
boolean isPaths = "paths".equals(s);
|
boolean isPaths = "paths".equals(s);
|
||||||
|
|
||||||
RDDTrainingApproach rddTrainingApproach;
|
RDDTrainingApproach rddTrainingApproach;
|
||||||
|
@ -144,7 +151,11 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
DataSet d = iter.next();
|
DataSet d = iter.next();
|
||||||
if (isPaths) {
|
if (isPaths) {
|
||||||
File out = new File(f, count + ".bin");
|
File out = new File(f, count + ".bin");
|
||||||
|
if(mds){
|
||||||
|
d.toMultiDataSet().save(out);
|
||||||
|
} else {
|
||||||
d.save(out);
|
d.save(out);
|
||||||
|
}
|
||||||
String path = "file:///" + out.getAbsolutePath().replaceAll("\\\\", "/");
|
String path = "file:///" + out.getAbsolutePath().replaceAll("\\\\", "/");
|
||||||
paths.add(path);
|
paths.add(path);
|
||||||
}
|
}
|
||||||
|
@ -160,6 +171,27 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
|
|
||||||
INDArray paramsBefore = sparkNet.getNetwork().params().dup();
|
INDArray paramsBefore = sparkNet.getNetwork().params().dup();
|
||||||
ComputationGraph after;
|
ComputationGraph after;
|
||||||
|
if(mds) {
|
||||||
|
//Fitting from MultiDataSet
|
||||||
|
List<MultiDataSet> mdsList = new ArrayList<>();
|
||||||
|
for(DataSet d : ds){
|
||||||
|
mdsList.add(d.toMultiDataSet());
|
||||||
|
}
|
||||||
|
switch (s) {
|
||||||
|
case "direct":
|
||||||
|
case "export":
|
||||||
|
JavaRDD<MultiDataSet> dsRDD = sc.parallelize(mdsList);
|
||||||
|
after = sparkNet.fitMultiDataSet(dsRDD);
|
||||||
|
break;
|
||||||
|
case "paths":
|
||||||
|
JavaRDD<String> pathRdd = sc.parallelize(paths);
|
||||||
|
after = sparkNet.fitPathsMultiDataSet(pathRdd);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
//Fitting from DataSet
|
||||||
switch (s) {
|
switch (s) {
|
||||||
case "direct":
|
case "direct":
|
||||||
case "export":
|
case "export":
|
||||||
|
@ -173,6 +205,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
INDArray paramsAfter = after.params();
|
INDArray paramsAfter = after.params();
|
||||||
System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||||
|
@ -199,6 +232,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
lastDup = last.dup();
|
lastDup = last.dup();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -289,7 +323,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test @Ignore
|
||||||
public void testEpochUpdating() throws Exception {
|
public void testEpochUpdating() throws Exception {
|
||||||
//Ensure that epoch counter is incremented properly on the workers
|
//Ensure that epoch counter is incremented properly on the workers
|
||||||
|
|
||||||
|
@ -316,7 +350,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
|
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new AMSGrad(0.1))
|
.updater(new AMSGrad(0.001))
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX)
|
.layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX)
|
||||||
|
|
|
@ -20,12 +20,12 @@ log4j.appender.Console.layout=org.apache.log4j.PatternLayout
|
||||||
log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
|
log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
|
||||||
|
|
||||||
log4j.appender.org.springframework=DEBUG
|
log4j.appender.org.springframework=DEBUG
|
||||||
log4j.appender.org.deeplearning4j=DEBUG
|
log4j.appender.org.deeplearning4j=INFO
|
||||||
log4j.appender.org.nd4j=DEBUG
|
log4j.appender.org.nd4j=INFO
|
||||||
|
|
||||||
log4j.logger.org.springframework=INFO
|
log4j.logger.org.springframework=INFO
|
||||||
log4j.logger.org.deeplearning4j=DEBUG
|
log4j.logger.org.deeplearning4j=INFO
|
||||||
log4j.logger.org.nd4j=DEBUG
|
log4j.logger.org.nd4j=INFO
|
||||||
log4j.logger.org.apache.spark=WARN
|
log4j.logger.org.apache.spark=WARN
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@
|
||||||
|
|
||||||
<logger name="org.apache.catalina.core" level="DEBUG" />
|
<logger name="org.apache.catalina.core" level="DEBUG" />
|
||||||
<logger name="org.springframework" level="DEBUG" />
|
<logger name="org.springframework" level="DEBUG" />
|
||||||
<logger name="org.deeplearning4j" level="DEBUG" />
|
<logger name="org.deeplearning4j" level="INFO" />
|
||||||
<logger name="org.datavec" level="INFO" />
|
<logger name="org.datavec" level="INFO" />
|
||||||
<logger name="org.nd4j" level="INFO" />
|
<logger name="org.nd4j" level="INFO" />
|
||||||
<logger name="opennlp.uima.util" level="OFF" />
|
<logger name="opennlp.uima.util" level="OFF" />
|
||||||
|
|
|
@ -25,10 +25,6 @@
|
||||||
|
|
||||||
<artifactId>deeplearning4j-ui-components</artifactId>
|
<artifactId>deeplearning4j-ui-components</artifactId>
|
||||||
|
|
||||||
<properties>
|
|
||||||
<freemarker.version>2.3.23</freemarker.version>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.projectlombok</groupId>
|
<groupId>org.projectlombok</groupId>
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.ui.components.chart.style.StyleChart;
|
||||||
import org.deeplearning4j.ui.components.table.ComponentTable;
|
import org.deeplearning4j.ui.components.table.ComponentTable;
|
||||||
import org.deeplearning4j.ui.components.table.style.StyleTable;
|
import org.deeplearning4j.ui.components.table.style.StyleTable;
|
||||||
import org.deeplearning4j.ui.standalone.StaticPageUtil;
|
import org.deeplearning4j.ui.standalone.StaticPageUtil;
|
||||||
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
|
|
|
@ -60,7 +60,7 @@
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.freemarker</groupId>
|
<groupId>org.freemarker</groupId>
|
||||||
<artifactId>freemarker</artifactId>
|
<artifactId>freemarker</artifactId>
|
||||||
<version>2.3.29</version>
|
<version>${freemarker.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|
|
@ -200,6 +200,7 @@ public class TrainModule implements UIModule {
|
||||||
}));
|
}));
|
||||||
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
|
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
|
||||||
} else {
|
} else {
|
||||||
|
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview")));
|
||||||
r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
|
r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
|
||||||
r.add(new Route("/train/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)));
|
r.add(new Route("/train/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)));
|
||||||
r.add(new Route("/train/overview", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingOverview.html.ftl", rc)));
|
r.add(new Route("/train/overview", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingOverview.html.ftl", rc)));
|
||||||
|
|
|
@ -1654,29 +1654,6 @@ public class SDVariable implements Serializable {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (!(o instanceof SDVariable)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SDVariable that = (SDVariable) o;
|
|
||||||
|
|
||||||
if (!Objects.equals(varName, that.varName)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (variableType != that.variableType) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if(sameDiff != that.sameDiff){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return dataType == that.dataType;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
int result = super.hashCode();
|
int result = super.hashCode();
|
||||||
|
@ -1695,4 +1672,26 @@ public class SDVariable implements Serializable {
|
||||||
v.sameDiff = sd;
|
v.sameDiff = sd;
|
||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o){
|
||||||
|
if(o == this) return true;
|
||||||
|
if(!(o instanceof SDVariable))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
SDVariable s = (SDVariable)o;
|
||||||
|
if(!varName.equals(s.varName))
|
||||||
|
return false;
|
||||||
|
if(variableType != s.variableType)
|
||||||
|
return false;
|
||||||
|
if(dataType != s.dataType)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if(variableType == VariableType.VARIABLE || variableType == VariableType.CONSTANT){
|
||||||
|
INDArray a1 = getArr();
|
||||||
|
INDArray a2 = s.getArr();
|
||||||
|
return a1.equals(a2);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1234,13 +1234,14 @@ public class SameDiff extends SDBaseOps {
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass())
|
||||||
|
return false;
|
||||||
|
|
||||||
SameDiff sameDiff = (SameDiff) o;
|
SameDiff sameDiff = (SameDiff) o;
|
||||||
|
|
||||||
if (variables != null ? !variables.equals(sameDiff.variables) : sameDiff.variables != null)
|
boolean eqVars = variables.equals(sameDiff.variables);
|
||||||
return false;
|
boolean eqOps = ops.equals(sameDiff.ops);
|
||||||
return sameDiffFunctionInstances != null ? sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null;
|
return eqVars && eqOps;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -5843,4 +5844,10 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
return base + "_" + inc;
|
return base + "_" + inc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
return "SameDiff(nVars=" + variables.size() + ",nOps=" + ops.size() + ")";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,10 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.internal;
|
package org.nd4j.autodiff.samediff.internal;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.*;
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -28,6 +25,7 @@ import java.util.List;
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Data //TODO immutable?
|
@Data //TODO immutable?
|
||||||
@Builder
|
@Builder
|
||||||
|
@EqualsAndHashCode(exclude = {"gradient", "variableIndex"})
|
||||||
public class Variable {
|
public class Variable {
|
||||||
protected String name;
|
protected String name;
|
||||||
protected SDVariable variable;
|
protected SDVariable variable;
|
||||||
|
|
|
@ -173,9 +173,6 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
||||||
if(recordMetaData != null){
|
|
||||||
throw new UnsupportedOperationException("Evaluation with record metadata not yet implemented for EvaluationBinary");
|
|
||||||
}
|
|
||||||
eval(labels, networkPredictions, maskArray);
|
eval(labels, networkPredictions, maskArray);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -325,7 +325,7 @@ public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
eval(labels, networkPredictions, maskArray);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -229,7 +229,7 @@ public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
eval(labels, networkPredictions, maskArray);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -3556,4 +3556,52 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
assertTrue(msg, msg.contains("\"labels\"") && msg.contains("No array was provided"));
|
assertTrue(msg, msg.contains("\"labels\"") && msg.contains("No array was provided"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEquals1(){
|
||||||
|
|
||||||
|
SameDiff sd1 = SameDiff.create();
|
||||||
|
SameDiff sd2 = SameDiff.create();
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
SDVariable p1 = sd1.placeHolder("ph", DataType.FLOAT, -1, 10);
|
||||||
|
SDVariable p2 = sd2.placeHolder("ph", DataType.FLOAT, -1, 10);
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
SDVariable w1 = sd1.constant("c1",1.0f);
|
||||||
|
SDVariable w2 = sd2.constant("c1",1.0f);
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
SDVariable a1 = p1.add("add", w1);
|
||||||
|
SDVariable a2 = p2.add("add", w2);
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
SDVariable w1a = sd1.constant("c2", 2.0f);
|
||||||
|
SDVariable w2a = sd2.constant("cX", 2.0f);
|
||||||
|
|
||||||
|
assertNotEquals(sd1, sd2);
|
||||||
|
w2a.rename("c2");
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
sd2.createGradFunction("ph");
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
w2a.getArr().assign(3.0f);
|
||||||
|
|
||||||
|
assertNotEquals(sd1, sd2);
|
||||||
|
|
||||||
|
w1a.getArr().assign(3.0f);
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
SDVariable s1 = p1.sub("op", w1);
|
||||||
|
SDVariable s2 = p2.add("op", w1);
|
||||||
|
assertNotEquals(sd1, sd2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class OpsMappingTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 90000L;
|
return 180000L; //Can be slow on some CI machines such as PPC
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -95,7 +95,7 @@ public class Downloader {
|
||||||
}
|
}
|
||||||
// try extracting
|
// try extracting
|
||||||
try{
|
try{
|
||||||
ArchiveUtils.unzipFileTo(f.getAbsolutePath(), extractToDir.getAbsolutePath());
|
ArchiveUtils.unzipFileTo(f.getAbsolutePath(), extractToDir.getAbsolutePath(), false);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
log.warn("Error extracting {} files from file {} - retrying...", name, f.getAbsolutePath(), t);
|
log.warn("Error extracting {} files from file {} - retrying...", name, f.getAbsolutePath(), t);
|
||||||
f.delete();
|
f.delete();
|
||||||
|
|
|
@ -51,6 +51,10 @@ public class ArchiveUtils {
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
*/
|
*/
|
||||||
public static void unzipFileTo(String file, String dest) throws IOException {
|
public static void unzipFileTo(String file, String dest) throws IOException {
|
||||||
|
unzipFileTo(file, dest, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException {
|
||||||
File target = new File(file);
|
File target = new File(file);
|
||||||
if (!target.exists())
|
if (!target.exists())
|
||||||
throw new IllegalArgumentException("Archive doesnt exist");
|
throw new IllegalArgumentException("Archive doesnt exist");
|
||||||
|
@ -93,7 +97,9 @@ public class ArchiveUtils {
|
||||||
|
|
||||||
fos.close();
|
fos.close();
|
||||||
ze = zis.getNextEntry();
|
ze = zis.getNextEntry();
|
||||||
log.debug("File extracted: " + newFile.getAbsoluteFile());
|
if(logFiles) {
|
||||||
|
log.info("File extracted: " + newFile.getAbsoluteFile());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
zis.closeEntry();
|
zis.closeEntry();
|
||||||
|
@ -112,7 +118,9 @@ public class ArchiveUtils {
|
||||||
TarArchiveEntry entry;
|
TarArchiveEntry entry;
|
||||||
/* Read the tar entries using the getNextEntry method **/
|
/* Read the tar entries using the getNextEntry method **/
|
||||||
while ((entry = (TarArchiveEntry) tarIn.getNextEntry()) != null) {
|
while ((entry = (TarArchiveEntry) tarIn.getNextEntry()) != null) {
|
||||||
|
if(logFiles) {
|
||||||
log.info("Extracting: " + entry.getName());
|
log.info("Extracting: " + entry.getName());
|
||||||
|
}
|
||||||
/* If the entry is a directory, create the directory. */
|
/* If the entry is a directory, create the directory. */
|
||||||
|
|
||||||
if (entry.isDirectory()) {
|
if (entry.isDirectory()) {
|
||||||
|
|
Loading…
Reference in New Issue