Add placeholder value for keras learning phase if required (#30)
* automatically add placeholders for keras_learning_phase if required Signed-off-by: Ryan Nett <rnett@skymind.io> * comment Signed-off-by: Ryan Nett <rnett@skymind.io>
This commit is contained in:
		
							parent
							
								
									c28372cb49
								
							
						
					
					
						commit
						366d850f5e
					
				| @ -31,6 +31,7 @@ import org.nd4j.base.Preconditions; | ||||
| import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; | ||||
| 
 | ||||
| import java.util.*; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| /** | ||||
|  * Additional functionality to add: | ||||
| @ -193,8 +194,16 @@ public abstract class AbstractSession<T, O> { | ||||
|                 } | ||||
| 
 | ||||
|                 if(required && (placeholderValues == null || !placeholderValues.containsKey(s))){ | ||||
|                     throw new IllegalStateException("An input placeholder \"" + s + "\" is required to calculate the requested outputs," + | ||||
|                             " but a placeholder value was not provided"); | ||||
| 
 | ||||
|                     // Some Keras layers (like GRU) do different things depending on whether the model is training. | ||||
|                     // We provide this value directly. | ||||
|                     if(s.endsWith("keras_learning_phase")){ | ||||
|                         placeholderValues.put(s, (T) Nd4j.scalar(training)); | ||||
|                     } else { | ||||
|                         throw new IllegalStateException( | ||||
|                                 "An input placeholder \"" + s + "\" is required to calculate the requested outputs," + | ||||
|                                         " but a placeholder value was not provided"); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user