Add placeholder value for keras learning phase if required (#30)
* automatically add placeholders for keras_learning_phase if required Signed-off-by: Ryan Nett <rnett@skymind.io> * comment Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
c28372cb49
commit
366d850f5e
|
@ -31,6 +31,7 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Additional functionality to add:
|
* Additional functionality to add:
|
||||||
|
@ -193,11 +194,19 @@ public abstract class AbstractSession<T, O> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if(required && (placeholderValues == null || !placeholderValues.containsKey(s))){
|
if(required && (placeholderValues == null || !placeholderValues.containsKey(s))){
|
||||||
throw new IllegalStateException("An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
|
|
||||||
|
// Some Keras layers (like GRU) do different things depending on whether the model is training.
|
||||||
|
// We provide this value directly.
|
||||||
|
if(s.endsWith("keras_learning_phase")){
|
||||||
|
placeholderValues.put(s, (T) Nd4j.scalar(training));
|
||||||
|
} else {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
|
||||||
" but a placeholder value was not provided");
|
" but a placeholder value was not provided");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//Step 2: execute in any order, until we have all required nodeOutputs
|
//Step 2: execute in any order, until we have all required nodeOutputs
|
||||||
/*
|
/*
|
||||||
|
|
Loading…
Reference in New Issue