SameDiff op runtime benchmarking listener (#42)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
18c01f5bdc
commit
1d96bb9e6e
|
@ -509,7 +509,7 @@ public abstract class DifferentialFunction {
|
|||
* @return the arguments for a given function
|
||||
*/
|
||||
public SDVariable[] args() {
|
||||
return sameDiff.getInputVariablesForOp(this);
|
||||
return sameDiff == null ? null : sameDiff.getInputVariablesForOp(this);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
package org.nd4j.autodiff.listeners.debugging;
|
||||
|
||||
import lombok.*;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.text.DecimalFormat;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* A simple listener for benchmarking single operations in SameDiff<br>
|
||||
* Supports 2 modes:<br>
|
||||
* - SINGLE_ITER_PRINT: Print the runtime of the first iteration<br>
|
||||
* - AGGREGATE: Collect statistics for multiple runs, that can be accessed (by op name) via {@link #getAggregateModeMap()}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Getter
|
||||
public class OpBenchmarkListener extends BaseListener {
|
||||
|
||||
public enum Mode {SINGLE_ITER_PRINT, AGGREGATE}
|
||||
|
||||
private final Operation operation;
|
||||
private final Mode mode;
|
||||
private final long minRuntime;
|
||||
private Map<String,OpExec> aggregateModeMap;
|
||||
|
||||
@Getter(AccessLevel.PRIVATE)
|
||||
private long start;
|
||||
@Getter(AccessLevel.PRIVATE)
|
||||
private boolean printActive;
|
||||
private boolean printDone;
|
||||
|
||||
public OpBenchmarkListener(Operation operation, @NonNull Mode mode) {
|
||||
this(operation, mode, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param operation Operation to collect stats for
|
||||
* @param mode Mode - see {@link OpBenchmarkListener}
|
||||
* @param minRuntime Minimum runtime - only applies to Mode.SINGLE_ITER_PRINT. If op runtime below this: don't print
|
||||
*/
|
||||
public OpBenchmarkListener(Operation operation, @NonNull Mode mode, long minRuntime) {
|
||||
this.operation = operation;
|
||||
this.mode = mode;
|
||||
this.minRuntime = minRuntime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return this.operation == null || this.operation == operation;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void operationStart(SameDiff sd, Operation op) {
|
||||
if(printDone)
|
||||
return;
|
||||
if(this.operation == null || this.operation == op)
|
||||
printActive = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void operationEnd(SameDiff sd, Operation op) {
|
||||
if(printDone)
|
||||
return;
|
||||
if(this.operation == null || this.operation == op) {
|
||||
printActive = false;
|
||||
printDone = true;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
||||
start = System.currentTimeMillis();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||
long now = System.currentTimeMillis();
|
||||
|
||||
if (mode == Mode.SINGLE_ITER_PRINT && printActive && (now-start) > this.minRuntime) {
|
||||
System.out.println(getOpString(op, now));
|
||||
} else if (mode == Mode.AGGREGATE) {
|
||||
if(aggregateModeMap == null)
|
||||
aggregateModeMap = new LinkedHashMap<>();
|
||||
|
||||
if(!aggregateModeMap.containsKey(op.getName())){
|
||||
String s = getOpString(op, null);
|
||||
OpExec oe = new OpExec(op.getName(), op.getOp().opName(), op.getOp().getClass(),
|
||||
new ArrayList<Long>(), s);
|
||||
aggregateModeMap.put(op.getName(), oe);
|
||||
}
|
||||
|
||||
aggregateModeMap.get(op.getName()).getRuntimeMs().add(now-start);
|
||||
}
|
||||
}
|
||||
|
||||
private String getOpString(SameDiffOp op, Long now){
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(op.getName()).append(" - ").append(op.getOp().getClass().getSimpleName())
|
||||
.append("(").append(op.getOp().opName()).append(") - ");
|
||||
if(now != null) {
|
||||
sb.append(now - start).append(" ms\n");
|
||||
}
|
||||
|
||||
if (op.getOp() instanceof DynamicCustomOp) {
|
||||
DynamicCustomOp dco = (DynamicCustomOp) op.getOp();
|
||||
int x = 0;
|
||||
|
||||
for (INDArray i : dco.inputArguments()) {
|
||||
sb.append(" in ").append(x++).append(": ").append(i.shapeInfoToString()).append("\n");
|
||||
}
|
||||
x = 0;
|
||||
for (INDArray o : dco.outputArguments()) {
|
||||
sb.append(" out ").append(x++).append(": ").append(o.shapeInfoToString()).append("\n");
|
||||
}
|
||||
long[] iargs = dco.iArgs();
|
||||
boolean[] bargs = dco.bArgs();
|
||||
double[] targs = dco.tArgs();
|
||||
if (iargs != null && iargs.length > 0) {
|
||||
sb.append(" iargs: ").append(Arrays.toString(iargs)).append("\n");
|
||||
}
|
||||
if (bargs != null && bargs.length > 0) {
|
||||
sb.append(" bargs: ").append(Arrays.toString(bargs)).append("\n");
|
||||
}
|
||||
if (targs != null && targs.length > 0) {
|
||||
sb.append(" targs: ").append(Arrays.toString(targs)).append("\n");
|
||||
}
|
||||
} else {
|
||||
Op o = (Op) op.getOp();
|
||||
if (o.x() != null)
|
||||
sb.append(" x: ").append(o.x().shapeInfoToString());
|
||||
if (o.y() != null)
|
||||
sb.append(" y: ").append(o.y().shapeInfoToString());
|
||||
if (o.z() != null)
|
||||
sb.append(" z: ").append(o.z().shapeInfoToString());
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
public static class OpExec {
|
||||
private final String opOwnName;
|
||||
private final String opName;
|
||||
private final Class<?> opClass;
|
||||
private List<Long> runtimeMs;
|
||||
private String firstIter;
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
DecimalFormat df = new DecimalFormat("0.000");
|
||||
|
||||
return opOwnName + " - op class: " + opClass.getSimpleName() + " (op name: " + opName + ")\n"
|
||||
+ "count: " + runtimeMs.size() + ", mean: " + df.format(avgMs()) + "ms, std: " + df.format(stdMs()) + "ms, min: " + minMs() + "ms, max: " + maxMs() + "ms\n"
|
||||
+ firstIter;
|
||||
}
|
||||
|
||||
public double avgMs() {
|
||||
long sum = 0;
|
||||
for (Long l : runtimeMs) {
|
||||
sum += l;
|
||||
}
|
||||
return sum / (double) runtimeMs.size();
|
||||
}
|
||||
|
||||
public double stdMs() {
|
||||
return Nd4j.createFromArray(ArrayUtil.toArrayLong(runtimeMs)).stdNumber().doubleValue();
|
||||
}
|
||||
|
||||
public long minMs() {
|
||||
return Nd4j.createFromArray(ArrayUtil.toArrayLong(runtimeMs)).minNumber().longValue();
|
||||
}
|
||||
|
||||
public long maxMs() {
|
||||
return Nd4j.createFromArray(ArrayUtil.toArrayLong(runtimeMs)).maxNumber().longValue();
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue