Dependency version upgrades + small SameDiff fix (#405)

* #8861 Training evaluation on variables not required for loss fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Dependency version updates flagged by dependabot

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-04-23 15:38:42 +10:00 committed by GitHub
parent 8f765c80ff
commit f5f77df846
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 9 deletions

View File

@ -132,7 +132,7 @@ public abstract class AbstractSession<T, O> {
Preconditions.checkState(!variables.isEmpty() || !requiredActivations.isEmpty(), "Variables to perform forward pass for must not be empty"); Preconditions.checkState(!variables.isEmpty() || !requiredActivations.isEmpty(), "Variables to perform forward pass for must not be empty");
if (requiredActivations == null) if (requiredActivations == null)
requiredActivations = Collections.emptyList(); requiredActivations = Collections.emptySet();
if (at == null) if (at == null)
at = At.defaultAt(); at = At.defaultAt();

View File

@ -75,7 +75,7 @@ public class TrainingSession extends InferenceSession {
this.listeners = filtered.isEmpty() ? null : filtered; this.listeners = filtered.isEmpty() ? null : filtered;
} }
List<String> requiredActivations = new ArrayList<>(); Set<String> requiredActivations = new HashSet<>();
gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for
for (String s : paramsToTrain) { for (String s : paramsToTrain) {
Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s); Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s);
@ -95,6 +95,12 @@ public class TrainingSession extends InferenceSession {
gradVarToVarMap.put(grad.name(), s); gradVarToVarMap.put(grad.name(), s);
} }
//Also add evaluations - in case we want to evaluate something that isn't required to determine loss
// (hence wouldn't normally be calculated)
if(config.getTrainEvaluations() != null){
requiredActivations.addAll(config.getTrainEvaluations().keySet());
}
//Set up losses //Set up losses
lossVarsToLossIdx = new LinkedHashMap<>(); lossVarsToLossIdx = new LinkedHashMap<>();
List<String> lossVars; List<String> lossVars;

View File

