Fix SameDiff session termination condition when listener requests array after final requested output (#423)

* custom listener test

* abst session

Signed-off-by: eraly <susan.eraly@gmail.com>

* Partial fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix for execution termination condition

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small error mesage improvement

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: eraly <susan.eraly@gmail.com>
master
Alex Black 2020-04-30 10:47:32 +10:00 committed by GitHub
parent 88d3c4867f
commit ffab4eec42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 95 additions and 11 deletions

View File

@ -248,16 +248,17 @@ public abstract class AbstractSession<T, O> {
*/
Map<String, T> out = new HashMap<>(); //Outputs, returned to the user
Set<String> allExecuted = new HashSet<>();
int step = 0; //Number of execution steps
//Next 3: current execution frame
String currentFrame = OUTER_FRAME;
int currentFrameIter = 0;
FrameIter currParentFrame = null;
ExecStepPredicate predicate = new ExecStepPredicate();
while (out.size() < userRequestedUnique.size()) {
while (allExecuted.size() < allRequired.size()) {
if (!dt.hasNewAllSatisfied()) {
//Haven't got all of the outputs the user requested, but there's nothing left that we can execute. Should not happen.
execFailed(userRequestedUnique, out, step);
execFailed(userRequestedUnique, out, allRequired, allExecuted, step);
}
//Get variable in the current frame/iteration and execute it's corresponding op
@ -289,10 +290,13 @@ public abstract class AbstractSession<T, O> {
Preconditions.checkNotNull(arr, "Encountered null placeholder array for constant: %s", vid);
nodeOutputs.put(vid, arr);
outFrameIter = new FrameIter(OUTER_FRAME, 0, null);
if (allRequired.contains(es.getName())) {
if (userRequestedUnique.contains(es.getName())) {
//User requested const/variable as one of the outputs
out.put(es.getName(), arr);
}
if(allRequired.contains(es.getName())){
allExecuted.add(es.getName());
}
} else if (es.getType() == ExecType.PLACEHOLDER) {
VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null);
T phVal = placeholderValues == null ? null : placeholderValues.get(es.getName());
@ -305,6 +309,9 @@ public abstract class AbstractSession<T, O> {
//User requested placeholder value as one of the outputs
out.put(es.getName(), placeholderValues.get(es.getName()));
}
if(allRequired.contains(es.getName())){
allExecuted.add(es.getName());
}
} else if (es.getType() == ExecType.OP) {
String opName = es.getName();
SameDiffOp op = sameDiff.getOps().get(opName);
@ -399,9 +406,12 @@ public abstract class AbstractSession<T, O> {
VarId vid = new VarId(n, outFrameIter.getFrame(), outFrameIter.getIteration(), outFrameIter.getParentFrame());
nodeOutputs.put(vid, opOutputValues[i]);
if (allRequired.contains(n)) {
if (userRequestedUnique.contains(n)) {
out.put(n, opOutputValues[i]);
}
if(allRequired.contains(n)){
allExecuted.add(n);
}
}
//Post execution: update dependency tracker so we know what is available to execute next, given we now
@ -508,17 +518,19 @@ public abstract class AbstractSession<T, O> {
* @param out Current outputs
* @param step Execution step
*/
protected void execFailed(Set<String> userRequestedUnique, Map<String, T> out, int step) {
protected void execFailed(Set<String> userRequestedUnique, Map<String, T> out, Set<String> allRequired, Set<String> allExecuted, int step) {
int missingCount = userRequestedUnique.size() - out.size();
StringBuilder sb = new StringBuilder();
sb.append("No variable are available for execution at step ")
.append(step).append(": ").append(missingCount).append(" values remaining");
.append(step).append(": ").append(missingCount).append(" requested output values remaining, ")
.append(allExecuted.size() - allRequired.size()).append(" variables required to be executed remaining");
Set<String> missing = new HashSet<>();
for (String s : userRequestedUnique) {
if (!out.containsKey(s)) {
missing.add(s);
}
}
if (missingCount <= 10) {
sb.append(". Missing variables: ");
sb.append(missing);

View File

@ -16,11 +16,6 @@
package org.nd4j.autodiff.samediff.listeners;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.util.*;
import org.junit.Test;
import org.nd4j.autodiff.listeners.*;
import org.nd4j.autodiff.listeners.impl.ScoreListener;
@ -49,6 +44,13 @@ import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.weightinit.impl.XavierInitScheme;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.*;
public class ListenerTest extends BaseNd4jTest {
public ListenerTest(Nd4jBackend backend) {
@ -260,6 +262,42 @@ public class ListenerTest extends BaseNd4jTest {
}
}
@Test
public void testCustomListener() {
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3));
SDVariable z = sd.nn().linear("z", in, w, b);
SDVariable out = sd.nn().softmax("out", z, 1);
SDVariable loss = sd.loss().softmaxCrossEntropy("loss", label, out, null);
//Create and set the training configuration
double learningRate = 1e-3;
TrainingConfig config = new TrainingConfig.Builder()
.l2(1e-4) //L2 regularization
.updater(new Adam(learningRate)) //Adam optimizer with specified learning rate
.dataSetFeatureMapping("input") //DataSet features array should be associated with variable "input"
.dataSetLabelMapping("label") //DataSet label array should be associated with variable "label
.addEvaluations(false,"out",0,new Evaluation())
.build();
sd.setTrainingConfig(config);
CustomListener listener = new CustomListener();
Map<String,INDArray> m = sd.output()
.data(new IrisDataSetIterator(150, 150))
.output("out")
.listeners(listener)
.exec();
assertEquals(1, m.size());
assertTrue(m.containsKey("out"));
assertNotNull(listener.z);
assertNotNull(listener.out);
}
private static class TestListener implements Listener {
public TestListener(Operation operation){
@ -356,4 +394,38 @@ public class ListenerTest extends BaseNd4jTest {
preUpdateCount++;
}
}
private static class CustomListener extends BaseListener {
public INDArray z;
public INDArray out;
// Specify that this listener is active during inference operations
@Override
public boolean isActive(Operation operation) {
return operation == Operation.INFERENCE;
}
// Specify that this listener requires the activations of "z" and "out"
@Override
public ListenerVariables requiredVariables(SameDiff sd) {
return new ListenerVariables.Builder().inferenceVariables("z", "out").build();
}
// Called when the activation of a variable becomes available
@Override
public void activationAvailable(SameDiff sd, At at,
MultiDataSet batch, SameDiffOp op,
String varName, INDArray activation) {
System.out.println("activation:" + varName);
// if the variable is z or out, store its activation
if (varName.equals("z")) {
z = activation.detach().dup();
} else if (varName.equals("out")) {
out = activation.detach().dup();
}
}
}
}