diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java index c20b5855f..2b7121af4 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java @@ -1381,4 +1381,17 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { assertNotNull(ds.getFeatures()); assertNull(ds.getLabels()); } + + + @Test + public void testCollectMetaData(){ + RecordReaderDataSetIterator trainIter = new RecordReaderDataSetIterator.Builder(new CollectionRecordReader(Collections.>emptyList()), 1) + .collectMetaData(true) + .build(); + assertTrue(trainIter.isCollectMetaData()); + trainIter.setCollectMetaData(false); + assertFalse(trainIter.isCollectMetaData()); + trainIter.setCollectMetaData(true); + assertTrue(trainIter.isCollectMetaData()); + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java index 812ea2b08..bd65af6a3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java @@ -33,7 +33,6 @@ import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator; -import org.deeplearning4j.eval.meta.Prediction; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.layers.*; @@ -52,19 +51,13 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.INDArrayIndex; -import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import org.nd4j.linalg.util.FeatureUtil; import org.nd4j.resources.Resources; import java.util.*; import static org.junit.Assert.*; -import static org.nd4j.linalg.indexing.NDArrayIndex.all; -import static org.nd4j.linalg.indexing.NDArrayIndex.interval; /** * Created by agibsonccc on 12/22/14. @@ -165,7 +158,7 @@ public class EvalTest extends BaseDL4JTest { assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix()); } - @Test(timeout = 300000) + @Test public void testEvaluationWithMetaData() throws Exception { RecordReader csv = new CSVRecordReader(); @@ -256,6 +249,30 @@ public class EvalTest extends BaseDL4JTest { assertEquals(actualCounts[i], actualClassI.size()); assertEquals(predictedCounts[i], predictedClassI.size()); } + + + //Finally: test doEvaluation methods + rrdsi.reset(); + org.nd4j.evaluation.classification.Evaluation e2 = new org.nd4j.evaluation.classification.Evaluation(); + net.doEvaluation(rrdsi, e2); + for (int i = 0; i < 3; i++) { + List actualClassI = e2.getPredictionsByActualClass(i); + List predictedClassI = e2.getPredictionByPredictedClass(i); + assertEquals(actualCounts[i], actualClassI.size()); + assertEquals(predictedCounts[i], predictedClassI.size()); + } + + ComputationGraph cg = net.toComputationGraph(); + rrdsi.reset(); + e2 = new org.nd4j.evaluation.classification.Evaluation(); + cg.doEvaluation(rrdsi, e2); + for (int i = 0; i < 3; i++) { + List actualClassI = e2.getPredictionsByActualClass(i); + List predictedClassI = e2.getPredictionByPredictedClass(i); + assertEquals(actualCounts[i], actualClassI.size()); + assertEquals(predictedCounts[i], predictedClassI.size()); + } + } private static void apply(org.nd4j.evaluation.classification.Evaluation e, int nTimes, INDArray predicted, INDArray actual) { @@ -504,11 +521,11 @@ public class EvalTest extends BaseDL4JTest { list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{ds.getFeatures()}, new INDArray[]{ds.getLabels(), ds.getLabels()})); } - Evaluation e = new Evaluation(); - RegressionEvaluation e2 = new RegressionEvaluation(); - Map evals = new HashMap<>(); - evals.put(0, new IEvaluation[]{(IEvaluation) e}); - evals.put(1, new IEvaluation[]{(IEvaluation) e2}); + org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation(); + org.nd4j.evaluation.regression.RegressionEvaluation e2 = new org.nd4j.evaluation.regression.RegressionEvaluation(); + Map evals = new HashMap<>(); + evals.put(0, new org.nd4j.evaluation.IEvaluation[]{e}); + evals.put(1, new org.nd4j.evaluation.IEvaluation[]{e2}); cg.evaluate(new IteratorMultiDataSetIterator(list.iterator(), 30), evals); @@ -567,14 +584,14 @@ public class EvalTest extends BaseDL4JTest { } try { - net.evaluateROC(iter); + net.evaluateROC(iter, 0); fail("Expected exception"); } catch (IllegalStateException e){ assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC")); } try { - net.evaluateROCMultiClass(iter); + net.evaluateROCMultiClass(iter, 0); fail("Expected exception"); } catch (IllegalStateException e){ assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass")); @@ -589,14 +606,14 @@ public class EvalTest extends BaseDL4JTest { } try { - cg.evaluateROC(iter); + cg.evaluateROC(iter, 0); fail("Expected exception"); } catch (IllegalStateException e){ assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC")); } try { - cg.evaluateROCMultiClass(iter); + cg.evaluateROCMultiClass(iter, 0); fail("Expected exception"); } catch (IllegalStateException e){ assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass")); @@ -606,10 +623,10 @@ public class EvalTest extends BaseDL4JTest { //Disable validation, and check same thing: net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false); net.evaluate(iter); - net.evaluateROCMultiClass(iter); + net.evaluateROCMultiClass(iter, 0); cg.getConfiguration().setValidateOutputLayerConfig(false); cg.evaluate(iter); - cg.evaluateROCMultiClass(iter); + cg.evaluateROCMultiClass(iter, 0); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java index 7df75f6da..ba469546d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java @@ -61,7 +61,7 @@ public class RegressionEvalTest extends BaseDL4JTest { DataSet ds = new DataSet(f, l); DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds)); - RegressionEvaluation re = net.evaluateRegression(iter); + org.nd4j.evaluation.regression.RegressionEvaluation re = net.evaluateRegression(iter); for (int i = 0; i < 5; i++) { assertEquals(1.0, re.meanSquaredError(i), 1e-6); diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/CacheableExtractableDataSetFetcher.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/CacheableExtractableDataSetFetcher.java index 97574a99f..4a8d01aa4 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/CacheableExtractableDataSetFetcher.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/CacheableExtractableDataSetFetcher.java @@ -86,7 +86,7 @@ public abstract class CacheableExtractableDataSetFetcher implements CacheableDat } try { - ArchiveUtils.unzipFileTo(tmpFile.getAbsolutePath(), localCacheDir.getAbsolutePath()); + ArchiveUtils.unzipFileTo(tmpFile.getAbsolutePath(), localCacheDir.getAbsolutePath(), false); } catch (Throwable t){ //Catch any errors during extraction, and delete the directory to avoid leaving the dir in an invalid state if(localCacheDir.exists()) diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator.java index bda0f9c95..9f7813d5c 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator.java @@ -205,6 +205,7 @@ public class RecordReaderDataSetIterator implements DataSetIterator { this.numPossibleLabels = b.numPossibleLabels; this.regression = b.regression; this.preProcessor = b.preProcessor; + this.collectMetaData = b.collectMetaData; } /** diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/util/KuromojiBinFilesFetcher.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/util/KuromojiBinFilesFetcher.java index adcd87b5b..d2945cf3d 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/util/KuromojiBinFilesFetcher.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/util/KuromojiBinFilesFetcher.java @@ -67,7 +67,7 @@ public class KuromojiBinFilesFetcher { new URL("https://dl4jdata.blob.core.windows.net/kuromoji/kuromoji_bin_files.tar.gz"), tarFile); } - ArchiveUtils.unzipFileTo(tarFile.getAbsolutePath(), rootDir.getAbsolutePath()); + ArchiveUtils.unzipFileTo(tarFile.getAbsolutePath(), rootDir.getAbsolutePath(), false); return rootDir.getAbsoluteFile(); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 0a34fe95a..582b20a15 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -4170,6 +4170,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { INDArray[] featuresMasks = next.getFeaturesMaskArrays(); INDArray[] labels = next.getLabels(); INDArray[] labelMasks = next.getLabelsMaskArrays(); + List meta = next.getExampleMetaData(); try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) { INDArray[] out = outputOfLayersDetached(false, FwdPassType.STANDARD, getOutputLayerIndices(), features, featuresMasks, labelMasks, true, false, ws); @@ -4188,7 +4189,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { for (IEvaluation evaluation : evalsThisOutput) - evaluation.eval(currLabel, currOut, next.getLabelsMaskArray(i)); + evaluation.eval(currLabel, currOut, next.getLabelsMaskArray(i), meta); } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/util/ComputationGraphUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/util/ComputationGraphUtil.java index a27ce9a4c..4b9918203 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/util/ComputationGraphUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/util/ComputationGraphUtil.java @@ -23,6 +23,9 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import java.io.Serializable; +import java.util.List; + public class ComputationGraphUtil { private ComputationGraphUtil() {} @@ -33,13 +36,16 @@ public class ComputationGraphUtil { INDArray l = dataSet.getLabels(); INDArray fMask = dataSet.getFeaturesMaskArray(); INDArray lMask = dataSet.getLabelsMaskArray(); + List meta = dataSet.getExampleMetaData(); INDArray[] fNew = f == null ? null : new INDArray[] {f}; INDArray[] lNew = l == null ? null : new INDArray[] {l}; INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null); INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null); - return new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew); + org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew); + mds.setExampleMetaData(meta); + return mds; } /** Convert a DataSetIterator to a MultiDataSetIterator, via an adaptor class */ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index bce86b9ce..5cc536810 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -25,14 +25,11 @@ import lombok.val; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.bytedeco.javacpp.Pointer; -import org.nd4j.adapters.OutputAdapter; -import org.nd4j.linalg.dataset.AsyncDataSetIterator;; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; -import org.deeplearning4j.eval.RegressionEvaluation; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JInvalidInputException; -import org.deeplearning4j.nn.api.*; import org.deeplearning4j.nn.api.Updater; +import org.deeplearning4j.nn.api.*; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.RecurrentLayer; import org.deeplearning4j.nn.conf.*; @@ -44,8 +41,8 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.FrozenLayer; import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop; -import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer; import org.deeplearning4j.nn.layers.LayerHelper; +import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer; import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.updater.UpdaterCreator; import org.deeplearning4j.nn.workspace.ArrayType; @@ -58,19 +55,23 @@ import org.deeplearning4j.util.CrashReportingUtil; import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.NetworkUtils; import org.deeplearning4j.util.OutputLayerUtil; +import org.nd4j.adapters.OutputAdapter; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROCMultiClass; +import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.AllocationPolicy; import org.nd4j.linalg.api.memory.enums.LearningPolicy; import org.nd4j.linalg.api.memory.enums.ResetPolicy; import org.nd4j.linalg.api.memory.enums.SpillPolicy; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.AsyncDataSetIterator; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -84,7 +85,6 @@ import org.nd4j.linalg.heartbeat.reports.Task; import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils; import org.nd4j.linalg.heartbeat.utils.TaskUtils; import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Triple; import org.nd4j.linalg.schedule.ISchedule; @@ -96,6 +96,8 @@ import org.nd4j.util.OneTimeLogger; import java.io.*; import java.util.*; +; + /** * MultiLayerNetwork is a neural network with multiple layers in a stack, and usually an output layer.
@@ -3315,19 +3317,39 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura * @param iterator Iterator to evaluate on * @return Evaluation object; results of evaluation on all examples in the data set */ - public T evaluate(DataSetIterator iterator) { + public T evaluate(@NonNull DataSetIterator iterator) { return (T)evaluate(iterator, null); } + /** + * Evaluate the network (classification performance). + * Can only be used with MultiDataSetIterator instances with a single input/output array + * + * @param iterator Iterator to evaluate on + * @return Evaluation object; results of evaluation on all examples in the data set + */ + public Evaluation evaluate(@NonNull MultiDataSetIterator iterator) { + return evaluate(new MultiDataSetWrapperIterator(iterator)); + } + /** * Evaluate the network for regression performance * @param iterator Data to evaluate on - * @return + * @return Regression evaluation */ public T evaluateRegression(DataSetIterator iterator) { return (T)doEvaluation(iterator, new RegressionEvaluation(iterator.totalOutcomes()))[0]; } + /** + * Evaluate the network for regression performance + * Can only be used with MultiDataSetIterator instances with a single input/output array + * @param iterator Data to evaluate on + */ + public org.nd4j.evaluation.regression.RegressionEvaluation evaluateRegression(MultiDataSetIterator iterator) { + return evaluateRegression(new MultiDataSetWrapperIterator(iterator)); + } + /** * @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration */ @@ -3424,6 +3446,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura INDArray labels = next.getLabels(); INDArray fMask = next.getFeaturesMaskArray(); INDArray lMask = next.getLabelsMaskArray(); + List meta = next.getExampleMetaData(); if (!useRnnSegments) { @@ -3433,7 +3456,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { for (T evaluation : evaluations) - evaluation.eval(labels, out, lMask); + evaluation.eval(labels, out, lMask, meta); } } } else { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/encoding/threshold/AdaptiveThresholdAlgorithm.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/encoding/threshold/AdaptiveThresholdAlgorithm.java index 299732287..7dd56815c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/encoding/threshold/AdaptiveThresholdAlgorithm.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/encoding/threshold/AdaptiveThresholdAlgorithm.java @@ -222,8 +222,11 @@ public class AdaptiveThresholdAlgorithm implements ThresholdAlgorithm { if(a == null || Double.isNaN(a.lastThreshold)) return; + lastThresholdSum += a.lastThreshold; - lastSparsitySum += a.lastSparsity; + if (!Double.isNaN(a.lastSparsity)) { + lastSparsitySum += a.lastSparsity; + } count++; } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml index daf0dd9b7..1198ae733 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml @@ -38,16 +38,22 @@ nd4j-aeron ${nd4j.version} - - org.nd4j - nd4j-parameter-server-node_2.11 - ${nd4j.version} - org.deeplearning4j dl4j-spark_2.11 ${project.version} + + org.nd4j + nd4j-parameter-server-node_2.11 + ${nd4j.version} + + + net.jpountz.lz4 + lz4 + + + org.projectlombok lombok diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualMultiDataSetIterator.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualMultiDataSetIterator.java index 1de2d8636..a3c3b43a8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualMultiDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualMultiDataSetIterator.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.dataset.api.iterator.ParallelMultiDataSetIterator; import java.util.Iterator; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; /** * This MultiDataSetIterator implementation does accumulation of MultiDataSets from different Spark executors, wrt Thread/Device Affinity @@ -32,14 +33,16 @@ import java.util.List; public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator { protected final List> iterators; + protected final AtomicInteger position; public VirtualMultiDataSetIterator(@NonNull List> iterators) { this.iterators = iterators; + this.position = new AtomicInteger(0); } @Override public MultiDataSet next(int num) { - return null; + return next(); } @Override @@ -59,27 +62,34 @@ public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator @Override public boolean asyncSupported() { - return false; + return true; } @Override public void reset() { - + throw new UnsupportedOperationException(); } @Override public boolean hasNext() { - return false; + // just checking if that's not the last iterator, or if that's the last one - check if it has something + boolean ret = position.get() < iterators.size() - 1 + || (position.get() < iterators.size() && iterators.get(position.get()).hasNext()); + return ret; } @Override public MultiDataSet next() { - return null; + // TODO: this solution isn't ideal, it assumes non-empty iterators all the time. Would be nice to do something here + if (!iterators.get(position.get()).hasNext()) + position.getAndIncrement(); + + return iterators.get(position.get()).next(); } @Override public void remove() { - + // no-op } @Override diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java index 81fce7fbf..b6a8bb81c 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java @@ -109,6 +109,7 @@ public class SharedTrainingWrapper { // now we're creating DataSetIterators, to feed ParallelWrapper iteratorDS = new VirtualDataSetIterator(iteratorsDS); + iteratorMDS = new VirtualMultiDataSetIterator(iteratorsMDS); } public static synchronized SharedTrainingWrapper getInstance(long id) { @@ -447,17 +448,19 @@ public class SharedTrainingWrapper { throw new DL4JInvalidConfigException("No iterators were defined for training"); try { - while((iteratorDS != null && iteratorDS.hasNext()) || (iteratorMDS != null && iteratorMDS.hasNext())) { + boolean dsNext; + boolean mdsNext; + while((dsNext = iteratorDS != null && iteratorDS.hasNext()) || (mdsNext = iteratorMDS != null && iteratorMDS.hasNext())) { //Loop as a guard against concurrent modifications and RCs if (wrapper != null) { - if (iteratorDS != null) + if (dsNext) wrapper.fit(iteratorDS); else wrapper.fit(iteratorMDS); } else { // if wrapper is null, we're fitting standalone model then - if (iteratorDS != null) { + if (dsNext) { if (model instanceof ComputationGraph) { ((ComputationGraph) originalModel).fit(iteratorDS); } else if (model instanceof MultiLayerNetwork) { @@ -472,7 +475,8 @@ public class SharedTrainingWrapper { } } - consumer.getUpdatesQueue().purge(); + if(consumer != null) + consumer.getUpdatesQueue().purge(); } } catch (Throwable t){ log.warn("Exception encountered during fit operation", t); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java index c97292a2c..50aa564c1 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java @@ -116,8 +116,7 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable } protected int numExecutors() { - int numProc = Runtime.getRuntime().availableProcessors(); - return Math.min(4, numProc); + return 4; } protected MultiLayerConfiguration getBasicConf() { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java index 53a4b32b1..ab034604e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java @@ -49,6 +49,7 @@ import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.AMSGrad; @@ -66,137 +67,170 @@ import java.util.concurrent.ConcurrentHashMap; import static org.junit.Assert.*; @Slf4j -@Ignore("AB 2019/05/21 - Failing - Issue #7657") +//@Ignore("AB 2019/05/21 - Failing - Issue #7657") public class GradientSharingTrainingTest extends BaseSparkTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + @Test public void trainSanityCheck() throws Exception { - INDArray last = null; - INDArray lastDup = null; - for (String s : new String[]{"paths", "direct", "export"}) { - System.out.println("--------------------------------------------------------------------------------------------------------------"); - log.info("Starting: {}", s); - boolean isPaths = "paths".equals(s); + for(boolean mds : new boolean[]{false, true}) { + INDArray last = null; + INDArray lastDup = null; + for (String s : new String[]{"paths", "direct", "export"}) { + System.out.println("--------------------------------------------------------------------------------------------------------------"); + log.info("Starting: {} - {}", s, (mds ? "MultiDataSet" : "DataSet")); + boolean isPaths = "paths".equals(s); - RDDTrainingApproach rddTrainingApproach; - switch (s) { - case "direct": - rddTrainingApproach = RDDTrainingApproach.Direct; - break; - case "export": - rddTrainingApproach = RDDTrainingApproach.Export; - break; - case "paths": - rddTrainingApproach = RDDTrainingApproach.Direct; //Actualy not used for fitPaths - break; - default: - throw new RuntimeException(); - } - - File temp = testDir.newFolder(); - - - //TODO this probably won't work everywhere... - String controller = Inet4Address.getLocalHost().getHostAddress(); - String networkMask = controller.substring(0, controller.lastIndexOf('.')) + ".0" + "/16"; - - VoidConfiguration voidConfiguration = VoidConfiguration.builder() - .unicastPort(40123) // Should be open for IN/OUT communications on all Spark nodes - .networkMask(networkMask) // Local network mask - .controllerAddress(controller) - .meshBuildMode(MeshBuildMode.PLAIN) // everyone is connected to the master - .build(); - TrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, 2, new AdaptiveThresholdAlgorithm(1e-3), 16) - .rngSeed(12345) - .collectTrainingStats(false) - .batchSizePerWorker(16) // Minibatch size for each worker - .workersPerNode(2) // Workers per node - .rddTrainingApproach(rddTrainingApproach) - .exportDirectory("file:///" + temp.getAbsolutePath().replaceAll("\\\\", "/")) - .build(); - - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(12345) - .updater(new AMSGrad(0.1)) - .graphBuilder() - .addInputs("in") - .layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") - .setOutputs("out") - .build(); - - - SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm); - sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats()); - - System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); - File f = testDir.newFolder(); - DataSetIterator iter = new MnistDataSetIterator(16, true, 12345); - int count = 0; - List paths = new ArrayList<>(); - List ds = new ArrayList<>(); - while (iter.hasNext() && count++ < 8) { - DataSet d = iter.next(); - if (isPaths) { - File out = new File(f, count + ".bin"); - d.save(out); - String path = "file:///" + out.getAbsolutePath().replaceAll("\\\\", "/"); - paths.add(path); - } - ds.add(d); - } - - int numIter = 1; - double[] acc = new double[numIter + 1]; - for (int i = 0; i < numIter; i++) { - //Check accuracy before: - DataSetIterator testIter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10); - Evaluation eBefore = sparkNet.getNetwork().evaluate(testIter); - - INDArray paramsBefore = sparkNet.getNetwork().params().dup(); - ComputationGraph after; + RDDTrainingApproach rddTrainingApproach; switch (s) { case "direct": + rddTrainingApproach = RDDTrainingApproach.Direct; + break; case "export": - JavaRDD dsRDD = sc.parallelize(ds); - after = sparkNet.fit(dsRDD); + rddTrainingApproach = RDDTrainingApproach.Export; break; case "paths": - JavaRDD pathRdd = sc.parallelize(paths); - after = sparkNet.fitPaths(pathRdd); + rddTrainingApproach = RDDTrainingApproach.Direct; //Actualy not used for fitPaths break; default: throw new RuntimeException(); } - INDArray paramsAfter = after.params(); - System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); - System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); - System.out.println(Arrays.toString( - Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); - assertNotEquals(paramsBefore, paramsAfter); + File temp = testDir.newFolder(); - testIter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10); - Evaluation eAfter = after.evaluate(testIter); + //TODO this probably won't work everywhere... + String controller = Inet4Address.getLocalHost().getHostAddress(); + String networkMask = controller.substring(0, controller.lastIndexOf('.')) + ".0" + "/16"; - double accAfter = eAfter.accuracy(); - double accBefore = eBefore.accuracy(); - assertTrue("after: " + accAfter + ", before=" + accBefore, accAfter >= accBefore + 0.005); + VoidConfiguration voidConfiguration = VoidConfiguration.builder() + .unicastPort(40123) // Should be open for IN/OUT communications on all Spark nodes + .networkMask(networkMask) // Local network mask + .controllerAddress(controller) + .meshBuildMode(MeshBuildMode.PLAIN) // everyone is connected to the master + .build(); + TrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, 2, new AdaptiveThresholdAlgorithm(1e-3), 16) + .rngSeed(12345) + .collectTrainingStats(false) + .batchSizePerWorker(16) // Minibatch size for each worker + .workersPerNode(2) // Workers per node + .rddTrainingApproach(rddTrainingApproach) + .exportDirectory("file:///" + temp.getAbsolutePath().replaceAll("\\\\", "/")) + .build(); - if (i == 0) { - acc[0] = eBefore.accuracy(); + + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .seed(12345) + .updater(new AMSGrad(0.1)) + .graphBuilder() + .addInputs("in") + .layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX) + .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") + .setOutputs("out") + .build(); + + + SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm); + sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats()); + + System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); + File f = testDir.newFolder(); + DataSetIterator iter = new MnistDataSetIterator(16, true, 12345); + int count = 0; + List paths = new ArrayList<>(); + List ds = new ArrayList<>(); + while (iter.hasNext() && count++ < 8) { + DataSet d = iter.next(); + if (isPaths) { + File out = new File(f, count + ".bin"); + if(mds){ + d.toMultiDataSet().save(out); + } else { + d.save(out); + } + String path = "file:///" + out.getAbsolutePath().replaceAll("\\\\", "/"); + paths.add(path); + } + ds.add(d); } - acc[i + 1] = eAfter.accuracy(); + + int numIter = 1; + double[] acc = new double[numIter + 1]; + for (int i = 0; i < numIter; i++) { + //Check accuracy before: + DataSetIterator testIter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10); + Evaluation eBefore = sparkNet.getNetwork().evaluate(testIter); + + INDArray paramsBefore = sparkNet.getNetwork().params().dup(); + ComputationGraph after; + if(mds) { + //Fitting from MultiDataSet + List mdsList = new ArrayList<>(); + for(DataSet d : ds){ + mdsList.add(d.toMultiDataSet()); + } + switch (s) { + case "direct": + case "export": + JavaRDD dsRDD = sc.parallelize(mdsList); + after = sparkNet.fitMultiDataSet(dsRDD); + break; + case "paths": + JavaRDD pathRdd = sc.parallelize(paths); + after = sparkNet.fitPathsMultiDataSet(pathRdd); + break; + default: + throw new RuntimeException(); + } + } else { + //Fitting from DataSet + switch (s) { + case "direct": + case "export": + JavaRDD dsRDD = sc.parallelize(ds); + after = sparkNet.fit(dsRDD); + break; + case "paths": + JavaRDD pathRdd = sc.parallelize(paths); + after = sparkNet.fitPaths(pathRdd); + break; + default: + throw new RuntimeException(); + } + } + + INDArray paramsAfter = after.params(); + System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); + System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); + System.out.println(Arrays.toString( + Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); + assertNotEquals(paramsBefore, paramsAfter); + + + testIter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10); + Evaluation eAfter = after.evaluate(testIter); + + double accAfter = eAfter.accuracy(); + double accBefore = eBefore.accuracy(); + assertTrue("after: " + accAfter + ", before=" + accBefore, accAfter >= accBefore + 0.005); + + if (i == 0) { + acc[0] = eBefore.accuracy(); + } + acc[i + 1] = eAfter.accuracy(); + } + log.info("Accuracies: {}", Arrays.toString(acc)); + last = sparkNet.getNetwork().params(); + lastDup = last.dup(); } - log.info("Accuracies: {}", Arrays.toString(acc)); - last = sparkNet.getNetwork().params(); - lastDup = last.dup(); } } @@ -289,7 +323,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { } - @Test + @Test @Ignore public void testEpochUpdating() throws Exception { //Ensure that epoch counter is incremented properly on the workers @@ -316,7 +350,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() .seed(12345) - .updater(new AMSGrad(0.1)) + .updater(new AMSGrad(0.001)) .graphBuilder() .addInputs("in") .layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX) diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties index 5d1edb39f..4bee14770 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties @@ -20,12 +20,12 @@ log4j.appender.Console.layout=org.apache.log4j.PatternLayout log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n log4j.appender.org.springframework=DEBUG -log4j.appender.org.deeplearning4j=DEBUG -log4j.appender.org.nd4j=DEBUG +log4j.appender.org.deeplearning4j=INFO +log4j.appender.org.nd4j=INFO log4j.logger.org.springframework=INFO -log4j.logger.org.deeplearning4j=DEBUG -log4j.logger.org.nd4j=DEBUG +log4j.logger.org.deeplearning4j=INFO +log4j.logger.org.nd4j=INFO log4j.logger.org.apache.spark=WARN diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml index 4d94f2516..9605642db 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml @@ -35,7 +35,7 @@ - + diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml index 4f2436e28..8f83b803e 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml @@ -25,10 +25,6 @@ deeplearning4j-ui-components - - 2.3.23 - - org.projectlombok diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestStandAlone.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestStandAlone.java index aaca2eb26..7ba9f9c36 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestStandAlone.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestStandAlone.java @@ -24,6 +24,7 @@ import org.deeplearning4j.ui.components.chart.style.StyleChart; import org.deeplearning4j.ui.components.table.ComponentTable; import org.deeplearning4j.ui.components.table.style.StyleTable; import org.deeplearning4j.ui.standalone.StaticPageUtil; +import org.junit.Ignore; import org.junit.Test; import java.awt.*; diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml index 4405d15f7..a66b85ece 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml @@ -60,7 +60,7 @@ org.freemarker freemarker - 2.3.29 + ${freemarker.version} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index 5648de738..00e2c9422 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -200,6 +200,7 @@ public class TrainModule implements UIModule { })); r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc))); } else { + r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview"))); r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID))); r.add(new Route("/train/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc))); r.add(new Route("/train/overview", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingOverview.html.ftl", rc))); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 6d9e34ed0..65416a659 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -1654,29 +1654,6 @@ public class SDVariable implements Serializable { return x; } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof SDVariable)) { - return false; - } - - SDVariable that = (SDVariable) o; - - if (!Objects.equals(varName, that.varName)) { - return false; - } - if (variableType != that.variableType) { - return false; - } - if(sameDiff != that.sameDiff){ - return false; - } - return dataType == that.dataType; - } - @Override public int hashCode() { int result = super.hashCode(); @@ -1695,4 +1672,26 @@ public class SDVariable implements Serializable { v.sameDiff = sd; return v; } + + @Override + public boolean equals(Object o){ + if(o == this) return true; + if(!(o instanceof SDVariable)) + return false; + + SDVariable s = (SDVariable)o; + if(!varName.equals(s.varName)) + return false; + if(variableType != s.variableType) + return false; + if(dataType != s.dataType) + return false; + + if(variableType == VariableType.VARIABLE || variableType == VariableType.CONSTANT){ + INDArray a1 = getArr(); + INDArray a2 = s.getArr(); + return a1.equals(a2); + } + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 7ca809b2d..3411e2007 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -1234,13 +1234,14 @@ public class SameDiff extends SDBaseOps { @Override public boolean equals(Object o) { if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (o == null || getClass() != o.getClass()) + return false; SameDiff sameDiff = (SameDiff) o; - if (variables != null ? !variables.equals(sameDiff.variables) : sameDiff.variables != null) - return false; - return sameDiffFunctionInstances != null ? sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null; + boolean eqVars = variables.equals(sameDiff.variables); + boolean eqOps = ops.equals(sameDiff.ops); + return eqVars && eqOps; } /** @@ -5843,4 +5844,10 @@ public class SameDiff extends SDBaseOps { return base + "_" + inc; } + + + @Override + public String toString(){ + return "SameDiff(nVars=" + variables.size() + ",nOps=" + ops.size() + ")"; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java index e8041955b..4e7c88a4b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java @@ -16,10 +16,7 @@ package org.nd4j.autodiff.samediff.internal; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; +import lombok.*; import org.nd4j.autodiff.samediff.SDVariable; import java.util.List; @@ -28,6 +25,7 @@ import java.util.List; @NoArgsConstructor @Data //TODO immutable? @Builder +@EqualsAndHashCode(exclude = {"gradient", "variableIndex"}) public class Variable { protected String name; protected SDVariable variable; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java index 0d2f1fb62..fea9b7308 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java @@ -173,9 +173,6 @@ public class EvaluationBinary extends BaseEvaluation { @Override public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List recordMetaData) { - if(recordMetaData != null){ - throw new UnsupportedOperationException("Evaluation with record metadata not yet implemented for EvaluationBinary"); - } eval(labels, networkPredictions, maskArray); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java index 0d137d0e9..1a0348324 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java @@ -325,7 +325,7 @@ public class EvaluationCalibration extends BaseEvaluation @Override public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List recordMetaData) { - throw new UnsupportedOperationException("Not yet implemented"); + eval(labels, networkPredictions, maskArray); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java index cc206f0df..b5fac0dd4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java @@ -229,7 +229,7 @@ public class RegressionEvaluation extends BaseEvaluation { @Override public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List recordMetaData) { - throw new UnsupportedOperationException("Not yet implemented"); + eval(labels, networkPredictions, maskArray); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 409ac422a..73780538a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -3556,4 +3556,52 @@ public class SameDiffTests extends BaseNd4jTest { assertTrue(msg, msg.contains("\"labels\"") && msg.contains("No array was provided")); } } + + + @Test + public void testEquals1(){ + + SameDiff sd1 = SameDiff.create(); + SameDiff sd2 = SameDiff.create(); + + assertEquals(sd1, sd2); + + SDVariable p1 = sd1.placeHolder("ph", DataType.FLOAT, -1, 10); + SDVariable p2 = sd2.placeHolder("ph", DataType.FLOAT, -1, 10); + + assertEquals(sd1, sd2); + + SDVariable w1 = sd1.constant("c1",1.0f); + SDVariable w2 = sd2.constant("c1",1.0f); + + assertEquals(sd1, sd2); + + SDVariable a1 = p1.add("add", w1); + SDVariable a2 = p2.add("add", w2); + + assertEquals(sd1, sd2); + + SDVariable w1a = sd1.constant("c2", 2.0f); + SDVariable w2a = sd2.constant("cX", 2.0f); + + assertNotEquals(sd1, sd2); + w2a.rename("c2"); + + assertEquals(sd1, sd2); + + sd2.createGradFunction("ph"); + + assertEquals(sd1, sd2); + + w2a.getArr().assign(3.0f); + + assertNotEquals(sd1, sd2); + + w1a.getArr().assign(3.0f); + assertEquals(sd1, sd2); + + SDVariable s1 = p1.sub("op", w1); + SDVariable s2 = p2.add("op", w1); + assertNotEquals(sd1, sd2); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java index 454739496..03b469e70 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java @@ -61,7 +61,7 @@ public class OpsMappingTests extends BaseNd4jTest { @Override public long getTimeoutMilliseconds() { - return 90000L; + return 180000L; //Can be slow on some CI machines such as PPC } @Test diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java index ecaf3ea7f..05c44c29e 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java @@ -95,7 +95,7 @@ public class Downloader { } // try extracting try{ - ArchiveUtils.unzipFileTo(f.getAbsolutePath(), extractToDir.getAbsolutePath()); + ArchiveUtils.unzipFileTo(f.getAbsolutePath(), extractToDir.getAbsolutePath(), false); } catch (Throwable t){ log.warn("Error extracting {} files from file {} - retrying...", name, f.getAbsolutePath(), t); f.delete(); diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java b/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java index f0c6ef318..d51d9ca9b 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java @@ -51,6 +51,10 @@ public class ArchiveUtils { * @throws IOException */ public static void unzipFileTo(String file, String dest) throws IOException { + unzipFileTo(file, dest, true); + } + + public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException { File target = new File(file); if (!target.exists()) throw new IllegalArgumentException("Archive doesnt exist"); @@ -93,7 +97,9 @@ public class ArchiveUtils { fos.close(); ze = zis.getNextEntry(); - log.debug("File extracted: " + newFile.getAbsoluteFile()); + if(logFiles) { + log.info("File extracted: " + newFile.getAbsoluteFile()); + } } zis.closeEntry(); @@ -112,7 +118,9 @@ public class ArchiveUtils { TarArchiveEntry entry; /* Read the tar entries using the getNextEntry method **/ while ((entry = (TarArchiveEntry) tarIn.getNextEntry()) != null) { - log.info("Extracting: " + entry.getName()); + if(logFiles) { + log.info("Extracting: " + entry.getName()); + } /* If the entry is a directory, create the directory. */ if (entry.isDirectory()) {