SameDiff profiler analysis improvements (#141)
* #8555 SameDiff profiler analysis improvements Signed-off-by: Alex Black <blacka101@gmail.com> * Fix TF sub-op aggregation Signed-off-by: Alex Black <blacka101@gmail.com> * Small filtering tweak Signed-off-by: Alex Black <blacka101@gmail.com> * Copyright headers Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
ce02b6fae7
commit
1f9e1b6022
|
@ -0,0 +1,42 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.autodiff.listeners.profiler.comparison;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.experimental.Accessors;
|
||||||
|
import org.nd4j.linalg.function.BiFunction;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Accessors(fluent = true)
|
||||||
|
@Builder
|
||||||
|
public class Config {
|
||||||
|
|
||||||
|
private String p1Name;
|
||||||
|
private String p2Name;
|
||||||
|
private File profile1;
|
||||||
|
private File profile2;
|
||||||
|
private boolean profile1IsDir;
|
||||||
|
private boolean profile2IsDir;
|
||||||
|
@Builder.Default private ProfileAnalyzer.ProfileFormat profile1Format = ProfileAnalyzer.ProfileFormat.SAMEDIFF;
|
||||||
|
@Builder.Default private ProfileAnalyzer.ProfileFormat profile2Format = ProfileAnalyzer.ProfileFormat.SAMEDIFF;
|
||||||
|
@Builder.Default private ProfileAnalyzer.SortBy sortBy = ProfileAnalyzer.SortBy.PROFILE1_PC;
|
||||||
|
private BiFunction<OpStats,OpStats,Boolean> filter; //Return true to keep, false to remove
|
||||||
|
@Builder.Default private ProfileAnalyzer.OutputFormat format = ProfileAnalyzer.OutputFormat.TEXT;
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.autodiff.listeners.profiler.comparison;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.list.NDArrayList;
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
|
@Data
|
||||||
|
public class OpStats {
|
||||||
|
private String opInstanceName;
|
||||||
|
private String opName;
|
||||||
|
private int count;
|
||||||
|
private NDArrayList timesUs;
|
||||||
|
private Long sumUs;
|
||||||
|
}
|
|
@ -69,6 +69,12 @@ public class ProfileAnalyzer {
|
||||||
*/
|
*/
|
||||||
public enum SortBy {PROFILE1_PC, PROFILE2_PC, RATIO}
|
public enum SortBy {PROFILE1_PC, PROFILE2_PC, RATIO}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TEXT: Human readable, columns padded for alignment<br>
|
||||||
|
* CSV: CSV format, comma separated
|
||||||
|
*/
|
||||||
|
public enum OutputFormat {TEXT,CSV}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Summarize and print to stdout the specified profile file
|
* Summarize and print to stdout the specified profile file
|
||||||
|
@ -139,6 +145,10 @@ public class ProfileAnalyzer {
|
||||||
* @param profileFormat Profile format
|
* @param profileFormat Profile format
|
||||||
*/
|
*/
|
||||||
public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat) {
|
public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat) {
|
||||||
|
return getTraceEvents(file, profileFormat, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat, boolean aggregateTFSubOps) {
|
||||||
ObjectMapper json = ProfilingListener.jsonMapper();
|
ObjectMapper json = ProfilingListener.jsonMapper();
|
||||||
|
|
||||||
String content;
|
String content;
|
||||||
|
@ -189,6 +199,94 @@ public class ProfileAnalyzer {
|
||||||
te.setName(df.opName());
|
te.setName(df.opName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if(aggregateTFSubOps){
|
||||||
|
//For CUDA ops, TF will log sub-ops like:
|
||||||
|
//fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/task:0/device:GPU:0,async=false#@@cudnn::maxwell::gemm::computeOffsetsKernel(cudnn::maxwell::gemm::ComputeOffsetsParams)
|
||||||
|
//fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/task:0/device:GPU:0,async=false#@@maxwell_scudnn_128x64_relu_interior_nn
|
||||||
|
//fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/task:0/device:GPU:0,async=false#@@void tensorflow::functor::ShuffleInTensor3Simple<float, 2, 1, 0, false>(int, float const*, tensorflow::functor::Dimension<3>, float*)
|
||||||
|
//We'll join these into one op, then strip everything after the ":" to recover the op name
|
||||||
|
|
||||||
|
//Also, TF has multiple sub-ops like this, sequentially, that need to be joined:
|
||||||
|
//19 = {TraceEvent@3157} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601259742, dur=466, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)"
|
||||||
|
//20 = {TraceEvent@3181} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601260229, dur=29, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)"
|
||||||
|
//21 = {TraceEvent@3206} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601260329, dur=31, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)"
|
||||||
|
//22 = {TraceEvent@3247} "TraceEvent(name=Conv2D#id=80,device=/job, categories=null, ph=X, ts=1576896601260390, dur=4998, tts=null, pid=5, tid=0, args={name=conv1/Conv2D, op=Conv2D#id=80,device=/job}, cname=null)"
|
||||||
|
|
||||||
|
Map<String,TraceEvent> map = new HashMap<>(); //Key: Op name with ID
|
||||||
|
List<TraceEvent> out = new ArrayList<>();
|
||||||
|
TraceEvent last = null;
|
||||||
|
for(TraceEvent te : events){
|
||||||
|
if(last != null && last.getPh() == Phase.X && te.getPh() == Phase.X &&
|
||||||
|
last.getName().equals(te.getName()) &&
|
||||||
|
last.getArgs() != null && te.getArgs() != null &&
|
||||||
|
last.getArgs().get("name").equals(te.getArgs().get("name")) &&
|
||||||
|
last.getArgs().get("op").equals(te.getArgs().get("op"))){
|
||||||
|
//Aggregate - same names, ops, etc
|
||||||
|
last.setDur(last.getDur() + te.getDur());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
last = te;
|
||||||
|
if(te.getArgs() == null || te.getArgs().isEmpty()){
|
||||||
|
out.add(te);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
String n = (String) te.getArgs().get("name");
|
||||||
|
|
||||||
|
//Aggregate by op name...
|
||||||
|
//"fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/replica:0/..." -> "fire2/e1x1/Conv2D"
|
||||||
|
//We're relying on TF's "one iteration per json file" here
|
||||||
|
if(n.matches("[\\w/_-]+:[\\w/_-]+#id=\\d+.*")) {
|
||||||
|
int idx = n.indexOf("#");
|
||||||
|
String sub1 = n.substring(0, idx);
|
||||||
|
String sub;
|
||||||
|
if (sub1.contains(":")) {
|
||||||
|
sub = sub1.substring(0, sub1.lastIndexOf(":"));
|
||||||
|
} else {
|
||||||
|
sub = sub1;
|
||||||
|
}
|
||||||
|
if (map.containsKey(sub)) {
|
||||||
|
TraceEvent t = map.get(sub);
|
||||||
|
Long dur = t.getDur();
|
||||||
|
if (dur == null && te.getDur() == null)
|
||||||
|
continue;
|
||||||
|
t.setDur(dur == null ? te.getDur() : dur + (te.getDur() == null ? 0 : te.getDur()));
|
||||||
|
} else {
|
||||||
|
map.put(sub, te);
|
||||||
|
out.add(te);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if(map.containsKey(n)){
|
||||||
|
TraceEvent t = map.get(n);
|
||||||
|
t.setDur(t.getDur() + te.getDur());
|
||||||
|
} else {
|
||||||
|
map.put(n, te);
|
||||||
|
out.add(te);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Strip everything after ":" in "fire2/e1x1/Conv2D:Conv2D#id=74,device=/job:localhost/..."
|
||||||
|
for( int i=0; i<out.size(); i++ ){
|
||||||
|
TraceEvent te = out.get(i);
|
||||||
|
if(te.getArgs() == null || te.getArgs().isEmpty()){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
String n = (String) te.getArgs().get("name");
|
||||||
|
if(n.matches("[\\w/_-]+:[\\w/_-]+#id=\\d+.*")){
|
||||||
|
int idx = n.indexOf(':');
|
||||||
|
String sub = n.substring(0,idx);
|
||||||
|
te.getArgs().put("name", sub);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
events = out.toArray(new TraceEvent[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return events;
|
return events;
|
||||||
|
@ -209,7 +307,7 @@ public class ProfileAnalyzer {
|
||||||
Collections.sort(l, new Comparator<String>() {
|
Collections.sort(l, new Comparator<String>() {
|
||||||
@Override
|
@Override
|
||||||
public int compare(String o1, String o2) {
|
public int compare(String o1, String o2) {
|
||||||
return -Long.compare(stats.get(o1).sumUs, stats.get(o2).sumUs);
|
return -Long.compare(stats.get(o1).getSumUs(), stats.get(o2).getSumUs());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -218,7 +316,7 @@ public class ProfileAnalyzer {
|
||||||
int longestOpName = 30;
|
int longestOpName = 30;
|
||||||
for (String s : l) {
|
for (String s : l) {
|
||||||
longestName = Math.max(longestName, s.length() + 1);
|
longestName = Math.max(longestName, s.length() + 1);
|
||||||
longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1);
|
longestOpName = Math.max(longestOpName, stats.get(s).getOpName().length() + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
|
@ -227,12 +325,12 @@ public class ProfileAnalyzer {
|
||||||
String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
|
String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
|
||||||
for (String s : l) {
|
for (String s : l) {
|
||||||
OpStats st = stats.get(s);
|
OpStats st = stats.get(s);
|
||||||
double pc = (100.0 * st.sumUs) / allOpsUs;
|
double pc = (100.0 * st.getSumUs()) / allOpsUs;
|
||||||
INDArray arr = st.timesUs.array();
|
INDArray arr = st.getTimesUs().array();
|
||||||
long min = arr.minNumber().longValue();
|
long min = arr.minNumber().longValue();
|
||||||
long max = arr.maxNumber().longValue();
|
long max = arr.maxNumber().longValue();
|
||||||
double std = arr.stdNumber().doubleValue();
|
double std = arr.stdNumber().doubleValue();
|
||||||
sb.append(String.format(format, s, st.opName, st.count, st.getSumUs(), pc, min, max, std));
|
sb.append(String.format(format, s, st.getOpName(), st.getCount(), st.getSumUs(), pc, min, max, std));
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.toString();
|
return sb.toString();
|
||||||
|
@ -251,33 +349,21 @@ public class ProfileAnalyzer {
|
||||||
if (stats.containsKey(instanceName)) {
|
if (stats.containsKey(instanceName)) {
|
||||||
s = stats.get(instanceName);
|
s = stats.get(instanceName);
|
||||||
} else {
|
} else {
|
||||||
s = new OpStats(e.getName(), 0, new NDArrayList(DataType.LONG, 0), null);
|
s = new OpStats(instanceName, e.getName(), 0, new NDArrayList(DataType.LONG, 0), null);
|
||||||
stats.put(instanceName, s);
|
stats.put(instanceName, s);
|
||||||
}
|
}
|
||||||
s.count++;
|
s.setCount(s.getCount() + 1);
|
||||||
s.timesUs.add((double) e.getDur());
|
s.getTimesUs().add((double) e.getDur());
|
||||||
}
|
}
|
||||||
|
|
||||||
long allOpsUs = 0;
|
long allOpsUs = 0;
|
||||||
for (OpStats s : stats.values()) {
|
for (OpStats s : stats.values()) {
|
||||||
s.sumUs = s.timesUs.array().sumNumber().longValue();
|
s.setSumUs( s.getTimesUs().array().sumNumber().longValue());
|
||||||
allOpsUs += s.sumUs;
|
allOpsUs += s.getSumUs();
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Pair<>(allOpsUs, stats);
|
return new Pair<>(allOpsUs, stats);
|
||||||
}
|
}
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
@Data
|
|
||||||
private static class OpStats {
|
|
||||||
private String opName;
|
|
||||||
private int count;
|
|
||||||
private NDArrayList timesUs;
|
|
||||||
private Long sumUs;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compare the specified profile files, sorted by profile 1 % of total time
|
* Compare the specified profile files, sorted by profile 1 % of total time
|
||||||
*
|
*
|
||||||
|
@ -307,24 +393,36 @@ public class ProfileAnalyzer {
|
||||||
*/
|
*/
|
||||||
public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2,
|
public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2,
|
||||||
boolean firstIsDir, boolean secondIsDir, String name1, String name2, final SortBy sortBy) {
|
boolean firstIsDir, boolean secondIsDir, String name1, String name2, final SortBy sortBy) {
|
||||||
|
return compareProfiles(Config.builder()
|
||||||
|
.profile1(file1)
|
||||||
|
.profile2(file2)
|
||||||
|
.profile1Format(format1)
|
||||||
|
.profile2Format(format2)
|
||||||
|
.profile1IsDir(firstIsDir)
|
||||||
|
.profile2IsDir(secondIsDir)
|
||||||
|
.p1Name(name1)
|
||||||
|
.p2Name(name2)
|
||||||
|
.sortBy(sortBy)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
|
||||||
TraceEvent[] t1 = firstIsDir ? getTraceEventsDir(file1, format1) : getTraceEvents(file1, format1);
|
public static String compareProfiles(final Config c){
|
||||||
TraceEvent[] t2 = secondIsDir ? getTraceEventsDir(file2, format2) : getTraceEvents(file2, format2);
|
TraceEvent[] t1 = c.profile1IsDir() ? getTraceEventsDir(c.profile1(), c.profile1Format()) : getTraceEvents(c.profile1(), c.profile1Format());
|
||||||
|
TraceEvent[] t2 = c.profile2IsDir() ? getTraceEventsDir(c.profile2(), c.profile2Format()) : getTraceEvents(c.profile2(), c.profile2Format());
|
||||||
|
|
||||||
final Pair<Long, Map<String, OpStats>> p1 = aggregateTraceEvents(t1);
|
final Pair<Long, Map<String, OpStats>> p1 = aggregateTraceEvents(t1);
|
||||||
final Pair<Long, Map<String, OpStats>> p2 = aggregateTraceEvents(t2);
|
final Pair<Long, Map<String, OpStats>> p2 = aggregateTraceEvents(t2);
|
||||||
|
|
||||||
List<String> l = new ArrayList<>(sortBy != SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet());
|
List<String> l = new ArrayList<>(c.sortBy() != SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet());
|
||||||
Collections.sort(l, new Comparator<String>() {
|
Collections.sort(l, new Comparator<String>() {
|
||||||
@Override
|
@Override
|
||||||
public int compare(String o1, String o2) {
|
public int compare(String o1, String o2) {
|
||||||
switch (sortBy) {
|
switch (c.sortBy()) {
|
||||||
case PROFILE1_PC:
|
case PROFILE1_PC:
|
||||||
return -Long.compare(p1.getSecond().get(o1).sumUs, p1.getSecond().get(o2).sumUs);
|
return -Long.compare(p1.getSecond().get(o1).getSumUs(), p1.getSecond().get(o2).getSumUs());
|
||||||
case PROFILE2_PC:
|
case PROFILE2_PC:
|
||||||
return -Long.compare(p2.getSecond().get(o1).sumUs, p2.getSecond().get(o2).sumUs);
|
return -Long.compare(p2.getSecond().get(o1).getSumUs(), p2.getSecond().get(o2).getSumUs());
|
||||||
case RATIO:
|
case RATIO:
|
||||||
|
|
||||||
double m1a = meanTime(p1, o1);
|
double m1a = meanTime(p1, o1);
|
||||||
double m1b = meanTime(p1, o2);
|
double m1b = meanTime(p1, o2);
|
||||||
double m2a = meanTime(p2, o1);
|
double m2a = meanTime(p2, o1);
|
||||||
|
@ -342,42 +440,53 @@ public class ProfileAnalyzer {
|
||||||
|
|
||||||
|
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
sb.append("1 = ").append(name1 == null ? "Profile 1" : name1).append("\n")
|
sb.append("1 = ").append(c.p1Name() == null ? "Profile 1" : c.p1Name()).append("\n")
|
||||||
.append("2 = ").append(name2 == null ? "Profile 2" : name2).append("\n");
|
.append("2 = ").append(c.p2Name() == null ? "Profile 2" : c.p2Name()).append("\n");
|
||||||
|
|
||||||
//Work out longest name and op name:
|
//Work out longest name and op name:
|
||||||
int longestName = 30;
|
int longestName = 30;
|
||||||
int longestOpName = 30;
|
int longestOpName = 30;
|
||||||
Map<String, OpStats> stats = sortBy == SortBy.PROFILE2_PC ? p2.getSecond() : p1.getSecond();
|
Map<String, OpStats> stats = c.sortBy() == SortBy.PROFILE2_PC ? p2.getSecond() : p1.getSecond();
|
||||||
for (String s : l) {
|
for (String s : l) {
|
||||||
longestName = Math.max(longestName, s.length() + 1);
|
longestName = Math.max(longestName, s.length() + 1);
|
||||||
longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1);
|
longestOpName = Math.max(longestOpName, stats.get(s).getOpName().length() + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-16s%-13s%-13s%-14s%-14s%-12s%-12s%-14s%-14s%-10s%-10s%-10s%-10s\n";
|
String headerFormat;
|
||||||
|
String format;
|
||||||
|
if(c.format() == null || c.format() == OutputFormat.TEXT){
|
||||||
|
headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-16s%-13s%-13s%-14s%-14s%-12s%-12s%-14s%-14s%-10s%-10s%-10s%-10s\n";
|
||||||
|
format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-16.2f%-13.2f%-13.2f%-14d%-14d%-12.2f%-12.2f%-14d%-14d%-10d%-10d%-10.2f%-10.2f\n";
|
||||||
|
} else {
|
||||||
|
headerFormat = "%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s\n";
|
||||||
|
format = "%s,%s,%d,%d,%.2f,%.2f,%.2f,%d,%d,%.2f,%.2f,%d,%d,%d,%d,%.2f,%.2f\n";
|
||||||
|
}
|
||||||
sb.append(String.format(headerFormat, "Op Name", "Op", "Count (1)", "Count (2)", "Mean Ratio 1/2", "Mean (1)", "Mean (2)", "Total uS (1)", "Total uS (2)", "% (1)", "% (2)", "Min (1)", "Min (2)", "Max (1)", "Max (2)", "Std (1)", "Std (2)"));
|
sb.append(String.format(headerFormat, "Op Name", "Op", "Count (1)", "Count (2)", "Mean Ratio 1/2", "Mean (1)", "Mean (2)", "Total uS (1)", "Total uS (2)", "% (1)", "% (2)", "Min (1)", "Min (2)", "Max (1)", "Max (2)", "Std (1)", "Std (2)"));
|
||||||
String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-16.2f%-13.2f%-13.2f%-14d%-14d%-12.2f%-12.2f%-14d%-14d%-10d%-10d%-10.2f%-10.2f\n";
|
|
||||||
|
|
||||||
for (String s : l) {
|
for (String s : l) {
|
||||||
OpStats s1 = p1.getSecond().get(s);
|
OpStats s1 = p1.getSecond().get(s);
|
||||||
OpStats s2 = p2.getSecond().get(s);
|
OpStats s2 = p2.getSecond().get(s);
|
||||||
|
|
||||||
|
if(c.filter() != null && !c.filter().apply(s1, s2))
|
||||||
|
continue;
|
||||||
|
|
||||||
double m1 = s1 == null ? 0 : s1.getTimesUs().array().meanNumber().doubleValue();
|
double m1 = s1 == null ? 0 : s1.getTimesUs().array().meanNumber().doubleValue();
|
||||||
double m2 = s2 == null ? 0 : s2.getTimesUs().array().meanNumber().doubleValue();
|
double m2 = s2 == null ? 0 : s2.getTimesUs().array().meanNumber().doubleValue();
|
||||||
double ratio = m1 / m2;
|
double ratio = m1 / m2;
|
||||||
|
|
||||||
double pc1 = s1 == null ? 0 : 100.0 * s1.sumUs / p1.getFirst();
|
double pc1 = s1 == null ? 0 : 100.0 * s1.getSumUs() / p1.getFirst();
|
||||||
double pc2 = s2 == null ? 0 : 100.0 * s2.sumUs / p2.getFirst();
|
double pc2 = s2 == null ? 0 : 100.0 * s2.getSumUs() / p2.getFirst();
|
||||||
|
|
||||||
sb.append(String.format(format, s, s1 != null ? s1.opName : s2.opName,
|
sb.append(String.format(format, s, s1 != null ? s1.getOpName() : s2.getOpName(),
|
||||||
s1 != null ? s1.count : 0,
|
s1 != null ? s1.getCount() : 0,
|
||||||
s2 != null ? s2.count : 0,
|
s2 != null ? s2.getCount() : 0,
|
||||||
//Ratio of means, means
|
//Ratio of means, means
|
||||||
ratio,
|
ratio,
|
||||||
m1, m2,
|
m1, m2,
|
||||||
//Total us, percent of op total
|
//Total us, percent of op total
|
||||||
s1 != null ? s1.sumUs : 0,
|
s1 != null ? s1.getSumUs() : 0,
|
||||||
s2 != null ? s2.sumUs : 0,
|
s2 != null ? s2.getSumUs() : 0,
|
||||||
pc1, pc2,
|
pc1, pc2,
|
||||||
//Min, max, std
|
//Min, max, std
|
||||||
s1 != null ? s1.getTimesUs().array().minNumber().longValue() : 0,
|
s1 != null ? s1.getTimesUs().array().minNumber().longValue() : 0,
|
||||||
|
@ -391,36 +500,57 @@ public class ProfileAnalyzer {
|
||||||
boolean header = false;
|
boolean header = false;
|
||||||
String headerFormat2 = null;
|
String headerFormat2 = null;
|
||||||
String format3 = null;
|
String format3 = null;
|
||||||
for (String s : (sortBy == SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet())) {
|
List<String> toAppend = null;
|
||||||
|
for (String s : (c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet())) {
|
||||||
|
|
||||||
if (!set.contains(s)) {
|
if (!set.contains(s)) {
|
||||||
Map<String, OpStats> m = sortBy == SortBy.PROFILE2_PC ? p1.getSecond() : p2.getSecond();
|
Map<String, OpStats> m = c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond() : p2.getSecond();
|
||||||
|
OpStats st = m.get(s);
|
||||||
|
if(c.filter() != null){
|
||||||
|
OpStats other = c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond().get(s) : p2.getSecond().get(s);
|
||||||
|
boolean keep = c.filter().apply(other, st);
|
||||||
|
if(!keep)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (!header) {
|
if (!header) {
|
||||||
|
toAppend = new ArrayList<>();
|
||||||
|
|
||||||
longestName = 30;
|
longestName = 30;
|
||||||
longestOpName = 30;
|
longestOpName = 30;
|
||||||
for(String s2 : m.keySet()){
|
for(String s2 : m.keySet()){
|
||||||
longestName = Math.max(longestName, s2.length()+1);
|
longestName = Math.max(longestName, s2.length()+1);
|
||||||
longestOpName = Math.max(longestOpName, m.get(s2).opName.length()+1);
|
longestOpName = Math.max(longestOpName, m.get(s2).getOpName().length()+1);
|
||||||
|
}
|
||||||
|
if(c.format() == null || c.format() == OutputFormat.TEXT) {
|
||||||
|
headerFormat2 = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n";
|
||||||
|
format3 = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
|
||||||
|
} else {
|
||||||
|
headerFormat2 = "%s,%s,%s,%s,%s,%s,%s,%s\n";
|
||||||
|
format3 = "%s,%s,%d,%d,%.2f,%d,%d,%.2f\n";
|
||||||
}
|
}
|
||||||
headerFormat2 = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n";
|
|
||||||
format3 = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
|
|
||||||
|
|
||||||
sb.append(" *** Operations not in profile ").append(sortBy == SortBy.PROFILE2_PC ? "1" : "2").append(" but in profile ")
|
sb.append(" *** Operations not in profile ").append(c.sortBy() == SortBy.PROFILE2_PC ? "1" : "2").append(" but in profile ")
|
||||||
.append(sortBy == SortBy.PROFILE2_PC ? "2" : "1").append(" ***\n");
|
.append(c.sortBy() == SortBy.PROFILE2_PC ? "2" : "1").append(" ***\n");
|
||||||
sb.append(String.format(headerFormat2, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std"));
|
sb.append(String.format(headerFormat2, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std"));
|
||||||
header = true;
|
header = true;
|
||||||
}
|
}
|
||||||
long allOpsUs = sortBy == SortBy.PROFILE2_PC ? p1.getFirst() : p2.getFirst();
|
long allOpsUs = c.sortBy() == SortBy.PROFILE2_PC ? p1.getFirst() : p2.getFirst();
|
||||||
OpStats st = m.get(s);
|
|
||||||
double pc = (100.0 * st.getTimesUs().array().sumNumber().longValue()) / allOpsUs;
|
double pc = (100.0 * st.getTimesUs().array().sumNumber().longValue()) / allOpsUs;
|
||||||
INDArray arr = st.timesUs.array();
|
INDArray arr = st.getTimesUs().array();
|
||||||
long min = arr.minNumber().longValue();
|
long min = arr.minNumber().longValue();
|
||||||
long max = arr.maxNumber().longValue();
|
long max = arr.maxNumber().longValue();
|
||||||
double std = arr.stdNumber().doubleValue();
|
double std = arr.stdNumber().doubleValue();
|
||||||
sb.append(String.format(format3, s, st.opName, st.count, st.getSumUs(), pc, min, max, std));
|
toAppend.add(String.format(format3, s, st.getOpName(), st.getCount(), st.getSumUs(), pc, min, max, std));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if(toAppend != null){
|
||||||
|
Collections.sort(toAppend);
|
||||||
|
for(String s : toAppend){
|
||||||
|
sb.append(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return sb.toString();
|
return sb.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue