#8038 Fix Op profiler NaN/Inf triggering + add tests (#93)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-02 00:02:31 +10:00 committed by AlexDBlack
parent e18e2dc014
commit cd41c3540d
4 changed files with 99 additions and 15 deletions

View File

@ -26,8 +26,11 @@ import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.Arrays;
@ -41,8 +44,7 @@ public class OpExecutionerUtil {
private OpExecutionerUtil() {}
public static void checkForNaN(INDArray z) {
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
return;
if(z.isEmpty() || !z.dataType().isFPType())
@ -63,7 +65,7 @@ public class OpExecutionerUtil {
}
if (match > 0)
throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): ");
throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): ");
}
public static void checkForAny(INDArray z) {
@ -72,8 +74,7 @@ public class OpExecutionerUtil {
}
public static void checkForInf(INDArray z) {
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
return;
if(z.isEmpty() || !z.dataType().isFPType())
@ -94,13 +95,12 @@ public class OpExecutionerUtil {
}
if (match > 0)
throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)");
throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)");
}
public static void checkForNaN(Op op) {
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
return;
if (op.z() != null && !(op instanceof MatchCondition)) {
@ -109,8 +109,7 @@ public class OpExecutionerUtil {
}
public static void checkForInf(Op op) {
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
return;
if (op.z() != null && !(op instanceof MatchCondition)) {
@ -119,8 +118,7 @@ public class OpExecutionerUtil {
}
public static void checkForInf(CustomOp op) {
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
return;
for (val input: op.inputArguments())
@ -132,8 +130,7 @@ public class OpExecutionerUtil {
public static void checkForNaN(CustomOp op) {
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
return;
for (val input: op.inputArguments())

View File

@ -0,0 +1,22 @@
package org.nd4j.linalg.exception;
/**
* ND4JOpProfilerException: Thrown by the op profiler (if enabled) for example on NaN panic
*
* @author Alex Black
*/
public class ND4JOpProfilerException extends ND4JIllegalStateException {
public ND4JOpProfilerException() {
}
public ND4JOpProfilerException(String message) {
super(message);
}
public ND4JOpProfilerException(String message, Throwable cause) {
super(message, cause);
}
public ND4JOpProfilerException(Throwable cause) {
super(cause);
}
}

View File

@ -60,6 +60,7 @@ import org.nd4j.linalg.cpu.nativecpu.CpuTADManager;
import org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.linalg.primitives.AtomicBoolean;
@ -1628,7 +1629,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
*/
@Override
public INDArray[] exec(@NonNull CustomOp op) {
long st = profilingConfigurableHookIn(op);
if (op.numOutputArguments() == 0 && !op.isInplaceCall()) {
try {
@ -1668,6 +1668,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
Nd4j.getRandom().setStates(states.getFirst(), states.getSecond());
return result;
} catch (ND4JOpProfilerException e){
throw e;
} catch (Exception e) {
throw new RuntimeException("Op [" + name + "] execution failed", e);
}
@ -2062,6 +2064,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
@Override
public INDArray[] exec(CustomOp op, @NonNull OpContext context) {
long st = profilingConfigurableHookIn(op);
boolean mklOverride = false;
try {
if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) {
@ -2073,6 +2076,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
}
}
loop.execCustomOp2(null, op.opHash(), context.contextPointer());
if (context.getOutputArrays().isEmpty())
@ -2139,6 +2143,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
} finally {
if (mklOverride)
Nd4jCpu.Environment.getInstance().setUseMKLDNN(true);
profilingConfigurableHookOut(op, st);
}
}

View File

@ -25,12 +25,15 @@ import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.ProfilerConfig;
@ -428,4 +431,61 @@ public class OperationProfilerTests extends BaseNd4jTest {
assertEquals(1.0f, stats.getMeanValue(), 1e-5);
}
@Test
public void testNanPanic(){
try {
DynamicCustomOp op = DynamicCustomOp.builder("add")
.addInputs(Nd4j.valueArrayOf(10, Double.NaN).castTo(DataType.DOUBLE), Nd4j.scalar(0.0))
.addOutputs(Nd4j.create(DataType.DOUBLE, 10))
.build();
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(true).build());
try {
Nd4j.exec(op); //Should trigger NaN panic
fail();
} catch (Exception e){
assertTrue(e.getMessage(), e.getMessage().contains("NaN"));
}
INDArray in = op.getInputArgument(0);
try {
Transforms.sigmoid(in);
fail();
} catch (Exception e){
assertTrue(e.getMessage().contains("NaN"));
}
} finally {
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(false).build());
}
}
@Test
public void testInfPanic(){
try {
DynamicCustomOp op = DynamicCustomOp.builder("add")
.addInputs(Nd4j.valueArrayOf(10, Double.POSITIVE_INFINITY).castTo(DataType.DOUBLE), Nd4j.scalar(0.0))
.addOutputs(Nd4j.create(DataType.DOUBLE, 10))
.build();
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(true).build());
try {
Nd4j.exec(op); //Should trigger NaN panic
fail();
} catch (Exception e){
assertTrue(e.getMessage(), e.getMessage().contains("Inf"));
}
INDArray in = op.getInputArgument(0);
try {
Transforms.max(in, 1.0, false);
fail();
} catch (Exception e){
assertTrue(e.getMessage().contains("Inf"));
}
} finally {
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(false).build());
}
}
}