SameDiff: Add Java-level assertion check/exception (#96)

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-12-02 18:07:54 +11:00 committed by GitHub
parent 4ada65b384
commit 8123d9fa9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 0 deletions

View File

@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.Concat;
import org.nd4j.linalg.api.ops.impl.shape.Stack;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.*;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
@ -458,6 +459,25 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
INDArray out = mmgr.allocate(false, arr.dataType(), arr.shape());
out.assign(arr);
return new INDArray[]{out};
} else if (op instanceof Assert) {
Assert a = (Assert)op;
boolean condition = a.getInputArgument(0).getDouble(0) != 0.0;
if(!condition){
//Assertion failed
String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution";
if(a.numInputArguments() >= 3) {
INDArray msg = a.getInputArgument(2);
if (msg != null && msg.dataType() == DataType.UTF8) {
s += ": " + msg.getString(0);
}
}
if(a.numInputArguments() >= 5){
INDArray arr = a.getInputArgument(4);
s += "\n" + arr;
}
throw new IllegalStateException(s);
}
return ((Assert) op).outputArguments();
} else if (op instanceof CustomOp) {
CustomOp c = (CustomOp) op;
Nd4j.exec(c);