', at least in SUN / Oracle JVMs
+ final String jvmName = ManagementFactory.getRuntimeMXBean().getName();
+ final int index = jvmName.indexOf('@');
+
+ if (index < 1) {
+ // part before '@' empty (index = 0) / '@' not found (index = -1)
+ return 0;
+ }
+
+ try {
+ return Long.parseLong(jvmName.substring(0, index));
+ } catch (NumberFormatException e) {
+ // ignore
+ }
+ return 0;
+ }
+
+ /**
+ * Get a new JSON mapper for use in serializing/deserializing JSON format
+ */
+ public static ObjectMapper jsonMapper() {
+ ObjectMapper json = new ObjectMapper();
+ json.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+ json.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+ json.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false);
+ json.disable(SerializationFeature.INDENT_OUTPUT); //One line
+
+ return json;
+ }
+
+ /**
+ * Create a new builder
+ * @param outputFile Output file. Will be overwritten if file already exists
+ */
+ public static Builder builder(File outputFile) {
+ return new Builder(outputFile);
+ }
+
+ public static class Builder {
+ private final File outputFile;
+ private boolean all = true;
+ private int warmup = 0;
+ private int nIter = -1;
+ private long nMs = -1;
+ private Operation[] operations;
+
+ public Builder(@NonNull File outputFile) {
+ this.outputFile = outputFile;
+ }
+
+ /**
+ * If called, all data will be profiled with no limits (other than a warmup, if set)
+ */
+ public Builder recordAll() {
+ this.all = true;
+ this.nIter = -1;
+ this.nMs = -1;
+ return this;
+ }
+
+ /**
+ * Specify the number of warmup iterations - i.e., these will be excluded from profiling results
+ */
+ public Builder warmup(int iterations) {
+ this.warmup = iterations;
+ return this;
+ }
+
+ /**
+ * Set a limit on the maximum number of iterations to profile (after warmup, if any).
+ * Any ops executed after the specified number of iterations will not be profiled/recorded
+ */
+ public Builder maxProfileIterations(int iterations) {
+ this.nIter = iterations;
+ this.all = false;
+ return this;
+ }
+
+ /**
+ * Set a limit on the maximum duration for profiling, in milliseconds.
+ * Any ops executed after the specified amount of time since the first (non-warmup) operation start will not be
+ * profiled/recorded
+ */
+ public Builder maxProfilerMilliseconds(long ms) {
+ this.nMs = ms;
+ this.all = false;
+ return this;
+ }
+
+ /**
+ * Specify the operations (training, inference, etc) to profile.
+ * If not set, all operations are profiled
+ */
+ public Builder operations(Operation... operations) {
+ this.operations = operations;
+ return this;
+ }
+
+ /**
+ * Create the profiling listener
+ */
+ public ProfilingListener build() {
+ return new ProfilingListener(outputFile, all, warmup, nIter, nMs, operations);
+ }
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java
new file mode 100644
index 000000000..c8f3340d7
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java
@@ -0,0 +1,441 @@
+/* ******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+package org.nd4j.autodiff.listeners.profiler.comparison;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.NonNull;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.io.FileUtils;
+import org.nd4j.autodiff.functions.DifferentialFunction;
+import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
+import org.nd4j.autodiff.listeners.profiler.data.Phase;
+import org.nd4j.autodiff.listeners.profiler.data.TraceEvent;
+import org.nd4j.autodiff.listeners.profiler.data.TraceEvents;
+import org.nd4j.base.Preconditions;
+import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.primitives.Pair;
+import org.nd4j.list.NDArrayList;
+import org.nd4j.shade.jackson.databind.ObjectMapper;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.*;
+
+/**
+ * A profile analyzer, used for analyzing Chrome-format profiler dumps generated by both SameDiff's
+ * {@link ProfilingListener} and TensorFlow's profiler.
+ * Has methods for summarizing profiler statistics, as well as comparing two profiler dumps.
+ *
+ * Also supports analyzing/aggregating multiple JSON files in a directory, via the "...Directory(...)" methods.
+ *
+ * See {@link ProfilingListener}
+ * See {@link TraceEvent}
+ *
+ * @author Alex Black
+ */
+@Slf4j
+public class ProfileAnalyzer {
+
+ /**
+ * Chrome profiler supports 2 formats:
+ * SameDiff == JSON Array Format
+ * TensorFlow == JSON Object Format
+ */
+ public enum ProfileFormat {SAMEDIFF, TENSORFLOW}
+
+ /**
+ * Only applicable for profile comparisons.
+ * PROFILE1_PC - sort by profile 1 percentage of total time
+ * PROFILE2_PC - sort by profile 2 percentage of total time
+ * RATIO - sort by highest ratio (mean op time profile 1 / mean op time profile 2)
+ */
+ public enum SortBy {PROFILE1_PC, PROFILE2_PC, RATIO}
+
+
+ /**
+ * Summarize and print to stdout the specified profile file
+ *
+ * @param file Profile file
+ * @param profileFormat Format of the profiler file
+ */
+ public static void summarizeProfile(File file, ProfileFormat profileFormat) {
+ System.out.println(summarizeProfileStr(file, profileFormat));
+ }
+
+ /**
+ * Summarize and return as a string the specified profile file
+ *
+ * @param file Profile file
+ * @param profileFormat Format of the profiler file
+ */
+ public static String summarizeProfileStr(File file, ProfileFormat profileFormat) {
+ TraceEvent[] events = getTraceEvents(file, profileFormat);
+ return summarizeTraceEvents(events);
+ }
+
+ /**
+ * Aggregate, summarize and print to stdout all .json profile files in the specified directory (not recursive)
+ *
+ * @param dir Directory containing the profiles
+ * @param profileFormat Profile format
+ */
+ public static void summarizeProfileDirectory(File dir, ProfileFormat profileFormat) {
+ System.out.println(summarizeProfileDirectoryStr(dir, profileFormat));
+ }
+
+ /**
+ * Aggregate, summarize and return as a String all .json profile files in the specified directory (not recursive)
+ *
+ * @param dir Directory containing the profiles
+ * @param profileFormat Profile format
+ */
+ public static String summarizeProfileDirectoryStr(File dir, ProfileFormat profileFormat) {
+ return summarizeTraceEvents(getTraceEventsDir(dir, profileFormat));
+ }
+
+ /**
+ * Load, aggregate and return the TraceEvent object from all profiles in the specified directory
+ *
+ * @param dir Directory containing the profiles
+ * @param profileFormat Profile format
+ */
+ public static TraceEvent[] getTraceEventsDir(File dir, ProfileFormat profileFormat) {
+ File[] files = dir.listFiles();
+ Preconditions.checkState(files != null && files.length > 0, "No profiles found in directory: %s", dir);
+ List l = new ArrayList<>();
+ for (File f : files) {
+ if (!f.getName().endsWith(".json")) {
+ log.info("Skipping non-JSON file in directory - {}", f.getAbsolutePath());
+ continue;
+ }
+ TraceEvent[] e = getTraceEvents(f, profileFormat);
+ Collections.addAll(l, e);
+ }
+ return l.toArray(new TraceEvent[0]);
+ }
+
+ /**
+ * Load and return the TraceEvent object from the specified profile file
+ *
+ * @param file Profile file
+ * @param profileFormat Profile format
+ */
+ public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat) {
+ ObjectMapper json = ProfilingListener.jsonMapper();
+
+ String content;
+ try {
+ content = FileUtils.readFileToString(file, StandardCharsets.UTF_8);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ if (!content.matches(".*]\\s*")) {
+ if (content.endsWith(",")) {
+ //Has comma, missing ]
+ content = content.substring(0, content.length() - 1) + "]";
+ } else if (content.endsWith(",\n")) {
+ //Has comma and newline, missing ]
+ content = content.substring(0, content.length() - 2) + "]";
+ } else {
+ content = content + "]";
+ }
+ }
+
+ TraceEvent[] events;
+ if (profileFormat == ProfileFormat.SAMEDIFF) {
+ try {
+ events = json.readValue(content, TraceEvent[].class);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ } else {
+ //TF format
+ TraceEvents traceEvents;
+ try {
+ traceEvents = json.readValue(content, TraceEvents.class);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ events = traceEvents.getTraceEvents().toArray(new TraceEvent[0]);
+
+ //Clean up TF format - sometimes things like "Softmax" are actually profiled as "_MklSoftmax"
+ //And we'll align TF names to SameDiff names
+ for (TraceEvent te : events) {
+ if (TF_PROFILE_ALIASES.containsKey(te.getName())) {
+ te.setName(TF_PROFILE_ALIASES.get(te.getName()));
+ }
+
+ DifferentialFunction df = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(te.getName());
+ if (df != null) {
+ te.setName(df.opName());
+ }
+ }
+ }
+
+ return events;
+ }
+
+ /**
+ * Summarize the specified TraceEvents as a String
+ *
+ * @param events Events to summarize
+ */
+ public static String summarizeTraceEvents(TraceEvent[] events) {
+ Pair> p = aggregateTraceEvents(events);
+ final Map stats = p.getSecond();
+ long allOpsUs = p.getFirst();
+
+ //Summarize by op type:
+ List l = new ArrayList<>(stats.keySet());
+ Collections.sort(l, new Comparator() {
+ @Override
+ public int compare(String o1, String o2) {
+ return -Long.compare(stats.get(o1).sumUs, stats.get(o2).sumUs);
+ }
+ });
+
+ //Work out longest name and op name:
+ int longestName = 30;
+ int longestOpName = 30;
+ for (String s : l) {
+ longestName = Math.max(longestName, s.length() + 1);
+ longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1);
+ }
+
+ StringBuilder sb = new StringBuilder();
+ String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n";
+ sb.append(String.format(headerFormat, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std"));
+ String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
+ for (String s : l) {
+ OpStats st = stats.get(s);
+ double pc = (100.0 * st.sumUs) / allOpsUs;
+ INDArray arr = st.timesUs.array();
+ long min = arr.minNumber().longValue();
+ long max = arr.maxNumber().longValue();
+ double std = arr.stdNumber().doubleValue();
+ sb.append(String.format(format, s, st.opName, st.count, st.getSumUs(), pc, min, max, std));
+ }
+
+ return sb.toString();
+ }
+
+ private static Pair> aggregateTraceEvents(TraceEvent[] events) {
+ //Summarize by op (instance) name:
+ final Map stats = new HashMap<>();
+ for (TraceEvent e : events) {
+ if (e.getPh() != Phase.X || e.getDur() == null) {
+ continue;
+ }
+
+ OpStats s;
+ String instanceName = (String) e.getArgs().get("name");
+ if (stats.containsKey(instanceName)) {
+ s = stats.get(instanceName);
+ } else {
+ s = new OpStats(e.getName(), 0, new NDArrayList(DataType.LONG, 0), null);
+ stats.put(instanceName, s);
+ }
+ s.count++;
+ s.timesUs.add((double) e.getDur());
+ }
+
+ long allOpsUs = 0;
+ for (OpStats s : stats.values()) {
+ s.sumUs = s.timesUs.array().sumNumber().longValue();
+ allOpsUs += s.sumUs;
+ }
+
+ return new Pair<>(allOpsUs, stats);
+ }
+
+ @AllArgsConstructor
+ @NoArgsConstructor
+ @Data
+ private static class OpStats {
+ private String opName;
+ private int count;
+ private NDArrayList timesUs;
+ private Long sumUs;
+
+ }
+
+ /**
+ * Compare the specified profile files, sorted by profile 1 % of total time
+ *
+ * @param file1 First profile file
+ * @param file2 Second profile file
+ * @param format1 Format of first profile
+ * @param format2 Format of second profile
+ * @return Comparison summary as a String
+ */
+ public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2) {
+ return compareProfiles(file1, file2, format1, format2, false, false, null, null, SortBy.PROFILE1_PC);
+ }
+
+ /**
+ * Compare the specified profile files or directory
+ *
+ * @param file1 First profile file or directory of profiles
+ * @param file2 Second profile file or directory of profiles
+ * @param format1 Format for first profile file/s
+ * @param format2 Format for second profile file/s
+ * @param firstIsDir True if the first File object is a directory
+ * @param secondIsDir True if the second File object is a directory
+ * @param name1 Name of the first profile (just for display purposes). Optional
+ * @param name2 Name of the second profile (just for display purposes). Optional
+ * @param sortBy What to sort the summary results by
+ * @return Comparison summary as a String
+ */
+ public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2,
+ boolean firstIsDir, boolean secondIsDir, String name1, String name2, final SortBy sortBy) {
+
+ TraceEvent[] t1 = firstIsDir ? getTraceEventsDir(file1, format1) : getTraceEvents(file1, format1);
+ TraceEvent[] t2 = secondIsDir ? getTraceEventsDir(file2, format2) : getTraceEvents(file2, format2);
+
+ final Pair> p1 = aggregateTraceEvents(t1);
+ final Pair> p2 = aggregateTraceEvents(t2);
+
+ List l = new ArrayList<>(sortBy != SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet());
+ Collections.sort(l, new Comparator() {
+ @Override
+ public int compare(String o1, String o2) {
+ switch (sortBy) {
+ case PROFILE1_PC:
+ return -Long.compare(p1.getSecond().get(o1).sumUs, p1.getSecond().get(o2).sumUs);
+ case PROFILE2_PC:
+ return -Long.compare(p2.getSecond().get(o1).sumUs, p2.getSecond().get(o2).sumUs);
+ case RATIO:
+
+ double m1a = meanTime(p1, o1);
+ double m1b = meanTime(p1, o2);
+ double m2a = meanTime(p2, o1);
+ double m2b = meanTime(p2, o2);
+ double ratio1 = m1a / m2a;
+ double ratio2 = m1b / m2b;
+ return -Double.compare(ratio1, ratio2);
+ default:
+ throw new RuntimeException();
+ }
+ }
+ });
+
+ Set set = new HashSet<>(l);
+
+
+ StringBuilder sb = new StringBuilder();
+ sb.append("1 = ").append(name1 == null ? "Profile 1" : name1).append("\n")
+ .append("2 = ").append(name2 == null ? "Profile 2" : name2).append("\n");
+
+ //Work out longest name and op name:
+ int longestName = 30;
+ int longestOpName = 30;
+ Map stats = sortBy == SortBy.PROFILE2_PC ? p2.getSecond() : p1.getSecond();
+ for (String s : l) {
+ longestName = Math.max(longestName, s.length() + 1);
+ longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1);
+ }
+
+ String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-16s%-13s%-13s%-14s%-14s%-12s%-12s%-14s%-14s%-10s%-10s%-10s%-10s\n";
+ sb.append(String.format(headerFormat, "Op Name", "Op", "Count (1)", "Count (2)", "Mean Ratio 1/2", "Mean (1)", "Mean (2)", "Total uS (1)", "Total uS (2)", "% (1)", "% (2)", "Min (1)", "Min (2)", "Max (1)", "Max (2)", "Std (1)", "Std (2)"));
+ String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-16.2f%-13.2f%-13.2f%-14d%-14d%-12.2f%-12.2f%-14d%-14d%-10d%-10d%-10.2f%-10.2f\n";
+
+ for (String s : l) {
+ OpStats s1 = p1.getSecond().get(s);
+ OpStats s2 = p2.getSecond().get(s);
+
+ double m1 = s1 == null ? 0 : s1.getTimesUs().array().meanNumber().doubleValue();
+ double m2 = s2 == null ? 0 : s2.getTimesUs().array().meanNumber().doubleValue();
+ double ratio = m1 / m2;
+
+ double pc1 = s1 == null ? 0 : 100.0 * s1.sumUs / p1.getFirst();
+ double pc2 = s2 == null ? 0 : 100.0 * s2.sumUs / p2.getFirst();
+
+ sb.append(String.format(format, s, s1 != null ? s1.opName : s2.opName,
+ s1 != null ? s1.count : 0,
+ s2 != null ? s2.count : 0,
+ //Ratio of means, means
+ ratio,
+ m1, m2,
+ //Total us, percent of op total
+ s1 != null ? s1.sumUs : 0,
+ s2 != null ? s2.sumUs : 0,
+ pc1, pc2,
+ //Min, max, std
+ s1 != null ? s1.getTimesUs().array().minNumber().longValue() : 0,
+ s2 != null ? s2.getTimesUs().array().minNumber().longValue() : 0,
+ s1 != null ? s1.getTimesUs().array().maxNumber().longValue() : 0,
+ s2 != null ? s2.getTimesUs().array().maxNumber().longValue() : 0,
+ s1 != null ? s1.getTimesUs().array().stdNumber().doubleValue() : 0.0,
+ s2 != null ? s2.getTimesUs().array().stdNumber().doubleValue() : 0.0));
+ }
+
+ boolean header = false;
+ String headerFormat2 = null;
+ String format3 = null;
+ for (String s : (sortBy == SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet())) {
+
+ if (!set.contains(s)) {
+ Map m = sortBy == SortBy.PROFILE2_PC ? p1.getSecond() : p2.getSecond();
+ if (!header) {
+
+ longestName = 30;
+ longestOpName = 30;
+ for(String s2 : m.keySet()){
+ longestName = Math.max(longestName, s2.length()+1);
+ longestOpName = Math.max(longestOpName, m.get(s2).opName.length()+1);
+ }
+ headerFormat2 = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n";
+ format3 = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
+
+ sb.append(" *** Operations not in profile ").append(sortBy == SortBy.PROFILE2_PC ? "1" : "2").append(" but in profile ")
+ .append(sortBy == SortBy.PROFILE2_PC ? "2" : "1").append(" ***\n");
+ sb.append(String.format(headerFormat2, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std"));
+ header = true;
+ }
+ long allOpsUs = sortBy == SortBy.PROFILE2_PC ? p1.getFirst() : p2.getFirst();
+ OpStats st = m.get(s);
+ double pc = (100.0 * st.getTimesUs().array().sumNumber().longValue()) / allOpsUs;
+ INDArray arr = st.timesUs.array();
+ long min = arr.minNumber().longValue();
+ long max = arr.maxNumber().longValue();
+ double std = arr.stdNumber().doubleValue();
+ sb.append(String.format(format3, s, st.opName, st.count, st.getSumUs(), pc, min, max, std));
+ }
+ }
+ return sb.toString();
+ }
+
+ private static double meanTime(Pair> p, String name) {
+ if (!p.getSecond().containsKey(name)) {
+ return 0.0;
+ }
+ return p.getSecond().get(name).getTimesUs().array().meanNumber().doubleValue();
+ }
+
+
+ private static Map TF_PROFILE_ALIASES = new HashMap<>();
+
+ static {
+ TF_PROFILE_ALIASES.put("_MklSoftmax", "Softmax");
+ }
+
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/ColorName.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/ColorName.java
new file mode 100644
index 000000000..0d9c08deb
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/ColorName.java
@@ -0,0 +1,19 @@
+/* ******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+package org.nd4j.autodiff.listeners.profiler.data;
+
+public enum ColorName {
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/Phase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/Phase.java
new file mode 100644
index 000000000..bca7feb39
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/Phase.java
@@ -0,0 +1,44 @@
+/* ******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+package org.nd4j.autodiff.listeners.profiler.data;
+
+/**
+ * Chrome Profiler phase, for details see:
+ *
+ * https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit
+ */
+public enum Phase {
+ B,
+ E,
+ X,
+ I,
+ C,
+ b,
+ n,
+ e,
+ s,
+ t,
+ f,
+ P,
+ N,
+ O,
+ D,
+ M,
+ V,
+ v,
+ R,
+ c
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvent.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvent.java
new file mode 100644
index 000000000..e4270edd1
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvent.java
@@ -0,0 +1,53 @@
+/* ******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+package org.nd4j.autodiff.listeners.profiler.data;
+
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A TraceEvent, such as an operation execution.
+ * Intended mainly for JSON serialization/deserialization in Chrome profiler format
+ * Profiler format is described here:
+ * https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit
+ * See {@link org.nd4j.autodiff.listeners.profiler.ProfilingListener}
+ * See {@link org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer}
+ *
+ * @author Alex Black
+ */
+@Builder
+@Data
+@AllArgsConstructor
+@NoArgsConstructor
+public class TraceEvent {
+
+ private String name; //Name of event (usually op name)
+ private List categories; //Comma separated list of categories
+ private Phase ph; //Event type - phase (see table for options)
+ private long ts; //Timestamp, in microseconds (us)
+ private Long dur; //Duration, optional
+ private Long tts; //Optional, thlread timestamp, in microseconds
+ private long pid; //Process ID
+ private long tid; //Thread ID
+ private Map args; //Args
+ private ColorName cname; //Optional, color name (must be one of reserved color names: https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html )
+
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvents.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvents.java
new file mode 100644
index 000000000..b3ebf6d8a
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/data/TraceEvents.java
@@ -0,0 +1,34 @@
+/* ******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+package org.nd4j.autodiff.listeners.profiler.data;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.List;
+
+/**
+ * A simple holder for a list of trace events
+ *
+ * @author Alex Black
+ */
+@AllArgsConstructor
+@NoArgsConstructor
+@Data
+public class TraceEvents {
+ private List traceEvents;
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java
index 939708291..f3b539d8c 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/list/NDArrayList.java
@@ -18,6 +18,7 @@ package org.nd4j.list;
import lombok.NonNull;
import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
@@ -46,7 +47,12 @@ public class NDArrayList extends BaseNDArrayList {
* @param size the initial size of the array
*/
public NDArrayList(int size) {
- this.container = Nd4j.create(10L);
+ this(DataType.DOUBLE, size);
+ }
+
+ public NDArrayList(DataType dataType, int size) {
+ Preconditions.checkState(size >= 0, "Size must be non-negative - got %s", size);
+ this.container = Nd4j.create(dataType, Math.max(10L, size));
this.size = size;
}
@@ -84,6 +90,7 @@ public class NDArrayList extends BaseNDArrayList {
* directly, this gives you the relevant subset that reflects the content of the list)
* @return the view of the underlying ndarray relative to the collection's real size
*/
+ @Override
public INDArray array() {
if(isEmpty()) {
throw new ND4JIllegalStateException("Array is empty!");
@@ -137,6 +144,8 @@ public class NDArrayList extends BaseNDArrayList {
return true;
}
+
+
@Override
public boolean remove(Object o) {
int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ProfilingListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ProfilingListenerTest.java
new file mode 100644
index 000000000..c026446ad
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ProfilingListenerTest.java
@@ -0,0 +1,105 @@
+package org.nd4j.autodiff.samediff.listeners;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
+import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.BaseNd4jTest;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.factory.Nd4jBackend;
+
+import java.io.File;
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+public class ProfilingListenerTest extends BaseNd4jTest {
+
+ public ProfilingListenerTest(Nd4jBackend backend) {
+ super(backend);
+ }
+
+ @Override
+ public char ordering() {
+ return 'c';
+ }
+
+ @Rule
+ public TemporaryFolder testDir = new TemporaryFolder();
+
+ @Test
+ public void testProfilingListenerSimple() throws Exception {
+
+ SameDiff sd = SameDiff.create();
+ SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
+ SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2);
+ SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2));
+ SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2));
+ SDVariable sm = sd.nn.softmax("predictions", in.mmul("matmul", w).add("addbias", b));
+ SDVariable loss = sd.loss.logLoss("loss", label, sm);
+
+ INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3);
+ INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2);
+
+
+ File dir = testDir.newFolder();
+ File f = new File(dir, "test.json");
+ ProfilingListener listener = ProfilingListener.builder(f)
+ .recordAll()
+ .warmup(5)
+ .build();
+
+ sd.setListeners(listener);
+
+ Map ph = new HashMap<>();
+ ph.put("in", i);
+
+ for( int x=0; x<10; x++ ) {
+ sd.outputSingle(ph, "predictions");
+ }
+
+ String content = FileUtils.readFileToString(f, StandardCharsets.UTF_8);
+ System.out.println(content);
+
+ //Should be 2 begins and 2 ends for each entry
+ //5 warmup iterations, 5 profile iterations, x2 for both the op name and the op "instance" name
+ String[] opNames = {"mmul", "add", "softmax"};
+ for(String s : opNames){
+ assertEquals(s, 10, StringUtils.countMatches(content, s));
+ }
+
+
+ System.out.println("///////////////////////////////////////////");
+ ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.SAMEDIFF);
+
+ }
+
+ /*
+ @Test
+ public void testLoadTfProfile(){
+ File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json");
+ ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
+ }
+
+ @Test
+ public void testLoadTfProfileDir(){
+ File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles");
+ ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
+ }
+
+ @Test
+ public void testLoadTfProfileDir2(){
+ File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0");
+ ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
+ }
+ */
+}