diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 32a1cc362..4a6a5ce53 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.impl.shape.Stack; import org.nd4j.linalg.api.ops.impl.shape.tensorops.*; +import org.nd4j.linalg.api.ops.impl.transforms.Assert; import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; import org.nd4j.linalg.api.shape.LongShapeDescriptor; @@ -458,6 +459,25 @@ public class InferenceSession extends AbstractSession { INDArray out = mmgr.allocate(false, arr.dataType(), arr.shape()); out.assign(arr); return new INDArray[]{out}; + } else if (op instanceof Assert) { + Assert a = (Assert)op; + boolean condition = a.getInputArgument(0).getDouble(0) != 0.0; + if(!condition){ + //Assertion failed + String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution"; + if(a.numInputArguments() >= 3) { + INDArray msg = a.getInputArgument(2); + if (msg != null && msg.dataType() == DataType.UTF8) { + s += ": " + msg.getString(0); + } + } + if(a.numInputArguments() >= 5){ + INDArray arr = a.getInputArgument(4); + s += "\n" + arr; + } + throw new IllegalStateException(s); + } + return ((Assert) op).outputArguments(); } else if (op instanceof CustomOp) { CustomOp c = (CustomOp) op; Nd4j.exec(c);