Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
e18e2dc014
commit
cd41c3540d
|
@ -26,8 +26,11 @@ import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
import org.nd4j.linalg.profiler.OpProfiler;
|
||||||
|
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -41,8 +44,7 @@ public class OpExecutionerUtil {
|
||||||
private OpExecutionerUtil() {}
|
private OpExecutionerUtil() {}
|
||||||
|
|
||||||
public static void checkForNaN(INDArray z) {
|
public static void checkForNaN(INDArray z) {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC
|
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
|
||||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if(z.isEmpty() || !z.dataType().isFPType())
|
if(z.isEmpty() || !z.dataType().isFPType())
|
||||||
|
@ -63,7 +65,7 @@ public class OpExecutionerUtil {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (match > 0)
|
if (match > 0)
|
||||||
throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): ");
|
throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): ");
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void checkForAny(INDArray z) {
|
public static void checkForAny(INDArray z) {
|
||||||
|
@ -72,8 +74,7 @@ public class OpExecutionerUtil {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void checkForInf(INDArray z) {
|
public static void checkForInf(INDArray z) {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC
|
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
|
||||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if(z.isEmpty() || !z.dataType().isFPType())
|
if(z.isEmpty() || !z.dataType().isFPType())
|
||||||
|
@ -94,13 +95,12 @@ public class OpExecutionerUtil {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (match > 0)
|
if (match > 0)
|
||||||
throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)");
|
throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void checkForNaN(Op op) {
|
public static void checkForNaN(Op op) {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC
|
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
|
||||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (op.z() != null && !(op instanceof MatchCondition)) {
|
if (op.z() != null && !(op instanceof MatchCondition)) {
|
||||||
|
@ -109,8 +109,7 @@ public class OpExecutionerUtil {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void checkForInf(Op op) {
|
public static void checkForInf(Op op) {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC
|
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
|
||||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (op.z() != null && !(op instanceof MatchCondition)) {
|
if (op.z() != null && !(op instanceof MatchCondition)) {
|
||||||
|
@ -119,8 +118,7 @@ public class OpExecutionerUtil {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void checkForInf(CustomOp op) {
|
public static void checkForInf(CustomOp op) {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC
|
if (!OpProfiler.getInstance().getConfig().isCheckForINF())
|
||||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
|
||||||
return;
|
return;
|
||||||
|
|
||||||
for (val input: op.inputArguments())
|
for (val input: op.inputArguments())
|
||||||
|
@ -132,8 +130,7 @@ public class OpExecutionerUtil {
|
||||||
|
|
||||||
|
|
||||||
public static void checkForNaN(CustomOp op) {
|
public static void checkForNaN(CustomOp op) {
|
||||||
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC
|
if (!OpProfiler.getInstance().getConfig().isCheckForNAN())
|
||||||
&& Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
|
|
||||||
return;
|
return;
|
||||||
|
|
||||||
for (val input: op.inputArguments())
|
for (val input: op.inputArguments())
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
package org.nd4j.linalg.exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ND4JOpProfilerException: Thrown by the op profiler (if enabled) for example on NaN panic
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class ND4JOpProfilerException extends ND4JIllegalStateException {
|
||||||
|
public ND4JOpProfilerException() {
|
||||||
|
}
|
||||||
|
public ND4JOpProfilerException(String message) {
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ND4JOpProfilerException(String message, Throwable cause) {
|
||||||
|
super(message, cause);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ND4JOpProfilerException(Throwable cause) {
|
||||||
|
super(cause);
|
||||||
|
}
|
||||||
|
}
|
|
@ -60,6 +60,7 @@ import org.nd4j.linalg.cpu.nativecpu.CpuTADManager;
|
||||||
import org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom;
|
import org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
|
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.memory.MemcpyDirection;
|
import org.nd4j.linalg.memory.MemcpyDirection;
|
||||||
import org.nd4j.linalg.primitives.AtomicBoolean;
|
import org.nd4j.linalg.primitives.AtomicBoolean;
|
||||||
|
@ -1628,7 +1629,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] exec(@NonNull CustomOp op) {
|
public INDArray[] exec(@NonNull CustomOp op) {
|
||||||
long st = profilingConfigurableHookIn(op);
|
|
||||||
|
|
||||||
if (op.numOutputArguments() == 0 && !op.isInplaceCall()) {
|
if (op.numOutputArguments() == 0 && !op.isInplaceCall()) {
|
||||||
try {
|
try {
|
||||||
|
@ -1668,6 +1668,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
Nd4j.getRandom().setStates(states.getFirst(), states.getSecond());
|
Nd4j.getRandom().setStates(states.getFirst(), states.getSecond());
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
} catch (ND4JOpProfilerException e){
|
||||||
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException("Op [" + name + "] execution failed", e);
|
throw new RuntimeException("Op [" + name + "] execution failed", e);
|
||||||
}
|
}
|
||||||
|
@ -2062,6 +2064,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] exec(CustomOp op, @NonNull OpContext context) {
|
public INDArray[] exec(CustomOp op, @NonNull OpContext context) {
|
||||||
|
long st = profilingConfigurableHookIn(op);
|
||||||
boolean mklOverride = false;
|
boolean mklOverride = false;
|
||||||
try {
|
try {
|
||||||
if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) {
|
if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) {
|
||||||
|
@ -2073,6 +2076,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
loop.execCustomOp2(null, op.opHash(), context.contextPointer());
|
loop.execCustomOp2(null, op.opHash(), context.contextPointer());
|
||||||
|
|
||||||
if (context.getOutputArrays().isEmpty())
|
if (context.getOutputArrays().isEmpty())
|
||||||
|
@ -2139,6 +2143,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
} finally {
|
} finally {
|
||||||
if (mklOverride)
|
if (mklOverride)
|
||||||
Nd4jCpu.Environment.getInstance().setUseMKLDNN(true);
|
Nd4jCpu.Environment.getInstance().setUseMKLDNN(true);
|
||||||
|
profilingConfigurableHookOut(op, st);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,12 +25,15 @@ import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.profiler.OpProfiler;
|
import org.nd4j.linalg.profiler.OpProfiler;
|
||||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||||
|
@ -428,4 +431,61 @@ public class OperationProfilerTests extends BaseNd4jTest {
|
||||||
assertEquals(1.0f, stats.getMeanValue(), 1e-5);
|
assertEquals(1.0f, stats.getMeanValue(), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNanPanic(){
|
||||||
|
try {
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("add")
|
||||||
|
.addInputs(Nd4j.valueArrayOf(10, Double.NaN).castTo(DataType.DOUBLE), Nd4j.scalar(0.0))
|
||||||
|
.addOutputs(Nd4j.create(DataType.DOUBLE, 10))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(true).build());
|
||||||
|
try {
|
||||||
|
Nd4j.exec(op); //Should trigger NaN panic
|
||||||
|
fail();
|
||||||
|
} catch (Exception e){
|
||||||
|
assertTrue(e.getMessage(), e.getMessage().contains("NaN"));
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray in = op.getInputArgument(0);
|
||||||
|
|
||||||
|
try {
|
||||||
|
Transforms.sigmoid(in);
|
||||||
|
fail();
|
||||||
|
} catch (Exception e){
|
||||||
|
assertTrue(e.getMessage().contains("NaN"));
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(false).build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInfPanic(){
|
||||||
|
try {
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("add")
|
||||||
|
.addInputs(Nd4j.valueArrayOf(10, Double.POSITIVE_INFINITY).castTo(DataType.DOUBLE), Nd4j.scalar(0.0))
|
||||||
|
.addOutputs(Nd4j.create(DataType.DOUBLE, 10))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(true).build());
|
||||||
|
try {
|
||||||
|
Nd4j.exec(op); //Should trigger NaN panic
|
||||||
|
fail();
|
||||||
|
} catch (Exception e){
|
||||||
|
assertTrue(e.getMessage(), e.getMessage().contains("Inf"));
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray in = op.getInputArgument(0);
|
||||||
|
|
||||||
|
try {
|
||||||
|
Transforms.max(in, 1.0, false);
|
||||||
|
fail();
|
||||||
|
} catch (Exception e){
|
||||||
|
assertTrue(e.getMessage().contains("Inf"));
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(false).build());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue