diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/Config.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/Config.java new file mode 100644 index 000000000..602a26f99 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/Config.java @@ -0,0 +1,42 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.comparison; + +import lombok.Builder; +import lombok.Data; +import lombok.experimental.Accessors; +import org.nd4j.linalg.function.BiFunction; + +import java.io.File; + +@Data +@Accessors(fluent = true) +@Builder +public class Config { + + private String p1Name; + private String p2Name; + private File profile1; + private File profile2; + private boolean profile1IsDir; + private boolean profile2IsDir; + @Builder.Default private ProfileAnalyzer.ProfileFormat profile1Format = ProfileAnalyzer.ProfileFormat.SAMEDIFF; + @Builder.Default private ProfileAnalyzer.ProfileFormat profile2Format = ProfileAnalyzer.ProfileFormat.SAMEDIFF; + @Builder.Default private ProfileAnalyzer.SortBy sortBy = ProfileAnalyzer.SortBy.PROFILE1_PC; + private BiFunction filter; //Return true to keep, false to remove + @Builder.Default private ProfileAnalyzer.OutputFormat format = ProfileAnalyzer.OutputFormat.TEXT; + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/OpStats.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/OpStats.java new file mode 100644 index 000000000..0949020af --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/OpStats.java @@ -0,0 +1,32 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.listeners.profiler.comparison; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.nd4j.list.NDArrayList; + +@AllArgsConstructor +@NoArgsConstructor +@Data +public class OpStats { + private String opInstanceName; + private String opName; + private int count; + private NDArrayList timesUs; + private Long sumUs; +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java index c8f3340d7..421c13cb0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java @@ -69,6 +69,12 @@ public class ProfileAnalyzer { */ public enum SortBy {PROFILE1_PC, PROFILE2_PC, RATIO} + /** + * TEXT: Human readable, columns padded for alignment
+ * CSV: CSV format, comma separated + */ + public enum OutputFormat {TEXT,CSV} + /** * Summarize and print to stdout the specified profile file @@ -139,6 +145,10 @@ public class ProfileAnalyzer { * @param profileFormat Profile format */ public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat) { + return getTraceEvents(file, profileFormat, true); + } + + public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat, boolean aggregateTFSubOps) { ObjectMapper json = ProfilingListener.jsonMapper(); String content; @@ -189,6 +199,94 @@ public class ProfileAnalyzer { te.setName(df.opName()); } } + + + if(aggregateTFSubOps){ + //For CUDA ops, TF will log sub-ops like: + //fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/task:0/device:GPU:0,async=false#@@cudnn::maxwell::gemm::computeOffsetsKernel(cudnn::maxwell::gemm::ComputeOffsetsParams) + //fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/task:0/device:GPU:0,async=false#@@maxwell_scudnn_128x64_relu_interior_nn + //fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/task:0/device:GPU:0,async=false#@@void tensorflow::functor::ShuffleInTensor3Simple(int, float const*, tensorflow::functor::Dimension<3>, float*) + //We'll join these into one op, then strip everything after the ":" to recover the op name + + //Also, TF has multiple sub-ops like this, sequentially, that need to be joined: + //19 = {TraceEvent@3157} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601259742, dur=466, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)" + //20 = {TraceEvent@3181} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601260229, dur=29, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)" + //21 = {TraceEvent@3206} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601260329, dur=31, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)" + //22 = {TraceEvent@3247} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601260390, dur=4998, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)" + + Map map = new HashMap<>(); //Key: Op name with ID + List out = new ArrayList<>(); + TraceEvent last = null; + for(TraceEvent te : events){ + if(last != null && last.getPh() == Phase.X && te.getPh() == Phase.X && + last.getName().equals(te.getName()) && + last.getArgs() != null && te.getArgs() != null && + last.getArgs().get("name").equals(te.getArgs().get("name")) && + last.getArgs().get("op").equals(te.getArgs().get("op"))){ + //Aggregate - same names, ops, etc + last.setDur(last.getDur() + te.getDur()); + continue; + } + + last = te; + if(te.getArgs() == null || te.getArgs().isEmpty()){ + out.add(te); + continue; + } + + + String n = (String) te.getArgs().get("name"); + + //Aggregate by op name... + //"fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/..." -> "fire2/e1x1/Conv2D" + //We're relying on TF's "one iteration per json file" here + if(n.matches("[\\w/_-]+:[\\w/_-]+#id=\\d+.*")) { + int idx = n.indexOf("#"); + String sub1 = n.substring(0, idx); + String sub; + if (sub1.contains(":")) { + sub = sub1.substring(0, sub1.lastIndexOf(":")); + } else { + sub = sub1; + } + if (map.containsKey(sub)) { + TraceEvent t = map.get(sub); + Long dur = t.getDur(); + if (dur == null && te.getDur() == null) + continue; + t.setDur(dur == null ? te.getDur() : dur + (te.getDur() == null ? 0 : te.getDur())); + } else { + map.put(sub, te); + out.add(te); + } + } else { + if(map.containsKey(n)){ + TraceEvent t = map.get(n); + t.setDur(t.getDur() + te.getDur()); + } else { + map.put(n, te); + out.add(te); + } + } + } + + //Strip everything after ":" in "fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/..." + for( int i=0; i() { @Override public int compare(String o1, String o2) { - return -Long.compare(stats.get(o1).sumUs, stats.get(o2).sumUs); + return -Long.compare(stats.get(o1).getSumUs(), stats.get(o2).getSumUs()); } }); @@ -218,7 +316,7 @@ public class ProfileAnalyzer { int longestOpName = 30; for (String s : l) { longestName = Math.max(longestName, s.length() + 1); - longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1); + longestOpName = Math.max(longestOpName, stats.get(s).getOpName().length() + 1); } StringBuilder sb = new StringBuilder(); @@ -227,12 +325,12 @@ public class ProfileAnalyzer { String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n"; for (String s : l) { OpStats st = stats.get(s); - double pc = (100.0 * st.sumUs) / allOpsUs; - INDArray arr = st.timesUs.array(); + double pc = (100.0 * st.getSumUs()) / allOpsUs; + INDArray arr = st.getTimesUs().array(); long min = arr.minNumber().longValue(); long max = arr.maxNumber().longValue(); double std = arr.stdNumber().doubleValue(); - sb.append(String.format(format, s, st.opName, st.count, st.getSumUs(), pc, min, max, std)); + sb.append(String.format(format, s, st.getOpName(), st.getCount(), st.getSumUs(), pc, min, max, std)); } return sb.toString(); @@ -251,33 +349,21 @@ public class ProfileAnalyzer { if (stats.containsKey(instanceName)) { s = stats.get(instanceName); } else { - s = new OpStats(e.getName(), 0, new NDArrayList(DataType.LONG, 0), null); + s = new OpStats(instanceName, e.getName(), 0, new NDArrayList(DataType.LONG, 0), null); stats.put(instanceName, s); } - s.count++; - s.timesUs.add((double) e.getDur()); + s.setCount(s.getCount() + 1); + s.getTimesUs().add((double) e.getDur()); } long allOpsUs = 0; for (OpStats s : stats.values()) { - s.sumUs = s.timesUs.array().sumNumber().longValue(); - allOpsUs += s.sumUs; + s.setSumUs( s.getTimesUs().array().sumNumber().longValue()); + allOpsUs += s.getSumUs(); } return new Pair<>(allOpsUs, stats); } - - @AllArgsConstructor - @NoArgsConstructor - @Data - private static class OpStats { - private String opName; - private int count; - private NDArrayList timesUs; - private Long sumUs; - - } - /** * Compare the specified profile files, sorted by profile 1 % of total time * @@ -307,24 +393,36 @@ public class ProfileAnalyzer { */ public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2, boolean firstIsDir, boolean secondIsDir, String name1, String name2, final SortBy sortBy) { + return compareProfiles(Config.builder() + .profile1(file1) + .profile2(file2) + .profile1Format(format1) + .profile2Format(format2) + .profile1IsDir(firstIsDir) + .profile2IsDir(secondIsDir) + .p1Name(name1) + .p2Name(name2) + .sortBy(sortBy) + .build()); + } - TraceEvent[] t1 = firstIsDir ? getTraceEventsDir(file1, format1) : getTraceEvents(file1, format1); - TraceEvent[] t2 = secondIsDir ? getTraceEventsDir(file2, format2) : getTraceEvents(file2, format2); + public static String compareProfiles(final Config c){ + TraceEvent[] t1 = c.profile1IsDir() ? getTraceEventsDir(c.profile1(), c.profile1Format()) : getTraceEvents(c.profile1(), c.profile1Format()); + TraceEvent[] t2 = c.profile2IsDir() ? getTraceEventsDir(c.profile2(), c.profile2Format()) : getTraceEvents(c.profile2(), c.profile2Format()); final Pair> p1 = aggregateTraceEvents(t1); final Pair> p2 = aggregateTraceEvents(t2); - List l = new ArrayList<>(sortBy != SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet()); + List l = new ArrayList<>(c.sortBy() != SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet()); Collections.sort(l, new Comparator() { @Override public int compare(String o1, String o2) { - switch (sortBy) { + switch (c.sortBy()) { case PROFILE1_PC: - return -Long.compare(p1.getSecond().get(o1).sumUs, p1.getSecond().get(o2).sumUs); + return -Long.compare(p1.getSecond().get(o1).getSumUs(), p1.getSecond().get(o2).getSumUs()); case PROFILE2_PC: - return -Long.compare(p2.getSecond().get(o1).sumUs, p2.getSecond().get(o2).sumUs); + return -Long.compare(p2.getSecond().get(o1).getSumUs(), p2.getSecond().get(o2).getSumUs()); case RATIO: - double m1a = meanTime(p1, o1); double m1b = meanTime(p1, o2); double m2a = meanTime(p2, o1); @@ -342,42 +440,53 @@ public class ProfileAnalyzer { StringBuilder sb = new StringBuilder(); - sb.append("1 = ").append(name1 == null ? "Profile 1" : name1).append("\n") - .append("2 = ").append(name2 == null ? "Profile 2" : name2).append("\n"); + sb.append("1 = ").append(c.p1Name() == null ? "Profile 1" : c.p1Name()).append("\n") + .append("2 = ").append(c.p2Name() == null ? "Profile 2" : c.p2Name()).append("\n"); //Work out longest name and op name: int longestName = 30; int longestOpName = 30; - Map stats = sortBy == SortBy.PROFILE2_PC ? p2.getSecond() : p1.getSecond(); + Map stats = c.sortBy() == SortBy.PROFILE2_PC ? p2.getSecond() : p1.getSecond(); for (String s : l) { longestName = Math.max(longestName, s.length() + 1); - longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1); + longestOpName = Math.max(longestOpName, stats.get(s).getOpName().length() + 1); } - String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-16s%-13s%-13s%-14s%-14s%-12s%-12s%-14s%-14s%-10s%-10s%-10s%-10s\n"; + String headerFormat; + String format; + if(c.format() == null || c.format() == OutputFormat.TEXT){ + headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-16s%-13s%-13s%-14s%-14s%-12s%-12s%-14s%-14s%-10s%-10s%-10s%-10s\n"; + format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-16.2f%-13.2f%-13.2f%-14d%-14d%-12.2f%-12.2f%-14d%-14d%-10d%-10d%-10.2f%-10.2f\n"; + } else { + headerFormat = "%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s\n"; + format = "%s,%s,%d,%d,%.2f,%.2f,%.2f,%d,%d,%.2f,%.2f,%d,%d,%d,%d,%.2f,%.2f\n"; + } sb.append(String.format(headerFormat, "Op Name", "Op", "Count (1)", "Count (2)", "Mean Ratio 1/2", "Mean (1)", "Mean (2)", "Total uS (1)", "Total uS (2)", "% (1)", "% (2)", "Min (1)", "Min (2)", "Max (1)", "Max (2)", "Std (1)", "Std (2)")); - String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-16.2f%-13.2f%-13.2f%-14d%-14d%-12.2f%-12.2f%-14d%-14d%-10d%-10d%-10.2f%-10.2f\n"; + for (String s : l) { OpStats s1 = p1.getSecond().get(s); OpStats s2 = p2.getSecond().get(s); + if(c.filter() != null && !c.filter().apply(s1, s2)) + continue; + double m1 = s1 == null ? 0 : s1.getTimesUs().array().meanNumber().doubleValue(); double m2 = s2 == null ? 0 : s2.getTimesUs().array().meanNumber().doubleValue(); double ratio = m1 / m2; - double pc1 = s1 == null ? 0 : 100.0 * s1.sumUs / p1.getFirst(); - double pc2 = s2 == null ? 0 : 100.0 * s2.sumUs / p2.getFirst(); + double pc1 = s1 == null ? 0 : 100.0 * s1.getSumUs() / p1.getFirst(); + double pc2 = s2 == null ? 0 : 100.0 * s2.getSumUs() / p2.getFirst(); - sb.append(String.format(format, s, s1 != null ? s1.opName : s2.opName, - s1 != null ? s1.count : 0, - s2 != null ? s2.count : 0, + sb.append(String.format(format, s, s1 != null ? s1.getOpName() : s2.getOpName(), + s1 != null ? s1.getCount() : 0, + s2 != null ? s2.getCount() : 0, //Ratio of means, means ratio, m1, m2, //Total us, percent of op total - s1 != null ? s1.sumUs : 0, - s2 != null ? s2.sumUs : 0, + s1 != null ? s1.getSumUs() : 0, + s2 != null ? s2.getSumUs() : 0, pc1, pc2, //Min, max, std s1 != null ? s1.getTimesUs().array().minNumber().longValue() : 0, @@ -391,36 +500,57 @@ public class ProfileAnalyzer { boolean header = false; String headerFormat2 = null; String format3 = null; - for (String s : (sortBy == SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet())) { + List toAppend = null; + for (String s : (c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet())) { if (!set.contains(s)) { - Map m = sortBy == SortBy.PROFILE2_PC ? p1.getSecond() : p2.getSecond(); + Map m = c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond() : p2.getSecond(); + OpStats st = m.get(s); + if(c.filter() != null){ + OpStats other = c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond().get(s) : p2.getSecond().get(s); + boolean keep = c.filter().apply(other, st); + if(!keep) + continue; + } + if (!header) { + toAppend = new ArrayList<>(); longestName = 30; longestOpName = 30; for(String s2 : m.keySet()){ longestName = Math.max(longestName, s2.length()+1); - longestOpName = Math.max(longestOpName, m.get(s2).opName.length()+1); + longestOpName = Math.max(longestOpName, m.get(s2).getOpName().length()+1); + } + if(c.format() == null || c.format() == OutputFormat.TEXT) { + headerFormat2 = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n"; + format3 = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n"; + } else { + headerFormat2 = "%s,%s,%s,%s,%s,%s,%s,%s\n"; + format3 = "%s,%s,%d,%d,%.2f,%d,%d,%.2f\n"; } - headerFormat2 = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n"; - format3 = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n"; - sb.append(" *** Operations not in profile ").append(sortBy == SortBy.PROFILE2_PC ? "1" : "2").append(" but in profile ") - .append(sortBy == SortBy.PROFILE2_PC ? "2" : "1").append(" ***\n"); + sb.append(" *** Operations not in profile ").append(c.sortBy() == SortBy.PROFILE2_PC ? "1" : "2").append(" but in profile ") + .append(c.sortBy() == SortBy.PROFILE2_PC ? "2" : "1").append(" ***\n"); sb.append(String.format(headerFormat2, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std")); header = true; } - long allOpsUs = sortBy == SortBy.PROFILE2_PC ? p1.getFirst() : p2.getFirst(); - OpStats st = m.get(s); + long allOpsUs = c.sortBy() == SortBy.PROFILE2_PC ? p1.getFirst() : p2.getFirst(); double pc = (100.0 * st.getTimesUs().array().sumNumber().longValue()) / allOpsUs; - INDArray arr = st.timesUs.array(); + INDArray arr = st.getTimesUs().array(); long min = arr.minNumber().longValue(); long max = arr.maxNumber().longValue(); double std = arr.stdNumber().doubleValue(); - sb.append(String.format(format3, s, st.opName, st.count, st.getSumUs(), pc, min, max, std)); + toAppend.add(String.format(format3, s, st.getOpName(), st.getCount(), st.getSumUs(), pc, min, max, std)); } } + if(toAppend != null){ + Collections.sort(toAppend); + for(String s : toAppend){ + sb.append(s); + } + } + return sb.toString(); }