Add placeholder value for keras learning phase if required (#30)

* automatically add placeholders for keras_learning_phase if required

Signed-off-by: Ryan Nett <rnett@skymind.io>

* comment

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-06-28 18:50:01 -07:00 committed by AlexDBlack
parent c28372cb49
commit 366d850f5e
1 changed files with 11 additions and 2 deletions

View File

@ -31,6 +31,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
import java.util.*;
import org.nd4j.linalg.factory.Nd4j;
/**
* Additional functionality to add:
@ -193,8 +194,16 @@ public abstract class AbstractSession<T, O> {
}
if(required && (placeholderValues == null || !placeholderValues.containsKey(s))){
throw new IllegalStateException("An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
" but a placeholder value was not provided");
// Some Keras layers (like GRU) do different things depending on whether the model is training.
// We provide this value directly.
if(s.endsWith("keras_learning_phase")){
placeholderValues.put(s, (T) Nd4j.scalar(training));
} else {
throw new IllegalStateException(
"An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
" but a placeholder value was not provided");
}
}
}
}