From f5f77df8467c65da3985c1be079a2c5bd16ec2c7 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 23 Apr 2020 15:38:42 +1000 Subject: [PATCH] Dependency version upgrades + small SameDiff fix (#405) * #8861 Training evaluation on variables not required for loss fix Signed-off-by: Alex Black * Dependency version updates flagged by dependabot Signed-off-by: Alex Black --- .../samediff/internal/AbstractSession.java | 2 +- .../samediff/internal/TrainingSession.java | 8 ++++- .../samediff/SameDiffTrainingTest.java | 35 +++++++++++++++++++ nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml | 2 +- nd4j/nd4j-remote/nd4j-json-server/pom.xml | 2 +- pom.xml | 10 +++--- 6 files changed, 50 insertions(+), 9 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index d89fe05a5..a6273e0b3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -132,7 +132,7 @@ public abstract class AbstractSession { Preconditions.checkState(!variables.isEmpty() || !requiredActivations.isEmpty(), "Variables to perform forward pass for must not be empty"); if (requiredActivations == null) - requiredActivations = Collections.emptyList(); + requiredActivations = Collections.emptySet(); if (at == null) at = At.defaultAt(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java index e683acc47..d0d0b14cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java @@ -75,7 +75,7 @@ public class TrainingSession extends InferenceSession { this.listeners = filtered.isEmpty() ? null : filtered; } - List requiredActivations = new ArrayList<>(); + Set requiredActivations = new HashSet<>(); gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for for (String s : paramsToTrain) { Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s); @@ -95,6 +95,12 @@ public class TrainingSession extends InferenceSession { gradVarToVarMap.put(grad.name(), s); } + //Also add evaluations - in case we want to evaluate something that isn't required to determine loss + // (hence wouldn't normally be calculated) + if(config.getTrainEvaluations() != null){ + requiredActivations.addAll(config.getTrainEvaluations().keySet()); + } + //Set up losses lossVarsToLossIdx = new LinkedHashMap<>(); List lossVars; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java index 3ab940937..75dc77cbc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.samediff; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.util.Collections; @@ -35,6 +36,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator; import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; @@ -341,6 +343,39 @@ public class SameDiffTrainingTest extends BaseNd4jTest { History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1); } + @Test + public void testTrainingEvalVarNotReqForLoss(){ + //If a variable is not required for the loss - normally it won't be calculated + //But we want to make sure it IS calculated here - so we can perform evaluation on it + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable z = in.mmul(w); + SDVariable out = sd.nn.softmax("softmax", z); + SDVariable loss = sd.loss.logLoss("loss", label, out); + SDVariable notRequiredForLoss = sd.nn.softmax("notRequiredForLoss", z); + + sd.setTrainingConfig(TrainingConfig.builder() + .updater(new Adam(0.001)) + .dataSetFeatureMapping("in") + .dataSetLabelMapping("label") + .trainEvaluation("notRequiredForLoss", 0, new Evaluation()) + .build()); + +// sd.setListeners(new ScoreListener(1)); + + DataSet ds = new DataSet(Nd4j.rand(DataType.FLOAT, 3, 4), Nd4j.createFromArray(new float[][]{{1,0,0}, {0,1,0}, {0,0,1}})); + + History h = sd.fit() + .train(new SingletonDataSetIterator(ds), 4) + .exec(); + + List l = h.trainingEval(Evaluation.Metric.ACCURACY); + assertEquals(4, l.size()); + } + @Override public char ordering() { diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml b/nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml index fdaf7cc89..618b44c9e 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml @@ -40,7 +40,7 @@ com.mchange c3p0 - 0.9.5-pre5 + 0.9.5.4 org.nd4j diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/nd4j/nd4j-remote/nd4j-json-server/pom.xml index 5537216ca..004145101 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/pom.xml +++ b/nd4j/nd4j-remote/nd4j-json-server/pom.xml @@ -121,7 +121,7 @@ javax.activation activation - 1.1 + 1.1.1 diff --git a/pom.xml b/pom.xml index 15af4658d..b34148656 100644 --- a/pom.xml +++ b/pom.xml @@ -271,7 +271,7 @@ 0.5.0 2.3.23 2.8.1 - 2.7.0 + 2.9.8 2.3 3.2 3.1 @@ -325,14 +325,14 @@ 3.2.2 4.1 - 2.4.3 + 2.4.5 2 2.0.29 1.7.21 4.12 1.2.3 2.10.1 - 2.10.1 + 2.10.3 1.24 2.8.7 1.18.12 @@ -353,7 +353,7 @@ 2.8.0 1.2.0-3f79e055 4.10.0 - 3.8.3 + 3.9.0 1.10.0 1.14.0 @@ -368,7 +368,7 @@ 3.7.0 3.3.1 3.0.1 - 1.0.0-beta8 + 1.0.0 2.2.2 ${maven-git-commit-plugin.version}