SameDiff: Add Java-level assertion check/exception (#96)
Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
4ada65b384
commit
8123d9fa9b
|
@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Concat;
|
import org.nd4j.linalg.api.ops.impl.shape.Concat;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Stack;
|
import org.nd4j.linalg.api.ops.impl.shape.Stack;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.*;
|
import org.nd4j.linalg.api.ops.impl.shape.tensorops.*;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
@ -458,6 +459,25 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
INDArray out = mmgr.allocate(false, arr.dataType(), arr.shape());
|
INDArray out = mmgr.allocate(false, arr.dataType(), arr.shape());
|
||||||
out.assign(arr);
|
out.assign(arr);
|
||||||
return new INDArray[]{out};
|
return new INDArray[]{out};
|
||||||
|
} else if (op instanceof Assert) {
|
||||||
|
Assert a = (Assert)op;
|
||||||
|
boolean condition = a.getInputArgument(0).getDouble(0) != 0.0;
|
||||||
|
if(!condition){
|
||||||
|
//Assertion failed
|
||||||
|
String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution";
|
||||||
|
if(a.numInputArguments() >= 3) {
|
||||||
|
INDArray msg = a.getInputArgument(2);
|
||||||
|
if (msg != null && msg.dataType() == DataType.UTF8) {
|
||||||
|
s += ": " + msg.getString(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(a.numInputArguments() >= 5){
|
||||||
|
INDArray arr = a.getInputArgument(4);
|
||||||
|
s += "\n" + arr;
|
||||||
|
}
|
||||||
|
throw new IllegalStateException(s);
|
||||||
|
}
|
||||||
|
return ((Assert) op).outputArguments();
|
||||||
} else if (op instanceof CustomOp) {
|
} else if (op instanceof CustomOp) {
|
||||||
CustomOp c = (CustomOp) op;
|
CustomOp c = (CustomOp) op;
|
||||||
Nd4j.exec(c);
|
Nd4j.exec(c);
|
||||||
|
|
Loading…
Reference in New Issue