diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index 65191ea84..1b520e7aa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -31,6 +31,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; import java.util.*; +import org.nd4j.linalg.factory.Nd4j; /** * Additional functionality to add: @@ -193,8 +194,16 @@ public abstract class AbstractSession { } if(required && (placeholderValues == null || !placeholderValues.containsKey(s))){ - throw new IllegalStateException("An input placeholder \"" + s + "\" is required to calculate the requested outputs," + - " but a placeholder value was not provided"); + + // Some Keras layers (like GRU) do different things depending on whether the model is training. + // We provide this value directly. + if(s.endsWith("keras_learning_phase")){ + placeholderValues.put(s, (T) Nd4j.scalar(training)); + } else { + throw new IllegalStateException( + "An input placeholder \"" + s + "\" is required to calculate the requested outputs," + + " but a placeholder value was not provided"); + } } } }