Dependency version upgrades + small SameDiff fix (#405)
* #8861 Training evaluation on variables not required for loss fix Signed-off-by: Alex Black <blacka101@gmail.com> * Dependency version updates flagged by dependabot Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
8f765c80ff
commit
f5f77df846
|
@ -132,7 +132,7 @@ public abstract class AbstractSession<T, O> {
|
|||
Preconditions.checkState(!variables.isEmpty() || !requiredActivations.isEmpty(), "Variables to perform forward pass for must not be empty");
|
||||
|
||||
if (requiredActivations == null)
|
||||
requiredActivations = Collections.emptyList();
|
||||
requiredActivations = Collections.emptySet();
|
||||
|
||||
if (at == null)
|
||||
at = At.defaultAt();
|
||||
|
|
|
@ -75,7 +75,7 @@ public class TrainingSession extends InferenceSession {
|
|||
this.listeners = filtered.isEmpty() ? null : filtered;
|
||||
}
|
||||
|
||||
List<String> requiredActivations = new ArrayList<>();
|
||||
Set<String> requiredActivations = new HashSet<>();
|
||||
gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for
|
||||
for (String s : paramsToTrain) {
|
||||
Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s);
|
||||
|
@ -95,6 +95,12 @@ public class TrainingSession extends InferenceSession {
|
|||
gradVarToVarMap.put(grad.name(), s);
|
||||
}
|
||||
|
||||
//Also add evaluations - in case we want to evaluate something that isn't required to determine loss
|
||||
// (hence wouldn't normally be calculated)
|
||||
if(config.getTrainEvaluations() != null){
|
||||
requiredActivations.addAll(config.getTrainEvaluations().keySet());
|
||||
}
|
||||
|
||||
//Set up losses
|
||||
lossVarsToLossIdx = new LinkedHashMap<>();
|
||||
List<String> lossVars;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.autodiff.samediff;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.util.Collections;
|
||||
|
@ -35,6 +36,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||
|
@ -341,6 +343,39 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
|
|||
History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTrainingEvalVarNotReqForLoss(){
|
||||
//If a variable is not required for the loss - normally it won't be calculated
|
||||
//But we want to make sure it IS calculated here - so we can perform evaluation on it
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
||||
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
|
||||
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
|
||||
SDVariable z = in.mmul(w);
|
||||
SDVariable out = sd.nn.softmax("softmax", z);
|
||||
SDVariable loss = sd.loss.logLoss("loss", label, out);
|
||||
SDVariable notRequiredForLoss = sd.nn.softmax("notRequiredForLoss", z);
|
||||
|
||||
sd.setTrainingConfig(TrainingConfig.builder()
|
||||
.updater(new Adam(0.001))
|
||||
.dataSetFeatureMapping("in")
|
||||
.dataSetLabelMapping("label")
|
||||
.trainEvaluation("notRequiredForLoss", 0, new Evaluation())
|
||||
.build());
|
||||
|
||||
// sd.setListeners(new ScoreListener(1));
|
||||
|
||||
DataSet ds = new DataSet(Nd4j.rand(DataType.FLOAT, 3, 4), Nd4j.createFromArray(new float[][]{{1,0,0}, {0,1,0}, {0,0,1}}));
|
||||
|
||||
History h = sd.fit()
|
||||
.train(new SingletonDataSetIterator(ds), 4)
|
||||
.exec();
|
||||
|
||||
List<Double> l = h.trainingEval(Evaluation.Metric.ACCURACY);
|
||||
assertEquals(4, l.size());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
|
|
@ -40,7 +40,7 @@
|
|||
<dependency>
|
||||
<groupId>com.mchange</groupId>
|
||||
<artifactId>c3p0</artifactId>
|
||||
<version>0.9.5-pre5</version>
|
||||
<version>0.9.5.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
|
|
@ -121,7 +121,7 @@
|
|||
<dependency>
|
||||
<groupId>javax.activation</groupId>
|
||||
<artifactId>activation</artifactId>
|
||||
<version>1.1</version>
|
||||
<version>1.1.1</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
|
10
pom.xml
10
pom.xml
|
@ -271,7 +271,7 @@
|
|||
<jodah.typetools.version>0.5.0</jodah.typetools.version>
|
||||
<freemarker.version>2.3.23</freemarker.version>
|
||||
<geoip2.version>2.8.1</geoip2.version>
|
||||
<stream.analytics.version>2.7.0</stream.analytics.version>
|
||||
<stream.analytics.version>2.9.8</stream.analytics.version>
|
||||
<opencsv.version>2.3</opencsv.version>
|
||||
<tdigest.version>3.2</tdigest.version>
|
||||
<jtransforms.version>3.1</jtransforms.version>
|
||||
|
@ -325,14 +325,14 @@
|
|||
<commons-collections.version>3.2.2</commons-collections.version>
|
||||
<commons-collections4.version>4.1</commons-collections4.version>
|
||||
|
||||
<spark.version>2.4.3</spark.version>
|
||||
<spark.version>2.4.5</spark.version>
|
||||
<spark.major.version>2</spark.major.version>
|
||||
<args4j.version>2.0.29</args4j.version>
|
||||
<slf4j.version>1.7.21</slf4j.version>
|
||||
<junit.version>4.12</junit.version>
|
||||
<logback.version>1.2.3</logback.version>
|
||||
<jackson.version>2.10.1</jackson.version>
|
||||
<jackson.databind.version>2.10.1</jackson.databind.version>
|
||||
<jackson.databind.version>2.10.3</jackson.databind.version>
|
||||
<shaded.snakeyaml.version>1.24</shaded.snakeyaml.version>
|
||||
<geo.jackson.version>2.8.7</geo.jackson.version>
|
||||
<lombok.version>1.18.12</lombok.version>
|
||||
|
@ -353,7 +353,7 @@
|
|||
<gson.version>2.8.0</gson.version>
|
||||
<fbs.version>1.2.0-3f79e055</fbs.version>
|
||||
<threadly.version>4.10.0</threadly.version>
|
||||
<vertx.version>3.8.3</vertx.version>
|
||||
<vertx.version>3.9.0</vertx.version>
|
||||
|
||||
<flatbuffers.version>1.10.0</flatbuffers.version>
|
||||
<grpc.version>1.14.0</grpc.version>
|
||||
|
@ -368,7 +368,7 @@
|
|||
<maven-compiler-plugin.version>3.7.0</maven-compiler-plugin.version>
|
||||
<maven-scala-plugin.version>3.3.1</maven-scala-plugin.version>
|
||||
<maven-resources-plugin.version>3.0.1</maven-resources-plugin.version>
|
||||
<sbt-compiler-maven-plugin.version>1.0.0-beta8</sbt-compiler-maven-plugin.version>
|
||||
<sbt-compiler-maven-plugin.version>1.0.0</sbt-compiler-maven-plugin.version>
|
||||
<maven-git-commit-plugin.version>2.2.2</maven-git-commit-plugin.version>
|
||||
<maven-git-commit-id-plugin.version>${maven-git-commit-plugin.version}
|
||||
</maven-git-commit-id-plugin.version>
|
||||
|
|
Loading…
Reference in New Issue