Add placeholder value for keras learning phase if required (#30)
* automatically add placeholders for keras_learning_phase if required Signed-off-by: Ryan Nett <rnett@skymind.io> * comment Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
c28372cb49
commit
366d850f5e
|
@ -31,6 +31,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
|
||||
|
||||
import java.util.*;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* Additional functionality to add:
|
||||
|
@ -193,8 +194,16 @@ public abstract class AbstractSession<T, O> {
|
|||
}
|
||||
|
||||
if(required && (placeholderValues == null || !placeholderValues.containsKey(s))){
|
||||
throw new IllegalStateException("An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
|
||||
" but a placeholder value was not provided");
|
||||
|
||||
// Some Keras layers (like GRU) do different things depending on whether the model is training.
|
||||
// We provide this value directly.
|
||||
if(s.endsWith("keras_learning_phase")){
|
||||
placeholderValues.put(s, (T) Nd4j.scalar(training));
|
||||
} else {
|
||||
throw new IllegalStateException(
|
||||
"An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
|
||||
" but a placeholder value was not provided");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue