InferenceSession additional validation for shape calc (#122)

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-08-16 14:19:06 +10:00 committed by GitHub
parent 9f2ba6a85d
commit 65ff18383a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -682,6 +682,8 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
List<LongShapeDescriptor> outShape = customOp.calculateOutputShape(); List<LongShapeDescriptor> outShape = customOp.calculateOutputShape();
Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName()); Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName());
String[] outNames = df.outputVariablesNames(); String[] outNames = df.outputVariablesNames();
Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" +
" with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length);
for( int i=0; i<outShape.size(); i++ ){ for( int i=0; i<outShape.size(); i++ ){
INDArray currOutput = (customOp.numOutputArguments() <= i ? null : customOp.getOutputArgument(i)); INDArray currOutput = (customOp.numOutputArguments() <= i ? null : customOp.getOutputArgument(i));
LongShapeDescriptor reqShape = outShape.get(i); LongShapeDescriptor reqShape = outShape.get(i);