SameDiff: Add Java-level assertion check/exception (#96)
Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
4ada65b384
commit
8123d9fa9b
|
@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
|||
import org.nd4j.linalg.api.ops.impl.shape.Concat;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Stack;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
|
@ -458,6 +459,25 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
INDArray out = mmgr.allocate(false, arr.dataType(), arr.shape());
|
||||
out.assign(arr);
|
||||
return new INDArray[]{out};
|
||||
} else if (op instanceof Assert) {
|
||||
Assert a = (Assert)op;
|
||||
boolean condition = a.getInputArgument(0).getDouble(0) != 0.0;
|
||||
if(!condition){
|
||||
//Assertion failed
|
||||
String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution";
|
||||
if(a.numInputArguments() >= 3) {
|
||||
INDArray msg = a.getInputArgument(2);
|
||||
if (msg != null && msg.dataType() == DataType.UTF8) {
|
||||
s += ": " + msg.getString(0);
|
||||
}
|
||||
}
|
||||
if(a.numInputArguments() >= 5){
|
||||
INDArray arr = a.getInputArgument(4);
|
||||
s += "\n" + arr;
|
||||
}
|
||||
throw new IllegalStateException(s);
|
||||
}
|
||||
return ((Assert) op).outputArguments();
|
||||
} else if (op instanceof CustomOp) {
|
||||
CustomOp c = (CustomOp) op;
|
||||
Nd4j.exec(c);
|
||||
|
|
Loading…
Reference in New Issue