SameDiff profiler / tracing and profile analysis/comparison (#133)
* Profiler Signed-off-by: Alex Black <blacka101@gmail.com> * Next steps, polishing, and loading SD/TF format JSON Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Profile comparison method Signed-off-by: AlexDBlack <blacka101@gmail.com> * Make profiling result writing async to reduce main thread overhead Signed-off-by: AlexDBlack <blacka101@gmail.com> * Profiling polishing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Profile analyzer fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Polish Signed-off-by: Alex Black <blacka101@gmail.com> * Cleanup Signed-off-by: Alex Black <blacka101@gmail.com> * Small formatting improvement Signed-off-by: Alex Black <blacka101@gmail.com> * Formatting tweak Signed-off-by: Alex Black <blacka101@gmail.com> * License headers Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
e303c06042
commit
3d8f6d50a1
|
@ -0,0 +1,354 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.autodiff.listeners.profiler;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.Loss;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.listeners.profiler.data.Phase;
|
||||
import org.nd4j.autodiff.listeners.profiler.data.TraceEvent;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.AtomicBoolean;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||
import org.nd4j.shade.jackson.databind.MapperFeature;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||
|
||||
import java.io.*;
|
||||
import java.lang.management.ManagementFactory;
|
||||
import java.text.DecimalFormat;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||
import java.util.concurrent.LinkedBlockingDeque;
|
||||
|
||||
/**
|
||||
* SameDiff profiling listener: for profiling operation execution<br>
|
||||
* Writes profiles to a file in JSON format<br>
|
||||
* Format is Chrome profiler format. The output can be read by Google Chrome browser; open Chrome and go to:
|
||||
* chrome://tracing and load the output JSON format data
|
||||
* <br>
|
||||
* At present, only operation execution is profiled, not other aspects such as memory allocation and training-related
|
||||
* functionality.<br>
|
||||
* <br>
|
||||
* Tracing can be configured in a few different ways via the builder, {@link #builder(File)}:<br>
|
||||
* (a) warmup - don't record traces for the first N iterations<br>
|
||||
* (b) "all" mode (default) - record all-iterations, with no limit (after warmup, if applicable)<br>
|
||||
* (c) "n iterations" mode: record at most the first N iterations (after warmup, if applicable)<br>
|
||||
* (d) "n ms" mod: record for at most N milliseconds since the start of the first op execution (after warmup, if applicable)<br>
|
||||
*
|
||||
* Note: The Chrome Trace Event format can be found here:<br>
|
||||
* <a href="https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit">https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit</a>
|
||||
* SameDiff uses the JSON Array Format, as this can be written in an online/streaming manner.<br>
|
||||
* Conversely, TensorFlow uses the JSON Object Format.<br>
|
||||
* <br>
|
||||
* For summarizing, analyzing and comparing the results (SameDiff or TensorFlow format), see {@link org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer}<br>
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Getter
|
||||
@Slf4j
|
||||
public class ProfilingListener extends BaseListener {
|
||||
|
||||
private final File outputFile;
|
||||
private final boolean all;
|
||||
private final int warmup;
|
||||
private final int nIter;
|
||||
private final long nMs;
|
||||
private final Operation[] operations;
|
||||
|
||||
private final long pid;
|
||||
private final long tid;
|
||||
private Long firstOpStart = null; //Used for time termination
|
||||
private int countTotalIter = 0;
|
||||
private boolean logActive = false;
|
||||
private long opStartNano;
|
||||
|
||||
private Writer writer;
|
||||
private ObjectMapper json;
|
||||
|
||||
private final Thread fileWritingThread;
|
||||
private final BlockingQueue<TraceEvent> writeQueue;
|
||||
private final AtomicBoolean writing = new AtomicBoolean(false);
|
||||
|
||||
protected ProfilingListener(@NonNull File outputFile, boolean all, int warmup, int nIter, long nMs, Operation[] operations) {
|
||||
Preconditions.checkArgument(!outputFile.exists(), "Output file already exists: %s", outputFile);
|
||||
this.outputFile = outputFile;
|
||||
this.all = all;
|
||||
this.warmup = warmup;
|
||||
this.nIter = nIter;
|
||||
this.nMs = nMs;
|
||||
this.operations = operations;
|
||||
|
||||
this.pid = getProcessId();
|
||||
this.tid = Thread.currentThread().getId();
|
||||
|
||||
try {
|
||||
this.writer = new BufferedWriter(new FileWriter(outputFile, false));
|
||||
this.writer.write("["); //JSON array open (array close is optional for Chrome profiler format)
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
this.json = jsonMapper();
|
||||
|
||||
//Set up a queue so file access doesn't add latency to the execution thread
|
||||
writeQueue = new LinkedBlockingDeque<>();
|
||||
fileWritingThread = new Thread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
runHelper();
|
||||
} catch (Throwable t) {
|
||||
log.error("Error when attempting to write results to file", t);
|
||||
}
|
||||
}
|
||||
|
||||
public void runHelper() throws Exception {
|
||||
while (true) {
|
||||
TraceEvent te = writeQueue.take(); //Blocking
|
||||
writing.set(true);
|
||||
try {
|
||||
String j = json.writeValueAsString(te);
|
||||
writer.append(j);
|
||||
writer.append(",\n");
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
} finally {
|
||||
writing.set(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
fileWritingThread.setDaemon(true);
|
||||
fileWritingThread.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return operations == null || ArrayUtils.contains(operations, operation);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void operationStart(SameDiff sd, Operation op) {
|
||||
this.logActive = operations == null || ArrayUtils.contains(operations, op);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void operationEnd(SameDiff sd, Operation op) {
|
||||
if (this.logActive) {
|
||||
while ((!writeQueue.isEmpty() || writing.get()) && fileWritingThread.isAlive()) {
|
||||
//Wait for file writing thread to catch up
|
||||
try {
|
||||
Thread.sleep(100);
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
try {
|
||||
writer.flush();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
this.logActive = false;
|
||||
if (op == Operation.INFERENCE) {
|
||||
//Increment for inference; iteration done is called only for TRAINING
|
||||
countTotalIter++;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
|
||||
//Increment for training
|
||||
if (logActive) {
|
||||
countTotalIter++;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
||||
if (logActive) {
|
||||
opStartNano = System.nanoTime();
|
||||
|
||||
if(!all && nMs > 0 && firstOpStart == null)
|
||||
firstOpStart = opStartNano;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||
if (logActive) {
|
||||
long now = System.nanoTime();
|
||||
|
||||
if (warmup > 0 && countTotalIter < warmup) {
|
||||
return; //Skip due to warmup phase
|
||||
}
|
||||
|
||||
//Iteration termination
|
||||
int terminationPt = this.nIter > 0 ? this.nIter : Integer.MAX_VALUE;
|
||||
if (warmup > 0 && this.nIter > 0)
|
||||
terminationPt += this.warmup;
|
||||
|
||||
if (countTotalIter > terminationPt) {
|
||||
logActive = false;
|
||||
return; //Skip due to max number of itertions
|
||||
}
|
||||
|
||||
//Time termination
|
||||
if(!all && nMs > 0 && (now - firstOpStart)/1000 > nMs) {
|
||||
logActive = false;
|
||||
return;
|
||||
}
|
||||
|
||||
TraceEvent event = TraceEvent.builder()
|
||||
.name(op.getOp().opName())
|
||||
.categories(Collections.singletonList("Op"))
|
||||
.ts(opStartNano / 1000)
|
||||
.dur((now - opStartNano) / 1000)
|
||||
.pid((int)pid)
|
||||
.tid(tid)
|
||||
.ph(Phase.X)
|
||||
.args(Collections.<String, Object>singletonMap("name", op.getName()))
|
||||
.build();
|
||||
|
||||
writeQueue.add(event);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private long getProcessId() {
|
||||
// Note: may fail in some JVM implementations
|
||||
// therefore fallback has to be provided
|
||||
|
||||
// something like '<pid>@<hostname>', at least in SUN / Oracle JVMs
|
||||
final String jvmName = ManagementFactory.getRuntimeMXBean().getName();
|
||||
final int index = jvmName.indexOf('@');
|
||||
|
||||
if (index < 1) {
|
||||
// part before '@' empty (index = 0) / '@' not found (index = -1)
|
||||
return 0;
|
||||
}
|
||||
|
||||
try {
|
||||
return Long.parseLong(jvmName.substring(0, index));
|
||||
} catch (NumberFormatException e) {
|
||||
// ignore
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a new JSON mapper for use in serializing/deserializing JSON format
|
||||
*/
|
||||
public static ObjectMapper jsonMapper() {
|
||||
ObjectMapper json = new ObjectMapper();
|
||||
json.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||
json.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||
json.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false);
|
||||
json.disable(SerializationFeature.INDENT_OUTPUT); //One line
|
||||
|
||||
return json;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new builder
|
||||
* @param outputFile Output file. Will be overwritten if file already exists
|
||||
*/
|
||||
public static Builder builder(File outputFile) {
|
||||
return new Builder(outputFile);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private final File outputFile;
|
||||
private boolean all = true;
|
||||
private int warmup = 0;
|
||||
private int nIter = -1;
|
||||
private long nMs = -1;
|
||||
private Operation[] operations;
|
||||
|
||||
public Builder(@NonNull File outputFile) {
|
||||
this.outputFile = outputFile;
|
||||
}
|
||||
|
||||
/**
|
||||
* If called, all data will be profiled with no limits (other than a warmup, if set)
|
||||
*/
|
||||
public Builder recordAll() {
|
||||
this.all = true;
|
||||
this.nIter = -1;
|
||||
this.nMs = -1;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Specify the number of warmup iterations - i.e., these will be excluded from profiling results
|
||||
*/
|
||||
public Builder warmup(int iterations) {
|
||||
this.warmup = iterations;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a limit on the maximum number of iterations to profile (after warmup, if any).
|
||||
* Any ops executed after the specified number of iterations will not be profiled/recorded
|
||||
*/
|
||||
public Builder maxProfileIterations(int iterations) {
|
||||
this.nIter = iterations;
|
||||
this.all = false;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a limit on the maximum duration for profiling, in milliseconds.
|
||||
* Any ops executed after the specified amount of time since the first (non-warmup) operation start will not be
|
||||
* profiled/recorded
|
||||
*/
|
||||
public Builder maxProfilerMilliseconds(long ms) {
|
||||
this.nMs = ms;
|
||||
this.all = false;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Specify the operations (training, inference, etc) to profile.
|
||||
* If not set, all operations are profiled
|
||||
*/
|
||||
public Builder operations(Operation... operations) {
|
||||
this.operations = operations;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the profiling listener
|
||||
*/
|
||||
public ProfilingListener build() {
|
||||
return new ProfilingListener(outputFile, all, warmup, nIter, nMs, operations);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,441 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.autodiff.listeners.profiler.comparison;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
|
||||
import org.nd4j.autodiff.listeners.profiler.data.Phase;
|
||||
import org.nd4j.autodiff.listeners.profiler.data.TraceEvent;
|
||||
import org.nd4j.autodiff.listeners.profiler.data.TraceEvents;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.list.NDArrayList;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* A profile analyzer, used for analyzing Chrome-format profiler dumps generated by both SameDiff's<br>
|
||||
* {@link ProfilingListener} and TensorFlow's profiler.<br>
|
||||
* Has methods for summarizing profiler statistics, as well as comparing two profiler dumps.<br>
|
||||
* <br>
|
||||
* Also supports analyzing/aggregating multiple JSON files in a directory, via the "...Directory(...)" methods.
|
||||
* <p>
|
||||
* See {@link ProfilingListener}<br>
|
||||
* See {@link TraceEvent}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Slf4j
|
||||
public class ProfileAnalyzer {
|
||||
|
||||
/**
|
||||
* Chrome profiler supports 2 formats:<br>
|
||||
* SameDiff == JSON Array Format<br>
|
||||
* TensorFlow == JSON Object Format<br>
|
||||
*/
|
||||
public enum ProfileFormat {SAMEDIFF, TENSORFLOW}
|
||||
|
||||
/**
|
||||
* Only applicable for profile comparisons.<br>
|
||||
* PROFILE1_PC - sort by profile 1 percentage of total time<br>
|
||||
* PROFILE2_PC - sort by profile 2 percentage of total time<br>
|
||||
* RATIO - sort by highest ratio (mean op time profile 1 / mean op time profile 2)
|
||||
*/
|
||||
public enum SortBy {PROFILE1_PC, PROFILE2_PC, RATIO}
|
||||
|
||||
|
||||
/**
|
||||
* Summarize and print to stdout the specified profile file
|
||||
*
|
||||
* @param file Profile file
|
||||
* @param profileFormat Format of the profiler file
|
||||
*/
|
||||
public static void summarizeProfile(File file, ProfileFormat profileFormat) {
|
||||
System.out.println(summarizeProfileStr(file, profileFormat));
|
||||
}
|
||||
|
||||
/**
|
||||
* Summarize and return as a string the specified profile file
|
||||
*
|
||||
* @param file Profile file
|
||||
* @param profileFormat Format of the profiler file
|
||||
*/
|
||||
public static String summarizeProfileStr(File file, ProfileFormat profileFormat) {
|
||||
TraceEvent[] events = getTraceEvents(file, profileFormat);
|
||||
return summarizeTraceEvents(events);
|
||||
}
|
||||
|
||||
/**
|
||||
* Aggregate, summarize and print to stdout all .json profile files in the specified directory (not recursive)
|
||||
*
|
||||
* @param dir Directory containing the profiles
|
||||
* @param profileFormat Profile format
|
||||
*/
|
||||
public static void summarizeProfileDirectory(File dir, ProfileFormat profileFormat) {
|
||||
System.out.println(summarizeProfileDirectoryStr(dir, profileFormat));
|
||||
}
|
||||
|
||||
/**
|
||||
* Aggregate, summarize and return as a String all .json profile files in the specified directory (not recursive)
|
||||
*
|
||||
* @param dir Directory containing the profiles
|
||||
* @param profileFormat Profile format
|
||||
*/
|
||||
public static String summarizeProfileDirectoryStr(File dir, ProfileFormat profileFormat) {
|
||||
return summarizeTraceEvents(getTraceEventsDir(dir, profileFormat));
|
||||
}
|
||||
|
||||
/**
|
||||
* Load, aggregate and return the TraceEvent object from all profiles in the specified directory
|
||||
*
|
||||
* @param dir Directory containing the profiles
|
||||
* @param profileFormat Profile format
|
||||
*/
|
||||
public static TraceEvent[] getTraceEventsDir(File dir, ProfileFormat profileFormat) {
|
||||
File[] files = dir.listFiles();
|
||||
Preconditions.checkState(files != null && files.length > 0, "No profiles found in directory: %s", dir);
|
||||
List<TraceEvent> l = new ArrayList<>();
|
||||
for (File f : files) {
|
||||
if (!f.getName().endsWith(".json")) {
|
||||
log.info("Skipping non-JSON file in directory - {}", f.getAbsolutePath());
|
||||
continue;
|
||||
}
|
||||
TraceEvent[] e = getTraceEvents(f, profileFormat);
|
||||
Collections.addAll(l, e);
|
||||
}
|
||||
return l.toArray(new TraceEvent[0]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load and return the TraceEvent object from the specified profile file
|
||||
*
|
||||
* @param file Profile file
|
||||
* @param profileFormat Profile format
|
||||
*/
|
||||
public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat) {
|
||||
ObjectMapper json = ProfilingListener.jsonMapper();
|
||||
|
||||
String content;
|
||||
try {
|
||||
content = FileUtils.readFileToString(file, StandardCharsets.UTF_8);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
if (!content.matches(".*]\\s*")) {
|
||||
if (content.endsWith(",")) {
|
||||
//Has comma, missing ]
|
||||
content = content.substring(0, content.length() - 1) + "]";
|
||||
} else if (content.endsWith(",\n")) {
|
||||
//Has comma and newline, missing ]
|
||||
content = content.substring(0, content.length() - 2) + "]";
|
||||
} else {
|
||||
content = content + "]";
|
||||
}
|
||||
}
|
||||
|
||||
TraceEvent[] events;
|
||||
if (profileFormat == ProfileFormat.SAMEDIFF) {
|
||||
try {
|
||||
events = json.readValue(content, TraceEvent[].class);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else {
|
||||
//TF format
|
||||
TraceEvents traceEvents;
|
||||
try {
|
||||
traceEvents = json.readValue(content, TraceEvents.class);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
events = traceEvents.getTraceEvents().toArray(new TraceEvent[0]);
|
||||
|
||||
//Clean up TF format - sometimes things like "Softmax" are actually profiled as "_MklSoftmax"
|
||||
//And we'll align TF names to SameDiff names
|
||||
for (TraceEvent te : events) {
|
||||
if (TF_PROFILE_ALIASES.containsKey(te.getName())) {
|
||||
te.setName(TF_PROFILE_ALIASES.get(te.getName()));
|
||||
}
|
||||
|
||||
DifferentialFunction df = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(te.getName());
|
||||
if (df != null) {
|
||||
te.setName(df.opName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
/**
|
||||
* Summarize the specified TraceEvents as a String
|
||||
*
|
||||
* @param events Events to summarize
|
||||
*/
|
||||
public static String summarizeTraceEvents(TraceEvent[] events) {
|
||||
Pair<Long, Map<String, OpStats>> p = aggregateTraceEvents(events);
|
||||
final Map<String, OpStats> stats = p.getSecond();
|
||||
long allOpsUs = p.getFirst();
|
||||
|
||||
//Summarize by op type:
|
||||
List<String> l = new ArrayList<>(stats.keySet());
|
||||
Collections.sort(l, new Comparator<String>() {
|
||||
@Override
|
||||
public int compare(String o1, String o2) {
|
||||
return -Long.compare(stats.get(o1).sumUs, stats.get(o2).sumUs);
|
||||
}
|
||||
});
|
||||
|
||||
//Work out longest name and op name:
|
||||
int longestName = 30;
|
||||
int longestOpName = 30;
|
||||
for (String s : l) {
|
||||
longestName = Math.max(longestName, s.length() + 1);
|
||||
longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1);
|
||||
}
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n";
|
||||
sb.append(String.format(headerFormat, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std"));
|
||||
String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
|
||||
for (String s : l) {
|
||||
OpStats st = stats.get(s);
|
||||
double pc = (100.0 * st.sumUs) / allOpsUs;
|
||||
INDArray arr = st.timesUs.array();
|
||||
long min = arr.minNumber().longValue();
|
||||
long max = arr.maxNumber().longValue();
|
||||
double std = arr.stdNumber().doubleValue();
|
||||
sb.append(String.format(format, s, st.opName, st.count, st.getSumUs(), pc, min, max, std));
|
||||
}
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private static Pair<Long, Map<String, OpStats>> aggregateTraceEvents(TraceEvent[] events) {
|
||||
//Summarize by op (instance) name:
|
||||
final Map<String, OpStats> stats = new HashMap<>();
|
||||
for (TraceEvent e : events) {
|
||||
if (e.getPh() != Phase.X || e.getDur() == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
OpStats s;
|
||||
String instanceName = (String) e.getArgs().get("name");
|
||||
if (stats.containsKey(instanceName)) {
|
||||
s = stats.get(instanceName);
|
||||
} else {
|
||||
s = new OpStats(e.getName(), 0, new NDArrayList(DataType.LONG, 0), null);
|
||||
stats.put(instanceName, s);
|
||||
}
|
||||
s.count++;
|
||||
s.timesUs.add((double) e.getDur());
|
||||
}
|
||||
|
||||
long allOpsUs = 0;
|
||||
for (OpStats s : stats.values()) {
|
||||
s.sumUs = s.timesUs.array().sumNumber().longValue();
|
||||
allOpsUs += s.sumUs;
|
||||
}
|
||||
|
||||
return new Pair<>(allOpsUs, stats);
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
private static class OpStats {
|
||||
private String opName;
|
||||
private int count;
|
||||
private NDArrayList timesUs;
|
||||
private Long sumUs;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare the specified profile files, sorted by profile 1 % of total time
|
||||
*
|
||||
* @param file1 First profile file
|
||||
* @param file2 Second profile file
|
||||
* @param format1 Format of first profile
|
||||
* @param format2 Format of second profile
|
||||
* @return Comparison summary as a String
|
||||
*/
|
||||
public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2) {
|
||||
return compareProfiles(file1, file2, format1, format2, false, false, null, null, SortBy.PROFILE1_PC);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare the specified profile files or directory
|
||||
*
|
||||
* @param file1 First profile file or directory of profiles
|
||||
* @param file2 Second profile file or directory of profiles
|
||||
* @param format1 Format for first profile file/s
|
||||
* @param format2 Format for second profile file/s
|
||||
* @param firstIsDir True if the first File object is a directory
|
||||
* @param secondIsDir True if the second File object is a directory
|
||||
* @param name1 Name of the first profile (just for display purposes). Optional
|
||||
* @param name2 Name of the second profile (just for display purposes). Optional
|
||||
* @param sortBy What to sort the summary results by
|
||||
* @return Comparison summary as a String
|
||||
*/
|
||||
public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2,
|
||||
boolean firstIsDir, boolean secondIsDir, String name1, String name2, final SortBy sortBy) {
|
||||
|
||||
TraceEvent[] t1 = firstIsDir ? getTraceEventsDir(file1, format1) : getTraceEvents(file1, format1);
|
||||
TraceEvent[] t2 = secondIsDir ? getTraceEventsDir(file2, format2) : getTraceEvents(file2, format2);
|
||||
|
||||
final Pair<Long, Map<String, OpStats>> p1 = aggregateTraceEvents(t1);
|
||||
final Pair<Long, Map<String, OpStats>> p2 = aggregateTraceEvents(t2);
|
||||
|
||||
List<String> l = new ArrayList<>(sortBy != SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet());
|
||||
Collections.sort(l, new Comparator<String>() {
|
||||
@Override
|
||||
public int compare(String o1, String o2) {
|
||||
switch (sortBy) {
|
||||
case PROFILE1_PC:
|
||||
return -Long.compare(p1.getSecond().get(o1).sumUs, p1.getSecond().get(o2).sumUs);
|
||||
case PROFILE2_PC:
|
||||
return -Long.compare(p2.getSecond().get(o1).sumUs, p2.getSecond().get(o2).sumUs);
|
||||
case RATIO:
|
||||
|
||||
double m1a = meanTime(p1, o1);
|
||||
double m1b = meanTime(p1, o2);
|
||||
double m2a = meanTime(p2, o1);
|
||||
double m2b = meanTime(p2, o2);
|
||||
double ratio1 = m1a / m2a;
|
||||
double ratio2 = m1b / m2b;
|
||||
return -Double.compare(ratio1, ratio2);
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Set<String> set = new HashSet<>(l);
|
||||
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("1 = ").append(name1 == null ? "Profile 1" : name1).append("\n")
|
||||
.append("2 = ").append(name2 == null ? "Profile 2" : name2).append("\n");
|
||||
|
||||
//Work out longest name and op name:
|
||||
int longestName = 30;
|
||||
int longestOpName = 30;
|
||||
Map<String, OpStats> stats = sortBy == SortBy.PROFILE2_PC ? p2.getSecond() : p1.getSecond();
|
||||
for (String s : l) {
|
||||
longestName = Math.max(longestName, s.length() + 1);
|
||||
longestOpName = Math.max(longestOpName, stats.get(s).opName.length() + 1);
|
||||
}
|
||||
|
||||
String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-16s%-13s%-13s%-14s%-14s%-12s%-12s%-14s%-14s%-10s%-10s%-10s%-10s\n";
|
||||
sb.append(String.format(headerFormat, "Op Name", "Op", "Count (1)", "Count (2)", "Mean Ratio 1/2", "Mean (1)", "Mean (2)", "Total uS (1)", "Total uS (2)", "% (1)", "% (2)", "Min (1)", "Min (2)", "Max (1)", "Max (2)", "Std (1)", "Std (2)"));
|
||||
String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-16.2f%-13.2f%-13.2f%-14d%-14d%-12.2f%-12.2f%-14d%-14d%-10d%-10d%-10.2f%-10.2f\n";
|
||||
|
||||
for (String s : l) {
|
||||
OpStats s1 = p1.getSecond().get(s);
|
||||
OpStats s2 = p2.getSecond().get(s);
|
||||
|
||||
double m1 = s1 == null ? 0 : s1.getTimesUs().array().meanNumber().doubleValue();
|
||||
double m2 = s2 == null ? 0 : s2.getTimesUs().array().meanNumber().doubleValue();
|
||||
double ratio = m1 / m2;
|
||||
|
||||
double pc1 = s1 == null ? 0 : 100.0 * s1.sumUs / p1.getFirst();
|
||||
double pc2 = s2 == null ? 0 : 100.0 * s2.sumUs / p2.getFirst();
|
||||
|
||||
sb.append(String.format(format, s, s1 != null ? s1.opName : s2.opName,
|
||||
s1 != null ? s1.count : 0,
|
||||
s2 != null ? s2.count : 0,
|
||||
//Ratio of means, means
|
||||
ratio,
|
||||
m1, m2,
|
||||
//Total us, percent of op total
|
||||
s1 != null ? s1.sumUs : 0,
|
||||
s2 != null ? s2.sumUs : 0,
|
||||
pc1, pc2,
|
||||
//Min, max, std
|
||||
s1 != null ? s1.getTimesUs().array().minNumber().longValue() : 0,
|
||||
s2 != null ? s2.getTimesUs().array().minNumber().longValue() : 0,
|
||||
s1 != null ? s1.getTimesUs().array().maxNumber().longValue() : 0,
|
||||
s2 != null ? s2.getTimesUs().array().maxNumber().longValue() : 0,
|
||||
s1 != null ? s1.getTimesUs().array().stdNumber().doubleValue() : 0.0,
|
||||
s2 != null ? s2.getTimesUs().array().stdNumber().doubleValue() : 0.0));
|
||||
}
|
||||
|
||||
boolean header = false;
|
||||
String headerFormat2 = null;
|
||||
String format3 = null;
|
||||
for (String s : (sortBy == SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet())) {
|
||||
|
||||
if (!set.contains(s)) {
|
||||
Map<String, OpStats> m = sortBy == SortBy.PROFILE2_PC ? p1.getSecond() : p2.getSecond();
|
||||
if (!header) {
|
||||
|
||||
longestName = 30;
|
||||
longestOpName = 30;
|
||||
for(String s2 : m.keySet()){
|
||||
longestName = Math.max(longestName, s2.length()+1);
|
||||
longestOpName = Math.max(longestOpName, m.get(s2).opName.length()+1);
|
||||
}
|
||||
headerFormat2 = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n";
|
||||
format3 = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
|
||||
|
||||
sb.append(" *** Operations not in profile ").append(sortBy == SortBy.PROFILE2_PC ? "1" : "2").append(" but in profile ")
|
||||
.append(sortBy == SortBy.PROFILE2_PC ? "2" : "1").append(" ***\n");
|
||||
sb.append(String.format(headerFormat2, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std"));
|
||||
header = true;
|
||||
}
|
||||
long allOpsUs = sortBy == SortBy.PROFILE2_PC ? p1.getFirst() : p2.getFirst();
|
||||
OpStats st = m.get(s);
|
||||
double pc = (100.0 * st.getTimesUs().array().sumNumber().longValue()) / allOpsUs;
|
||||
INDArray arr = st.timesUs.array();
|
||||
long min = arr.minNumber().longValue();
|
||||
long max = arr.maxNumber().longValue();
|
||||
double std = arr.stdNumber().doubleValue();
|
||||
sb.append(String.format(format3, s, st.opName, st.count, st.getSumUs(), pc, min, max, std));
|
||||
}
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private static double meanTime(Pair<Long, Map<String, OpStats>> p, String name) {
|
||||
if (!p.getSecond().containsKey(name)) {
|
||||
return 0.0;
|
||||
}
|
||||
return p.getSecond().get(name).getTimesUs().array().meanNumber().doubleValue();
|
||||
}
|
||||
|
||||
|
||||
private static Map<String, String> TF_PROFILE_ALIASES = new HashMap<>();
|
||||
|
||||
static {
|
||||
TF_PROFILE_ALIASES.put("_MklSoftmax", "Softmax");
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.autodiff.listeners.profiler.data;
|
||||
|
||||
public enum ColorName {
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.autodiff.listeners.profiler.data;
|
||||
|
||||
/**
|
||||
* Chrome Profiler phase, for details see:
|
||||
* <a href="https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit">
|
||||
* https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit</a>
|
||||
*/
|
||||
public enum Phase {
|
||||
B,
|
||||
E,
|
||||
X,
|
||||
I,
|
||||
C,
|
||||
b,
|
||||
n,
|
||||
e,
|
||||
s,
|
||||
t,
|
||||
f,
|
||||
P,
|
||||
N,
|
||||
O,
|
||||
D,
|
||||
M,
|
||||
V,
|
||||
v,
|
||||
R,
|
||||
c
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.autodiff.listeners.profiler.data;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A TraceEvent, such as an operation execution.<br>
|
||||
* Intended mainly for JSON serialization/deserialization in Chrome profiler format<br>
|
||||
* Profiler format is described here: <a href="https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit">
|
||||
* https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit</a>
|
||||
* See {@link org.nd4j.autodiff.listeners.profiler.ProfilingListener}<br>
|
||||
* See {@link org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Builder
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class TraceEvent {
|
||||
|
||||
private String name; //Name of event (usually op name)
|
||||
private List<String> categories; //Comma separated list of categories
|
||||
private Phase ph; //Event type - phase (see table for options)
|
||||
private long ts; //Timestamp, in microseconds (us)
|
||||
private Long dur; //Duration, optional
|
||||
private Long tts; //Optional, thlread timestamp, in microseconds
|
||||
private long pid; //Process ID
|
||||
private long tid; //Thread ID
|
||||
private Map<String, Object> args; //Args
|
||||
private ColorName cname; //Optional, color name (must be one of reserved color names: https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html )
|
||||
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.autodiff.listeners.profiler.data;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A simple holder for a list of trace events
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public class TraceEvents {
|
||||
private List<TraceEvent> traceEvents;
|
||||
}
|
|
@ -18,6 +18,7 @@ package org.nd4j.list;
|
|||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -46,7 +47,12 @@ public class NDArrayList extends BaseNDArrayList<Double> {
|
|||
* @param size the initial size of the array
|
||||
*/
|
||||
public NDArrayList(int size) {
|
||||
this.container = Nd4j.create(10L);
|
||||
this(DataType.DOUBLE, size);
|
||||
}
|
||||
|
||||
public NDArrayList(DataType dataType, int size) {
|
||||
Preconditions.checkState(size >= 0, "Size must be non-negative - got %s", size);
|
||||
this.container = Nd4j.create(dataType, Math.max(10L, size));
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
|
@ -84,6 +90,7 @@ public class NDArrayList extends BaseNDArrayList<Double> {
|
|||
* directly, this gives you the relevant subset that reflects the content of the list)
|
||||
* @return the view of the underlying ndarray relative to the collection's real size
|
||||
*/
|
||||
@Override
|
||||
public INDArray array() {
|
||||
if(isEmpty()) {
|
||||
throw new ND4JIllegalStateException("Array is empty!");
|
||||
|
@ -137,6 +144,8 @@ public class NDArrayList extends BaseNDArrayList<Double> {
|
|||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public boolean remove(Object o) {
|
||||
int idx = BooleanIndexing.firstIndex(container,new EqualsCondition((double) o)).getInt(0);
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
package org.nd4j.autodiff.samediff.listeners;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
|
||||
import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class ProfilingListenerTest extends BaseNd4jTest {
|
||||
|
||||
public ProfilingListenerTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
public void testProfilingListenerSimple() throws Exception {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
|
||||
SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2);
|
||||
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2));
|
||||
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2));
|
||||
SDVariable sm = sd.nn.softmax("predictions", in.mmul("matmul", w).add("addbias", b));
|
||||
SDVariable loss = sd.loss.logLoss("loss", label, sm);
|
||||
|
||||
INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3);
|
||||
INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2);
|
||||
|
||||
|
||||
File dir = testDir.newFolder();
|
||||
File f = new File(dir, "test.json");
|
||||
ProfilingListener listener = ProfilingListener.builder(f)
|
||||
.recordAll()
|
||||
.warmup(5)
|
||||
.build();
|
||||
|
||||
sd.setListeners(listener);
|
||||
|
||||
Map<String,INDArray> ph = new HashMap<>();
|
||||
ph.put("in", i);
|
||||
|
||||
for( int x=0; x<10; x++ ) {
|
||||
sd.outputSingle(ph, "predictions");
|
||||
}
|
||||
|
||||
String content = FileUtils.readFileToString(f, StandardCharsets.UTF_8);
|
||||
System.out.println(content);
|
||||
|
||||
//Should be 2 begins and 2 ends for each entry
|
||||
//5 warmup iterations, 5 profile iterations, x2 for both the op name and the op "instance" name
|
||||
String[] opNames = {"mmul", "add", "softmax"};
|
||||
for(String s : opNames){
|
||||
assertEquals(s, 10, StringUtils.countMatches(content, s));
|
||||
}
|
||||
|
||||
|
||||
System.out.println("///////////////////////////////////////////");
|
||||
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.SAMEDIFF);
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
@Test
|
||||
public void testLoadTfProfile(){
|
||||
File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json");
|
||||
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLoadTfProfileDir(){
|
||||
File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles");
|
||||
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLoadTfProfileDir2(){
|
||||
File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0");
|
||||
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
||||
}
|
||||
*/
|
||||
}
|
Loading…
Reference in New Issue