diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java new file mode 100644 index 000000000..9b92b0412 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java @@ -0,0 +1,354 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler; + +import lombok.*; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Loss; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.listeners.profiler.data.Phase; +import org.nd4j.autodiff.listeners.profiler.data.TraceEvent; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.AtomicBoolean; +import org.nd4j.linalg.util.ArrayUtil; +import org.nd4j.shade.jackson.databind.DeserializationFeature; +import org.nd4j.shade.jackson.databind.MapperFeature; +import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.SerializationFeature; + +import java.io.*; +import java.lang.management.ManagementFactory; +import java.text.DecimalFormat; +import java.util.*; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.LinkedBlockingDeque; + +/** + * SameDiff profiling listener: for profiling operation execution
+ * Writes profiles to a file in JSON format
+ * Format is Chrome profiler format. The output can be read by Google Chrome browser; open Chrome and go to: + * chrome://tracing and load the output JSON format data + *
+ * At present, only operation execution is profiled, not other aspects such as memory allocation and training-related + * functionality.
+ *
+ * Tracing can be configured in a few different ways via the builder, {@link #builder(File)}:
+ * (a) warmup - don't record traces for the first N iterations
+ * (b) "all" mode (default) - record all-iterations, with no limit (after warmup, if applicable)
+ * (c) "n iterations" mode: record at most the first N iterations (after warmup, if applicable)
+ * (d) "n ms" mod: record for at most N milliseconds since the start of the first op execution (after warmup, if applicable)
+ * + * Note: The Chrome Trace Event format can be found here:
+ * https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit + * SameDiff uses the JSON Array Format, as this can be written in an online/streaming manner.
+ * Conversely, TensorFlow uses the JSON Object Format.
+ *
+ * For summarizing, analyzing and comparing the results (SameDiff or TensorFlow format), see {@link org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer}
+ * + * @author Alex Black + */ +@Getter +@Slf4j +public class ProfilingListener extends BaseListener { + + private final File outputFile; + private final boolean all; + private final int warmup; + private final int nIter; + private final long nMs; + private final Operation[] operations; + + private final long pid; + private final long tid; + private Long firstOpStart = null; //Used for time termination + private int countTotalIter = 0; + private boolean logActive = false; + private long opStartNano; + + private Writer writer; + private ObjectMapper json; + + private final Thread fileWritingThread; + private final BlockingQueue writeQueue; + private final AtomicBoolean writing = new AtomicBoolean(false); + + protected ProfilingListener(@NonNull File outputFile, boolean all, int warmup, int nIter, long nMs, Operation[] operations) { + Preconditions.checkArgument(!outputFile.exists(), "Output file already exists: %s", outputFile); + this.outputFile = outputFile; + this.all = all; + this.warmup = warmup; + this.nIter = nIter; + this.nMs = nMs; + this.operations = operations; + + this.pid = getProcessId(); + this.tid = Thread.currentThread().getId(); + + try { + this.writer = new BufferedWriter(new FileWriter(outputFile, false)); + this.writer.write("["); //JSON array open (array close is optional for Chrome profiler format) + } catch (IOException e) { + throw new RuntimeException(e); + } + + this.json = jsonMapper(); + + //Set up a queue so file access doesn't add latency to the execution thread + writeQueue = new LinkedBlockingDeque<>(); + fileWritingThread = new Thread(new Runnable() { + @Override + public void run() { + try { + runHelper(); + } catch (Throwable t) { + log.error("Error when attempting to write results to file", t); + } + } + + public void runHelper() throws Exception { + while (true) { + TraceEvent te = writeQueue.take(); //Blocking + writing.set(true); + try { + String j = json.writeValueAsString(te); + writer.append(j); + writer.append(",\n"); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + writing.set(false); + } + } + } + }); + fileWritingThread.setDaemon(true); + fileWritingThread.start(); + } + + @Override + public boolean isActive(Operation operation) { + return operations == null || ArrayUtils.contains(operations, operation); + } + + @Override + public void operationStart(SameDiff sd, Operation op) { + this.logActive = operations == null || ArrayUtils.contains(operations, op); + } + + @Override + public void operationEnd(SameDiff sd, Operation op) { + if (this.logActive) { + while ((!writeQueue.isEmpty() || writing.get()) && fileWritingThread.isAlive()) { + //Wait for file writing thread to catch up + try { + Thread.sleep(100); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + try { + writer.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + this.logActive = false; + if (op == Operation.INFERENCE) { + //Increment for inference; iteration done is called only for TRAINING + countTotalIter++; + } + } + + @Override + public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) { + //Increment for training + if (logActive) { + countTotalIter++; + } + } + + @Override + public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { + if (logActive) { + opStartNano = System.nanoTime(); + + if(!all && nMs > 0 && firstOpStart == null) + firstOpStart = opStartNano; + } + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + if (logActive) { + long now = System.nanoTime(); + + if (warmup > 0 && countTotalIter < warmup) { + return; //Skip due to warmup phase + } + + //Iteration termination + int terminationPt = this.nIter > 0 ? this.nIter : Integer.MAX_VALUE; + if (warmup > 0 && this.nIter > 0) + terminationPt += this.warmup; + + if (countTotalIter > terminationPt) { + logActive = false; + return; //Skip due to max number of itertions + } + + //Time termination + if(!all && nMs > 0 && (now - firstOpStart)/1000 > nMs) { + logActive = false; + return; + } + + TraceEvent event = TraceEvent.builder() + .name(op.getOp().opName()) + .categories(Collections.singletonList("Op")) + .ts(opStartNano / 1000) + .dur((now - opStartNano) / 1000) + .pid((int)pid) + .tid(tid) + .ph(Phase.X) + .args(Collections.singletonMap("name", op.getName())) + .build(); + + writeQueue.add(event); + } + } + + + private long getProcessId() { + // Note: may fail in some JVM implementations + // therefore fallback has to be provided + + // something like '@', at least in SUN / Oracle JVMs + final String jvmName = ManagementFactory.getRuntimeMXBean().getName(); + final int index = jvmName.indexOf('@'); + + if (index < 1) { + // part before '@' empty (index = 0) / '@' not found (index = -1) + return 0; + } + + try { + return Long.parseLong(jvmName.substring(0, index)); + } catch (NumberFormatException e) { + // ignore + } + return 0; + } + + /** + * Get a new JSON mapper for use in serializing/deserializing JSON format + */ + public static ObjectMapper jsonMapper() { + ObjectMapper json = new ObjectMapper(); + json.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + json.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + json.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false); + json.disable(SerializationFeature.INDENT_OUTPUT); //One line + + return json; + } + + /** + * Create a new builder + * @param outputFile Output file. Will be overwritten if file already exists + */ + public static Builder builder(File outputFile) { + return new Builder(outputFile); + } + + public static class Builder { + private final File outputFile; + private boolean all = true; + private int warmup = 0; + private int nIter = -1; + private long nMs = -1; + private Operation[] operations; + + public Builder(@NonNull File outputFile) { + this.outputFile = outputFile; + } + + /** + * If called, all data will be profiled with no limits (other than a warmup, if set) + */ + public Builder recordAll() { + this.all = true; + this.nIter = -1; + this.nMs = -1; + return this; + } + + /** + * Specify the number of warmup iterations - i.e., these will be excluded from profiling results + */ + public Builder warmup(int iterations) { + this.warmup = iterations; + return this; + } + + /** + * Set a limit on the maximum number of iterations to profile (after warmup, if any). + * Any ops executed after the specified number of iterations will not be profiled/recorded + */ + public Builder maxProfileIterations(int iterations) { + this.nIter = iterations; + this.all = false; + return this; + } + + /** + * Set a limit on the maximum duration for profiling, in milliseconds. + * Any ops executed after the specified amount of time since the first (non-warmup) operation start will not be + * profiled/recorded + */ + public Builder maxProfilerMilliseconds(long ms) { + this.nMs = ms; + this.all = false; + return this; + } + + /** + * Specify the operations (training, inference, etc) to profile. + * If not set, all operations are profiled + */ + public Builder operations(Operation... operations) { + this.operations = operations; + return this; + } + + /** + * Create the profiling listener + */ + public ProfilingListener build() { + return new ProfilingListener(outputFile, all, warmup, nIter, nMs, operations); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java new file mode 100644 index 000000000..c8f3340d7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java @@ -0,0 +1,441 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.comparison; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.FileUtils; +import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.listeners.profiler.ProfilingListener; +import org.nd4j.autodiff.listeners.profiler.data.Phase; +import org.nd4j.autodiff.listeners.profiler.data.TraceEvent; +import org.nd4j.autodiff.listeners.profiler.data.TraceEvents; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.converters.DifferentialFunctionClassHolder; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.list.NDArrayList; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.*; + +/** + * A profile analyzer, used for analyzing Chrome-format profiler dumps generated by both SameDiff's
+ * {@link ProfilingListener} and TensorFlow's profiler.
+ * Has methods for summarizing profiler statistics, as well as comparing two profiler dumps.
+ *
+ * Also supports analyzing/aggregating multiple JSON files in a directory, via the "...Directory(...)" methods. + *

+ * See {@link ProfilingListener}
+ * See {@link TraceEvent} + * + * @author Alex Black + */ +@Slf4j +public class ProfileAnalyzer { + + /** + * Chrome profiler supports 2 formats:
+ * SameDiff == JSON Array Format
+ * TensorFlow == JSON Object Format
+ */ + public enum ProfileFormat {SAMEDIFF, TENSORFLOW} + + /** + * Only applicable for profile comparisons.
+ * PROFILE1_PC - sort by profile 1 percentage of total time
+ * PROFILE2_PC - sort by profile 2 percentage of total time
+ * RATIO - sort by highest ratio (mean op time profile 1 / mean op time profile 2) + */ + public enum SortBy {PROFILE1_PC, PROFILE2_PC, RATIO} + + + /** + * Summarize and print to stdout the specified profile file + * + * @param file Profile file + * @param profileFormat Format of the profiler file + */ + public static void summarizeProfile(File file, ProfileFormat profileFormat) { + System.out.println(summarizeProfileStr(file, profileFormat)); + } + + /** + * Summarize and return as a string the specified profile file + * + * @param file Profile file + * @param profileFormat Format of the profiler file + */ + public static String summarizeProfileStr(File file, ProfileFormat profileFormat) { + TraceEvent[] events = getTraceEvents(file, profileFormat); + return summarizeTraceEvents(events); + } + + /** + * Aggregate, summarize and print to stdout all .json profile files in the specified directory (not recursive) + * + * @param dir Directory containing the profiles + * @param profileFormat Profile format + */ + public static void summarizeProfileDirectory(File dir, ProfileFormat profileFormat) { + System.out.println(summarizeProfileDirectoryStr(dir, profileFormat)); + } + + /** + * Aggregate, summarize and return as a String all .json profile files in the specified directory (not recursive) + * + * @param dir Directory containing the profiles + * @param profileFormat Profile format + */ + public static String summarizeProfileDirectoryStr(File dir, ProfileFormat profileFormat) { + return summarizeTraceEvents(getTraceEventsDir(dir, profileFormat)); + } + + /** + * Load, aggregate and return the TraceEvent object from all profiles in the specified directory + * + * @param dir Directory containing the profiles + * @param profileFormat Profile format + */ + public static TraceEvent[] getTraceEventsDir(File dir, ProfileFormat profileFormat) { + File[] files = dir.listFiles(); + Preconditions.checkState(files != null && files.length > 0, "No profiles found in directory: %s", dir); + List l = new ArrayList<>(); + for (File f : files) { + if (!f.getName().endsWith(".json")) { + log.info("Skipping non-JSON file in directory - {}", f.getAbsolutePath()); + continue; + } + TraceEvent[] e = getTraceEvents(f, profileFormat); + Collections.addAll(l, e); + } + return l.toArray(new TraceEvent[0]); + } + + /** + * Load and return the TraceEvent object from the specified profile file + * + * @param file Profile file + * @param profileFormat Profile format + */ + public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat) { + ObjectMapper json = ProfilingListener.jsonMapper(); + + String content; + try { + content = FileUtils.readFileToString(file, StandardCharsets.UTF_8); + } catch (IOException e) { + throw new RuntimeException(e); + } + + if (!content.matches(".*]\\s*")) { + if (content.endsWith(",")) { + //Has comma, missing ] + content = content.substring(0, content.length() - 1) + "]"; + } else if (content.endsWith(",\n")) { + //Has comma and newline, missing ] + content = content.substring(0, content.length() - 2) + "]"; + } else { + content = content + "]"; + } + } + + TraceEvent[] events; + if (profileFormat == ProfileFormat.SAMEDIFF) { + try { + events = json.readValue(content, TraceEvent[].class); + } catch (IOException e) { + throw new RuntimeException(e); + } + } else { + //TF format + TraceEvents traceEvents; + try { + traceEvents = json.readValue(content, TraceEvents.class); + } catch (IOException e) { + throw new RuntimeException(e); + } + events = traceEvents.getTraceEvents().toArray(new TraceEvent[0]); + + //Clean up TF format - sometimes things like "Softmax" are actually profiled as "_MklSoftmax" + //And we'll align TF names to SameDiff names + for (TraceEvent te : events) { + if (TF_PROFILE_ALIASES.containsKey(te.getName())) { + te.setName(TF_PROFILE_ALIASES.get(te.getName())); + } + + DifferentialFunction df = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(te.getName()); + if (df != null) { + te.setName(df.opName()); + } + } + } + + return events; + } + + /** + * Summarize the specified TraceEvents as a String + * + * @param events Events to summarize + */ + public static String summarizeTraceEvents(TraceEvent[] events) { + Pair> p = aggregateTraceEvents(events); + final Map stats = p.getSecond(); + long allOpsUs = p.getFirst(); + + //Summarize by op type: + List l = new ArrayList<>(stats.keySet()); + Collections.sort(l, new Comparator() { + @Override + public int compare(String o1, String o2) { + return -Long.compare(stats.get(o1).sumUs, stats.get(o2).sumUs); + } + }); + + //Work out longest name and op name: + int longestName = 30; + int longestOpName = 30; + for (String s : l) { + longestName = Math.max(longestName, s.length() + 1); + longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1); + } + + StringBuilder sb = new StringBuilder(); + String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n"; + sb.append(String.format(headerFormat, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std")); + String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n"; + for (String s : l) { + OpStats st = stats.get(s); + double pc = (100.0 * st.sumUs) / allOpsUs; + INDArray arr = st.timesUs.array(); + long min = arr.minNumber().longValue(); + long max = arr.maxNumber().longValue(); + double std = arr.stdNumber().doubleValue(); + sb.append(String.format(format, s, st.opName, st.count, st.getSumUs(), pc, min, max, std)); + } + + return sb.toString(); + } + + private static Pair> aggregateTraceEvents(TraceEvent[] events) { + //Summarize by op (instance) name: + final Map stats = new HashMap<>(); + for (TraceEvent e : events) { + if (e.getPh() != Phase.X || e.getDur() == null) { + continue; + } + + OpStats s; + String instanceName = (String) e.getArgs().get("name"); + if (stats.containsKey(instanceName)) { + s = stats.get(instanceName); + } else { + s = new OpStats(e.getName(), 0, new NDArrayList(DataType.LONG, 0), null); + stats.put(instanceName, s); + } + s.count++; + s.timesUs.add((double) e.getDur()); + } + + long allOpsUs = 0; + for (OpStats s : stats.values()) { + s.sumUs = s.timesUs.array().sumNumber().longValue(); + allOpsUs += s.sumUs; + } + + return new Pair<>(allOpsUs, stats); + } + + @AllArgsConstructor + @NoArgsConstructor + @Data + private static class OpStats { + private String opName; + private int count; + private NDArrayList timesUs; + private Long sumUs; + + } + + /** + * Compare the specified profile files, sorted by profile 1 % of total time + * + * @param file1 First profile file + * @param file2 Second profile file + * @param format1 Format of first profile + * @param format2 Format of second profile + * @return Comparison summary as a String + */ + public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2) { + return compareProfiles(file1, file2, format1, format2, false, false, null, null, SortBy.PROFILE1_PC); + } + + /** + * Compare the specified profile files or directory + * + * @param file1 First profile file or directory of profiles + * @param file2 Second profile file or directory of profiles + * @param format1 Format for first profile file/s + * @param format2 Format for second profile file/s + * @param firstIsDir True if the first File object is a directory + * @param secondIsDir True if the second File object is a directory + * @param name1 Name of the first profile (just for display purposes). Optional + * @param name2 Name of the second profile (just for display purposes). Optional + * @param sortBy What to sort the summary results by + * @return Comparison summary as a String + */ + public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2, + boolean firstIsDir, boolean secondIsDir, String name1, String name2, final SortBy sortBy) { + + TraceEvent[] t1 = firstIsDir ? getTraceEventsDir(file1, format1) : getTraceEvents(file1, format1); + TraceEvent[] t2 = secondIsDir ? getTraceEventsDir(file2, format2) : getTraceEvents(file2, format2); + + final Pair> p1 = aggregateTraceEvents(t1); + final Pair> p2 = aggregateTraceEvents(t2); + + List l = new ArrayList<>(sortBy != SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet()); + Collections.sort(l, new Comparator() { + @Override + public int compare(String o1, String o2) { + switch (sortBy) { + case PROFILE1_PC: + return -Long.compare(p1.getSecond().get(o1).sumUs, p1.getSecond().get(o2).sumUs); + case PROFILE2_PC: + return -Long.compare(p2.getSecond().get(o1).sumUs, p2.getSecond().get(o2).sumUs); + case RATIO: + + double m1a = meanTime(p1, o1); + double m1b = meanTime(p1, o2); + double m2a = meanTime(p2, o1); + double m2b = meanTime(p2, o2); + double ratio1 = m1a / m2a; + double ratio2 = m1b / m2b; + return -Double.compare(ratio1, ratio2); + default: + throw new RuntimeException(); + } + } + }); + + Set set = new HashSet<>(l); + + + StringBuilder sb = new StringBuilder(); + sb.append("1 = ").append(name1 == null ? "Profile 1" : name1).append("\n") + .append("2 = ").append(name2 == null ? "Profile 2" : name2).append("\n"); + + //Work out longest name and op name: + int longestName = 30; + int longestOpName = 30; + Map stats = sortBy == SortBy.PROFILE2_PC ? p2.getSecond() : p1.getSecond(); + for (String s : l) { + longestName = Math.max(longestName, s.length() + 1); + longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1); + } + + String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-16s%-13s%-13s%-14s%-14s%-12s%-12s%-14s%-14s%-10s%-10s%-10s%-10s\n"; + sb.append(String.format(headerFormat, "Op Name", "Op", "Count (1)", "Count (2)", "Mean Ratio 1/2", "Mean (1)", "Mean (2)", "Total uS (1)", "Total uS (2)", "% (1)", "% (2)", "Min (1)", "Min (2)", "Max (1)", "Max (2)", "Std (1)", "Std (2)")); + String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-16.2f%-13.2f%-13.2f%-14d%-14d%-12.2f%-12.2f%-14d%-14d%-10d%-10d%-10.2f%-10.2f\n"; + + for (String s : l) { + OpStats s1 = p1.getSecond().get(s); + OpStats s2 = p2.getSecond().get(s); + + double m1 = s1 == null ? 0 : s1.getTimesUs().array().meanNumber().doubleValue(); + double m2 = s2 == null ? 0 : s2.getTimesUs().array().meanNumber().doubleValue(); + double ratio = m1 / m2; + + double pc1 = s1 == null ? 0 : 100.0 * s1.sumUs / p1.getFirst(); + double pc2 = s2 == null ? 0 : 100.0 * s2.sumUs / p2.getFirst(); + + sb.append(String.format(format, s, s1 != null ? s1.opName : s2.opName, + s1 != null ? s1.count : 0, + s2 != null ? s2.count : 0, + //Ratio of means, means + ratio, + m1, m2, + //Total us, percent of op total + s1 != null ? s1.sumUs : 0, + s2 != null ? s2.sumUs : 0, + pc1, pc2, + //Min, max, std + s1 != null ? s1.getTimesUs().array().minNumber().longValue() : 0, + s2 != null ? s2.getTimesUs().array().minNumber().longValue() : 0, + s1 != null ? s1.getTimesUs().array().maxNumber().longValue() : 0, + s2 != null ? s2.getTimesUs().array().maxNumber().longValue() : 0, + s1 != null ? s1.getTimesUs().array().stdNumber().doubleValue() : 0.0, + s2 != null ? s2.getTimesUs().array().stdNumber().doubleValue() : 0.0)); + } + + boolean header = false; + String headerFormat2 = null; + String format3 = null; + for (String s : (sortBy == SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet())) { + + if (!set.contains(s)) { + Map m = sortBy == SortBy.PROFILE2_PC ? p1.getSecond() : p2.getSecond(); + if (!header) { + + longestName = 30; + longestOpName = 30; + for(String s2 : m.keySet()){ + longestName = Math.max(longestName, s2.length()+1); + longestOpName = Math.max(longestOpName, m.get(s2).opName.length()+1); + } + headerFormat2 = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n"; + format3 = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n"; + + sb.append(" *** Operations not in profile ").append(sortBy == SortBy.PROFILE2_PC ? "1" : "2").append(" but in profile ") + .append(sortBy == SortBy.PROFILE2_PC ? "2" : "1").append(" ***\n"); + sb.append(String.format(headerFormat2, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std")); + header = true; + } + long allOpsUs = sortBy == SortBy.PROFILE2_PC ? p1.getFirst() : p2.getFirst(); + OpStats st = m.get(s); + double pc = (100.0 * st.getTimesUs().array().sumNumber().longValue()) / allOpsUs; + INDArray arr = st.timesUs.array(); + long min = arr.minNumber().longValue(); + long max = arr.maxNumber().longValue(); + double std = arr.stdNumber().doubleValue(); + sb.append(String.format(format3, s, st.opName, st.count, st.getSumUs(), pc, min, max, std)); + } + } + return sb.toString(); + } + + private static double meanTime(Pair> p, String name) { + if (!p.getSecond().containsKey(name)) { + return 0.0; + } + return p.getSecond().get(name).getTimesUs().array().meanNumber().doubleValue(); + } + + + private static Map TF_PROFILE_ALIASES = new HashMap<>(); + + static { + TF_PROFILE_ALIASES.put("_MklSoftmax", "Softmax"); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/ColorName.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/ColorName.java new file mode 100644 index 000000000..0d9c08deb --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/ColorName.java @@ -0,0 +1,19 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.data; + +public enum ColorName { +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/Phase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/Phase.java new file mode 100644 index 000000000..bca7feb39 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/Phase.java @@ -0,0 +1,44 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.data; + +/** + * Chrome Profiler phase, for details see: + * + * https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit + */ +public enum Phase { + B, + E, + X, + I, + C, + b, + n, + e, + s, + t, + f, + P, + N, + O, + D, + M, + V, + v, + R, + c +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvent.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvent.java new file mode 100644 index 000000000..e4270edd1 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvent.java @@ -0,0 +1,53 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.data; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; +import java.util.Map; + +/** + * A TraceEvent, such as an operation execution.
+ * Intended mainly for JSON serialization/deserialization in Chrome profiler format
+ * Profiler format is described here: + * https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit + * See {@link org.nd4j.autodiff.listeners.profiler.ProfilingListener}
+ * See {@link org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer} + * + * @author Alex Black + */ +@Builder +@Data +@AllArgsConstructor +@NoArgsConstructor +public class TraceEvent { + + private String name; //Name of event (usually op name) + private List categories; //Comma separated list of categories + private Phase ph; //Event type - phase (see table for options) + private long ts; //Timestamp, in microseconds (us) + private Long dur; //Duration, optional + private Long tts; //Optional, thlread timestamp, in microseconds + private long pid; //Process ID + private long tid; //Thread ID + private Map args; //Args + private ColorName cname; //Optional, color name (must be one of reserved color names: https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html ) + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvents.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvents.java new file mode 100644 index 000000000..b3ebf6d8a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvents.java @@ -0,0 +1,34 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.data; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +/** + * A simple holder for a list of trace events + * + * @author Alex Black + */ +@AllArgsConstructor +@NoArgsConstructor +@Data +public class TraceEvents { + private List traceEvents; +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java index 939708291..f3b539d8c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java @@ -18,6 +18,7 @@ package org.nd4j.list; import lombok.NonNull; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -46,7 +47,12 @@ public class NDArrayList extends BaseNDArrayList { * @param size the initial size of the array */ public NDArrayList(int size) { - this.container = Nd4j.create(10L); + this(DataType.DOUBLE, size); + } + + public NDArrayList(DataType dataType, int size) { + Preconditions.checkState(size >= 0, "Size must be non-negative - got %s", size); + this.container = Nd4j.create(dataType, Math.max(10L, size)); this.size = size; } @@ -84,6 +90,7 @@ public class NDArrayList extends BaseNDArrayList { * directly, this gives you the relevant subset that reflects the content of the list) * @return the view of the underlying ndarray relative to the collection's real size */ + @Override public INDArray array() { if(isEmpty()) { throw new ND4JIllegalStateException("Array is empty!"); @@ -137,6 +144,8 @@ public class NDArrayList extends BaseNDArrayList { return true; } + + @Override public boolean remove(Object o) { int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ProfilingListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ProfilingListenerTest.java new file mode 100644 index 000000000..c026446ad --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ProfilingListenerTest.java @@ -0,0 +1,105 @@ +package org.nd4j.autodiff.samediff.listeners; + +import org.apache.commons.io.FileUtils; +import org.apache.commons.lang3.StringUtils; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.autodiff.listeners.profiler.ProfilingListener; +import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +public class ProfilingListenerTest extends BaseNd4jTest { + + public ProfilingListenerTest(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering() { + return 'c'; + } + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + @Test + public void testProfilingListenerSimple() throws Exception { + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2)); + SDVariable sm = sd.nn.softmax("predictions", in.mmul("matmul", w).add("addbias", b)); + SDVariable loss = sd.loss.logLoss("loss", label, sm); + + INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3); + INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2); + + + File dir = testDir.newFolder(); + File f = new File(dir, "test.json"); + ProfilingListener listener = ProfilingListener.builder(f) + .recordAll() + .warmup(5) + .build(); + + sd.setListeners(listener); + + Map ph = new HashMap<>(); + ph.put("in", i); + + for( int x=0; x<10; x++ ) { + sd.outputSingle(ph, "predictions"); + } + + String content = FileUtils.readFileToString(f, StandardCharsets.UTF_8); + System.out.println(content); + + //Should be 2 begins and 2 ends for each entry + //5 warmup iterations, 5 profile iterations, x2 for both the op name and the op "instance" name + String[] opNames = {"mmul", "add", "softmax"}; + for(String s : opNames){ + assertEquals(s, 10, StringUtils.countMatches(content, s)); + } + + + System.out.println("///////////////////////////////////////////"); + ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.SAMEDIFF); + + } + + /* + @Test + public void testLoadTfProfile(){ + File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json"); + ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); + } + + @Test + public void testLoadTfProfileDir(){ + File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles"); + ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); + } + + @Test + public void testLoadTfProfileDir2(){ + File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0"); + ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); + } + */ +}