Dependency version upgrades + small SameDiff fix (#405)
* #8861 Training evaluation on variables not required for loss fix Signed-off-by: Alex Black <blacka101@gmail.com> * Dependency version updates flagged by dependabot Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
8f765c80ff
commit
f5f77df846
|
@ -132,7 +132,7 @@ public abstract class AbstractSession<T, O> {
|
||||||
Preconditions.checkState(!variables.isEmpty() || !requiredActivations.isEmpty(), "Variables to perform forward pass for must not be empty");
|
Preconditions.checkState(!variables.isEmpty() || !requiredActivations.isEmpty(), "Variables to perform forward pass for must not be empty");
|
||||||
|
|
||||||
if (requiredActivations == null)
|
if (requiredActivations == null)
|
||||||
requiredActivations = Collections.emptyList();
|
requiredActivations = Collections.emptySet();
|
||||||
|
|
||||||
if (at == null)
|
if (at == null)
|
||||||
at = At.defaultAt();
|
at = At.defaultAt();
|
||||||
|
|
|
@ -75,7 +75,7 @@ public class TrainingSession extends InferenceSession {
|
||||||
this.listeners = filtered.isEmpty() ? null : filtered;
|
this.listeners = filtered.isEmpty() ? null : filtered;
|
||||||
}
|
}
|
||||||
|
|
||||||
List<String> requiredActivations = new ArrayList<>();
|
Set<String> requiredActivations = new HashSet<>();
|
||||||
gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for
|
gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for
|
||||||
for (String s : paramsToTrain) {
|
for (String s : paramsToTrain) {
|
||||||
Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s);
|
Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s);
|
||||||
|
@ -95,6 +95,12 @@ public class TrainingSession extends InferenceSession {
|
||||||
gradVarToVarMap.put(grad.name(), s);
|
gradVarToVarMap.put(grad.name(), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Also add evaluations - in case we want to evaluate something that isn't required to determine loss
|
||||||
|
// (hence wouldn't normally be calculated)
|
||||||
|
if(config.getTrainEvaluations() != null){
|
||||||
|
requiredActivations.addAll(config.getTrainEvaluations().keySet());
|
||||||
|
}
|
||||||
|
|
||||||
//Set up losses
|
//Set up losses
|
||||||
lossVarsToLossIdx = new LinkedHashMap<>();
|
lossVarsToLossIdx = new LinkedHashMap<>();
|
||||||
List<String> lossVars;
|
List<String> lossVars;
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff;
|
package org.nd4j.autodiff.samediff;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -35,6 +36,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
|
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||||
|
@ -341,6 +343,39 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
|
||||||
History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1);
|
History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTrainingEvalVarNotReqForLoss(){
|
||||||
|
//If a variable is not required for the loss - normally it won't be calculated
|
||||||
|
//But we want to make sure it IS calculated here - so we can perform evaluation on it
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
|
||||||
|
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
|
||||||
|
SDVariable z = in.mmul(w);
|
||||||
|
SDVariable out = sd.nn.softmax("softmax", z);
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||||
|
SDVariable notRequiredForLoss = sd.nn.softmax("notRequiredForLoss", z);
|
||||||
|
|
||||||
|
sd.setTrainingConfig(TrainingConfig.builder()
|
||||||
|
.updater(new Adam(0.001))
|
||||||
|
.dataSetFeatureMapping("in")
|
||||||
|
.dataSetLabelMapping("label")
|
||||||
|
.trainEvaluation("notRequiredForLoss", 0, new Evaluation())
|
||||||
|
.build());
|
||||||
|
|
||||||
|
// sd.setListeners(new ScoreListener(1));
|
||||||
|
|
||||||
|
DataSet ds = new DataSet(Nd4j.rand(DataType.FLOAT, 3, 4), Nd4j.createFromArray(new float[][]{{1,0,0}, {0,1,0}, {0,0,1}}));
|
||||||
|
|
||||||
|
History h = sd.fit()
|
||||||
|
.train(new SingletonDataSetIterator(ds), 4)
|
||||||
|
.exec();
|
||||||
|
|
||||||
|
List<Double> l = h.trainingEval(Evaluation.Metric.ACCURACY);
|
||||||
|
assertEquals(4, l.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
|
|
|
@ -40,7 +40,7 @@
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.mchange</groupId>
|
<groupId>com.mchange</groupId>
|
||||||
<artifactId>c3p0</artifactId>
|
<artifactId>c3p0</artifactId>
|
||||||
<version>0.9.5-pre5</version>
|
<version>0.9.5.4</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -121,7 +121,7 @@
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>javax.activation</groupId>
|
<groupId>javax.activation</groupId>
|
||||||
<artifactId>activation</artifactId>
|
<artifactId>activation</artifactId>
|
||||||
<version>1.1</version>
|
<version>1.1.1</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|
10
pom.xml
10
pom.xml
|
@ -271,7 +271,7 @@
|
||||||
<jodah.typetools.version>0.5.0</jodah.typetools.version>
|
<jodah.typetools.version>0.5.0</jodah.typetools.version>
|
||||||
<freemarker.version>2.3.23</freemarker.version>
|
<freemarker.version>2.3.23</freemarker.version>
|
||||||
<geoip2.version>2.8.1</geoip2.version>
|
<geoip2.version>2.8.1</geoip2.version>
|
||||||
<stream.analytics.version>2.7.0</stream.analytics.version>
|
<stream.analytics.version>2.9.8</stream.analytics.version>
|
||||||
<opencsv.version>2.3</opencsv.version>
|
<opencsv.version>2.3</opencsv.version>
|
||||||
<tdigest.version>3.2</tdigest.version>
|
<tdigest.version>3.2</tdigest.version>
|
||||||
<jtransforms.version>3.1</jtransforms.version>
|
<jtransforms.version>3.1</jtransforms.version>
|
||||||
|
@ -325,14 +325,14 @@
|
||||||
<commons-collections.version>3.2.2</commons-collections.version>
|
<commons-collections.version>3.2.2</commons-collections.version>
|
||||||
<commons-collections4.version>4.1</commons-collections4.version>
|
<commons-collections4.version>4.1</commons-collections4.version>
|
||||||
|
|
||||||
<spark.version>2.4.3</spark.version>
|
<spark.version>2.4.5</spark.version>
|
||||||
<spark.major.version>2</spark.major.version>
|
<spark.major.version>2</spark.major.version>
|
||||||
<args4j.version>2.0.29</args4j.version>
|
<args4j.version>2.0.29</args4j.version>
|
||||||
<slf4j.version>1.7.21</slf4j.version>
|
<slf4j.version>1.7.21</slf4j.version>
|
||||||
<junit.version>4.12</junit.version>
|
<junit.version>4.12</junit.version>
|
||||||
<logback.version>1.2.3</logback.version>
|
<logback.version>1.2.3</logback.version>
|
||||||
<jackson.version>2.10.1</jackson.version>
|
<jackson.version>2.10.1</jackson.version>
|
||||||
<jackson.databind.version>2.10.1</jackson.databind.version>
|
<jackson.databind.version>2.10.3</jackson.databind.version>
|
||||||
<shaded.snakeyaml.version>1.24</shaded.snakeyaml.version>
|
<shaded.snakeyaml.version>1.24</shaded.snakeyaml.version>
|
||||||
<geo.jackson.version>2.8.7</geo.jackson.version>
|
<geo.jackson.version>2.8.7</geo.jackson.version>
|
||||||
<lombok.version>1.18.12</lombok.version>
|
<lombok.version>1.18.12</lombok.version>
|
||||||
|
@ -353,7 +353,7 @@
|
||||||
<gson.version>2.8.0</gson.version>
|
<gson.version>2.8.0</gson.version>
|
||||||
<fbs.version>1.2.0-3f79e055</fbs.version>
|
<fbs.version>1.2.0-3f79e055</fbs.version>
|
||||||
<threadly.version>4.10.0</threadly.version>
|
<threadly.version>4.10.0</threadly.version>
|
||||||
<vertx.version>3.8.3</vertx.version>
|
<vertx.version>3.9.0</vertx.version>
|
||||||
|
|
||||||
<flatbuffers.version>1.10.0</flatbuffers.version>
|
<flatbuffers.version>1.10.0</flatbuffers.version>
|
||||||
<grpc.version>1.14.0</grpc.version>
|
<grpc.version>1.14.0</grpc.version>
|
||||||
|
@ -368,7 +368,7 @@
|
||||||
<maven-compiler-plugin.version>3.7.0</maven-compiler-plugin.version>
|
<maven-compiler-plugin.version>3.7.0</maven-compiler-plugin.version>
|
||||||
<maven-scala-plugin.version>3.3.1</maven-scala-plugin.version>
|
<maven-scala-plugin.version>3.3.1</maven-scala-plugin.version>
|
||||||
<maven-resources-plugin.version>3.0.1</maven-resources-plugin.version>
|
<maven-resources-plugin.version>3.0.1</maven-resources-plugin.version>
|
||||||
<sbt-compiler-maven-plugin.version>1.0.0-beta8</sbt-compiler-maven-plugin.version>
|
<sbt-compiler-maven-plugin.version>1.0.0</sbt-compiler-maven-plugin.version>
|
||||||
<maven-git-commit-plugin.version>2.2.2</maven-git-commit-plugin.version>
|
<maven-git-commit-plugin.version>2.2.2</maven-git-commit-plugin.version>
|
||||||
<maven-git-commit-id-plugin.version>${maven-git-commit-plugin.version}
|
<maven-git-commit-id-plugin.version>${maven-git-commit-plugin.version}
|
||||||
</maven-git-commit-id-plugin.version>
|
</maven-git-commit-id-plugin.version>
|
||||||
|
|
Loading…
Reference in New Issue