SameDiff profiler / tracing and profile analysis/comparison (#133)

* Profiler

Signed-off-by: Alex Black <blacka101@gmail.com>

* Next steps, polishing, and loading SD/TF format JSON

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Profile comparison method

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Make profiling result writing async to reduce main thread overhead

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Profiling polishing

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Profile analyzer fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Polish

Signed-off-by: Alex Black <blacka101@gmail.com>

* Cleanup

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small formatting improvement

Signed-off-by: Alex Black <blacka101@gmail.com>

* Formatting tweak

Signed-off-by: Alex Black <blacka101@gmail.com>

* License headers

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-12-19 23:43:58 +11:00 committed by GitHub
parent e303c06042
commit 3d8f6d50a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1060 additions and 1 deletions

View File

@ -0,0 +1,354 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.listeners.profiler;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.profiler.data.Phase;
import org.nd4j.autodiff.listeners.profiler.data.TraceEvent;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import java.io.*;
import java.lang.management.ManagementFactory;
import java.text.DecimalFormat;
import java.util.*;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.LinkedBlockingDeque;
/**
* SameDiff profiling listener: for profiling operation execution<br>
* Writes profiles to a file in JSON format<br>
* Format is Chrome profiler format. The output can be read by Google Chrome browser; open Chrome and go to:
* chrome://tracing and load the output JSON format data
* <br>
* At present, only operation execution is profiled, not other aspects such as memory allocation and training-related
* functionality.<br>
* <br>
* Tracing can be configured in a few different ways via the builder, {@link #builder(File)}:<br>
* (a) warmup - don't record traces for the first N iterations<br>
* (b) "all" mode (default) - record all-iterations, with no limit (after warmup, if applicable)<br>
* (c) "n iterations" mode: record at most the first N iterations (after warmup, if applicable)<br>
* (d) "n ms" mod: record for at most N milliseconds since the start of the first op execution (after warmup, if applicable)<br>
*
* Note: The Chrome Trace Event format can be found here:<br>
* <a href="https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit">https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit</a>
* SameDiff uses the JSON Array Format, as this can be written in an online/streaming manner.<br>
* Conversely, TensorFlow uses the JSON Object Format.<br>
* <br>
* For summarizing, analyzing and comparing the results (SameDiff or TensorFlow format), see {@link org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer}<br>
*
* @author Alex Black
*/
@Getter
@Slf4j
public class ProfilingListener extends BaseListener {
private final File outputFile;
private final boolean all;
private final int warmup;
private final int nIter;
private final long nMs;
private final Operation[] operations;
private final long pid;
private final long tid;
private Long firstOpStart = null; //Used for time termination
private int countTotalIter = 0;
private boolean logActive = false;
private long opStartNano;
private Writer writer;
private ObjectMapper json;
private final Thread fileWritingThread;
private final BlockingQueue<TraceEvent> writeQueue;
private final AtomicBoolean writing = new AtomicBoolean(false);
protected ProfilingListener(@NonNull File outputFile, boolean all, int warmup, int nIter, long nMs, Operation[] operations) {
Preconditions.checkArgument(!outputFile.exists(), "Output file already exists: %s", outputFile);
this.outputFile = outputFile;
this.all = all;
this.warmup = warmup;
this.nIter = nIter;
this.nMs = nMs;
this.operations = operations;
this.pid = getProcessId();
this.tid = Thread.currentThread().getId();
try {
this.writer = new BufferedWriter(new FileWriter(outputFile, false));
this.writer.write("["); //JSON array open (array close is optional for Chrome profiler format)
} catch (IOException e) {
throw new RuntimeException(e);
}
this.json = jsonMapper();
//Set up a queue so file access doesn't add latency to the execution thread
writeQueue = new LinkedBlockingDeque<>();
fileWritingThread = new Thread(new Runnable() {
@Override
public void run() {
try {
runHelper();
} catch (Throwable t) {
log.error("Error when attempting to write results to file", t);
}
}
public void runHelper() throws Exception {
while (true) {
TraceEvent te = writeQueue.take(); //Blocking
writing.set(true);
try {
String j = json.writeValueAsString(te);
writer.append(j);
writer.append(",\n");
} catch (IOException e) {
throw new RuntimeException(e);
} finally {
writing.set(false);
}
}
}
});
fileWritingThread.setDaemon(true);
fileWritingThread.start();
}
@Override
public boolean isActive(Operation operation) {
return operations == null || ArrayUtils.contains(operations, operation);
}
@Override
public void operationStart(SameDiff sd, Operation op) {
this.logActive = operations == null || ArrayUtils.contains(operations, op);
}
@Override
public void operationEnd(SameDiff sd, Operation op) {
if (this.logActive) {
while ((!writeQueue.isEmpty() || writing.get()) && fileWritingThread.isAlive()) {
//Wait for file writing thread to catch up
try {
Thread.sleep(100);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
try {
writer.flush();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
this.logActive = false;
if (op == Operation.INFERENCE) {
//Increment for inference; iteration done is called only for TRAINING
countTotalIter++;
}
}
@Override
public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
//Increment for training
if (logActive) {
countTotalIter++;
}
}
@Override
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
if (logActive) {
opStartNano = System.nanoTime();
if(!all && nMs > 0 && firstOpStart == null)
firstOpStart = opStartNano;
}
}
@Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
if (logActive) {
long now = System.nanoTime();
if (warmup > 0 && countTotalIter < warmup) {
return; //Skip due to warmup phase
}
//Iteration termination
int terminationPt = this.nIter > 0 ? this.nIter : Integer.MAX_VALUE;
if (warmup > 0 && this.nIter > 0)
terminationPt += this.warmup;
if (countTotalIter > terminationPt) {
logActive = false;
return; //Skip due to max number of itertions
}
//Time termination
if(!all && nMs > 0 && (now - firstOpStart)/1000 > nMs) {
logActive = false;
return;
}
TraceEvent event = TraceEvent.builder()
.name(op.getOp().opName())
.categories(Collections.singletonList("Op"))
.ts(opStartNano / 1000)
.dur((now - opStartNano) / 1000)
.pid((int)pid)
.tid(tid)
.ph(Phase.X)
.args(Collections.<String, Object>singletonMap("name", op.getName()))
.build();
writeQueue.add(event);
}
}
private long getProcessId() {
// Note: may fail in some JVM implementations
// therefore fallback has to be provided
// something like '<pid>@<hostname>', at least in SUN / Oracle JVMs
final String jvmName = ManagementFactory.getRuntimeMXBean().getName();
final int index = jvmName.indexOf('@');
if (index < 1) {
// part before '@' empty (index = 0) / '@' not found (index = -1)
return 0;
}
try {
return Long.parseLong(jvmName.substring(0, index));
} catch (NumberFormatException e) {
// ignore
}
return 0;
}
/**
* Get a new JSON mapper for use in serializing/deserializing JSON format
*/
public static ObjectMapper jsonMapper() {
ObjectMapper json = new ObjectMapper();
json.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
json.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
json.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false);
json.disable(SerializationFeature.INDENT_OUTPUT); //One line
return json;
}
/**
* Create a new builder
* @param outputFile Output file. Will be overwritten if file already exists
*/
public static Builder builder(File outputFile) {
return new Builder(outputFile);
}
public static class Builder {
private final File outputFile;
private boolean all = true;
private int warmup = 0;
private int nIter = -1;
private long nMs = -1;
private Operation[] operations;
public Builder(@NonNull File outputFile) {
this.outputFile = outputFile;
}
/**
* If called, all data will be profiled with no limits (other than a warmup, if set)
*/
public Builder recordAll() {
this.all = true;
this.nIter = -1;
this.nMs = -1;
return this;
}
/**
* Specify the number of warmup iterations - i.e., these will be excluded from profiling results
*/
public Builder warmup(int iterations) {
this.warmup = iterations;
return this;
}
/**
* Set a limit on the maximum number of iterations to profile (after warmup, if any).
* Any ops executed after the specified number of iterations will not be profiled/recorded
*/
public Builder maxProfileIterations(int iterations) {
this.nIter = iterations;
this.all = false;
return this;
}
/**
* Set a limit on the maximum duration for profiling, in milliseconds.
* Any ops executed after the specified amount of time since the first (non-warmup) operation start will not be
* profiled/recorded
*/
public Builder maxProfilerMilliseconds(long ms) {
this.nMs = ms;
this.all = false;
return this;
}
/**
* Specify the operations (training, inference, etc) to profile.
* If not set, all operations are profiled
*/
public Builder operations(Operation... operations) {
this.operations = operations;
return this;
}
/**
* Create the profiling listener
*/
public ProfilingListener build() {
return new ProfilingListener(outputFile, all, warmup, nIter, nMs, operations);
}
}
}

View File

@ -0,0 +1,441 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.listeners.profiler.comparison;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
import org.nd4j.autodiff.listeners.profiler.data.Phase;
import org.nd4j.autodiff.listeners.profiler.data.TraceEvent;
import org.nd4j.autodiff.listeners.profiler.data.TraceEvents;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.list.NDArrayList;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.*;
/**
* A profile analyzer, used for analyzing Chrome-format profiler dumps generated by both SameDiff's<br>
* {@link ProfilingListener} and TensorFlow's profiler.<br>
* Has methods for summarizing profiler statistics, as well as comparing two profiler dumps.<br>
* <br>
* Also supports analyzing/aggregating multiple JSON files in a directory, via the "...Directory(...)" methods.
* <p>
* See {@link ProfilingListener}<br>
* See {@link TraceEvent}
*
* @author Alex Black
*/
@Slf4j
public class ProfileAnalyzer {
/**
* Chrome profiler supports 2 formats:<br>
* SameDiff == JSON Array Format<br>
* TensorFlow == JSON Object Format<br>
*/
public enum ProfileFormat {SAMEDIFF, TENSORFLOW}
/**
* Only applicable for profile comparisons.<br>
* PROFILE1_PC - sort by profile 1 percentage of total time<br>
* PROFILE2_PC - sort by profile 2 percentage of total time<br>
* RATIO - sort by highest ratio (mean op time profile 1 / mean op time profile 2)
*/
public enum SortBy {PROFILE1_PC, PROFILE2_PC, RATIO}
/**
* Summarize and print to stdout the specified profile file
*
* @param file Profile file
* @param profileFormat Format of the profiler file
*/
public static void summarizeProfile(File file, ProfileFormat profileFormat) {
System.out.println(summarizeProfileStr(file, profileFormat));
}
/**
* Summarize and return as a string the specified profile file
*
* @param file Profile file
* @param profileFormat Format of the profiler file
*/
public static String summarizeProfileStr(File file, ProfileFormat profileFormat) {
TraceEvent[] events = getTraceEvents(file, profileFormat);
return summarizeTraceEvents(events);
}
/**
* Aggregate, summarize and print to stdout all .json profile files in the specified directory (not recursive)
*
* @param dir Directory containing the profiles
* @param profileFormat Profile format
*/
public static void summarizeProfileDirectory(File dir, ProfileFormat profileFormat) {
System.out.println(summarizeProfileDirectoryStr(dir, profileFormat));
}
/**
* Aggregate, summarize and return as a String all .json profile files in the specified directory (not recursive)
*
* @param dir Directory containing the profiles
* @param profileFormat Profile format
*/
public static String summarizeProfileDirectoryStr(File dir, ProfileFormat profileFormat) {
return summarizeTraceEvents(getTraceEventsDir(dir, profileFormat));
}
/**
* Load, aggregate and return the TraceEvent object from all profiles in the specified directory
*
* @param dir Directory containing the profiles
* @param profileFormat Profile format
*/
public static TraceEvent[] getTraceEventsDir(File dir, ProfileFormat profileFormat) {
File[] files = dir.listFiles();
Preconditions.checkState(files != null && files.length > 0, "No profiles found in directory: %s", dir);
List<TraceEvent> l = new ArrayList<>();
for (File f : files) {
if (!f.getName().endsWith(".json")) {
log.info("Skipping non-JSON file in directory - {}", f.getAbsolutePath());
continue;
}
TraceEvent[] e = getTraceEvents(f, profileFormat);
Collections.addAll(l, e);
}
return l.toArray(new TraceEvent[0]);
}
/**
* Load and return the TraceEvent object from the specified profile file
*
* @param file Profile file
* @param profileFormat Profile format
*/
public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat) {
ObjectMapper json = ProfilingListener.jsonMapper();
String content;
try {
content = FileUtils.readFileToString(file, StandardCharsets.UTF_8);
} catch (IOException e) {
throw new RuntimeException(e);
}
if (!content.matches(".*]\\s*")) {
if (content.endsWith(",")) {
//Has comma, missing ]
content = content.substring(0, content.length() - 1) + "]";
} else if (content.endsWith(",\n")) {
//Has comma and newline, missing ]
content = content.substring(0, content.length() - 2) + "]";
} else {
content = content + "]";
}
}
TraceEvent[] events;
if (profileFormat == ProfileFormat.SAMEDIFF) {
try {
events = json.readValue(content, TraceEvent[].class);
} catch (IOException e) {
throw new RuntimeException(e);
}
} else {
//TF format
TraceEvents traceEvents;
try {
traceEvents = json.readValue(content, TraceEvents.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
events = traceEvents.getTraceEvents().toArray(new TraceEvent[0]);
//Clean up TF format - sometimes things like "Softmax" are actually profiled as "_MklSoftmax"
//And we'll align TF names to SameDiff names
for (TraceEvent te : events) {
if (TF_PROFILE_ALIASES.containsKey(te.getName())) {
te.setName(TF_PROFILE_ALIASES.get(te.getName()));
}
DifferentialFunction df = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(te.getName());
if (df != null) {
te.setName(df.opName());
}
}
}
return events;
}
/**
* Summarize the specified TraceEvents as a String
*
* @param events Events to summarize
*/
public static String summarizeTraceEvents(TraceEvent[] events) {
Pair<Long, Map<String, OpStats>> p = aggregateTraceEvents(events);
final Map<String, OpStats> stats = p.getSecond();
long allOpsUs = p.getFirst();
//Summarize by op type:
List<String> l = new ArrayList<>(stats.keySet());
Collections.sort(l, new Comparator<String>() {
@Override
public int compare(String o1, String o2) {
return -Long.compare(stats.get(o1).sumUs, stats.get(o2).sumUs);
}
});
//Work out longest name and op name:
int longestName = 30;
int longestOpName = 30;
for (String s : l) {
longestName = Math.max(longestName, s.length() + 1);
longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1);
}
StringBuilder sb = new StringBuilder();
String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n";
sb.append(String.format(headerFormat, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std"));
String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
for (String s : l) {
OpStats st = stats.get(s);
double pc = (100.0 * st.sumUs) / allOpsUs;
INDArray arr = st.timesUs.array();
long min = arr.minNumber().longValue();
long max = arr.maxNumber().longValue();
double std = arr.stdNumber().doubleValue();
sb.append(String.format(format, s, st.opName, st.count, st.getSumUs(), pc, min, max, std));
}
return sb.toString();
}
private static Pair<Long, Map<String, OpStats>> aggregateTraceEvents(TraceEvent[] events) {
//Summarize by op (instance) name:
final Map<String, OpStats> stats = new HashMap<>();
for (TraceEvent e : events) {
if (e.getPh() != Phase.X || e.getDur() == null) {
continue;
}
OpStats s;
String instanceName = (String) e.getArgs().get("name");
if (stats.containsKey(instanceName)) {
s = stats.get(instanceName);
} else {
s = new OpStats(e.getName(), 0, new NDArrayList(DataType.LONG, 0), null);
stats.put(instanceName, s);
}
s.count++;
s.timesUs.add((double) e.getDur());
}
long allOpsUs = 0;
for (OpStats s : stats.values()) {
s.sumUs = s.timesUs.array().sumNumber().longValue();
allOpsUs += s.sumUs;
}
return new Pair<>(allOpsUs, stats);
}
@AllArgsConstructor
@NoArgsConstructor
@Data
private static class OpStats {
private String opName;
private int count;
private NDArrayList timesUs;
private Long sumUs;
}
/**
* Compare the specified profile files, sorted by profile 1 % of total time
*
* @param file1 First profile file
* @param file2 Second profile file
* @param format1 Format of first profile
* @param format2 Format of second profile
* @return Comparison summary as a String
*/
public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2) {
return compareProfiles(file1, file2, format1, format2, false, false, null, null, SortBy.PROFILE1_PC);
}
/**
* Compare the specified profile files or directory
*
* @param file1 First profile file or directory of profiles
* @param file2 Second profile file or directory of profiles
* @param format1 Format for first profile file/s
* @param format2 Format for second profile file/s
* @param firstIsDir True if the first File object is a directory
* @param secondIsDir True if the second File object is a directory
* @param name1 Name of the first profile (just for display purposes). Optional
* @param name2 Name of the second profile (just for display purposes). Optional
* @param sortBy What to sort the summary results by
* @return Comparison summary as a String
*/
public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2,
boolean firstIsDir, boolean secondIsDir, String name1, String name2, final SortBy sortBy) {
TraceEvent[] t1 = firstIsDir ? getTraceEventsDir(file1, format1) : getTraceEvents(file1, format1);
TraceEvent[] t2 = secondIsDir ? getTraceEventsDir(file2, format2) : getTraceEvents(file2, format2);
final Pair<Long, Map<String, OpStats>> p1 = aggregateTraceEvents(t1);
final Pair<Long, Map<String, OpStats>> p2 = aggregateTraceEvents(t2);
List<String> l = new ArrayList<>(sortBy != SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet());
Collections.sort(l, new Comparator<String>() {
@Override
public int compare(String o1, String o2) {
switch (sortBy) {
case PROFILE1_PC:
return -Long.compare(p1.getSecond().get(o1).sumUs, p1.getSecond().get(o2).sumUs);
case PROFILE2_PC:
return -Long.compare(p2.getSecond().get(o1).sumUs, p2.getSecond().get(o2).sumUs);
case RATIO:
double m1a = meanTime(p1, o1);
double m1b = meanTime(p1, o2);
double m2a = meanTime(p2, o1);
double m2b = meanTime(p2, o2);
double ratio1 = m1a / m2a;
double ratio2 = m1b / m2b;
return -Double.compare(ratio1, ratio2);
default:
throw new RuntimeException();
}
}
});
Set<String> set = new HashSet<>(l);
StringBuilder sb = new StringBuilder();
sb.append("1 = ").append(name1 == null ? "Profile 1" : name1).append("\n")
.append("2 = ").append(name2 == null ? "Profile 2" : name2).append("\n");
//Work out longest name and op name:
int longestName = 30;
int longestOpName = 30;
Map<String, OpStats> stats = sortBy == SortBy.PROFILE2_PC ? p2.getSecond() : p1.getSecond();
for (String s : l) {
longestName = Math.max(longestName, s.length() + 1);
longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1);
}
String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-16s%-13s%-13s%-14s%-14s%-12s%-12s%-14s%-14s%-10s%-10s%-10s%-10s\n";
sb.append(String.format(headerFormat, "Op Name", "Op", "Count (1)", "Count (2)", "Mean Ratio 1/2", "Mean (1)", "Mean (2)", "Total uS (1)", "Total uS (2)", "% (1)", "% (2)", "Min (1)", "Min (2)", "Max (1)", "Max (2)", "Std (1)", "Std (2)"));
String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-16.2f%-13.2f%-13.2f%-14d%-14d%-12.2f%-12.2f%-14d%-14d%-10d%-10d%-10.2f%-10.2f\n";
for (String s : l) {
OpStats s1 = p1.getSecond().get(s);
OpStats s2 = p2.getSecond().get(s);
double m1 = s1 == null ? 0 : s1.getTimesUs().array().meanNumber().doubleValue();
double m2 = s2 == null ? 0 : s2.getTimesUs().array().meanNumber().doubleValue();
double ratio = m1 / m2;
double pc1 = s1 == null ? 0 : 100.0 * s1.sumUs / p1.getFirst();
double pc2 = s2 == null ? 0 : 100.0 * s2.sumUs / p2.getFirst();
sb.append(String.format(format, s, s1 != null ? s1.opName : s2.opName,
s1 != null ? s1.count : 0,
s2 != null ? s2.count : 0,
//Ratio of means, means
ratio,
m1, m2,
//Total us, percent of op total
s1 != null ? s1.sumUs : 0,
s2 != null ? s2.sumUs : 0,
pc1, pc2,
//Min, max, std
s1 != null ? s1.getTimesUs().array().minNumber().longValue() : 0,
s2 != null ? s2.getTimesUs().array().minNumber().longValue() : 0,
s1 != null ? s1.getTimesUs().array().maxNumber().longValue() : 0,
s2 != null ? s2.getTimesUs().array().maxNumber().longValue() : 0,
s1 != null ? s1.getTimesUs().array().stdNumber().doubleValue() : 0.0,
s2 != null ? s2.getTimesUs().array().stdNumber().doubleValue() : 0.0));
}
boolean header = false;
String headerFormat2 = null;
String format3 = null;
for (String s : (sortBy == SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet())) {
if (!set.contains(s)) {
Map<String, OpStats> m = sortBy == SortBy.PROFILE2_PC ? p1.getSecond() : p2.getSecond();
if (!header) {
longestName = 30;
longestOpName = 30;
for(String s2 : m.keySet()){
longestName = Math.max(longestName, s2.length()+1);
longestOpName = Math.max(longestOpName, m.get(s2).opName.length()+1);
}
headerFormat2 = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n";
format3 = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
sb.append(" *** Operations not in profile ").append(sortBy == SortBy.PROFILE2_PC ? "1" : "2").append(" but in profile ")
.append(sortBy == SortBy.PROFILE2_PC ? "2" : "1").append(" ***\n");
sb.append(String.format(headerFormat2, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std"));
header = true;
}
long allOpsUs = sortBy == SortBy.PROFILE2_PC ? p1.getFirst() : p2.getFirst();
OpStats st = m.get(s);
double pc = (100.0 * st.getTimesUs().array().sumNumber().longValue()) / allOpsUs;
INDArray arr = st.timesUs.array();
long min = arr.minNumber().longValue();
long max = arr.maxNumber().longValue();
double std = arr.stdNumber().doubleValue();
sb.append(String.format(format3, s, st.opName, st.count, st.getSumUs(), pc, min, max, std));
}
}
return sb.toString();
}
private static double meanTime(Pair<Long, Map<String, OpStats>> p, String name) {
if (!p.getSecond().containsKey(name)) {
return 0.0;
}
return p.getSecond().get(name).getTimesUs().array().meanNumber().doubleValue();
}
private static Map<String, String> TF_PROFILE_ALIASES = new HashMap<>();
static {
TF_PROFILE_ALIASES.put("_MklSoftmax", "Softmax");
}
}

View File

@ -0,0 +1,19 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.listeners.profiler.data;
public enum ColorName {
}

View File

@ -0,0 +1,44 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.listeners.profiler.data;
/**
* Chrome Profiler phase, for details see:
* <a href="https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit">
* https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit</a>
*/
public enum Phase {
B,
E,
X,
I,
C,
b,
n,
e,
s,
t,
f,
P,
N,
O,
D,
M,
V,
v,
R,
c
}

View File

@ -0,0 +1,53 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.listeners.profiler.data;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
import java.util.Map;
/**
* A TraceEvent, such as an operation execution.<br>
* Intended mainly for JSON serialization/deserialization in Chrome profiler format<br>
* Profiler format is described here: <a href="https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit">
* https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit</a>
* See {@link org.nd4j.autodiff.listeners.profiler.ProfilingListener}<br>
* See {@link org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer}
*
* @author Alex Black
*/
@Builder
@Data
@AllArgsConstructor
@NoArgsConstructor
public class TraceEvent {
private String name; //Name of event (usually op name)
private List<String> categories; //Comma separated list of categories
private Phase ph; //Event type - phase (see table for options)
private long ts; //Timestamp, in microseconds (us)
private Long dur; //Duration, optional
private Long tts; //Optional, thlread timestamp, in microseconds
private long pid; //Process ID
private long tid; //Thread ID
private Map<String, Object> args; //Args
private ColorName cname; //Optional, color name (must be one of reserved color names: https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html )
}

View File

@ -0,0 +1,34 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.listeners.profiler.data;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
/**
* A simple holder for a list of trace events
*
* @author Alex Black
*/
@AllArgsConstructor
@NoArgsConstructor
@Data
public class TraceEvents {
private List<TraceEvent> traceEvents;
}

View File

@ -18,6 +18,7 @@ package org.nd4j.list;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -46,7 +47,12 @@ public class NDArrayList extends BaseNDArrayList<Double> {
* @param size the initial size of the array * @param size the initial size of the array
*/ */
public NDArrayList(int size) { public NDArrayList(int size) {
this.container = Nd4j.create(10L); this(DataType.DOUBLE, size);
}
public NDArrayList(DataType dataType, int size) {
Preconditions.checkState(size >= 0, "Size must be non-negative - got %s", size);
this.container = Nd4j.create(dataType, Math.max(10L, size));
this.size = size; this.size = size;
} }
@ -84,6 +90,7 @@ public class NDArrayList extends BaseNDArrayList<Double> {
* directly, this gives you the relevant subset that reflects the content of the list) * directly, this gives you the relevant subset that reflects the content of the list)
* @return the view of the underlying ndarray relative to the collection's real size * @return the view of the underlying ndarray relative to the collection's real size
*/ */
@Override
public INDArray array() { public INDArray array() {
if(isEmpty()) { if(isEmpty()) {
throw new ND4JIllegalStateException("Array is empty!"); throw new ND4JIllegalStateException("Array is empty!");
@ -137,6 +144,8 @@ public class NDArrayList extends BaseNDArrayList<Double> {
return true; return true;
} }
@Override @Override
public boolean remove(Object o) { public boolean remove(Object o) {
int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0); int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0);

View File

@ -0,0 +1,105 @@
package org.nd4j.autodiff.samediff.listeners;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
public class ProfilingListenerTest extends BaseNd4jTest {
public ProfilingListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testProfilingListenerSimple() throws Exception {
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2);
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2));
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2));
SDVariable sm = sd.nn.softmax("predictions", in.mmul("matmul", w).add("addbias", b));
SDVariable loss = sd.loss.logLoss("loss", label, sm);
INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3);
INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2);
File dir = testDir.newFolder();
File f = new File(dir, "test.json");
ProfilingListener listener = ProfilingListener.builder(f)
.recordAll()
.warmup(5)
.build();
sd.setListeners(listener);
Map<String,INDArray> ph = new HashMap<>();
ph.put("in", i);
for( int x=0; x<10; x++ ) {
sd.outputSingle(ph, "predictions");
}
String content = FileUtils.readFileToString(f, StandardCharsets.UTF_8);
System.out.println(content);
//Should be 2 begins and 2 ends for each entry
//5 warmup iterations, 5 profile iterations, x2 for both the op name and the op "instance" name
String[] opNames = {"mmul", "add", "softmax"};
for(String s : opNames){
assertEquals(s, 10, StringUtils.countMatches(content, s));
}
System.out.println("///////////////////////////////////////////");
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.SAMEDIFF);
}
/*
@Test
public void testLoadTfProfile(){
File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json");
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
}
@Test
public void testLoadTfProfileDir(){
File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
}
@Test
public void testLoadTfProfileDir2(){
File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
}
*/
}