Dependency version upgrades + small SameDiff fix (#405)
* #8861 Training evaluation on variables not required for loss fix Signed-off-by: Alex Black <blacka101@gmail.com> * Dependency version updates flagged by dependabot Signed-off-by: Alex Black <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									8f765c80ff
								
							
						
					
					
						commit
						f5f77df846
					
				| @ -132,7 +132,7 @@ public abstract class AbstractSession<T, O> { | ||||
|         Preconditions.checkState(!variables.isEmpty() || !requiredActivations.isEmpty(), "Variables to perform forward pass for must not be empty"); | ||||
| 
 | ||||
|         if (requiredActivations == null) | ||||
|             requiredActivations = Collections.emptyList(); | ||||
|             requiredActivations = Collections.emptySet(); | ||||
| 
 | ||||
|         if (at == null) | ||||
|             at = At.defaultAt(); | ||||
|  | ||||
| @ -75,7 +75,7 @@ public class TrainingSession extends InferenceSession { | ||||
|             this.listeners = filtered.isEmpty() ? null : filtered; | ||||
|         } | ||||
| 
 | ||||
|         List<String> requiredActivations = new ArrayList<>(); | ||||
|         Set<String> requiredActivations = new HashSet<>(); | ||||
|         gradVarToVarMap = new HashMap<>();       //Key: gradient variable. Value: variable that the key is gradient for | ||||
|         for (String s : paramsToTrain) { | ||||
|             Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s); | ||||
| @ -95,6 +95,12 @@ public class TrainingSession extends InferenceSession { | ||||
|             gradVarToVarMap.put(grad.name(), s); | ||||
|         } | ||||
| 
 | ||||
|         //Also add evaluations - in case we want to evaluate something that isn't required to determine loss | ||||
|         // (hence wouldn't normally be calculated) | ||||
|         if(config.getTrainEvaluations() != null){ | ||||
|             requiredActivations.addAll(config.getTrainEvaluations().keySet()); | ||||
|         } | ||||
| 
 | ||||
|         //Set up losses | ||||
|         lossVarsToLossIdx = new LinkedHashMap<>(); | ||||
|         List<String> lossVars; | ||||
|  | ||||
| @ -16,6 +16,7 @@ | ||||
| 
 | ||||
| package org.nd4j.autodiff.samediff; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.junit.Assert.assertTrue; | ||||
| 
 | ||||
| import java.util.Collections; | ||||
| @ -35,6 +36,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| import org.nd4j.linalg.dataset.IrisDataSetIterator; | ||||
| import org.nd4j.linalg.dataset.MultiDataSet; | ||||
| import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator; | ||||
| import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator; | ||||
| import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | ||||
| import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; | ||||
| @ -341,6 +343,39 @@ public class SameDiffTrainingTest extends BaseNd4jTest { | ||||
|         History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testTrainingEvalVarNotReqForLoss(){ | ||||
|         //If a variable is not required for the loss - normally it won't be calculated | ||||
|         //But we want to make sure it IS calculated here - so we can perform evaluation on it | ||||
| 
 | ||||
|         SameDiff sd = SameDiff.create(); | ||||
|         SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); | ||||
|         SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); | ||||
|         SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); | ||||
|         SDVariable z = in.mmul(w); | ||||
|         SDVariable out = sd.nn.softmax("softmax", z); | ||||
|         SDVariable loss = sd.loss.logLoss("loss", label, out); | ||||
|         SDVariable notRequiredForLoss = sd.nn.softmax("notRequiredForLoss", z); | ||||
| 
 | ||||
|         sd.setTrainingConfig(TrainingConfig.builder() | ||||
|                 .updater(new Adam(0.001)) | ||||
|                 .dataSetFeatureMapping("in") | ||||
|                 .dataSetLabelMapping("label") | ||||
|                 .trainEvaluation("notRequiredForLoss", 0, new Evaluation()) | ||||
|                 .build()); | ||||
| 
 | ||||
| //        sd.setListeners(new ScoreListener(1)); | ||||
| 
 | ||||
|         DataSet ds = new DataSet(Nd4j.rand(DataType.FLOAT, 3, 4), Nd4j.createFromArray(new float[][]{{1,0,0}, {0,1,0}, {0,0,1}})); | ||||
| 
 | ||||
|         History h = sd.fit() | ||||
|                 .train(new SingletonDataSetIterator(ds), 4) | ||||
|                 .exec(); | ||||
| 
 | ||||
|         List<Double> l = h.trainingEval(Evaluation.Metric.ACCURACY); | ||||
|         assertEquals(4, l.size()); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public char ordering() { | ||||
|  | ||||
| @ -40,7 +40,7 @@ | ||||
|         <dependency> | ||||
|             <groupId>com.mchange</groupId> | ||||
|             <artifactId>c3p0</artifactId> | ||||
|             <version>0.9.5-pre5</version> | ||||
|             <version>0.9.5.4</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.nd4j</groupId> | ||||
|  | ||||
| @ -121,7 +121,7 @@ | ||||
|         <dependency> | ||||
|             <groupId>javax.activation</groupId> | ||||
|             <artifactId>activation</artifactId> | ||||
|             <version>1.1</version> | ||||
|             <version>1.1.1</version> | ||||
|         </dependency> | ||||
| 
 | ||||
|         <dependency> | ||||
|  | ||||
							
								
								
									
										10
									
								
								pom.xml
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								pom.xml
									
									
									
									
									
								
							| @ -271,7 +271,7 @@ | ||||
|         <jodah.typetools.version>0.5.0</jodah.typetools.version> | ||||
|         <freemarker.version>2.3.23</freemarker.version> | ||||
|         <geoip2.version>2.8.1</geoip2.version> | ||||
|         <stream.analytics.version>2.7.0</stream.analytics.version> | ||||
|         <stream.analytics.version>2.9.8</stream.analytics.version> | ||||
|         <opencsv.version>2.3</opencsv.version> | ||||
|         <tdigest.version>3.2</tdigest.version> | ||||
|         <jtransforms.version>3.1</jtransforms.version> | ||||
| @ -325,14 +325,14 @@ | ||||
|         <commons-collections.version>3.2.2</commons-collections.version> | ||||
|         <commons-collections4.version>4.1</commons-collections4.version> | ||||
| 
 | ||||
|         <spark.version>2.4.3</spark.version> | ||||
|         <spark.version>2.4.5</spark.version> | ||||
|         <spark.major.version>2</spark.major.version> | ||||
|         <args4j.version>2.0.29</args4j.version> | ||||
|         <slf4j.version>1.7.21</slf4j.version> | ||||
|         <junit.version>4.12</junit.version> | ||||
|         <logback.version>1.2.3</logback.version> | ||||
|         <jackson.version>2.10.1</jackson.version> | ||||
|         <jackson.databind.version>2.10.1</jackson.databind.version> | ||||
|         <jackson.databind.version>2.10.3</jackson.databind.version> | ||||
|         <shaded.snakeyaml.version>1.24</shaded.snakeyaml.version> | ||||
|         <geo.jackson.version>2.8.7</geo.jackson.version> | ||||
|         <lombok.version>1.18.12</lombok.version> | ||||
| @ -353,7 +353,7 @@ | ||||
|         <gson.version>2.8.0</gson.version> | ||||
|         <fbs.version>1.2.0-3f79e055</fbs.version> | ||||
|         <threadly.version>4.10.0</threadly.version> | ||||
|         <vertx.version>3.8.3</vertx.version> | ||||
|         <vertx.version>3.9.0</vertx.version> | ||||
| 
 | ||||
|         <flatbuffers.version>1.10.0</flatbuffers.version> | ||||
|         <grpc.version>1.14.0</grpc.version> | ||||
| @ -368,7 +368,7 @@ | ||||
|         <maven-compiler-plugin.version>3.7.0</maven-compiler-plugin.version> | ||||
|         <maven-scala-plugin.version>3.3.1</maven-scala-plugin.version> | ||||
|         <maven-resources-plugin.version>3.0.1</maven-resources-plugin.version> | ||||
|         <sbt-compiler-maven-plugin.version>1.0.0-beta8</sbt-compiler-maven-plugin.version> | ||||
|         <sbt-compiler-maven-plugin.version>1.0.0</sbt-compiler-maven-plugin.version> | ||||
|         <maven-git-commit-plugin.version>2.2.2</maven-git-commit-plugin.version> | ||||
|         <maven-git-commit-id-plugin.version>${maven-git-commit-plugin.version} | ||||
|         </maven-git-commit-id-plugin.version> | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user