Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
e18e2dc014
commit
cd41c3540d
|
@ -26,8 +26,11 @@ import org.nd4j.linalg.api.ops.CustomOp;
|
|||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||
import org.nd4j.linalg.profiler.OpProfiler;
|
||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
@ -41,8 +44,7 @@ public class OpExecutionerUtil {
|
|||
private OpExecutionerUtil() {}
|
||||
|
||||
public static void checkForNaN(INDArray z) {
|
||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC
|
||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
||||
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
|
||||
return;
|
||||
|
||||
if(z.isEmpty() || !z.dataType().isFPType())
|
||||
|
@ -63,7 +65,7 @@ public class OpExecutionerUtil {
|
|||
}
|
||||
|
||||
if (match > 0)
|
||||
throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): ");
|
||||
throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): ");
|
||||
}
|
||||
|
||||
public static void checkForAny(INDArray z) {
|
||||
|
@ -72,8 +74,7 @@ public class OpExecutionerUtil {
|
|||
}
|
||||
|
||||
public static void checkForInf(INDArray z) {
|
||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC
|
||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
||||
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
|
||||
return;
|
||||
|
||||
if(z.isEmpty() || !z.dataType().isFPType())
|
||||
|
@ -94,13 +95,12 @@ public class OpExecutionerUtil {
|
|||
}
|
||||
|
||||
if (match > 0)
|
||||
throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)");
|
||||
throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)");
|
||||
|
||||
}
|
||||
|
||||
public static void checkForNaN(Op op) {
|
||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC
|
||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
||||
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
|
||||
return;
|
||||
|
||||
if (op.z() != null && !(op instanceof MatchCondition)) {
|
||||
|
@ -109,8 +109,7 @@ public class OpExecutionerUtil {
|
|||
}
|
||||
|
||||
public static void checkForInf(Op op) {
|
||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC
|
||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
||||
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
|
||||
return;
|
||||
|
||||
if (op.z() != null && !(op instanceof MatchCondition)) {
|
||||
|
@ -119,8 +118,7 @@ public class OpExecutionerUtil {
|
|||
}
|
||||
|
||||
public static void checkForInf(CustomOp op) {
|
||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC
|
||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
||||
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
|
||||
return;
|
||||
|
||||
for (val input: op.inputArguments())
|
||||
|
@ -132,8 +130,7 @@ public class OpExecutionerUtil {
|
|||
|
||||
|
||||
public static void checkForNaN(CustomOp op) {
|
||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC
|
||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
||||
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
|
||||
return;
|
||||
|
||||
for (val input: op.inputArguments())
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
package org.nd4j.linalg.exception;
|
||||
|
||||
/**
|
||||
* ND4JOpProfilerException: Thrown by the op profiler (if enabled) for example on NaN panic
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class ND4JOpProfilerException extends ND4JIllegalStateException {
|
||||
public ND4JOpProfilerException() {
|
||||
}
|
||||
public ND4JOpProfilerException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public ND4JOpProfilerException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
public ND4JOpProfilerException(Throwable cause) {
|
||||
super(cause);
|
||||
}
|
||||
}
|
|
@ -60,6 +60,7 @@ import org.nd4j.linalg.cpu.nativecpu.CpuTADManager;
|
|||
import org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.memory.MemcpyDirection;
|
||||
import org.nd4j.linalg.primitives.AtomicBoolean;
|
||||
|
@ -1628,7 +1629,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
*/
|
||||
@Override
|
||||
public INDArray[] exec(@NonNull CustomOp op) {
|
||||
long st = profilingConfigurableHookIn(op);
|
||||
|
||||
if (op.numOutputArguments() == 0 && !op.isInplaceCall()) {
|
||||
try {
|
||||
|
@ -1668,6 +1668,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
Nd4j.getRandom().setStates(states.getFirst(), states.getSecond());
|
||||
|
||||
return result;
|
||||
} catch (ND4JOpProfilerException e){
|
||||
throw e;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Op [" + name + "] execution failed", e);
|
||||
}
|
||||
|
@ -2062,6 +2064,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public INDArray[] exec(CustomOp op, @NonNull OpContext context) {
|
||||
long st = profilingConfigurableHookIn(op);
|
||||
boolean mklOverride = false;
|
||||
try {
|
||||
if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) {
|
||||
|
@ -2073,6 +2076,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
loop.execCustomOp2(null, op.opHash(), context.contextPointer());
|
||||
|
||||
if (context.getOutputArrays().isEmpty())
|
||||
|
@ -2139,6 +2143,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
} finally {
|
||||
if (mklOverride)
|
||||
Nd4jCpu.Environment.getInstance().setUseMKLDNN(true);
|
||||
profilingConfigurableHookOut(op, st);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,12 +25,15 @@ import org.junit.Ignore;
|
|||
import org.junit.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.profiler.OpProfiler;
|
||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||
|
@ -428,4 +431,61 @@ public class OperationProfilerTests extends BaseNd4jTest {
|
|||
assertEquals(1.0f, stats.getMeanValue(), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNanPanic(){
|
||||
try {
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("add")
|
||||
.addInputs(Nd4j.valueArrayOf(10, Double.NaN).castTo(DataType.DOUBLE), Nd4j.scalar(0.0))
|
||||
.addOutputs(Nd4j.create(DataType.DOUBLE, 10))
|
||||
.build();
|
||||
|
||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(true).build());
|
||||
try {
|
||||
Nd4j.exec(op); //Should trigger NaN panic
|
||||
fail();
|
||||
} catch (Exception e){
|
||||
assertTrue(e.getMessage(), e.getMessage().contains("NaN"));
|
||||
}
|
||||
|
||||
INDArray in = op.getInputArgument(0);
|
||||
|
||||
try {
|
||||
Transforms.sigmoid(in);
|
||||
fail();
|
||||
} catch (Exception e){
|
||||
assertTrue(e.getMessage().contains("NaN"));
|
||||
}
|
||||
} finally {
|
||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(false).build());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInfPanic(){
|
||||
try {
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("add")
|
||||
.addInputs(Nd4j.valueArrayOf(10, Double.POSITIVE_INFINITY).castTo(DataType.DOUBLE), Nd4j.scalar(0.0))
|
||||
.addOutputs(Nd4j.create(DataType.DOUBLE, 10))
|
||||
.build();
|
||||
|
||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(true).build());
|
||||
try {
|
||||
Nd4j.exec(op); //Should trigger NaN panic
|
||||
fail();
|
||||
} catch (Exception e){
|
||||
assertTrue(e.getMessage(), e.getMessage().contains("Inf"));
|
||||
}
|
||||
|
||||
INDArray in = op.getInputArgument(0);
|
||||
|
||||
try {
|
||||
Transforms.max(in, 1.0, false);
|
||||
fail();
|
||||
} catch (Exception e){
|
||||
assertTrue(e.getMessage().contains("Inf"));
|
||||
}
|
||||
} finally {
|
||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(false).build());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue