InferenceSession additional validation for shape calc (#122)
Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
9f2ba6a85d
commit
65ff18383a
|
@ -682,6 +682,8 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
||||||
List<LongShapeDescriptor> outShape = customOp.calculateOutputShape();
|
List<LongShapeDescriptor> outShape = customOp.calculateOutputShape();
|
||||||
Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName());
|
Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName());
|
||||||
String[] outNames = df.outputVariablesNames();
|
String[] outNames = df.outputVariablesNames();
|
||||||
|
Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" +
|
||||||
|
" with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length);
|
||||||
for( int i=0; i<outShape.size(); i++ ){
|
for( int i=0; i<outShape.size(); i++ ){
|
||||||
INDArray currOutput = (customOp.numOutputArguments() <= i ? null : customOp.getOutputArgument(i));
|
INDArray currOutput = (customOp.numOutputArguments() <= i ? null : customOp.getOutputArgument(i));
|
||||||
LongShapeDescriptor reqShape = outShape.get(i);
|
LongShapeDescriptor reqShape = outShape.get(i);
|
||||||
|
|
Loading…
Reference in New Issue