Fix SameDiff session termination condition when listener requests array after final requested output (#423)
* custom listener test * abst session Signed-off-by: eraly <susan.eraly@gmail.com> * Partial fix Signed-off-by: Alex Black <blacka101@gmail.com> * Fix for execution termination condition Signed-off-by: Alex Black <blacka101@gmail.com> * Small error mesage improvement Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: eraly <susan.eraly@gmail.com>master
parent
88d3c4867f
commit
ffab4eec42
|
@ -248,16 +248,17 @@ public abstract class AbstractSession<T, O> {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
Map<String, T> out = new HashMap<>(); //Outputs, returned to the user
|
Map<String, T> out = new HashMap<>(); //Outputs, returned to the user
|
||||||
|
Set<String> allExecuted = new HashSet<>();
|
||||||
int step = 0; //Number of execution steps
|
int step = 0; //Number of execution steps
|
||||||
//Next 3: current execution frame
|
//Next 3: current execution frame
|
||||||
String currentFrame = OUTER_FRAME;
|
String currentFrame = OUTER_FRAME;
|
||||||
int currentFrameIter = 0;
|
int currentFrameIter = 0;
|
||||||
FrameIter currParentFrame = null;
|
FrameIter currParentFrame = null;
|
||||||
ExecStepPredicate predicate = new ExecStepPredicate();
|
ExecStepPredicate predicate = new ExecStepPredicate();
|
||||||
while (out.size() < userRequestedUnique.size()) {
|
while (allExecuted.size() < allRequired.size()) {
|
||||||
if (!dt.hasNewAllSatisfied()) {
|
if (!dt.hasNewAllSatisfied()) {
|
||||||
//Haven't got all of the outputs the user requested, but there's nothing left that we can execute. Should not happen.
|
//Haven't got all of the outputs the user requested, but there's nothing left that we can execute. Should not happen.
|
||||||
execFailed(userRequestedUnique, out, step);
|
execFailed(userRequestedUnique, out, allRequired, allExecuted, step);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Get variable in the current frame/iteration and execute it's corresponding op
|
//Get variable in the current frame/iteration and execute it's corresponding op
|
||||||
|
@ -289,10 +290,13 @@ public abstract class AbstractSession<T, O> {
|
||||||
Preconditions.checkNotNull(arr, "Encountered null placeholder array for constant: %s", vid);
|
Preconditions.checkNotNull(arr, "Encountered null placeholder array for constant: %s", vid);
|
||||||
nodeOutputs.put(vid, arr);
|
nodeOutputs.put(vid, arr);
|
||||||
outFrameIter = new FrameIter(OUTER_FRAME, 0, null);
|
outFrameIter = new FrameIter(OUTER_FRAME, 0, null);
|
||||||
if (allRequired.contains(es.getName())) {
|
if (userRequestedUnique.contains(es.getName())) {
|
||||||
//User requested const/variable as one of the outputs
|
//User requested const/variable as one of the outputs
|
||||||
out.put(es.getName(), arr);
|
out.put(es.getName(), arr);
|
||||||
}
|
}
|
||||||
|
if(allRequired.contains(es.getName())){
|
||||||
|
allExecuted.add(es.getName());
|
||||||
|
}
|
||||||
} else if (es.getType() == ExecType.PLACEHOLDER) {
|
} else if (es.getType() == ExecType.PLACEHOLDER) {
|
||||||
VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null);
|
VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null);
|
||||||
T phVal = placeholderValues == null ? null : placeholderValues.get(es.getName());
|
T phVal = placeholderValues == null ? null : placeholderValues.get(es.getName());
|
||||||
|
@ -305,6 +309,9 @@ public abstract class AbstractSession<T, O> {
|
||||||
//User requested placeholder value as one of the outputs
|
//User requested placeholder value as one of the outputs
|
||||||
out.put(es.getName(), placeholderValues.get(es.getName()));
|
out.put(es.getName(), placeholderValues.get(es.getName()));
|
||||||
}
|
}
|
||||||
|
if(allRequired.contains(es.getName())){
|
||||||
|
allExecuted.add(es.getName());
|
||||||
|
}
|
||||||
} else if (es.getType() == ExecType.OP) {
|
} else if (es.getType() == ExecType.OP) {
|
||||||
String opName = es.getName();
|
String opName = es.getName();
|
||||||
SameDiffOp op = sameDiff.getOps().get(opName);
|
SameDiffOp op = sameDiff.getOps().get(opName);
|
||||||
|
@ -399,9 +406,12 @@ public abstract class AbstractSession<T, O> {
|
||||||
VarId vid = new VarId(n, outFrameIter.getFrame(), outFrameIter.getIteration(), outFrameIter.getParentFrame());
|
VarId vid = new VarId(n, outFrameIter.getFrame(), outFrameIter.getIteration(), outFrameIter.getParentFrame());
|
||||||
nodeOutputs.put(vid, opOutputValues[i]);
|
nodeOutputs.put(vid, opOutputValues[i]);
|
||||||
|
|
||||||
if (allRequired.contains(n)) {
|
if (userRequestedUnique.contains(n)) {
|
||||||
out.put(n, opOutputValues[i]);
|
out.put(n, opOutputValues[i]);
|
||||||
}
|
}
|
||||||
|
if(allRequired.contains(n)){
|
||||||
|
allExecuted.add(n);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Post execution: update dependency tracker so we know what is available to execute next, given we now
|
//Post execution: update dependency tracker so we know what is available to execute next, given we now
|
||||||
|
@ -508,17 +518,19 @@ public abstract class AbstractSession<T, O> {
|
||||||
* @param out Current outputs
|
* @param out Current outputs
|
||||||
* @param step Execution step
|
* @param step Execution step
|
||||||
*/
|
*/
|
||||||
protected void execFailed(Set<String> userRequestedUnique, Map<String, T> out, int step) {
|
protected void execFailed(Set<String> userRequestedUnique, Map<String, T> out, Set<String> allRequired, Set<String> allExecuted, int step) {
|
||||||
int missingCount = userRequestedUnique.size() - out.size();
|
int missingCount = userRequestedUnique.size() - out.size();
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
sb.append("No variable are available for execution at step ")
|
sb.append("No variable are available for execution at step ")
|
||||||
.append(step).append(": ").append(missingCount).append(" values remaining");
|
.append(step).append(": ").append(missingCount).append(" requested output values remaining, ")
|
||||||
|
.append(allExecuted.size() - allRequired.size()).append(" variables required to be executed remaining");
|
||||||
Set<String> missing = new HashSet<>();
|
Set<String> missing = new HashSet<>();
|
||||||
for (String s : userRequestedUnique) {
|
for (String s : userRequestedUnique) {
|
||||||
if (!out.containsKey(s)) {
|
if (!out.containsKey(s)) {
|
||||||
missing.add(s);
|
missing.add(s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (missingCount <= 10) {
|
if (missingCount <= 10) {
|
||||||
sb.append(". Missing variables: ");
|
sb.append(". Missing variables: ");
|
||||||
sb.append(missing);
|
sb.append(missing);
|
||||||
|
|
|
@ -16,11 +16,6 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.listeners;
|
package org.nd4j.autodiff.samediff.listeners;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.autodiff.listeners.*;
|
import org.nd4j.autodiff.listeners.*;
|
||||||
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
||||||
|
@ -49,6 +44,13 @@ import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.weightinit.impl.XavierInitScheme;
|
import org.nd4j.weightinit.impl.XavierInitScheme;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class ListenerTest extends BaseNd4jTest {
|
public class ListenerTest extends BaseNd4jTest {
|
||||||
|
|
||||||
public ListenerTest(Nd4jBackend backend) {
|
public ListenerTest(Nd4jBackend backend) {
|
||||||
|
@ -260,6 +262,42 @@ public class ListenerTest extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCustomListener() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
|
||||||
|
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
|
||||||
|
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3));
|
||||||
|
SDVariable z = sd.nn().linear("z", in, w, b);
|
||||||
|
SDVariable out = sd.nn().softmax("out", z, 1);
|
||||||
|
SDVariable loss = sd.loss().softmaxCrossEntropy("loss", label, out, null);
|
||||||
|
|
||||||
|
//Create and set the training configuration
|
||||||
|
double learningRate = 1e-3;
|
||||||
|
TrainingConfig config = new TrainingConfig.Builder()
|
||||||
|
.l2(1e-4) //L2 regularization
|
||||||
|
.updater(new Adam(learningRate)) //Adam optimizer with specified learning rate
|
||||||
|
.dataSetFeatureMapping("input") //DataSet features array should be associated with variable "input"
|
||||||
|
.dataSetLabelMapping("label") //DataSet label array should be associated with variable "label
|
||||||
|
.addEvaluations(false,"out",0,new Evaluation())
|
||||||
|
.build();
|
||||||
|
sd.setTrainingConfig(config);
|
||||||
|
|
||||||
|
CustomListener listener = new CustomListener();
|
||||||
|
Map<String,INDArray> m = sd.output()
|
||||||
|
.data(new IrisDataSetIterator(150, 150))
|
||||||
|
.output("out")
|
||||||
|
.listeners(listener)
|
||||||
|
.exec();
|
||||||
|
|
||||||
|
assertEquals(1, m.size());
|
||||||
|
assertTrue(m.containsKey("out"));
|
||||||
|
assertNotNull(listener.z);
|
||||||
|
assertNotNull(listener.out);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
private static class TestListener implements Listener {
|
private static class TestListener implements Listener {
|
||||||
|
|
||||||
public TestListener(Operation operation){
|
public TestListener(Operation operation){
|
||||||
|
@ -356,4 +394,38 @@ public class ListenerTest extends BaseNd4jTest {
|
||||||
preUpdateCount++;
|
preUpdateCount++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class CustomListener extends BaseListener {
|
||||||
|
|
||||||
|
public INDArray z;
|
||||||
|
public INDArray out;
|
||||||
|
|
||||||
|
// Specify that this listener is active during inference operations
|
||||||
|
@Override
|
||||||
|
public boolean isActive(Operation operation) {
|
||||||
|
return operation == Operation.INFERENCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specify that this listener requires the activations of "z" and "out"
|
||||||
|
@Override
|
||||||
|
public ListenerVariables requiredVariables(SameDiff sd) {
|
||||||
|
return new ListenerVariables.Builder().inferenceVariables("z", "out").build();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Called when the activation of a variable becomes available
|
||||||
|
@Override
|
||||||
|
public void activationAvailable(SameDiff sd, At at,
|
||||||
|
MultiDataSet batch, SameDiffOp op,
|
||||||
|
String varName, INDArray activation) {
|
||||||
|
System.out.println("activation:" + varName);
|
||||||
|
|
||||||
|
// if the variable is z or out, store its activation
|
||||||
|
if (varName.equals("z")) {
|
||||||
|
z = activation.detach().dup();
|
||||||
|
} else if (varName.equals("out")) {
|
||||||
|
out = activation.detach().dup();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue