diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java index 0521755cf..6a1ca5c3d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java @@ -26,8 +26,11 @@ import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Conditions; +import org.nd4j.linalg.profiler.OpProfiler; +import org.nd4j.linalg.profiler.ProfilerConfig; import org.nd4j.linalg.util.ArrayUtil; import java.util.Arrays; @@ -41,8 +44,7 @@ public class OpExecutionerUtil { private OpExecutionerUtil() {} public static void checkForNaN(INDArray z) { - if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC - && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) + if (!OpProfiler.getInstance().getConfig().isCheckForNAN()) return; if(z.isEmpty() || !z.dataType().isFPType()) @@ -63,7 +65,7 @@ public class OpExecutionerUtil { } if (match > 0) - throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): "); + throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): "); } public static void checkForAny(INDArray z) { @@ -72,8 +74,7 @@ public class OpExecutionerUtil { } public static void checkForInf(INDArray z) { - if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC - && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) + if (!OpProfiler.getInstance().getConfig().isCheckForINF()) return; if(z.isEmpty() || !z.dataType().isFPType()) @@ -94,13 +95,12 @@ public class OpExecutionerUtil { } if (match > 0) - throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)"); + throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)"); } public static void checkForNaN(Op op) { - if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC - && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) + if (!OpProfiler.getInstance().getConfig().isCheckForNAN()) return; if (op.z() != null && !(op instanceof MatchCondition)) { @@ -109,8 +109,7 @@ public class OpExecutionerUtil { } public static void checkForInf(Op op) { - if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC - && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) + if (!OpProfiler.getInstance().getConfig().isCheckForINF()) return; if (op.z() != null && !(op instanceof MatchCondition)) { @@ -119,8 +118,7 @@ public class OpExecutionerUtil { } public static void checkForInf(CustomOp op) { - if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC - && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) + if (!OpProfiler.getInstance().getConfig().isCheckForINF()) return; for (val input: op.inputArguments()) @@ -132,8 +130,7 @@ public class OpExecutionerUtil { public static void checkForNaN(CustomOp op) { - if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC - && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) + if (!OpProfiler.getInstance().getConfig().isCheckForNAN()) return; for (val input: op.inputArguments()) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/exception/ND4JOpProfilerException.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/exception/ND4JOpProfilerException.java new file mode 100644 index 000000000..66ced8014 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/exception/ND4JOpProfilerException.java @@ -0,0 +1,22 @@ +package org.nd4j.linalg.exception; + +/** + * ND4JOpProfilerException: Thrown by the op profiler (if enabled) for example on NaN panic + * + * @author Alex Black + */ +public class ND4JOpProfilerException extends ND4JIllegalStateException { + public ND4JOpProfilerException() { + } + public ND4JOpProfilerException(String message) { + super(message); + } + + public ND4JOpProfilerException(String message, Throwable cause) { + super(message, cause); + } + + public ND4JOpProfilerException(Throwable cause) { + super(cause); + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index 5494dff50..23cedba09 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -60,6 +60,7 @@ import org.nd4j.linalg.cpu.nativecpu.CpuTADManager; import org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom; import org.nd4j.linalg.exception.ND4JIllegalArgumentException; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.memory.MemcpyDirection; import org.nd4j.linalg.primitives.AtomicBoolean; @@ -1628,7 +1629,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { */ @Override public INDArray[] exec(@NonNull CustomOp op) { - long st = profilingConfigurableHookIn(op); if (op.numOutputArguments() == 0 && !op.isInplaceCall()) { try { @@ -1668,6 +1668,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { Nd4j.getRandom().setStates(states.getFirst(), states.getSecond()); return result; + } catch (ND4JOpProfilerException e){ + throw e; } catch (Exception e) { throw new RuntimeException("Op [" + name + "] execution failed", e); } @@ -2062,6 +2064,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public INDArray[] exec(CustomOp op, @NonNull OpContext context) { + long st = profilingConfigurableHookIn(op); boolean mklOverride = false; try { if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) { @@ -2073,6 +2076,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } } + loop.execCustomOp2(null, op.opHash(), context.contextPointer()); if (context.getOutputArrays().isEmpty()) @@ -2139,6 +2143,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } finally { if (mklOverride) Nd4jCpu.Environment.getInstance().setUseMKLDNN(true); + profilingConfigurableHookOut(op, st); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java index e84e136fe..8a67bd2c2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java @@ -25,12 +25,15 @@ import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.profiler.OpProfiler; import org.nd4j.linalg.profiler.ProfilerConfig; @@ -428,4 +431,61 @@ public class OperationProfilerTests extends BaseNd4jTest { assertEquals(1.0f, stats.getMeanValue(), 1e-5); } + @Test + public void testNanPanic(){ + try { + DynamicCustomOp op = DynamicCustomOp.builder("add") + .addInputs(Nd4j.valueArrayOf(10, Double.NaN).castTo(DataType.DOUBLE), Nd4j.scalar(0.0)) + .addOutputs(Nd4j.create(DataType.DOUBLE, 10)) + .build(); + + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(true).build()); + try { + Nd4j.exec(op); //Should trigger NaN panic + fail(); + } catch (Exception e){ + assertTrue(e.getMessage(), e.getMessage().contains("NaN")); + } + + INDArray in = op.getInputArgument(0); + + try { + Transforms.sigmoid(in); + fail(); + } catch (Exception e){ + assertTrue(e.getMessage().contains("NaN")); + } + } finally { + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(false).build()); + } + } + + @Test + public void testInfPanic(){ + try { + DynamicCustomOp op = DynamicCustomOp.builder("add") + .addInputs(Nd4j.valueArrayOf(10, Double.POSITIVE_INFINITY).castTo(DataType.DOUBLE), Nd4j.scalar(0.0)) + .addOutputs(Nd4j.create(DataType.DOUBLE, 10)) + .build(); + + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(true).build()); + try { + Nd4j.exec(op); //Should trigger NaN panic + fail(); + } catch (Exception e){ + assertTrue(e.getMessage(), e.getMessage().contains("Inf")); + } + + INDArray in = op.getInputArgument(0); + + try { + Transforms.max(in, 1.0, false); + fail(); + } catch (Exception e){ + assertTrue(e.getMessage().contains("Inf")); + } + } finally { + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(false).build()); + } + } }