SameDiff op runtime benchmarking listener (#42)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
18c01f5bdc
commit
1d96bb9e6e
|
@ -509,7 +509,7 @@ public abstract class DifferentialFunction {
|
||||||
* @return the arguments for a given function
|
* @return the arguments for a given function
|
||||||
*/
|
*/
|
||||||
public SDVariable[] args() {
|
public SDVariable[] args() {
|
||||||
return sameDiff.getInputVariablesForOp(this);
|
return sameDiff == null ? null : sameDiff.getInputVariablesForOp(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -0,0 +1,189 @@
|
||||||
|
package org.nd4j.autodiff.listeners.debugging;
|
||||||
|
|
||||||
|
import lombok.*;
|
||||||
|
import org.nd4j.autodiff.listeners.At;
|
||||||
|
import org.nd4j.autodiff.listeners.BaseListener;
|
||||||
|
import org.nd4j.autodiff.listeners.Operation;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
import java.text.DecimalFormat;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple listener for benchmarking single operations in SameDiff<br>
|
||||||
|
* Supports 2 modes:<br>
|
||||||
|
* - SINGLE_ITER_PRINT: Print the runtime of the first iteration<br>
|
||||||
|
* - AGGREGATE: Collect statistics for multiple runs, that can be accessed (by op name) via {@link #getAggregateModeMap()}
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
public class OpBenchmarkListener extends BaseListener {
|
||||||
|
|
||||||
|
public enum Mode {SINGLE_ITER_PRINT, AGGREGATE}
|
||||||
|
|
||||||
|
private final Operation operation;
|
||||||
|
private final Mode mode;
|
||||||
|
private final long minRuntime;
|
||||||
|
private Map<String,OpExec> aggregateModeMap;
|
||||||
|
|
||||||
|
@Getter(AccessLevel.PRIVATE)
|
||||||
|
private long start;
|
||||||
|
@Getter(AccessLevel.PRIVATE)
|
||||||
|
private boolean printActive;
|
||||||
|
private boolean printDone;
|
||||||
|
|
||||||
|
public OpBenchmarkListener(Operation operation, @NonNull Mode mode) {
|
||||||
|
this(operation, mode, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param operation Operation to collect stats for
|
||||||
|
* @param mode Mode - see {@link OpBenchmarkListener}
|
||||||
|
* @param minRuntime Minimum runtime - only applies to Mode.SINGLE_ITER_PRINT. If op runtime below this: don't print
|
||||||
|
*/
|
||||||
|
public OpBenchmarkListener(Operation operation, @NonNull Mode mode, long minRuntime) {
|
||||||
|
this.operation = operation;
|
||||||
|
this.mode = mode;
|
||||||
|
this.minRuntime = minRuntime;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isActive(Operation operation) {
|
||||||
|
return this.operation == null || this.operation == operation;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void operationStart(SameDiff sd, Operation op) {
|
||||||
|
if(printDone)
|
||||||
|
return;
|
||||||
|
if(this.operation == null || this.operation == op)
|
||||||
|
printActive = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void operationEnd(SameDiff sd, Operation op) {
|
||||||
|
if(printDone)
|
||||||
|
return;
|
||||||
|
if(this.operation == null || this.operation == op) {
|
||||||
|
printActive = false;
|
||||||
|
printDone = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
||||||
|
start = System.currentTimeMillis();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||||
|
long now = System.currentTimeMillis();
|
||||||
|
|
||||||
|
if (mode == Mode.SINGLE_ITER_PRINT && printActive && (now-start) > this.minRuntime) {
|
||||||
|
System.out.println(getOpString(op, now));
|
||||||
|
} else if (mode == Mode.AGGREGATE) {
|
||||||
|
if(aggregateModeMap == null)
|
||||||
|
aggregateModeMap = new LinkedHashMap<>();
|
||||||
|
|
||||||
|
if(!aggregateModeMap.containsKey(op.getName())){
|
||||||
|
String s = getOpString(op, null);
|
||||||
|
OpExec oe = new OpExec(op.getName(), op.getOp().opName(), op.getOp().getClass(),
|
||||||
|
new ArrayList<Long>(), s);
|
||||||
|
aggregateModeMap.put(op.getName(), oe);
|
||||||
|
}
|
||||||
|
|
||||||
|
aggregateModeMap.get(op.getName()).getRuntimeMs().add(now-start);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getOpString(SameDiffOp op, Long now){
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
sb.append(op.getName()).append(" - ").append(op.getOp().getClass().getSimpleName())
|
||||||
|
.append("(").append(op.getOp().opName()).append(") - ");
|
||||||
|
if(now != null) {
|
||||||
|
sb.append(now - start).append(" ms\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op.getOp() instanceof DynamicCustomOp) {
|
||||||
|
DynamicCustomOp dco = (DynamicCustomOp) op.getOp();
|
||||||
|
int x = 0;
|
||||||
|
|
||||||
|
for (INDArray i : dco.inputArguments()) {
|
||||||
|
sb.append(" in ").append(x++).append(": ").append(i.shapeInfoToString()).append("\n");
|
||||||
|
}
|
||||||
|
x = 0;
|
||||||
|
for (INDArray o : dco.outputArguments()) {
|
||||||
|
sb.append(" out ").append(x++).append(": ").append(o.shapeInfoToString()).append("\n");
|
||||||
|
}
|
||||||
|
long[] iargs = dco.iArgs();
|
||||||
|
boolean[] bargs = dco.bArgs();
|
||||||
|
double[] targs = dco.tArgs();
|
||||||
|
if (iargs != null && iargs.length > 0) {
|
||||||
|
sb.append(" iargs: ").append(Arrays.toString(iargs)).append("\n");
|
||||||
|
}
|
||||||
|
if (bargs != null && bargs.length > 0) {
|
||||||
|
sb.append(" bargs: ").append(Arrays.toString(bargs)).append("\n");
|
||||||
|
}
|
||||||
|
if (targs != null && targs.length > 0) {
|
||||||
|
sb.append(" targs: ").append(Arrays.toString(targs)).append("\n");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Op o = (Op) op.getOp();
|
||||||
|
if (o.x() != null)
|
||||||
|
sb.append(" x: ").append(o.x().shapeInfoToString());
|
||||||
|
if (o.y() != null)
|
||||||
|
sb.append(" y: ").append(o.y().shapeInfoToString());
|
||||||
|
if (o.z() != null)
|
||||||
|
sb.append(" z: ").append(o.z().shapeInfoToString());
|
||||||
|
}
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Data
|
||||||
|
public static class OpExec {
|
||||||
|
private final String opOwnName;
|
||||||
|
private final String opName;
|
||||||
|
private final Class<?> opClass;
|
||||||
|
private List<Long> runtimeMs;
|
||||||
|
private String firstIter;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
DecimalFormat df = new DecimalFormat("0.000");
|
||||||
|
|
||||||
|
return opOwnName + " - op class: " + opClass.getSimpleName() + " (op name: " + opName + ")\n"
|
||||||
|
+ "count: " + runtimeMs.size() + ", mean: " + df.format(avgMs()) + "ms, std: " + df.format(stdMs()) + "ms, min: " + minMs() + "ms, max: " + maxMs() + "ms\n"
|
||||||
|
+ firstIter;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double avgMs() {
|
||||||
|
long sum = 0;
|
||||||
|
for (Long l : runtimeMs) {
|
||||||
|
sum += l;
|
||||||
|
}
|
||||||
|
return sum / (double) runtimeMs.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
public double stdMs() {
|
||||||
|
return Nd4j.createFromArray(ArrayUtil.toArrayLong(runtimeMs)).stdNumber().doubleValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
public long minMs() {
|
||||||
|
return Nd4j.createFromArray(ArrayUtil.toArrayLong(runtimeMs)).minNumber().longValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
public long maxMs() {
|
||||||
|
return Nd4j.createFromArray(ArrayUtil.toArrayLong(runtimeMs)).maxNumber().longValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue