From 366d850f5e35786328c2b9a5440cffec9f823633 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 28 Jun 2019 18:50:01 -0700 Subject: [PATCH] Add placeholder value for keras learning phase if required (#30) * automatically add placeholders for keras_learning_phase if required Signed-off-by: Ryan Nett * comment Signed-off-by: Ryan Nett --- .../autodiff/samediff/internal/AbstractSession.java | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index 65191ea84..1b520e7aa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -31,6 +31,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; import java.util.*; +import org.nd4j.linalg.factory.Nd4j; /** * Additional functionality to add: @@ -193,8 +194,16 @@ public abstract class AbstractSession { } if(required && (placeholderValues == null || !placeholderValues.containsKey(s))){ - throw new IllegalStateException("An input placeholder \"" + s + "\" is required to calculate the requested outputs," + - " but a placeholder value was not provided"); + + // Some Keras layers (like GRU) do different things depending on whether the model is training. + // We provide this value directly. + if(s.endsWith("keras_learning_phase")){ + placeholderValues.put(s, (T) Nd4j.scalar(training)); + } else { + throw new IllegalStateException( + "An input placeholder \"" + s + "\" is required to calculate the requested outputs," + + " but a placeholder value was not provided"); + } } } }