@ -16,6 +16,7 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import java.util.Collections; import java.util.Collections;
@ -35,6 +36,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.dataset.IrisDataSetIterator;
import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator; import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
@ -341,6 +343,39 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1); History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1);
} }
@Test
public void testTrainingEvalVarNotReqForLoss(){
//If a variable is not required for the loss - normally it won't be calculated
//But we want to make sure it IS calculated here - so we can perform evaluation on it
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
SDVariable z = in.mmul(w);
SDVariable out = sd.nn.softmax("softmax", z);
SDVariable loss = sd.loss.logLoss("loss", label, out);
SDVariable notRequiredForLoss = sd.nn.softmax("notRequiredForLoss", z);
sd.setTrainingConfig(TrainingConfig.builder()
.updater(new Adam(0.001))
.dataSetFeatureMapping("in")
.dataSetLabelMapping("label")
.trainEvaluation("notRequiredForLoss", 0, new Evaluation())
.build());
// sd.setListeners(new ScoreListener(1));
DataSet ds = new DataSet(Nd4j.rand(DataType.FLOAT, 3, 4), Nd4j.createFromArray(new float[][]{{1,0,0}, {0,1,0}, {0,0,1}}));
History h = sd.fit()
.train(new SingletonDataSetIterator(ds), 4)
.exec();
List<Double> l = h.trainingEval(Evaluation.Metric.ACCURACY);
assertEquals(4, l.size());
}
@Override @Override
public char ordering() { public char ordering() {

View File

@ -40,7 +40,7 @@
<dependency> <dependency>
<groupId>com.mchange</groupId> <groupId>com.mchange</groupId>
<artifactId>c3p0</artifactId> <artifactId>c3p0</artifactId>
<version>0.9.5-pre5</version> <version>0.9.5.4</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>

View File

@ -121,7 +121,7 @@
<dependency> <dependency>
<groupId>javax.activation</groupId> <groupId>javax.activation</groupId>
<artifactId>activation</artifactId> <artifactId>activation</artifactId>
<version>1.1</version> <version>1.1.1</version>
</dependency> </dependency>
<dependency> <dependency>

10
pom.xml
View File

@ -271,7 +271,7 @@
<jodah.typetools.version>0.5.0</jodah.typetools.version> <jodah.typetools.version>0.5.0</jodah.typetools.version>
<freemarker.version>2.3.23</freemarker.version> <freemarker.version>2.3.23</freemarker.version>
<geoip2.version>2.8.1</geoip2.version> <geoip2.version>2.8.1</geoip2.version>
<stream.analytics.version>2.7.0</stream.analytics.version> <stream.analytics.version>2.9.8</stream.analytics.version>
<opencsv.version>2.3</opencsv.version> <opencsv.version>2.3</opencsv.version>
<tdigest.version>3.2</tdigest.version> <tdigest.version>3.2</tdigest.version>
<jtransforms.version>3.1</jtransforms.version> <jtransforms.version>3.1</jtransforms.version>
@ -325,14 +325,14 @@
<commons-collections.version>3.2.2</commons-collections.version> <commons-collections.version>3.2.2</commons-collections.version>
<commons-collections4.version>4.1</commons-collections4.version> <commons-collections4.version>4.1</commons-collections4.version>
<spark.version>2.4.3</spark.version> <spark.version>2.4.5</spark.version>
<spark.major.version>2</spark.major.version> <spark.major.version>2</spark.major.version>
<args4j.version>2.0.29</args4j.version> <args4j.version>2.0.29</args4j.version>
<slf4j.version>1.7.21</slf4j.version> <slf4j.version>1.7.21</slf4j.version>
<junit.version>4.12</junit.version> <junit.version>4.12</junit.version>
<logback.version>1.2.3</logback.version> <logback.version>1.2.3</logback.version>
<jackson.version>2.10.1</jackson.version> <jackson.version>2.10.1</jackson.version>
<jackson.databind.version>2.10.1</jackson.databind.version> <jackson.databind.version>2.10.3</jackson.databind.version>
<shaded.snakeyaml.version>1.24</shaded.snakeyaml.version> <shaded.snakeyaml.version>1.24</shaded.snakeyaml.version>
<geo.jackson.version>2.8.7</geo.jackson.version> <geo.jackson.version>2.8.7</geo.jackson.version>
<lombok.version>1.18.12</lombok.version> <lombok.version>1.18.12</lombok.version>
@ -353,7 +353,7 @@
<gson.version>2.8.0</gson.version> <gson.version>2.8.0</gson.version>
<fbs.version>1.2.0-3f79e055</fbs.version> <fbs.version>1.2.0-3f79e055</fbs.version>
<threadly.version>4.10.0</threadly.version> <threadly.version>4.10.0</threadly.version>
<vertx.version>3.8.3</vertx.version> <vertx.version>3.9.0</vertx.version>
<flatbuffers.version>1.10.0</flatbuffers.version> <flatbuffers.version>1.10.0</flatbuffers.version>
<grpc.version>1.14.0</grpc.version> <grpc.version>1.14.0</grpc.version>
@ -368,7 +368,7 @@
<maven-compiler-plugin.version>3.7.0</maven-compiler-plugin.version> <maven-compiler-plugin.version>3.7.0</maven-compiler-plugin.version>
<maven-scala-plugin.version>3.3.1</maven-scala-plugin.version> <maven-scala-plugin.version>3.3.1</maven-scala-plugin.version>
<maven-resources-plugin.version>3.0.1</maven-resources-plugin.version> <maven-resources-plugin.version>3.0.1</maven-resources-plugin.version>
<sbt-compiler-maven-plugin.version>1.0.0-beta8</sbt-compiler-maven-plugin.version> <sbt-compiler-maven-plugin.version>1.0.0</sbt-compiler-maven-plugin.version>
<maven-git-commit-plugin.version>2.2.2</maven-git-commit-plugin.version> <maven-git-commit-plugin.version>2.2.2</maven-git-commit-plugin.version>
<maven-git-commit-id-plugin.version>${maven-git-commit-plugin.version} <maven-git-commit-id-plugin.version>${maven-git-commit-plugin.version}
</maven-git-commit-id-plugin.version> </maven-git-commit-id-plugin.version>