Fix SameDiff session termination condition when listener requests array after final requested output (#423)
* custom listener test * abst session Signed-off-by: eraly <susan.eraly@gmail.com> * Partial fix Signed-off-by: Alex Black <blacka101@gmail.com> * Fix for execution termination condition Signed-off-by: Alex Black <blacka101@gmail.com> * Small error mesage improvement Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: eraly <susan.eraly@gmail.com>master
parent
88d3c4867f
commit
ffab4eec42
|
@ -248,16 +248,17 @@ public abstract class AbstractSession<T, O> {
|
|||
*/
|
||||
|
||||
Map<String, T> out = new HashMap<>(); //Outputs, returned to the user
|
||||
Set<String> allExecuted = new HashSet<>();
|
||||
int step = 0; //Number of execution steps
|
||||
//Next 3: current execution frame
|
||||
String currentFrame = OUTER_FRAME;
|
||||
int currentFrameIter = 0;
|
||||
FrameIter currParentFrame = null;
|
||||
ExecStepPredicate predicate = new ExecStepPredicate();
|
||||
while (out.size() < userRequestedUnique.size()) {
|
||||
while (allExecuted.size() < allRequired.size()) {
|
||||
if (!dt.hasNewAllSatisfied()) {
|
||||
//Haven't got all of the outputs the user requested, but there's nothing left that we can execute. Should not happen.
|
||||
execFailed(userRequestedUnique, out, step);
|
||||
execFailed(userRequestedUnique, out, allRequired, allExecuted, step);
|
||||
}
|
||||
|
||||
//Get variable in the current frame/iteration and execute it's corresponding op
|
||||
|
@ -289,10 +290,13 @@ public abstract class AbstractSession<T, O> {
|
|||
Preconditions.checkNotNull(arr, "Encountered null placeholder array for constant: %s", vid);
|
||||
nodeOutputs.put(vid, arr);
|
||||
outFrameIter = new FrameIter(OUTER_FRAME, 0, null);
|
||||
if (allRequired.contains(es.getName())) {
|
||||
if (userRequestedUnique.contains(es.getName())) {
|
||||
//User requested const/variable as one of the outputs
|
||||
out.put(es.getName(), arr);
|
||||
}
|
||||
if(allRequired.contains(es.getName())){
|
||||
allExecuted.add(es.getName());
|
||||
}
|
||||
} else if (es.getType() == ExecType.PLACEHOLDER) {
|
||||
VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null);
|
||||
T phVal = placeholderValues == null ? null : placeholderValues.get(es.getName());
|
||||
|
@ -305,6 +309,9 @@ public abstract class AbstractSession<T, O> {
|
|||
//User requested placeholder value as one of the outputs
|
||||
out.put(es.getName(), placeholderValues.get(es.getName()));
|
||||
}
|
||||
if(allRequired.contains(es.getName())){
|
||||
allExecuted.add(es.getName());
|
||||
}
|
||||
} else if (es.getType() == ExecType.OP) {
|
||||
String opName = es.getName();
|
||||
SameDiffOp op = sameDiff.getOps().get(opName);
|
||||
|
@ -399,9 +406,12 @@ public abstract class AbstractSession<T, O> {
|
|||
VarId vid = new VarId(n, outFrameIter.getFrame(), outFrameIter.getIteration(), outFrameIter.getParentFrame());
|
||||
nodeOutputs.put(vid, opOutputValues[i]);
|
||||
|
||||
if (allRequired.contains(n)) {
|
||||
if (userRequestedUnique.contains(n)) {
|
||||
out.put(n, opOutputValues[i]);
|
||||
}
|
||||
if(allRequired.contains(n)){
|
||||
allExecuted.add(n);
|
||||
}
|
||||
}
|
||||
|
||||
//Post execution: update dependency tracker so we know what is available to execute next, given we now
|
||||
|
@ -508,17 +518,19 @@ public abstract class AbstractSession<T, O> {
|
|||
* @param out Current outputs
|
||||
* @param step Execution step
|
||||
*/
|
||||
protected void execFailed(Set<String> userRequestedUnique, Map<String, T> out, int step) {
|
||||
protected void execFailed(Set<String> userRequestedUnique, Map<String, T> out, Set<String> allRequired, Set<String> allExecuted, int step) {
|
||||
int missingCount = userRequestedUnique.size() - out.size();
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("No variable are available for execution at step ")
|
||||
.append(step).append(": ").append(missingCount).append(" values remaining");
|
||||
.append(step).append(": ").append(missingCount).append(" requested output values remaining, ")
|
||||
.append(allExecuted.size() - allRequired.size()).append(" variables required to be executed remaining");
|
||||
Set<String> missing = new HashSet<>();
|
||||
for (String s : userRequestedUnique) {
|
||||
if (!out.containsKey(s)) {
|
||||
missing.add(s);
|
||||
}
|
||||
}
|
||||
|
||||
if (missingCount <= 10) {
|
||||
sb.append(". Missing variables: ");
|
||||
sb.append(missing);
|
||||
|
|
|
@ -16,11 +16,6 @@
|
|||
|
||||
package org.nd4j.autodiff.samediff.listeners;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.autodiff.listeners.*;
|
||||
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
||||
|
@ -49,6 +44,13 @@ import org.nd4j.linalg.learning.config.Adam;
|
|||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.weightinit.impl.XavierInitScheme;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class ListenerTest extends BaseNd4jTest {
|
||||
|
||||
public ListenerTest(Nd4jBackend backend) {
|
||||
|
@ -260,6 +262,42 @@ public class ListenerTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCustomListener() {
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
|
||||
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
|
||||
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
|
||||
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3));
|
||||
SDVariable z = sd.nn().linear("z", in, w, b);
|
||||
SDVariable out = sd.nn().softmax("out", z, 1);
|
||||
SDVariable loss = sd.loss().softmaxCrossEntropy("loss", label, out, null);
|
||||
|
||||
//Create and set the training configuration
|
||||
double learningRate = 1e-3;
|
||||
TrainingConfig config = new TrainingConfig.Builder()
|
||||
.l2(1e-4) //L2 regularization
|
||||
.updater(new Adam(learningRate)) //Adam optimizer with specified learning rate
|
||||
.dataSetFeatureMapping("input") //DataSet features array should be associated with variable "input"
|
||||
.dataSetLabelMapping("label") //DataSet label array should be associated with variable "label
|
||||
.addEvaluations(false,"out",0,new Evaluation())
|
||||
.build();
|
||||
sd.setTrainingConfig(config);
|
||||
|
||||
CustomListener listener = new CustomListener();
|
||||
Map<String,INDArray> m = sd.output()
|
||||
.data(new IrisDataSetIterator(150, 150))
|
||||
.output("out")
|
||||
.listeners(listener)
|
||||
.exec();
|
||||
|
||||
assertEquals(1, m.size());
|
||||
assertTrue(m.containsKey("out"));
|
||||
assertNotNull(listener.z);
|
||||
assertNotNull(listener.out);
|
||||
|
||||
}
|
||||
|
||||
private static class TestListener implements Listener {
|
||||
|
||||
public TestListener(Operation operation){
|
||||
|
@ -356,4 +394,38 @@ public class ListenerTest extends BaseNd4jTest {
|
|||
preUpdateCount++;
|
||||
}
|
||||
}
|
||||
|
||||
private static class CustomListener extends BaseListener {
|
||||
|
||||
public INDArray z;
|
||||
public INDArray out;
|
||||
|
||||
// Specify that this listener is active during inference operations
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return operation == Operation.INFERENCE;
|
||||
}
|
||||
|
||||
// Specify that this listener requires the activations of "z" and "out"
|
||||
@Override
|
||||
public ListenerVariables requiredVariables(SameDiff sd) {
|
||||
return new ListenerVariables.Builder().inferenceVariables("z", "out").build();
|
||||
}
|
||||
|
||||
// Called when the activation of a variable becomes available
|
||||
@Override
|
||||
public void activationAvailable(SameDiff sd, At at,
|
||||
MultiDataSet batch, SameDiffOp op,
|
||||
String varName, INDArray activation) {
|
||||
System.out.println("activation:" + varName);
|
||||
|
||||
// if the variable is z or out, store its activation
|
||||
if (varName.equals("z")) {
|
||||
z = activation.detach().dup();
|
||||
} else if (varName.equals("out")) {
|
||||
out = activation.detach().dup();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue