Merge remote-tracking branch 'konduit/master'

master
AlexDBlack 2019-11-23 00:22:27 +11:00
commit 289d9dc141
333 changed files with 9444 additions and 7918 deletions

View File

@ -26,78 +26,13 @@
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<artifactId>arbiter-ui_2.11</artifactId> <artifactId>arbiter-ui</artifactId>
<name>arbiter-ui</name> <name>arbiter-ui</name>
<properties> <properties>
<java.compile.version>1.8</java.compile.version> <java.compile.version>1.8</java.compile.version>
</properties> </properties>
<profiles> <profiles>
<!-- To build UI templates: run "mvn compile -P buildUiTemplates" -->
<profile>
<id>buildUiTemplates</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<build>
<plugins>
<plugin>
<groupId>com.google.code.play2-maven-plugin</groupId>
<artifactId>play2-maven-plugin</artifactId>
<version>${maven-play2-plugin.version}</version>
<!-- Generate Scala Page Templates
The Play framework template engine ("twirl") uses templates for HTML pages (or in principle any text-based
data: CSV, XML etc). These templates (*.scala.html files) need to be converted to Scala classes using
code generation. This is done here during the Maven compile phase.
However, the Maven Play framework plugin does not allow proper customization of the output directory. Thus,
we generate these Scala classes in the default location (in the target/twirl/main/ directory) and use the
maven resources plugin to copy them to the actual location we want.
To generate the latest versions of these templates (after modifying or adding a new template), just run
"mvn compile" in either the main project directory, or within the deeplearning4j ui module separately.
-->
<executions>
<execution>
<id>GenerateTemplates</id>
<phase>compile</phase>
<configuration>
</configuration>
<goals>
<goal>template-compile</goal>
</goals>
</execution>
</executions>
</plugin>
<!-- Copy the generated Scala templates to the appropriate directory. See templates comment above. -->
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>${maven-resources-plugin.version}</version>
<executions>
<execution>
<id>CopyTemplates</id>
<phase>compile</phase>
<goals>
<goal>copy-resources</goal>
</goals>
<configuration>
<outputDirectory>${basedir}/src/main/scala/org/deeplearning4j/arbiter/ui/views/html</outputDirectory>
<resources>
<resource>
<directory>${basedir}/target/twirl/main/org/deeplearning4j/arbiter/ui/views/html</directory>
<filtering>true</filtering>
</resource>
</resources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
<profile> <profile>
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
@ -115,10 +50,17 @@
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui_2.11</artifactId> <artifactId>deeplearning4j-ui</artifactId>
<version>${dl4j.version}</version> <version>${dl4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<scope>test</scope>
<version>${logback.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-deeplearning4j</artifactId> <artifactId>arbiter-deeplearning4j</artifactId>
@ -134,72 +76,6 @@
</dependencies> </dependencies>
<build> <build>
<extensions>
<extension>
<groupId>org.apache.maven.wagon</groupId>
<artifactId>wagon-http</artifactId>
<version>2.9</version>
</extension>
</extensions>
<plugins>
<!-- Build helper plugin: used to add the multiple independent source directories (Java, Scala, HTML templates) -->
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<version>${maven-build-helper-plugin.version}</version>
<executions>
<execution>
<id>add-source</id>
<phase>compile</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>src/main/scala</source>
<source>src/main/java</source>
<source>src/main/views</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>${maven-scala-plugin.version}</version>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
<configuration>
<scalaVersion>${scala.version}</scalaVersion>
</configuration>
</plugin>
<plugin>
<groupId>com.google.code.sbt-compiler-maven-plugin</groupId>
<artifactId>sbt-compiler-maven-plugin</artifactId>
<version>${sbt-compiler-maven-plugin.version}</version>
</plugin>
<!-- Maven compiler plugin last: should ensure Scala code is compiled before Java code -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId> <!-- Version set by deeplearning4j-parent dependency management -->
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
<pluginManagement> <pluginManagement>
<plugins> <plugins>
<plugin> <plugin>

View File

@ -20,8 +20,6 @@ import lombok.AllArgsConstructor;
import org.apache.commons.compress.utils.IOUtils; import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.api.storage.Persistable; import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.arbiter.ui.module.ArbiterModule; import org.deeplearning4j.arbiter.ui.module.ArbiterModule;
import org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport;
import scala.annotation.meta.field;
import java.io.*; import java.io.*;
import java.lang.reflect.Field; import java.lang.reflect.Field;

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,8 @@
package org.deeplearning4j.arbiter.ui.module; package org.deeplearning4j.arbiter.ui.module;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.ext.web.RoutingContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.api.storage.Persistable; import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
@ -31,8 +34,8 @@ import org.deeplearning4j.arbiter.ui.UpdateStatus;
import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable; import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable;
import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable; import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
import org.deeplearning4j.arbiter.ui.misc.UIUtils; import org.deeplearning4j.arbiter.ui.misc.UIUtils;
import org.deeplearning4j.arbiter.ui.views.html.ArbiterUI;
import org.deeplearning4j.arbiter.util.ObjectUtils; import org.deeplearning4j.arbiter.util.ObjectUtils;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.ui.api.Component; import org.deeplearning4j.ui.api.Component;
import org.deeplearning4j.ui.api.*; import org.deeplearning4j.ui.api.*;
import org.deeplearning4j.ui.components.chart.ChartLine; import org.deeplearning4j.ui.components.chart.ChartLine;
@ -48,18 +51,15 @@ import org.deeplearning4j.ui.i18n.I18NResource;
import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter; import org.joda.time.format.DateTimeFormatter;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import play.libs.Json; import org.nd4j.shade.jackson.core.JsonProcessingException;
import play.mvc.Result;
import play.mvc.Results;
import java.awt.*; import java.awt.*;
import java.text.DecimalFormat; import java.text.DecimalFormat;
import java.util.*;
import java.util.List; import java.util.List;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import static org.deeplearning4j.arbiter.ui.misc.JsonMapper.asJson; import static org.deeplearning4j.arbiter.ui.misc.JsonMapper.asJson;
import static play.mvc.Results.ok;
/** /**
* A Deeplearning4j {@link UIModule}, for integration with DL4J's user interface * A Deeplearning4j {@link UIModule}, for integration with DL4J's user interface
@ -73,8 +73,6 @@ public class ArbiterModule implements UIModule {
private static final DateTimeFormatter TIME_FORMATTER = DateTimeFormat.forPattern("YYYY-MM-dd HH:mm ZZ"); private static final DateTimeFormatter TIME_FORMATTER = DateTimeFormat.forPattern("YYYY-MM-dd HH:mm ZZ");
public static final String ARBITER_UI_TYPE_ID = "ArbiterUI"; public static final String ARBITER_UI_TYPE_ID = "ArbiterUI";
private static final String JSON = "application/json";
private AtomicBoolean loggedArbiterAddress = new AtomicBoolean(false); private AtomicBoolean loggedArbiterAddress = new AtomicBoolean(false);
private Map<String, StatsStorage> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>()); private Map<String, StatsStorage> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>());
private String currentSessionID; private String currentSessionID;
@ -138,17 +136,18 @@ public class ArbiterModule implements UIModule {
@Override @Override
public List<Route> getRoutes() { public List<Route> getRoutes() {
Route r1 = new Route("/arbiter", HttpMethod.GET, FunctionType.Supplier, () -> Results.ok(ArbiterUI.apply())); Route r1 = new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response()
Route r3 = new Route("/arbiter/lastUpdate", HttpMethod.GET, FunctionType.Supplier, this::getLastUpdateTime); .putHeader("content-type", "text/html; charset=utf-8").sendFile("templates/ArbiterUI.html"));
Route r4 = new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, FunctionType.Function, this::getModelLastUpdateTimes); Route r3 = new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc));
Route r5 = new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, FunctionType.Function, this::getCandidateInfo); Route r4 = new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc));
Route r6 = new Route("/arbiter/config", HttpMethod.GET, FunctionType.Supplier, this::getOptimizationConfig); Route r5 = new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, (path, rc) -> this.getCandidateInfo(path.get(0), rc));
Route r7 = new Route("/arbiter/results", HttpMethod.GET, FunctionType.Supplier, this::getSummaryResults); Route r6 = new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc));
Route r8 = new Route("/arbiter/summary", HttpMethod.GET, FunctionType.Supplier, this::getSummaryStatus); Route r7 = new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc));
Route r8 = new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc));
Route r9a = new Route("/arbiter/sessions/all", HttpMethod.GET, FunctionType.Supplier, this::listSessions); Route r9a = new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.listSessions(rc));
Route r9b = new Route("/arbiter/sessions/current", HttpMethod.GET, FunctionType.Supplier, this::currentSession); Route r9b = new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc));
Route r9c = new Route("/arbiter/sessions/set/:to", HttpMethod.GET, FunctionType.Function, this::setSession); Route r9c = new Route("/arbiter/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc));
return Arrays.asList(r1, r3, r4, r5, r6, r7, r8, r9a, r9b, r9c); return Arrays.asList(r1, r3, r4, r5, r6, r7, r8, r9a, r9b, r9c);
} }
@ -197,23 +196,27 @@ public class ArbiterModule implements UIModule {
getDefaultSession(); getDefaultSession();
} }
private Result currentSession() { private void currentSession(RoutingContext rc) {
String sid = currentSessionID == null ? "" : currentSessionID; String sid = currentSessionID == null ? "" : currentSessionID;
return ok(asJson(sid)).as(JSON); rc.response()
.putHeader("content-type", "application/json")
.end(asJson(sid));
} }
private Result listSessions() { private void listSessions(RoutingContext rc) {
return Results.ok(asJson(knownSessionIDs.keySet())).as(JSON); rc.response()
.putHeader("content-type", "application/json")
.end(asJson(knownSessionIDs.keySet()));
} }
private Result setSession(String newSessionID) { private void setSession(String newSessionID, RoutingContext rc) {
log.debug("Arbiter UI: Set to session {}", newSessionID); log.debug("Arbiter UI: Set to session {}", newSessionID);
if (knownSessionIDs.containsKey(newSessionID)) { if (knownSessionIDs.containsKey(newSessionID)) {
currentSessionID = newSessionID; currentSessionID = newSessionID;
return ok(); rc.response().end();
} else { } else {
return Results.badRequest("Unknown session ID: " + newSessionID); rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()).end("Unknown session ID: " + newSessionID);
} }
} }
@ -255,30 +258,39 @@ public class ArbiterModule implements UIModule {
} }
/** /**
* @return Last update time for the page * Return the last update time for the page
*/ */
private Result getLastUpdateTime(){ private void getLastUpdateTime(RoutingContext rc){
//TODO - this forces updates on every request... which is fine, just inefficient //TODO - this forces updates on every request... which is fine, just inefficient
long t = System.currentTimeMillis(); long t = System.currentTimeMillis();
UpdateStatus us = new UpdateStatus(t, t, t); UpdateStatus us = new UpdateStatus(t, t, t);
return ok(Json.toJson(us)); rc.response().putHeader("content-type", "application/json").end(asJson(us));
}
private String asJson(Object o){
try{
return JsonMappers.getMapper().writeValueAsString(o);
} catch (JsonProcessingException e){
throw new RuntimeException("Error converting object to JSON", e);
}
} }
/** /**
* Get the last update time for the specified model IDs * Get the last update time for the specified model IDs
* @param modelIDs Model IDs to get the update time for * @param modelIDs Model IDs to get the update time for
*/ */
private Result getModelLastUpdateTimes(String modelIDs){ private void getModelLastUpdateTimes(String modelIDs, RoutingContext rc){
if(currentSessionID == null){ if(currentSessionID == null){
return ok(); rc.response().end();
return;
} }
StatsStorage ss = knownSessionIDs.get(currentSessionID); StatsStorage ss = knownSessionIDs.get(currentSessionID);
if(ss == null){ if(ss == null){
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID); log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID);
return ok("-1"); rc.response().end("-1");
return;
} }
String[] split = modelIDs.split(","); String[] split = modelIDs.split(",");
@ -292,7 +304,7 @@ public class ArbiterModule implements UIModule {
} }
} }
return ok(Json.toJson(lastUpdateTimes)); rc.response().putHeader("content-type", "application/json").end(asJson(lastUpdateTimes));
} }
/** /**
@ -301,12 +313,13 @@ public class ArbiterModule implements UIModule {
* @param candidateId ID for the candidate * @param candidateId ID for the candidate
* @return Content/info for the candidate * @return Content/info for the candidate
*/ */
private Result getCandidateInfo(String candidateId){ private void getCandidateInfo(String candidateId, RoutingContext rc){
StatsStorage ss = knownSessionIDs.get(currentSessionID); StatsStorage ss = knownSessionIDs.get(currentSessionID);
if(ss == null){ if(ss == null){
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID); log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID);
return ok(); rc.response().end();
return;
} }
GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);; GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);;
@ -316,7 +329,10 @@ public class ArbiterModule implements UIModule {
if(p == null){ if(p == null){
String title = "No results found for model " + candidateId + "."; String title = "No results found for model " + candidateId + ".";
ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build(); ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build();
return ok(asJson(ct)).as(JSON); rc.response()
.putHeader("content-type", "application/json")
.end(asJson(ct));
return;
} }
ModelInfoPersistable mip = (ModelInfoPersistable)p; ModelInfoPersistable mip = (ModelInfoPersistable)p;
@ -474,25 +490,27 @@ public class ArbiterModule implements UIModule {
ComponentDiv cd = new ComponentDiv(STYLE_DIV_WIDTH_100_PC, components); ComponentDiv cd = new ComponentDiv(STYLE_DIV_WIDTH_100_PC, components);
return ok(asJson(cd)).as(JSON); rc.response().putHeader("content-type", "application/json").end(asJson(cd));
} }
/** /**
* Get the optimization configuration - second section in the page * Get the optimization configuration - second section in the page
*/ */
private Result getOptimizationConfig(){ private void getOptimizationConfig(RoutingContext rc){
StatsStorage ss = knownSessionIDs.get(currentSessionID); StatsStorage ss = knownSessionIDs.get(currentSessionID);
if(ss == null){ if(ss == null){
log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID);
return ok(); rc.response().end();
return;
} }
Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
if(p == null){ if(p == null){
log.info("No static info"); log.debug("No static info");
return ok(); rc.response().end();
return;
} }
List<Component> components = new ArrayList<>(); List<Component> components = new ArrayList<>();
@ -574,14 +592,15 @@ public class ArbiterModule implements UIModule {
ComponentDiv cd = new ComponentDiv(STYLE_DIV_WIDTH_100_PC, components); ComponentDiv cd = new ComponentDiv(STYLE_DIV_WIDTH_100_PC, components);
return ok(asJson(cd)).as(JSON); rc.response().putHeader("content-type", "application/json").end(asJson(cd));
} }
private Result getSummaryResults(){ private void getSummaryResults(RoutingContext rc){
StatsStorage ss = knownSessionIDs.get(currentSessionID); StatsStorage ss = knownSessionIDs.get(currentSessionID);
if(ss == null){ if(ss == null){
log.debug("getSummaryResults(): Session ID is unknown: {}", currentSessionID); log.debug("getSummaryResults(): Session ID is unknown: {}", currentSessionID);
return ok(); rc.response().end();
return;
} }
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID));
@ -592,24 +611,26 @@ public class ArbiterModule implements UIModule {
table.add(new String[]{mip.getModelIdx().toString(), score, mip.getStatus().toString()}); table.add(new String[]{mip.getModelIdx().toString(), score, mip.getStatus().toString()});
} }
return ok(asJson(table)).as(JSON); rc.response().putHeader("content-type", "application/json").end(asJson(table));
} }
/** /**
* Get summary status information: first section in the page * Get summary status information: first section in the page
*/ */
private Result getSummaryStatus(){ private void getSummaryStatus(RoutingContext rc){
StatsStorage ss = knownSessionIDs.get(currentSessionID); StatsStorage ss = knownSessionIDs.get(currentSessionID);
if(ss == null){ if(ss == null){
log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID);
return ok(); rc.response().end();
return;
} }
Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
if(p == null){ if(p == null){
log.info("No static info"); log.info("No static info");
return ok(); rc.response().end();
return;
} }
GlobalConfigPersistable gcp = (GlobalConfigPersistable)p; GlobalConfigPersistable gcp = (GlobalConfigPersistable)p;
@ -711,7 +732,7 @@ public class ArbiterModule implements UIModule {
ComponentDiv cd = new ComponentDiv(STYLE_DIV_WIDTH_100_PC, components); ComponentDiv cd = new ComponentDiv(STYLE_DIV_WIDTH_100_PC, components);
return ok(asJson(cd)).as(JSON); rc.response().putHeader("content-type", "application/json").end(asJson(cd));
} }

View File

@ -1,5 +1,6 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc. ~ Copyright (c) 2015-2018 Skymind, Inc.
~ Copyright (c) 2019 Konduit K.K.
~ ~
~ This program and the accompanying materials are made available under the ~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at ~ terms of the Apache License, Version 2.0 which is available at
@ -17,12 +18,6 @@
<html> <html>
<head> <head>
<style type="text/css"> <style type="text/css">
/* Color and style reference.
To change: do find + replace on comment + color
heading background: headingbgcol #063E53 //Old candidates: #3B5998
heading text color: headingtextcolor white
*/
.hd { .hd {
background-color: #000000; background-color: #000000;
@ -116,7 +111,7 @@
} }
.resultsHeadingDiv { .resultsHeadingDiv {
background-color: /*headingbgcol*/#063E53; background-color: #063E53;
color: white; color: white;
font-family: "Open Sans", sans-serif; font-family: "Open Sans", sans-serif;
font-size: 16px; font-size: 16px;

View File

@ -47,13 +47,13 @@ import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator;
import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener; import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

View File

@ -71,7 +71,7 @@ public class DL4JSystemProperties {
public static final String CRASH_DUMP_OUTPUT_DIRECTORY_PROPERTY = "org.deeplearning4j.crash.reporting.directory"; public static final String CRASH_DUMP_OUTPUT_DIRECTORY_PROPERTY = "org.deeplearning4j.crash.reporting.directory";
/** /**
* Applicability: deeplearning4j-ui_2.xx<br> * Applicability: deeplearning4j-ui<br>
* Description: The DL4J training UI (StatsListener + UIServer.getInstance().attach(ss)) will subsample the number * Description: The DL4J training UI (StatsListener + UIServer.getInstance().attach(ss)) will subsample the number
* of chart points when a lot of data is present - i.e., only a maximum number of points will be shown on each chart. * of chart points when a lot of data is present - i.e., only a maximum number of points will be shown on each chart.
* This is to reduce the UI bandwidth requirements and client-side rendering cost. * This is to reduce the UI bandwidth requirements and client-side rendering cost.
@ -81,7 +81,7 @@ public class DL4JSystemProperties {
/** /**
* Applicability: deeplearning4j-play (deeplearning4j-ui_2.xx)<br> * Applicability: deeplearning4j-vertx (deeplearning4j-ui)<br>
* Description: This property sets the port that the UI will be available on. Default port: 9000. * Description: This property sets the port that the UI will be available on. Default port: 9000.
* Set to 0 for a random port. * Set to 0 for a random port.
*/ */

View File

@ -142,7 +142,7 @@ public class GradientCheckTests extends BaseDL4JTest {
// (a) activation function // (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
// (c) Loss function (with specified output activations) // (c) Loss function (with specified output activations)
Activation[] activFns = {Activation.SIGMOID, Activation.TANH}; Activation[] activFns = {Activation.SIGMOID, Activation.TANH, Activation.MISH};
boolean[] characteristic = {false, true}; //If true: run some backprop steps first boolean[] characteristic = {false, true}; //If true: run some backprop steps first
LossFunction[] lossFunctions = {LossFunction.MCXENT, LossFunction.MSE}; LossFunction[] lossFunctions = {LossFunction.MCXENT, LossFunction.MSE};

View File

@ -83,6 +83,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
new LossMultiLabel(), new LossWasserstein(), new LossMultiLabel(), new LossWasserstein(),
new LossSparseMCXENT()
}; };
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
@ -116,7 +117,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
Activation.IDENTITY, // MixtureDensity Activation.IDENTITY, // MixtureDensity
Activation.TANH, // MixtureDensity + tanh Activation.TANH, // MixtureDensity + tanh
Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper
Activation.IDENTITY // Wasserstein Activation.IDENTITY, // Wasserstein
Activation.SOFTMAX, //sparse MCXENT
}; };
int[] nOut = new int[] {1, //xent int[] nOut = new int[] {1, //xent
@ -151,6 +153,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
10, // Mixture Density + tanh 10, // Mixture Density + tanh
10, // MultiLabel 10, // MultiLabel
2, // Wasserstein 2, // Wasserstein
4, //sparse MCXENT
}; };
int[] minibatchSizes = new int[] {1, 3}; int[] minibatchSizes = new int[] {1, 3};
@ -233,7 +236,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0), new LossFMeasure(), new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0), new LossFMeasure(),
new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
new LossMultiLabel(), new LossWasserstein() new LossMultiLabel(), new LossWasserstein(),
new LossSparseMCXENT()
}; };
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
@ -266,7 +270,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
Activation.IDENTITY, // MixtureDensity Activation.IDENTITY, // MixtureDensity
Activation.TANH, // MixtureDensity + tanh Activation.TANH, // MixtureDensity + tanh
Activation.TANH, // MultiLabel Activation.TANH, // MultiLabel
Activation.IDENTITY // Wasserstein Activation.IDENTITY, // Wasserstein
Activation.SOFTMAX
}; };
int[] nOut = new int[] {1, //xent int[] nOut = new int[] {1, //xent
@ -300,6 +305,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
10, // Mixture Density + tanh 10, // Mixture Density + tanh
10, // MultiLabel 10, // MultiLabel
2, // Wasserstein 2, // Wasserstein
4
}; };
int[] minibatchSizes = new int[] {1, 3}; int[] minibatchSizes = new int[] {1, 3};
@ -476,6 +482,23 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
break;
case "LossSparseMCXENT":
if (labelsShape.length == 2) {
ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1);
for (int i = 0; i < labelsShape[0]; i++) {
ret[1].putScalar(i, 0, r.nextInt((int) labelsShape[1]));
}
} else if (labelsShape.length == 3) {
ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1, labelsShape[2]);
for (int i = 0; i < labelsShape[0]; i++) {
for (int j = 0; j < labelsShape[2]; j++) {
ret[1].putScalar(i, 0, j, r.nextInt((int) labelsShape[1]));
}
}
} else {
throw new UnsupportedOperationException();
}
break; break;
case "LossHinge": case "LossHinge":
case "LossSquaredHinge": case "LossSquaredHinge":

View File

@ -34,6 +34,7 @@ import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMSE; import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
import java.util.Random; import java.util.Random;
@ -61,19 +62,12 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
int nOut = 2; int nOut = 2;
int miniBatchSize = 3; int miniBatchSize = 3;
ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT()}; ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT(), new LossSparseMCXENT()};
for (int maskType = 0; maskType < 3; maskType++) { for (int maskType = 0; maskType < 3; maskType++) {
Random r = new Random(12345L); Random r = new Random(12345L);
INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength}); INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength});
INDArray labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
int idx = r.nextInt(nOut);
labels.putScalar(new int[]{i, idx, j}, 1.0);
}
}
INDArray labelMask; INDArray labelMask;
String mt; String mt;
@ -85,13 +79,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
break; break;
case 1: case 1:
//Per time step masking //Per time step masking
labelMask = Nd4j.createUninitialized(miniBatchSize, timeSeriesLength); labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, timeSeriesLength);
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5)); Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
mt = "PerTimeStep"; mt = "PerTimeStep";
break; break;
case 2: case 2:
//Per output masking: //Per output masking:
labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, nOut, timeSeriesLength}); labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, nOut, timeSeriesLength);
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5)); Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
mt = "PerOutput"; mt = "PerOutput";
break; break;
@ -101,6 +95,26 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
for (ILossFunction lf : lfs) { for (ILossFunction lf : lfs) {
INDArray labels;
if(lf instanceof LossSparseMCXENT){
labels = Nd4j.zeros(miniBatchSize, 1, timeSeriesLength);
for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
int idx = r.nextInt(nOut);
labels.putScalar(new int[]{i, 0, j}, idx);
}
}
} else {
labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
int idx = r.nextInt(nOut);
labels.putScalar(new int[]{i, idx, j}, 1.0);
}
}
}
Activation oa = maskType == 2 ? Activation.SIGMOID : Activation.SOFTMAX; Activation oa = maskType == 2 ? Activation.SIGMOID : Activation.SOFTMAX;
MultiLayerConfiguration conf = MultiLayerConfiguration conf =

View File

@ -75,7 +75,6 @@ public class KerasReshape extends KerasLayer {
List<Integer> targetShapeList = (List<Integer>) innerConfig.get(targetShape); List<Integer> targetShapeList = (List<Integer>) innerConfig.get(targetShape);
this.targetShape = listToLongArray(targetShapeList); this.targetShape = listToLongArray(targetShapeList);
} }
} }
/** /**

View File

@ -43,13 +43,14 @@ import static org.nd4j.linalg.util.ArrayUtil.prodLong;
@Data @Data
@Slf4j @Slf4j
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize"}) @JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"})
public class ReshapePreprocessor extends BaseInputPreProcessor { public class ReshapePreprocessor extends BaseInputPreProcessor {
private long[] inputShape; private long[] inputShape;
private long[] targetShape; private long[] targetShape;
private boolean hasMiniBatchDimension = false; private boolean hasMiniBatchDimension = false;
private int miniBatchSize; private int miniBatchSize;
private long[] staticTargetShape;
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) { public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) {
this.inputShape = inputShape; this.inputShape = inputShape;
@ -80,10 +81,19 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
// the target shape read from a keras config does not have mini-batch size // the target shape read from a keras config does not have mini-batch size
// included. We prepend it here dynamically. // included. We prepend it here dynamically.
long[] targetShape;
if (staticTargetShape != null){
targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize);
hasMiniBatchDimension = true;
this.miniBatchSize = miniBatchSize;
}
else{
targetShape = this.targetShape;
}
if (!this.hasMiniBatchDimension) { if (!this.hasMiniBatchDimension) {
targetShape = prependMiniBatchSize(targetShape, miniBatchSize); targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
inputShape = prependMiniBatchSize(inputShape, miniBatchSize); inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
this.hasMiniBatchDimension = true;
this.miniBatchSize = miniBatchSize; this.miniBatchSize = miniBatchSize;
} }
if (this.miniBatchSize != miniBatchSize) { if (this.miniBatchSize != miniBatchSize) {
@ -95,7 +105,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){ if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
} }
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(this.targetShape)); return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
} else { } else {
throw new IllegalStateException("Input shape " + Arrays.toString(input.shape()) throw new IllegalStateException("Input shape " + Arrays.toString(input.shape())
+ " and output shape" + Arrays.toString(inputShape) + " do not match"); + " and output shape" + Arrays.toString(inputShape) + " do not match");
@ -122,20 +132,27 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
@Override @Override
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException { public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0); val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0);
InputType ret;
switch (shape.length) { switch (shape.length) {
case 2: case 2:
return InputType.feedForward(shape[1]); ret = InputType.feedForward(shape[1]);
break;
case 3: case 3:
return InputType.recurrent(shape[2], shape[1]); ret = InputType.recurrent(shape[2], shape[1]);
break;
case 4: case 4:
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){ if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){
return InputType.convolutional(shape[1], shape[2], shape[3]); ret = InputType.convolutional(shape[1], shape[2], shape[3]);
}else { }else {
return InputType.convolutional(shape[2], shape[3], shape[1]); ret = InputType.convolutional(shape[2], shape[3], shape[1]);
} }
break;
default: default:
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Cannot infer input type for reshape array " + Arrays.toString(shape)); "Cannot infer input type for reshape array " + Arrays.toString(shape));
} }
this.staticTargetShape = ret.getShape();
return ret;
} }
} }

View File

@ -54,7 +54,7 @@ public class KerasLossUtils {
} else if (kerasLoss.equals(conf.getKERAS_LOSS_HINGE())) { } else if (kerasLoss.equals(conf.getKERAS_LOSS_HINGE())) {
dl4jLoss = LossFunctions.LossFunction.HINGE; dl4jLoss = LossFunctions.LossFunction.HINGE;
} else if (kerasLoss.equals(conf.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY())) { } else if (kerasLoss.equals(conf.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY())) {
throw new UnsupportedKerasConfigurationException("Loss function " + kerasLoss + " not supported yet."); dl4jLoss = LossFunctions.LossFunction.SPARSE_MCXENT;
} else if (kerasLoss.equals(conf.getKERAS_LOSS_BINARY_CROSSENTROPY())) { } else if (kerasLoss.equals(conf.getKERAS_LOSS_BINARY_CROSSENTROPY())) {
dl4jLoss = LossFunctions.LossFunction.XENT; dl4jLoss = LossFunctions.LossFunction.XENT;
} else if (kerasLoss.equals(conf.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) { } else if (kerasLoss.equals(conf.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) {

View File

@ -255,6 +255,17 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
} }
} }
@Test
public void ReshapeEmbeddingConcatTest() throws Exception{
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
ComputationGraphConfiguration config =
new KerasModel().modelBuilder().modelJsonInputStream(is)
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
ComputationGraph model = new ComputationGraph(config);
model.init();
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
}
}
private void runSequentialConfigTest(String path) throws Exception { private void runSequentialConfigTest(String path) throws Exception {
try(InputStream is = Resources.asStream(path)) { try(InputStream is = Resources.asStream(path)) {

View File

@ -49,6 +49,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
import java.io.File; import java.io.File;
@ -88,7 +89,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Test(expected = IllegalStateException.class) @Test(expected = IllegalStateException.class)
public void fileNotFoundEndToEnd() throws Exception { public void fileNotFoundEndToEnd() throws Exception {
String modelPath = "modelimport/keras/examples/foo/bar.h5"; String modelPath = "modelimport/keras/examples/foo/bar.h5";
importEndModelTest(modelPath, null, true, true, false); importEndModelTest(modelPath, null, true, true, false, false);
} }
/** /**
@ -98,28 +99,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importMnistMlpTfKeras1() throws Exception { public void importMnistMlpTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importMnistMlpThKeras1() throws Exception { public void importMnistMlpThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, true, false); importEndModelTest(modelPath, inputsOutputPath, false, true, false, false);
} }
@Test @Test
public void importMnistMlpTfKeras2() throws Exception { public void importMnistMlpTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importMnistMlpReshapeTfKeras1() throws Exception { public void importMnistMlpReshapeTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, true); importEndModelTest(modelPath, inputsOutputPath, true, true, true, false);
} }
/** /**
@ -129,21 +130,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importMnistCnnTfKeras1() throws Exception { public void importMnistCnnTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, false, false); importEndModelTest(modelPath, inputsOutputPath, true, false, false, false);
} }
@Test @Test
public void importMnistCnnThKeras1() throws Exception { public void importMnistCnnThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, true, true); importEndModelTest(modelPath, inputsOutputPath, false, true, true, false);
} }
@Test @Test
public void importMnistCnnTfKeras2() throws Exception { public void importMnistCnnTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, true); importEndModelTest(modelPath, inputsOutputPath, true, true, true, false);
} }
/** /**
@ -153,28 +154,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importImdbLstmTfKeras1() throws Exception { public void importImdbLstmTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importImdbLstmThKeras1() throws Exception { public void importImdbLstmThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importImdbLstmTfKeras2() throws Exception { public void importImdbLstmTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importImdbLstmThKeras2() throws Exception { public void importImdbLstmThKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, true, false); importEndModelTest(modelPath, inputsOutputPath, false, true, false, false);
} }
/** /**
@ -185,21 +186,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importImdbFasttextTfKeras1() throws Exception { public void importImdbFasttextTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, false, false); importEndModelTest(modelPath, inputsOutputPath, false, false, false, false);
} }
@Test @Test
public void importImdbFasttextThKeras1() throws Exception { public void importImdbFasttextThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, false, false); importEndModelTest(modelPath, inputsOutputPath, false, false, false, false);
} }
@Test @Test
public void importImdbFasttextTfKeras2() throws Exception { public void importImdbFasttextTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, false, false); importEndModelTest(modelPath, inputsOutputPath, true, false, false, false);
} }
/** /**
@ -209,21 +210,21 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importSimpleLstmTfKeras1() throws Exception { public void importSimpleLstmTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importSimpleLstmThKeras1() throws Exception { public void importSimpleLstmThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
@Test @Test
public void importSimpleLstmTfKeras2() throws Exception { public void importSimpleLstmTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, false, false); importEndModelTest(modelPath, inputsOutputPath, true, false, false, false);
} }
@ -235,7 +236,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
String modelPath = "modelimport/keras/examples/simple_flatten_lstm/simple_flatten_lstm_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/simple_flatten_lstm/simple_flatten_lstm_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/" + String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/" +
"simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5"; "simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
/** /**
@ -246,7 +247,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
/** /**
@ -257,7 +258,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
String modelPath = "modelimport/keras/examples/simple_rnn/simple_rnn_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/simple_rnn/simple_rnn_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_rnn/" + String inputsOutputPath = "modelimport/keras/examples/simple_rnn/" +
"simple_rnn_tf_keras_2_inputs_and_outputs.h5"; "simple_rnn_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
} }
/** /**
@ -267,7 +268,18 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importCnnNoBiasTfKeras2() throws Exception { public void importCnnNoBiasTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, false, true); importEndModelTest(modelPath, inputsOutputPath, true, true, true, false);
}
@Test
public void importSparseXent() throws Exception {
String modelPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_inputs_and_outputs.h5";
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, true, true);
Layer outLayer = net.getOutputLayer();
assertTrue(outLayer instanceof org.deeplearning4j.nn.layers.LossLayer);
LossLayer llConf = (LossLayer) outLayer.getConfig();
assertEquals(new LossSparseMCXENT(), llConf.getLossFn());
} }
/** /**
@ -626,19 +638,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
} }
} }
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
private void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions) throws Exception { boolean checkGradients, boolean enforceTrainingConfig) throws Exception {
importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, false);
}
public void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
boolean checkGradients) throws Exception {
MultiLayerNetwork model; MultiLayerNetwork model;
try(InputStream is = Resources.asStream(modelPath)) { try(InputStream is = Resources.asStream(modelPath)) {
File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
KerasSequentialModel kerasModel = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) KerasSequentialModel kerasModel = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
.enforceTrainingConfig(false).buildSequential(); .enforceTrainingConfig(enforceTrainingConfig).buildSequential();
model = kerasModel.getMultiLayerNetwork(); model = kerasModel.getMultiLayerNetwork();
} }
@ -699,6 +706,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
checkGradients(model, input, testLabels); checkGradients(model, input, testLabels);
} }
} }
return model;
} }
private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception { private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception {

View File

@ -34,16 +34,6 @@
<scala.binary.version>2.11</scala.binary.version> <scala.binary.version>2.11</scala.binary.version>
</properties> </properties>
<build> <build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
<pluginManagement> <pluginManagement>
<plugins> <plugins>
<plugin> <plugin>
@ -56,9 +46,14 @@
<include>*.java</include> <include>*.java</include>
<include>**/*.java</include> <include>**/*.java</include>
</includes> </includes>
<systemPropertyVariables> </configuration>
<config.file>../../deeplearning4j-ui-parent/deeplearning4j-play/src/main/resources/application.conf</config.file> </plugin>
</systemPropertyVariables> <plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration> </configuration>
</plugin> </plugin>
</plugins> </plugins>

View File

@ -68,7 +68,7 @@
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui_2.11</artifactId> <artifactId>deeplearning4j-ui</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -79,6 +80,17 @@ import java.util.Map;
* .build(); * .build();
* } * }
* </pre> * </pre>
* <br>
* <b>Example to use an instantiated iterator for inference:</b><br>
* <pre>
* {@code
* BertIterator b;
* List<String> forInference;
* Pair<INDArray[],INDArray[]> featuresAndMask = b.featurizeSentences(forInference);
* INDArray[] features = featuresAndMask.getFirst();
* INDArray[] featureMasks = featuresAndMask.getSecond();
* }
* </pre>
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br> * This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br>
* <br> * <br>
* <u><b>{@link LengthHandling} configuration:</b></u><br> * <u><b>{@link LengthHandling} configuration:</b></u><br>
@ -89,7 +101,7 @@ import java.util.Map;
* <b>CLIP_ONLY</b>: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in * <b>CLIP_ONLY</b>: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in
* a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the * a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the
* maximum (within the current minibatch) they will be zero padded and masked.<br> * maximum (within the current minibatch) they will be zero padded and masked.<br>
*<br><br> * <br><br>
* <u><b>{@link FeatureArrays} configuration:</b></u><br> * <u><b>{@link FeatureArrays} configuration:</b></u><br>
* Determines what arrays should be included.<br> * Determines what arrays should be included.<br>
* <b>INDICES_MASK</b>: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).<br> * <b>INDICES_MASK</b>: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).<br>
@ -107,8 +119,11 @@ import java.util.Map;
public class BertIterator implements MultiDataSetIterator { public class BertIterator implements MultiDataSetIterator {
public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION} public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION}
public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY} public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY}
public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID} public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID}
public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC} public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC}
protected Task task; protected Task task;
@ -116,12 +131,13 @@ public class BertIterator implements MultiDataSetIterator {
protected int maxTokens = -1; protected int maxTokens = -1;
protected int minibatchSize = 32; protected int minibatchSize = 32;
protected boolean padMinibatches = false; protected boolean padMinibatches = false;
@Getter @Setter @Getter
@Setter
protected MultiDataSetPreProcessor preProcessor; protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null; protected LabeledSentenceProvider sentenceProvider = null;
protected LengthHandling lengthHandling; protected LengthHandling lengthHandling;
protected FeatureArrays featureArrays; protected FeatureArrays featureArrays;
protected Map<String,Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
protected BertSequenceMasker masker = null; protected BertSequenceMasker masker = null;
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null; protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
protected String maskToken; protected String maskToken;
@ -130,7 +146,7 @@ public class BertIterator implements MultiDataSetIterator {
protected List<String> vocabKeysAsList; protected List<String> vocabKeysAsList;
protected BertIterator(Builder b){ protected BertIterator(Builder b) {
this.task = b.task; this.task = b.task;
this.tokenizerFactory = b.tokenizerFactory; this.tokenizerFactory = b.tokenizerFactory;
this.maxTokens = b.maxTokens; this.maxTokens = b.maxTokens;
@ -166,10 +182,10 @@ public class BertIterator implements MultiDataSetIterator {
public MultiDataSet next(int num) { public MultiDataSet next(int num) {
Preconditions.checkState(hasNext(), "No next element available"); Preconditions.checkState(hasNext(), "No next element available");
List<Pair<String,String>> list = new ArrayList<>(num); List<Pair<String, String>> list = new ArrayList<>(num);
int count = 0; int mbSize = 0;
if(sentenceProvider != null){ if (sentenceProvider != null) {
while(sentenceProvider.hasNext() && count++ < num) { while (sentenceProvider.hasNext() && mbSize++ < num) {
list.add(sentenceProvider.nextSentence()); list.add(sentenceProvider.nextSentence());
} }
} else { } else {
@ -177,41 +193,60 @@ public class BertIterator implements MultiDataSetIterator {
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented"); throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
} }
//Get and tokenize the sentences for this minibatch
List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(num);
int longestSeq = -1;
for(Pair<String,String> p : list){
List<String> tokens = tokenizeSentence(p.getFirst());
tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
longestSeq = Math.max(longestSeq, tokens.size());
}
//Determine output array length... Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list);
int outLength; List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
switch (lengthHandling){ int outLength = outLTokenizedSentencesPair.getLeft();
case FIXED_LENGTH:
outLength = maxTokens;
break;
case ANY_LENGTH:
outLength = longestSeq;
break;
case CLIP_ONLY:
outLength = Math.min(maxTokens, longestSeq);
break;
default:
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
}
int mb = tokenizedSentences.size(); Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength);
int mbPadded = padMinibatches ? minibatchSize : mb; INDArray[] featureArray = featuresAndMaskArraysPair.getFirst();
INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond();
Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength);
INDArray[] labelArray = labelsAndMaskArraysPair.getFirst();
INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond();
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(featureArray, labelArray, featureMaskArray, labelMaskArray);
if (preProcessor != null)
preProcessor.preProcess(mds);
return mds;
}
/**
* For use during inference. Will convert a given list of sentences to features and feature masks as appropriate.
*
* @param listOnlySentences
* @return Pair of INDArrays[], first element is feature arrays and the second is the masks array
*/
public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) {
List<Pair<String, String>> sentencesWithNullLabel = addDummyLabel(listOnlySentences);
Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel);
List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
int outLength = outLTokenizedSentencesPair.getLeft();
Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength);
if (preProcessor != null) {
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null);
preProcessor.preProcess(dummyMDS);
return new Pair<INDArray[],INDArray[]>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
}
return convertMiniBatchFeatures(tokenizedSentences, outLength);
}
private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokenizedSentences, int outLength) {
int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
int[][] outIdxs = new int[mbPadded][outLength]; int[][] outIdxs = new int[mbPadded][outLength];
int[][] outMask = new int[mbPadded][outLength]; int[][] outMask = new int[mbPadded][outLength];
for (int i = 0; i < tokenizedSentences.size(); i++) {
for( int i=0; i<tokenizedSentences.size(); i++ ){ Pair<List<String>, String> p = tokenizedSentences.get(i);
Pair<List<String>,String> p = tokenizedSentences.get(i);
List<String> t = p.getFirst(); List<String> t = p.getFirst();
for( int j=0; j<outLength && j<t.size(); j++ ){ for (int j = 0; j < outLength && j < t.size(); j++) {
Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encontered: token \"%s\" is not in vocabulary", t.get(j)); Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j));
int idx = vocabMap.get(t.get(j)); int idx = vocabMap.get(t.get(j));
outIdxs[i][j] = idx; outIdxs[i][j] = idx;
outMask[i][j] = 1; outMask[i][j] = 1;
@ -224,7 +259,7 @@ public class BertIterator implements MultiDataSetIterator {
INDArray outSegmentIdArr; INDArray outSegmentIdArr;
INDArray[] f; INDArray[] f;
INDArray[] fm; INDArray[] fm;
if(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID){ if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
//For now: always segment index 0 (only single s sequence input supported) //For now: always segment index 0 (only single s sequence input supported)
outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength); outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength);
f = new INDArray[]{outIdxsArr, outSegmentIdArr}; f = new INDArray[]{outIdxsArr, outSegmentIdArr};
@ -233,17 +268,50 @@ public class BertIterator implements MultiDataSetIterator {
f = new INDArray[]{outIdxsArr}; f = new INDArray[]{outIdxsArr};
fm = new INDArray[]{outMaskArr}; fm = new INDArray[]{outMaskArr};
} }
return new Pair<>(f, fm);
}
private Pair<Integer, List<Pair<List<String>, String>>> tokenizeMiniBatch(List<Pair<String, String>> list) {
//Get and tokenize the sentences for this minibatch
List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(list.size());
int longestSeq = -1;
for (Pair<String, String> p : list) {
List<String> tokens = tokenizeSentence(p.getFirst());
tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
longestSeq = Math.max(longestSeq, tokens.size());
}
//Determine output array length...
int outLength;
switch (lengthHandling) {
case FIXED_LENGTH:
outLength = maxTokens;
break;
case ANY_LENGTH:
outLength = longestSeq;
break;
case CLIP_ONLY:
outLength = Math.min(maxTokens, longestSeq);
break;
default:
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
}
return new Pair<>(outLength, tokenizedSentences);
}
private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
INDArray[] l = new INDArray[1]; INDArray[] l = new INDArray[1];
INDArray[] lm; INDArray[] lm;
if(task == Task.SEQ_CLASSIFICATION){ int mbSize = tokenizedSentences.size();
int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
if (task == Task.SEQ_CLASSIFICATION) {
//Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses] //Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses]
int numClasses; int numClasses;
int[] classLabels = new int[mbPadded]; int[] classLabels = new int[mbPadded];
if(sentenceProvider != null){ if (sentenceProvider != null) {
numClasses = sentenceProvider.numLabelClasses(); numClasses = sentenceProvider.numLabelClasses();
List<String> labels = sentenceProvider.allLabels(); List<String> labels = sentenceProvider.allLabels();
for(int i=0; i<mb; i++ ){ for (int i = 0; i < mbSize; i++) {
String lbl = tokenizedSentences.get(i).getRight(); String lbl = tokenizedSentences.get(i).getRight();
classLabels[i] = labels.indexOf(lbl); classLabels[i] = labels.indexOf(lbl);
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
@ -252,21 +320,21 @@ public class BertIterator implements MultiDataSetIterator {
throw new RuntimeException(); throw new RuntimeException();
} }
l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses); l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses);
for( int i=0; i<mb; i++ ){ for (int i = 0; i < mbSize; i++) {
l[0].putScalar(i, classLabels[i], 1.0); l[0].putScalar(i, classLabels[i], 1.0);
} }
lm = null; lm = null;
if(padMinibatches && mb != mbPadded){ if (padMinibatches && mbSize != mbPadded) {
INDArray a = Nd4j.zeros(DataType.FLOAT, mbPadded, 1); INDArray a = Nd4j.zeros(DataType.FLOAT, mbPadded, 1);
lm = new INDArray[]{a}; lm = new INDArray[]{a};
a.get(NDArrayIndex.interval(0, mb), NDArrayIndex.all()).assign(1); a.get(NDArrayIndex.interval(0, mbSize), NDArrayIndex.all()).assign(1);
} }
} else if(task == Task.UNSUPERVISED){ } else if (task == Task.UNSUPERVISED) {
//Unsupervised, masked language model task //Unsupervised, masked language model task
//Output is either 2d, or 3d depending on settings //Output is either 2d, or 3d depending on settings
if(vocabKeysAsList == null){ if (vocabKeysAsList == null) {
String[] arr = new String[vocabMap.size()]; String[] arr = new String[vocabMap.size()];
for(Map.Entry<String,Integer> e : vocabMap.entrySet()){ for (Map.Entry<String, Integer> e : vocabMap.entrySet()) {
arr[e.getValue()] = e.getKey(); arr[e.getValue()] = e.getKey();
} }
vocabKeysAsList = Arrays.asList(arr); vocabKeysAsList = Arrays.asList(arr);
@ -276,31 +344,31 @@ public class BertIterator implements MultiDataSetIterator {
int vocabSize = vocabMap.size(); int vocabSize = vocabMap.size();
INDArray labelArr; INDArray labelArr;
INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength); INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength);
if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){ if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
labelArr = Nd4j.create(DataType.INT, mbPadded, outLength); labelArr = Nd4j.create(DataType.INT, mbPadded, outLength);
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){ } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength); labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength);
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){ } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize); labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize);
} else { } else {
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat); throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
} }
for( int i=0; i<mb; i++ ){ for (int i = 0; i < mbSize; i++) {
List<String> tokens = tokenizedSentences.get(i).getFirst(); List<String> tokens = tokenizedSentences.get(i).getFirst();
Pair<List<String>,boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList); Pair<List<String>, boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList);
List<String> maskedTokens = p.getFirst(); List<String> maskedTokens = p.getFirst();
boolean[] predictionTarget = p.getSecond(); boolean[] predictionTarget = p.getSecond();
int seqLen = Math.min(predictionTarget.length, outLength); int seqLen = Math.min(predictionTarget.length, outLength);
for(int j=0; j<seqLen; j++ ){ for (int j = 0; j < seqLen; j++) {
if(predictionTarget[j]){ if (predictionTarget[j]) {
String oldToken = tokenizedSentences.get(i).getFirst().get(j); //This is target String oldToken = tokenizedSentences.get(i).getFirst().get(j); //This is target
int targetTokenIdx = vocabMap.get(oldToken); int targetTokenIdx = vocabMap.get(oldToken);
if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){ if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
labelArr.putScalar(i, j, targetTokenIdx); labelArr.putScalar(i, j, targetTokenIdx);
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){ } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
labelArr.putScalar(i, j, targetTokenIdx, 1.0); labelArr.putScalar(i, j, targetTokenIdx, 1.0);
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){ } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
labelArr.putScalar(j, i, targetTokenIdx, 1.0); labelArr.putScalar(j, i, targetTokenIdx, 1.0);
} }
@ -309,7 +377,8 @@ public class BertIterator implements MultiDataSetIterator {
//Also update previously created feature label indexes: //Also update previously created feature label indexes:
String newToken = maskedTokens.get(j); String newToken = maskedTokens.get(j);
int newTokenIdx = vocabMap.get(newToken); int newTokenIdx = vocabMap.get(newToken);
outIdxsArr.putScalar(i,j,newTokenIdx); //first element of features is outIdxsArr
featureArray[0].putScalar(i, j, newTokenIdx);
} }
} }
} }
@ -319,19 +388,14 @@ public class BertIterator implements MultiDataSetIterator {
} else { } else {
throw new IllegalStateException("Task not yet implemented: " + task); throw new IllegalStateException("Task not yet implemented: " + task);
} }
return new Pair<>(l, lm);
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f, l, fm, lm);
if(preProcessor != null)
preProcessor.preProcess(mds);
return mds;
} }
private List<String> tokenizeSentence(String sentence) { private List<String> tokenizeSentence(String sentence) {
Tokenizer t = tokenizerFactory.create(sentence); Tokenizer t = tokenizerFactory.create(sentence);
List<String> tokens = new ArrayList<>(); List<String> tokens = new ArrayList<>();
if(prependToken != null) if (prependToken != null)
tokens.add(prependToken); tokens.add(prependToken);
while (t.hasMoreTokens()) { while (t.hasMoreTokens()) {
@ -341,6 +405,16 @@ public class BertIterator implements MultiDataSetIterator {
return tokens; return tokens;
} }
private List<Pair<String, String>> addDummyLabel(List<String> listOnlySentences) {
List<Pair<String, String>> list = new ArrayList<>(listOnlySentences.size());
for (String s : listOnlySentences) {
list.add(new Pair<String, String>(s, null));
}
return list;
}
@Override @Override
public boolean resetSupported() { public boolean resetSupported() {
return true; return true;
@ -353,12 +427,12 @@ public class BertIterator implements MultiDataSetIterator {
@Override @Override
public void reset() { public void reset() {
if(sentenceProvider != null){ if (sentenceProvider != null) {
sentenceProvider.reset(); sentenceProvider.reset();
} }
} }
public static Builder builder(){ public static Builder builder() {
return new Builder(); return new Builder();
} }
@ -373,7 +447,7 @@ public class BertIterator implements MultiDataSetIterator {
protected MultiDataSetPreProcessor preProcessor; protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null; protected LabeledSentenceProvider sentenceProvider = null;
protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID; protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
protected Map<String,Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
protected BertSequenceMasker masker = new BertMaskedLMMasker(); protected BertSequenceMasker masker = new BertMaskedLMMasker();
protected UnsupervisedLabelFormat unsupervisedLabelFormat; protected UnsupervisedLabelFormat unsupervisedLabelFormat;
protected String maskToken; protected String maskToken;
@ -382,7 +456,7 @@ public class BertIterator implements MultiDataSetIterator {
/** /**
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details. * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
*/ */
public Builder task(Task task){ public Builder task(Task task) {
this.task = task; this.task = task;
return this; return this;
} }
@ -392,18 +466,19 @@ public class BertIterator implements MultiDataSetIterator {
* For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory} * For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory}
* is used * is used
*/ */
public Builder tokenizer(TokenizerFactory tokenizerFactory){ public Builder tokenizer(TokenizerFactory tokenizerFactory) {
this.tokenizerFactory = tokenizerFactory; this.tokenizerFactory = tokenizerFactory;
return this; return this;
} }
/** /**
* Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details. * Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details.
* @param lengthHandling Length handling *
* @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH} * @param lengthHandling Length handling
* @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH}
* @return * @return
*/ */
public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength){ public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength) {
this.lengthHandling = lengthHandling; this.lengthHandling = lengthHandling;
this.maxTokens = maxLength; this.maxTokens = maxLength;
return this; return this;
@ -412,9 +487,10 @@ public class BertIterator implements MultiDataSetIterator {
/** /**
* Minibatch size to use (number of examples to train on for each iteration) * Minibatch size to use (number of examples to train on for each iteration)
* See also: {@link #padMinibatches} * See also: {@link #padMinibatches}
* @param minibatchSize Minibatch size *
* @param minibatchSize Minibatch size
*/ */
public Builder minibatchSize(int minibatchSize){ public Builder minibatchSize(int minibatchSize) {
this.minibatchSize = minibatchSize; this.minibatchSize = minibatchSize;
return this; return this;
} }
@ -429,7 +505,7 @@ public class BertIterator implements MultiDataSetIterator {
* Both options should result in exactly the same model. However, some BERT implementations may require exactly an * Both options should result in exactly the same model. However, some BERT implementations may require exactly an
* exact number of examples in all minibatches to function. * exact number of examples in all minibatches to function.
*/ */
public Builder padMinibatches(boolean padMinibatches){ public Builder padMinibatches(boolean padMinibatches) {
this.padMinibatches = padMinibatches; this.padMinibatches = padMinibatches;
return this; return this;
} }
@ -437,7 +513,7 @@ public class BertIterator implements MultiDataSetIterator {
/** /**
* Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null) * Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null)
*/ */
public Builder preProcessor(MultiDataSetPreProcessor preProcessor){ public Builder preProcessor(MultiDataSetPreProcessor preProcessor) {
this.preProcessor = preProcessor; this.preProcessor = preProcessor;
return this; return this;
} }
@ -446,7 +522,7 @@ public class BertIterator implements MultiDataSetIterator {
* Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised * Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised
* use case, the labels will be ignored. * use case, the labels will be ignored.
*/ */
public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider){ public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
this.sentenceProvider = sentenceProvider; this.sentenceProvider = sentenceProvider;
return this; return this;
} }
@ -454,7 +530,7 @@ public class BertIterator implements MultiDataSetIterator {
/** /**
* Specify what arrays should be returned. See {@link BertIterator} for more details. * Specify what arrays should be returned. See {@link BertIterator} for more details.
*/ */
public Builder featureArrays(FeatureArrays featureArrays){ public Builder featureArrays(FeatureArrays featureArrays) {
this.featureArrays = featureArrays; this.featureArrays = featureArrays;
return this; return this;
} }
@ -465,7 +541,7 @@ public class BertIterator implements MultiDataSetIterator {
* If using {@link BertWordPieceTokenizerFactory}, * If using {@link BertWordPieceTokenizerFactory},
* this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()} * this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()}
*/ */
public Builder vocabMap(Map<String,Integer> vocabMap){ public Builder vocabMap(Map<String, Integer> vocabMap) {
this.vocabMap = vocabMap; this.vocabMap = vocabMap;
return this; return this;
} }
@ -475,7 +551,7 @@ public class BertIterator implements MultiDataSetIterator {
* masked language model. This can be used to customize how the masking is performed.<br> * masked language model. This can be used to customize how the masking is performed.<br>
* Default: {@link BertMaskedLMMasker} * Default: {@link BertMaskedLMMasker}
*/ */
public Builder masker(BertSequenceMasker masker){ public Builder masker(BertSequenceMasker masker) {
this.masker = masker; this.masker = masker;
return this; return this;
} }
@ -485,7 +561,7 @@ public class BertIterator implements MultiDataSetIterator {
* masked language model. Used to specify the format that the labels should be returned in. * masked language model. Used to specify the format that the labels should be returned in.
* See {@link BertIterator} for more details. * See {@link BertIterator} for more details.
*/ */
public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat){ public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat) {
this.unsupervisedLabelFormat = labelFormat; this.unsupervisedLabelFormat = labelFormat;
return this; return this;
} }
@ -497,7 +573,7 @@ public class BertIterator implements MultiDataSetIterator {
* the exact behaviour will depend on what masker is used.<br> * the exact behaviour will depend on what masker is used.<br>
* Note that this must be in the vocabulary map set in {@link #vocabMap} * Note that this must be in the vocabulary map set in {@link #vocabMap}
*/ */
public Builder maskToken(String maskToken){ public Builder maskToken(String maskToken) {
this.maskToken = maskToken; this.maskToken = maskToken;
return this; return this;
} }
@ -510,12 +586,12 @@ public class BertIterator implements MultiDataSetIterator {
* *
* @param prependToken The token to start each sequence with (null: no token will be prepended) * @param prependToken The token to start each sequence with (null: no token will be prepended)
*/ */
public Builder prependToken(String prependToken){ public Builder prependToken(String prependToken) {
this.prependToken = prependToken; this.prependToken = prependToken;
return this; return this;
} }
public BertIterator build(){ public BertIterator build() {
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed"); Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required"); Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map<String,Integer>) to set"); Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map<String,Integer>) to set");

View File

@ -26,7 +26,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
@ -34,10 +33,7 @@ import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Arrays; import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.Random;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -54,6 +50,9 @@ public class TestBertIterator extends BaseDL4JTest {
String toTokenize1 = "I saw a girl with a telescope."; String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
BertIterator b = BertIterator.builder() BertIterator b = BertIterator.builder()
@ -100,12 +99,15 @@ public class TestBertIterator extends BaseDL4JTest {
assertEquals(expF, mds.getFeatures(0)); assertEquals(expF, mds.getFeatures(0));
assertEquals(expM, mds.getFeaturesMaskArray(0)); assertEquals(expM, mds.getFeaturesMaskArray(0));
assertEquals(expF,b.featurizeSentences(forInference).getFirst()[0]);
assertEquals(expM,b.featurizeSentences(forInference).getSecond()[0]);
assertFalse(b.hasNext()); assertFalse(b.hasNext());
b.reset(); b.reset();
assertTrue(b.hasNext()); assertTrue(b.hasNext());
MultiDataSet mds2 = b.next(); MultiDataSet mds2 = b.next();
forInference.set(0,toTokenize2);
//Same thing, but with segment ID also //Same thing, but with segment ID also
b = BertIterator.builder() b = BertIterator.builder()
.tokenizer(t) .tokenizer(t)
@ -118,9 +120,11 @@ public class TestBertIterator extends BaseDL4JTest {
.build(); .build();
mds = b.next(); mds = b.next();
assertEquals(2, mds.getFeatures().length); assertEquals(2, mds.getFeatures().length);
assertEquals(2,b.featurizeSentences(forInference).getFirst().length);
//Segment ID should be all 0s for single segment task //Segment ID should be all 0s for single segment task
INDArray segmentId = expM.like(); INDArray segmentId = expM.like();
assertEquals(segmentId, mds.getFeatures(1)); assertEquals(segmentId, mds.getFeatures(1));
assertEquals(segmentId,b.featurizeSentences(forInference).getFirst()[1]);
} }
@Test(timeout = 20000L) @Test(timeout = 20000L)
@ -157,6 +161,9 @@ public class TestBertIterator extends BaseDL4JTest {
public void testLengthHandling() throws Exception { public void testLengthHandling() throws Exception {
String toTokenize1 = "I saw a girl with a telescope."; String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
@ -205,6 +212,8 @@ public class TestBertIterator extends BaseDL4JTest {
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeatures(0)); assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeatures(0));
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeaturesMaskArray(0)); assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0,14)), mds.getFeaturesMaskArray(0));
assertEquals(mds.getFeatures(0),b.featurizeSentences(forInference).getFirst()[0]);
assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]);
//Clip only: clip to maximum, but don't pad if less //Clip only: clip to maximum, but don't pad if less
b = BertIterator.builder() b = BertIterator.builder()
@ -227,6 +236,9 @@ public class TestBertIterator extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
String toTokenize1 = "I saw a girl with a telescope."; String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
@ -288,6 +300,9 @@ public class TestBertIterator extends BaseDL4JTest {
assertEquals(expM, mds.getFeaturesMaskArray(0)); assertEquals(expM, mds.getFeaturesMaskArray(0));
assertEquals(expL, mds.getLabels(0)); assertEquals(expL, mds.getLabels(0));
assertEquals(expLM, mds.getLabelsMaskArray(0)); assertEquals(expLM, mds.getLabelsMaskArray(0));
assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
} }
private static class TestSentenceProvider implements LabeledSentenceProvider { private static class TestSentenceProvider implements LabeledSentenceProvider {

View File

@ -35,15 +35,18 @@
</properties> </properties>
<build> <build>
<plugins> <pluginManagement>
<plugin> <plugins>
<artifactId>maven-compiler-plugin</artifactId> <plugin>
<configuration> <groupId>org.apache.maven.plugins</groupId>
<source>1.8</source> <artifactId>maven-compiler-plugin</artifactId>
<target>1.8</target> <configuration>
</configuration> <source>1.8</source>
</plugin> <target>1.8</target>
</plugins> </configuration>
</plugin>
</plugins>
</pluginManagement>
</build> </build>
<dependencies> <dependencies>
<dependency> <dependency>

View File

@ -28,16 +28,18 @@
<name>deeplearning4j-parallel-wrapper</name> <name>deeplearning4j-parallel-wrapper</name>
<build> <build>
<plugins> <pluginManagement>
<plugin> <plugins>
<groupId>org.apache.maven.plugins</groupId> <plugin>
<artifactId>maven-compiler-plugin</artifactId> <groupId>org.apache.maven.plugins</groupId>
<configuration> <artifactId>maven-compiler-plugin</artifactId>
<source>1.8</source> <configuration>
<target>1.8</target> <source>1.8</source>
</configuration> <target>1.8</target>
</plugin> </configuration>
</plugins> </plugin>
</plugins>
</pluginManagement>
</build> </build>
<properties> <properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
@ -84,7 +86,7 @@
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-play_2.11</artifactId> <artifactId>deeplearning4j-ui</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@ -79,15 +79,18 @@
<build> <build>
<plugins> <pluginManagement>
<plugin> <plugins>
<artifactId>maven-compiler-plugin</artifactId> <plugin>
<configuration> <groupId>org.apache.maven.plugins</groupId>
<source>1.8</source> <artifactId>maven-compiler-plugin</artifactId>
<target>1.8</target> <configuration>
</configuration> <source>1.8</source>
</plugin> <target>1.8</target>
</plugins> </configuration>
</plugin>
</plugins>
</pluginManagement>
</build> </build>
</project> </project>

View File

@ -67,16 +67,9 @@
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-play_2.11</artifactId> <artifactId>deeplearning4j-ui</artifactId>
<version>${deeplearning4j.version}</version> <version>${deeplearning4j.version}</version>
<scope>test</scope> <scope>test</scope>
<exclusions>
<!-- To avoid clashing net.jpountz.lz4:lz4:1.3.0 and org.lz4:lz4-java:jar:1.4.0 - -->
<exclusion>
<groupId>net.jpountz.lz4</groupId>
<artifactId>lz4</artifactId>
</exclusion>
</exclusions>
</dependency> </dependency>
<dependency> <dependency>

View File

@ -1,17 +0,0 @@
logs
project/project
project/target
target
tmp
.history
dist
/.idea
/*.iml
/out
/.idea_modules
/.classpath
/.project
/RUNNING_PID
/.settings

View File

@ -1,88 +0,0 @@
# Deeplearning4j UI (deeplearning4j-play) Developer Readme
The deeplearning4j UI utilizes an embedded Play server for hosting web pages. This guide documents some basic information
about the design, and how new pages can be added.
## Overview
As of DL4J 0.7.2, version 2.4 is utilized for scala 2.10 compatibility, which was dropped in Play 2.5 and later.
Due to DL4J being managed by Maven, not all features of Play (such as conf/routes approach to routing configuration, and
the default internationalization mechanism) are supported.
Some key concepts:
- The actual HTML pages are defined in the main/views/org.deeplearning4j.ui.views.* directory
- These views are based on Twirl template engine: in practice, templates are mostly HTML with a
little scala code mixed in
- The views themselves need to be rendered to a scala class (under main/scala). This is currently done manually,
see the 'building templates' section.
- Routing is somewhat non-standard in DL4J, compared to a typical Play application
- The DL4J has the concept of a 'UI module': a related set of pages to serve via the UI
- Each module implements the UIModule interface, which defines a set of routes, and how to handle page requests
- Custom modules are supported: this is not enabled by default. See PlayUIServer UI_CUSTOM_MODULE_PROPERTY for details.
- Internationalization (i.e., multiple languages) is supported, but not using the standard Play mechanisms
- The main class and entry point is PlayUIServer
## Building Templates
The current setup (using Maven + Play) means that the HTML templates need to be built manually.
Just run ```mvn compile -P buildUiTemplates``` on the command line to do this - this will generate the Scala classes
from all of the HTML files (in the views directory) and copy them to the appropriate location in the scala directory.
## Adding New Pages
Adding a new (built-in) page to the UI can be done using the following approach:
- Define a new module in org.deeplearning4j.ui.module, that implements the UIModule interface
- Provide routes: these are the endpoints that the user will be able to query, and define what methods will be used
to get the result to return for that page. See for example TrainModule for details
- Each route needs to specify the method to call to get the result (must return a play.mvc.Result object)
- Supplier: used to return results with no args; Function: 1 arg; BiFunction and Function3: 2 and 3 args respectively.
For function/bifunction etc, the arguments are specified with semicolons, like "/myRoute/:myArg"; the called
method should have an appropriate number of arguments to match this.
- Optionally: add code to handle callbacks, storage events, stats storage. See the section below.
- Add the module to the others, in the PlayUIServer. Should be 1 line: ```uiModules.add(new MyNewModule());```
## Assets
Again (due to issues with Maven + Play support), DL4J defines a custom Assets serving system. Assets are static files
such as images, CSS, javascript files, etc.
The contents of /resources/deeplearning4jUiAssets/* get mapped to /assets/* on the running server. To add a new asset,
simply add it to the assets folder, and then reference it as normal in the HTML templates or elsewhere.
## StatsStorage, Callbacks, and Getting Network and Other Info from DL4J
The Deeplearning4j UI supports callbacks from StatsStorage. StatsStorage instances are objects/interfaces that store
training information - some of it may be static/stored, some of it may be streaming in, in real-time.
When a user calls ```UIServer.getInstance().attach(StatsStorage)``` the provided StatsStorage instance will provide
callbacks to the UI whenever something changes. For example, new information from a trained network is added to the
StatsStorage from the StatsListener, the modules that are registered for callbacks (of that type) will be notified.
The UI modules can then query the StatStorage instance (or not) to get the relevant information for displaying in the UI.
Each UI module specifies the callbacks it wants to receive via Strings (UIModule.getCallbackTypeIDs). Each String is a
key that must match the TypeID of the data to listen for; consequently, there is a correspondence between the TypeID of
the generating class (StatsListener, for example) and the TypeID that the UI module wants to receive. UI modules may
specify zero or more callback type IDs. Any information that is not relevant to the UI module (i.e., doesn't match a
specified TypeID for the module) won't be forwarded on to the UI module.
## Internationalization
The Play framework does provide an internationalization mechanism, though this was found to be problematic when running
Play via a Maven project. Consequently, DL4J defines its own I18N implementation, that is functionally similar to the
Play mechanism.
The I18N interface, and the DefaultI18N class define getMessage(String) methods. For example, getMessage()
Currently (as of 0.7.2), the language setting is *server side* not client side - thus only one language can be shown per
UI server. Furthermore, the implementation is currently set to fall back to English.
The actual content for internationalization is present under the resources/dl4j_i18n directory. Each file within this
directory should contain an ISO639 language code (en, ja, de etc) as the extension.
Conceptually, the files may be separated; in practice, the contents of all files (for the relevant language) are
In practice, just add a snippet such as ```@i18n.getMessage("train.pagetitle")``` to the HTML template to get the
appropriate entry for the current language. See the train module UI pages for more examples.
Note also that it is necessary to provide an I18N instance to the templates. See the TrainModule routing section
for an example on how to do this.

View File

@ -1,65 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.api;
import lombok.AllArgsConstructor;
import lombok.Data;
import play.mvc.Http;
import play.mvc.Result;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
/**
* A Route specifies an endpoint that can be queried in the UI - along with how it should be handled
*
* @author Alex Black
*/
@Data
@AllArgsConstructor
public class Route {
private final String route;
private final HttpMethod httpMethod;
private final FunctionType functionType;
private final Supplier<Result> supplier;
private final Function<String, Result> function;
private final BiFunction<String, String, Result> function2;
private final Function<Http.Request, Result> request0Function;
private final BiFunction<Http.Request, String, Result> request1Function;
public Route(String route, HttpMethod method, FunctionType functionType, Supplier<Result> supplier) {
this(route, method, functionType, supplier, null, null, null, null);
}
public Route(String route, HttpMethod method, FunctionType functionType, Function<String, Result> function) {
this(route, method, functionType, null, function, null, null, null);
}
public static Route request0Function(String route, HttpMethod httpMethod, Function<Http.Request, Result> function){
return new Route(route, httpMethod, FunctionType.Request0Function, null, null, null, function, null);
}
public static Route request1Function(String route, HttpMethod httpMethod, BiFunction<Http.Request, String, Result> function){
return new Route(route, httpMethod, FunctionType.Request1Function, null, null, null, null, function);
}
public Route(String route, HttpMethod method, FunctionType functionType,
BiFunction<String, String, Result> function) {
this(route, method, functionType, null, null, function, null, null);
}
}

View File

@ -1,163 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.module.tsne;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.ui.api.FunctionType;
import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.i18n.I18NResource;
import play.libs.Files;
import play.libs.Json;
import play.mvc.Http;
import play.mvc.Result;
import play.mvc.Results;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.*;
import static play.mvc.Results.badRequest;
import static play.mvc.Results.ok;
/**
* Created by Alex on 25/10/2016.
*/
public class TsneModule implements UIModule {
private static final String UPLOADED_FILE = "UploadedFile";
private Map<String, List<String>> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>());
private List<String> uploadedFileLines = null;
public TsneModule() {
}
@Override
public List<String> getCallbackTypeIDs() {
return Collections.emptyList();
}
@Override
public List<Route> getRoutes() {
Route r1 = new Route("/tsne", HttpMethod.GET, FunctionType.Supplier,
() -> ok(org.deeplearning4j.ui.views.html.tsne.Tsne.apply()));
Route r2 = new Route("/tsne/sessions", HttpMethod.GET, FunctionType.Supplier, this::listSessions);
Route r3 = new Route("/tsne/coords/:sid", HttpMethod.GET, FunctionType.Function, this::getCoords);
Route r4 = Route.request0Function("/tsne/upload", HttpMethod.POST, this::uploadFile);
Route r5 = Route.request1Function("/tsne/post/:sid", HttpMethod.POST, this::postFile);
return Arrays.asList(r1, r2, r3, r4, r5);
}
@Override
public void reportStorageEvents(Collection<StatsStorageEvent> events) {
}
@Override
public void onAttach(StatsStorage statsStorage) {
}
@Override
public void onDetach(StatsStorage statsStorage) {
}
@Override
public List<I18NResource> getInternationalizationResources() {
return Collections.emptyList();
}
private Result listSessions() {
List<String> list = new ArrayList<>(knownSessionIDs.keySet());
if (uploadedFileLines != null) {
list.add(UPLOADED_FILE);
}
return Results.ok(Json.toJson(list));
}
private Result getCoords(String sessionId) {
if (UPLOADED_FILE.equals(sessionId) && uploadedFileLines != null) {
return Results.ok(Json.toJson(uploadedFileLines));
} else if (knownSessionIDs.containsKey(sessionId)) {
return Results.ok(Json.toJson(knownSessionIDs.get(sessionId)));
} else {
return Results.ok();
}
}
private Result uploadFile(Http.Request request) {
Http.MultipartFormData<Files.TemporaryFile> body = request.body().asMultipartFormData();
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> fileParts = body.getFiles();
if (fileParts.isEmpty()) {
return badRequest("No file uploaded");
}
Http.MultipartFormData.FilePart<Files.TemporaryFile> uploadedFile = fileParts.get(0);
String fileName = uploadedFile.getFilename();
String contentType = uploadedFile.getContentType();
File file = uploadedFile.getRef().path().toFile();
try {
uploadedFileLines = FileUtils.readLines(file, StandardCharsets.UTF_8);
} catch (IOException e) {
return badRequest("Could not read from uploaded file");
}
return ok("File uploaded: " + fileName + ", " + contentType + ", " + file);
}
private Result postFile(Http.Request request, String sid) {
// System.out.println("POST FILE CALLED: " + sid);
Http.MultipartFormData<Files.TemporaryFile> body = request.body().asMultipartFormData();
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> fileParts = body.getFiles();
if (fileParts.isEmpty()) {
// System.out.println("**** NO FILE ****");
return badRequest("No file uploaded");
}
Http.MultipartFormData.FilePart<Files.TemporaryFile> uploadedFile = fileParts.get(0);
String fileName = uploadedFile.getFilename();
String contentType = uploadedFile.getContentType();
File file = uploadedFile.getRef().path().toFile();
List<String> lines;
try {
// Set to uploadedFileLines as well, as the TSNE UI doesn't allow to properly select Sessions yet
lines = uploadedFileLines = FileUtils.readLines(file);
} catch (IOException e) {
// System.out.println("**** COULD NOT READ FILE ****");
return badRequest("Could not read from uploaded file");
}
knownSessionIDs.put(sid, lines);
return ok("File uploaded: " + fileName + ", " + contentType + ", " + file);
}
}

View File

@ -1,68 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.play.staticroutes;
import org.nd4j.shade.guava.net.HttpHeaders;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FilenameUtils;
import org.nd4j.linalg.io.ClassPathResource;
import play.mvc.Result;
import play.mvc.StaticFileMimeTypes;
import java.io.InputStream;
import java.util.Optional;
import static play.mvc.Results.ok;
/**
* Simple function for serving assets. Assets are assumed to be in the specified root directory
*
* @author Alex Black
*/
@AllArgsConstructor
@Slf4j
public class Assets {
public static Result assetRequest(String assetsRootDirectory, String s) {
String fullPath;
if(s.startsWith("webjars/")){
fullPath = "META-INF/resources/" + s;
} else {
fullPath = assetsRootDirectory + s;
}
InputStream inputStream;
try {
inputStream = new ClassPathResource(fullPath).getInputStream();
} catch (Throwable t) {
log.warn("Could not find requested UI asset: {}", s, t);
return ok();
}
String fileName = FilenameUtils.getName(fullPath);
Optional<String> contentType = StaticFileMimeTypes.fileMimeTypes().forFileName(fileName);
String ct;
ct = contentType.orElse("application/octet-stream");
return ok(inputStream)
.withHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\"")
.as(ct);
}
}

View File

@ -1,38 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.play.staticroutes;
import org.deeplearning4j.ui.i18n.I18NProvider;
import play.mvc.Result;
import java.util.function.BiFunction;
import static play.mvc.Results.ok;
/**
* Route for multi-session internationalization setting
*
* @author Tamas Fenyvesi
*/
public class MultiSessionI18NRoute implements BiFunction<String, String, Result> {
@Override
public Result apply(String sessionId, String languageCode) {
I18NProvider.getInstance(sessionId).setDefaultLanguage(languageCode);
return ok();
}
}

View File

@ -1,350 +0,0 @@
# This is the main configuration file for the application.
# https://www.playframework.com/documentation/latest/ConfigFile
# ~~~~~
# Play uses HOCON as its configuration file format. HOCON has a number
# of advantages over other config formats, but there are two things that
# can be used when modifying settings.
#
# You can include other configuration files in this main application.conf file:
#include "extra-config.conf"
#
# You can declare variables and substitute for them:
#mykey = ${some.value}
#
# And if an environment variable exists when there is no other subsitution, then
# HOCON will fall back to substituting environment variable:
#mykey = ${JAVA_HOME}
## Akka
# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration
# https://www.playframework.com/documentation/latest/JavaAkka#Configuration
# ~~~~~
# Play uses Akka internally and exposes Akka Streams and actors in Websockets and
# other streaming HTTP responses.
akka {
# "akka.log-config-on-start" is extraordinarly useful because it log the complete
# configuration at INFO level, including defaults and overrides, so it s worth
# putting at the very top.
#
# Put the following in your conf/logback.xml file:
#
# <logger name="akka.actor" level="INFO" />
#
# And then uncomment this line to debug the configuration.
#
#log-config-on-start = true
}
## Modules
# https://www.playframework.com/documentation/latest/Modules
# ~~~~~
# Control which modules are loaded when Play starts. Note that modules are
# the replacement for "GlobalSettings", which are deprecated in 2.5.x.
# Please see https://www.playframework.com/documentation/latest/GlobalSettings
# for more information.
#
# You can also extend Play functionality by using one of the publically available
# Play modules: https://playframework.com/documentation/latest/ModuleDirectory
play.modules {
# By default, Play will load any class called Module that is defined
# in the root package (the "app" directory), or you can define them
# explicitly below.
# If there are any built-in modules that you want to disable, you can list them here.
#enabled += my.application.Module
# If there are any built-in modules that you want to disable, you can list them here.
#disabled += ""
}
## Internationalisation
# https://www.playframework.com/documentation/latest/JavaI18N
# https://www.playframework.com/documentation/latest/ScalaI18N
# ~~~~~
# Play comes with its own i18n settings, which allow the user's preferred language
# to map through to internal messages, or allow the language to be stored in a cookie.
play.i18n {
# The application languages
langs = [ "en" ]
# Whether the language cookie should be secure or not
#langCookieSecure = true
# Whether the HTTP only attribute of the cookie should be set to true
#langCookieHttpOnly = true
}
## Play HTTP settings
# ~~~~~
play.http {
## Router
# https://www.playframework.com/documentation/latest/JavaRouting
# https://www.playframework.com/documentation/latest/ScalaRouting
# ~~~~~
# Define the Router object to use for this application.
# This router will be looked up first when the application is starting up,
# so make sure this is the entry point.
# Furthermore, it's assumed your route file is named properly.
# So for an application router like `my.application.Router`,
# you may need to define a router file `conf/my.application.routes`.
# Default to Routes in the root package (aka "apps" folder) (and conf/routes)
#router = my.application.Router
## Action Creator
# https://www.playframework.com/documentation/latest/JavaActionCreator
# ~~~~~
#actionCreator = null
## ErrorHandler
# https://www.playframework.com/documentation/latest/JavaRouting
# https://www.playframework.com/documentation/latest/ScalaRouting
# ~~~~~
# If null, will attempt to load a class called ErrorHandler in the root package,
#errorHandler = null
## Filters
# https://www.playframework.com/documentation/latest/ScalaHttpFilters
# https://www.playframework.com/documentation/latest/JavaHttpFilters
# ~~~~~
# Filters run code on every request. They can be used to perform
# common logic for all your actions, e.g. adding common headers.
# Defaults to "Filters" in the root package (aka "apps" folder)
# Alternatively you can explicitly register a class here.
#filters += my.application.Filters
## Session & Flash
# https://www.playframework.com/documentation/latest/JavaSessionFlash
# https://www.playframework.com/documentation/latest/ScalaSessionFlash
# ~~~~~
session {
# Sets the cookie to be sent only over HTTPS.
#secure = true
# Sets the cookie to be accessed only by the server.
#httpOnly = true
# Sets the max-age field of the cookie to 5 minutes.
# NOTE: this only sets when the browser will discard the cookie. Play will consider any
# cookie value with a valid signature to be a valid session forever. To implement a server side session timeout,
# you need to put a timestamp in the session and check it at regular intervals to possibly expire it.
#maxAge = 300
# Sets the domain on the session cookie.
#domain = "example.com"
}
flash {
# Sets the cookie to be sent only over HTTPS.
#secure = true
# Sets the cookie to be accessed only by the server.
#httpOnly = true
}
}
## Netty Provider
# https://www.playframework.com/documentation/latest/SettingsNetty
# ~~~~~
play.server.netty {
# Whether the Netty wire should be logged
#log.wire = true
# If you run Play on Linux, you can use Netty's native socket transport
# for higher performance with less garbage.
#transport = "native"
}
## WS (HTTP Client)
# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS
# ~~~~~
# The HTTP client primarily used for REST APIs. The default client can be
# configured directly, but you can also create different client instances
# with customized settings. You must enable this by adding to build.sbt:
#
# libraryDependencies += ws // or javaWs if using java
#
play.ws {
# Sets HTTP requests not to follow 302 requests
#followRedirects = false
# Sets the maximum number of open HTTP connections for the client.
#ahc.maxConnectionsTotal = 50
## WS SSL
# https://www.playframework.com/documentation/latest/WsSSL
# ~~~~~
ssl {
# Configuring HTTPS with Play WS does not require programming. You can
# set up both trustManager and keyManager for mutual authentication, and
# turn on JSSE debugging in development with a reload.
#debug.handshake = true
#trustManager = {
# stores = [
# { type = "JKS", path = "exampletrust.jks" }
# ]
#}
}
}
## Cache
# https://www.playframework.com/documentation/latest/JavaCache
# https://www.playframework.com/documentation/latest/ScalaCache
# ~~~~~
# Play comes with an integrated cache API that can reduce the operational
# overhead of repeated requests. You must enable this by adding to build.sbt:
#
# libraryDependencies += cache
#
play.cache {
# If you want to bind several caches, you can bind the individually
#bindCaches = ["db-cache", "user-cache", "session-cache"]
}
## Filters
# https://www.playframework.com/documentation/latest/Filters
# ~~~~~
# There are a number of built-in filters that can be enabled and configured
# to give Play greater security. You must enable this by adding to build.sbt:
#
# libraryDependencies += filters
#
play.filters {
## CORS filter configuration
# https://www.playframework.com/documentation/latest/CorsFilter
# ~~~~~
# CORS is a protocol that allows web applications to make requests from the browser
# across different domains.
# NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has
# dependencies on CORS settings.
cors {
# Filter paths by a whitelist of path prefixes
#pathPrefixes = ["/some/path", ...]
# The allowed origins. If null, all origins are allowed.
#allowedOrigins = ["http://www.example.com"]
# The allowed HTTP methods. If null, all methods are allowed
#allowedHttpMethods = ["GET", "POST"]
}
## CSRF Filter
# https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter
# https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter
# ~~~~~
# Play supports multiple methods for verifying that a request is not a CSRF request.
# The primary mechanism is a CSRF token. This token gets placed either in the query string
# or body of every form submitted, and also gets placed in the users session.
# Play then verifies that both tokens are present and match.
csrf {
# Sets the cookie to be sent only over HTTPS
#cookie.secure = true
# Defaults to CSRFErrorHandler in the root package.
#errorHandler = MyCSRFErrorHandler
}
## Security headers filter configuration
# https://www.playframework.com/documentation/latest/SecurityHeaders
# ~~~~~
# Defines security headers that prevent XSS attacks.
# If enabled, then all options are set to the below configuration by default:
headers {
# The X-Frame-Options header. If null, the header is not set.
#frameOptions = "DENY"
# The X-XSS-Protection header. If null, the header is not set.
#xssProtection = "1; mode=block"
# The X-Content-Type-Options header. If null, the header is not set.
#contentTypeOptions = "nosniff"
# The X-Permitted-Cross-Domain-Policies header. If null, the header is not set.
#permittedCrossDomainPolicies = "master-only"
# The Content-Security-Policy header. If null, the header is not set.
#contentSecurityPolicy = "default-src 'self'"
}
## Allowed hosts filter configuration
# https://www.playframework.com/documentation/latest/AllowedHostsFilter
# ~~~~~
# Play provides a filter that lets you configure which hosts can access your application.
# This is useful to prevent cache poisoning attacks.
hosts {
# Allow requests to example.com, its subdomains, and localhost:9000.
#allowed = [".example.com", "localhost:9000"]
}
}
## Evolutions
# https://www.playframework.com/documentation/latest/Evolutions
# ~~~~~
# Evolutions allows database scripts to be automatically run on startup in dev mode
# for database migrations. You must enable this by adding to build.sbt:
#
# libraryDependencies += evolutions
#
play.evolutions {
# You can disable evolutions for a specific datasource if necessary
#db.default.enabled = false
}
## Database Connection Pool
# https://www.playframework.com/documentation/latest/SettingsJDBC
# ~~~~~
# Play doesn't require a JDBC database to run, but you can easily enable one.
#
# libraryDependencies += jdbc
#
play.db {
# The combination of these two settings results in "db.default" as the
# default JDBC pool:
#config = "db"
#default = "default"
# Play uses HikariCP as the default connection pool. You can override
# settings by changing the prototype:
prototype {
# Sets a fixed JDBC connection pool size of 50
#hikaricp.minimumIdle = 50
#hikaricp.maximumPoolSize = 50
}
}
## JDBC Datasource
# https://www.playframework.com/documentation/latest/JavaDatabase
# https://www.playframework.com/documentation/latest/ScalaDatabase
# ~~~~~
# Once JDBC datasource is set up, you can work with several different
# database options:
#
# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick
# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA
# EBean: https://playframework.com/documentation/latest/JavaEbean
# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm
#
db {
# You can declare as many datasources as you want.
# By convention, the default datasource is named `default`
# https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database
default.driver = org.h2.Driver
default.url = "jdbc:h2:mem:play"
#default.username = sa
#default.password = ""
# You can expose this datasource via JNDI if needed (Useful for JPA)
default.jndiName=DefaultDS
# You can turn on SQL logging for any datasource
# https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements
#default.logSql=true
}
jpa.default=defaultPersistenceUnit
#Increase default maximum post length - used for remote listener functionality
#Can get response 413 with larger networks without setting this
# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead
#parsers.text.maxLength=10M
play.http.parser.maxMemoryBuffer=10M

View File

@ -1,159 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.views.html.convolutional
import play.twirl.api._
import play.twirl.api.TemplateMagic._
object Activations_Scope0 {
import models._
import controllers._
import play.api.i18n._
import views.html._
import play.api.templates.PlayMagic._
import play.api.mvc._
import play.api.data._
class Activations extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template0[play.twirl.api.HtmlFormat.Appendable] {
/**/
def apply():play.twirl.api.HtmlFormat.Appendable = {
_display_ {
{
Seq[Any](format.raw/*1.1*/("""<!DOCTYPE html>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Neural Network activations</title>
<!-- jQuery -->
<script src="/assets/webjars/jquery/2.2.0/jquery.min.js"></script>
"""),format.raw/*25.68*/("""
"""),format.raw/*26.13*/("""<!-- Latest compiled and minified CSS -->
<link id="bootstrap-style" href="/assets/webjars/bootstrap/2.3.2/css/bootstrap.min.css" rel="stylesheet">
<!-- Latest compiled and minified JavaScript -->
<script src="/assets/webjars/bootstrap/2.3.2/js/bootstrap.min.js"></script>
<style>
body """),format.raw/*33.14*/("""{"""),format.raw/*33.15*/("""
"""),format.raw/*34.13*/("""font-family: 'Roboto', sans-serif;
color: #333;
font-weight: 300;
font-size: 16px;
"""),format.raw/*38.9*/("""}"""),format.raw/*38.10*/("""
"""),format.raw/*40.9*/(""".hd """),format.raw/*40.13*/("""{"""),format.raw/*40.14*/("""
"""),format.raw/*41.13*/("""background-color: #000000;
font-size: 18px;
color: #FFFFFF;
"""),format.raw/*44.9*/("""}"""),format.raw/*44.10*/("""
"""),format.raw/*46.9*/(""".block """),format.raw/*46.16*/("""{"""),format.raw/*46.17*/("""
"""),format.raw/*47.13*/("""width: 250px;
height: 350px;
display: inline-block;
border: 1px solid #DEDEDE;
margin-right: 64px;
"""),format.raw/*52.9*/("""}"""),format.raw/*52.10*/("""
"""),format.raw/*54.9*/(""".hd-small """),format.raw/*54.19*/("""{"""),format.raw/*54.20*/("""
"""),format.raw/*55.13*/("""background-color: #000000;
font-size: 14px;
color: #FFFFFF;
"""),format.raw/*58.9*/("""}"""),format.raw/*58.10*/("""
"""),format.raw/*59.9*/("""</style>
<script type="text/javascript">
setInterval(function () """),format.raw/*62.33*/("""{"""),format.raw/*62.34*/("""
"""),format.raw/*63.13*/("""var d = new Date();
$("#pic").removeAttr("src").attr("src", "/activations/data?timestamp=" + new Date().getTime());
"""),format.raw/*65.9*/("""}"""),format.raw/*65.10*/(""", 3000);
</script>
</head>
<body>
<table style="width: 100%;
padding: 5px;" class="hd">
<tbody>
<tr>
<td style="width: 48px;"><a href="/"><img src="/assets/legacy/deeplearning4j.img" border="0"/></a></td>
<td>DeepLearning4j UI</td>
<td style="width: 128px;">&nbsp; <!-- placeholder for future use --></td>
</tr>
</tbody>
</table>
<br /> <br />
<div style="width: 100%;
text-align: center">
<div id="embed" style="display: inline-block;"> <!-- style="border: 1px solid #CECECE;" -->
<img src="/activations/data" id="pic" />
</div>
</div>
</body>
</html>"""))
}
}
}
def render(): play.twirl.api.HtmlFormat.Appendable = apply()
def f:(() => play.twirl.api.HtmlFormat.Appendable) = () => apply()
def ref: this.type = this
}
}
/**/
object Activations extends Activations_Scope0.Activations
/*
-- GENERATED --
DATE: Tue May 07 21:39:40 AEST 2019
SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/convolutional/Activations.scala.html
HASH: fb7b9cdfdd31f76d21208213b839aeebac3dabe2
MATRIX: 657->0|1724->1098|1766->1112|2133->1451|2162->1452|2204->1466|2362->1597|2391->1598|2430->1610|2462->1614|2491->1615|2533->1629|2655->1724|2684->1725|2723->1737|2758->1744|2787->1745|2829->1759|3016->1919|3045->1920|3084->1932|3122->1942|3151->1943|3193->1957|3315->2052|3344->2053|3381->2063|3494->2148|3523->2149|3565->2163|3730->2301|3759->2302
LINES: 25->1|49->25|50->26|57->33|57->33|58->34|62->38|62->38|64->40|64->40|64->40|65->41|68->44|68->44|70->46|70->46|70->46|71->47|76->52|76->52|78->54|78->54|78->54|79->55|82->58|82->58|83->59|86->62|86->62|87->63|89->65|89->65
-- GENERATED --
*/

View File

@ -1,240 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.views.html.samediff
import play.twirl.api._
import play.twirl.api.TemplateMagic._
object SameDiffUI_Scope0 {
import models._
import controllers._
import play.api.i18n._
import views.html._
import play.api.templates.PlayMagic._
import play.api.mvc._
import play.api.data._
class SameDiffUI extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template0[play.twirl.api.HtmlFormat.Appendable] {
/**/
def apply/*1.2*/():play.twirl.api.HtmlFormat.Appendable = {
_display_ {
{
Seq[Any](format.raw/*1.4*/("""
"""),format.raw/*2.1*/("""<!DOCTYPE html>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2019 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<html lang="en" style="height: 100%">
<head>
<meta charset="utf-8">
<title>SameDiff Graph Visualization</title>
<!-- start: Mobile Specific -->
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- end: Mobile Specific -->
<link id="bootstrap-style" href="/assets/webjars/bootstrap/4.3.1/dist/css/bootstrap.min.css" rel="stylesheet">
<link href="/assets/css/samediff/samediff.css" rel="stylesheet">
<![endif]-->
</head>
<body>
<!-- Start JavaScript-->
<script src="/assets/webjars/jquery/2.2.0/jquery.min.js"></script>
<script src="/assets/webjars/jquery-ui/1.10.2/ui/minified/jquery-ui.min.js"></script>
<script src="/assets/webjars/bootstrap/4.3.1/dist/js/bootstrap.min.js"></script>
<script src="/assets/webjars/jquery-cookie/1.4.1-1/jquery.cookie.js"></script>
<script src="/assets/webjars/flatbuffers/1.9.0/js/flatbuffers.js"></script>
<script src="/assets/webjars/cytoscape/3.3.3/dist/cytoscape.min.js"></script>
<script src="/assets/webjars/dagre/0.8.4/dist/dagre.min.js"></script>
<script src="/assets/webjars/cytoscape-dagre/2.1.0/cytoscape-dagre.js"></script>
<script src="/assets/webjars/cytoscape-cose-bilkent/4.0.0/cytoscape-cose-bilkent.js"></script>
<script src="/assets/webjars/webcola/3.1.3/WebCola/cola.js"></script>
<script src="/assets/webjars/cytoscape-cola/2.3.0/cytoscape-cola.js"></script>
<script src="/assets/webjars/cytoscape-euler/1.2.1/cytoscape-euler.js"></script>
<script src="/assets/webjars/klayjs/0.4.1/klay.js"></script>
<script src="/assets/webjars/cytoscape-klay/3.1.2/cytoscape-klay.js"></script>
<script src="/assets/webjars/weaverjs/1.2.0/dist/weaver.js"></script>
<script src="/assets/webjars/cytoscape-spread/3.0.0/cytoscape-spread.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.pie.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.stack.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.resize.min.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.selection.js"></script>
<script src="/assets/js/samediff/generated/uigraphevents_generated.js"></script>
<script src="/assets/js/samediff/generated/uigraphstatic_generated.js"></script>
<script src="/assets/js/samediff/generated/array_generated.js"></script>
<script src="/assets/js/samediff/generated/utils_generated.js"></script>
<script src="/assets/js/samediff/generated/variable_generated.js"></script>
<script src="/assets/js/samediff/samediff-ui.js"></script>
<script src="/assets/js/samediff/samediff-graph.js"></script>
<script src="/assets/js/samediff/samediff-plots.js"></script>
<script src="/assets/js/samediff/flatbuffers-utils.js"></script>
"""),format.raw/*71.34*/("""
"""),format.raw/*72.9*/("""<div class="container-fluid" style="min-height: 100%">
<div class="row">
<!-- NavBar - Bootstrap classes -->
<nav class="navbar navbar-expand navbar-dark bg-dark" style="width: 100pc">
<a class="navbar-brand" href="#">SameDiff</a>
<button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbarNavAltMarkup" aria-controls="navbarNavAltMarkup" aria-expanded="false" aria-label="Toggle navigation">
<span class="navbar-toggler-icon"></span>
</button>
<div class="collapse navbar-collapse" id="navbarNavAltMarkup">
<div class="navbar-nav">
<a id="sdnavgraph" class="nav-item nav-link active" href="#" onclick="samediffSetPage('graph')">
Graph</a>
<a id="sdnavplots" class="nav-item nav-link" href="#" onclick="samediffSetPage('plots')">
Plots</a>
<a id="sdnaveval" class="nav-item nav-link" href="#" onclick="samediffSetPage('evaluation')">
Evaluation</a>
<a id="sdnavperf" class="nav-item nav-link" href="#" onclick="samediffSetPage('performance')">
Performance</a>
<a class="nav-item nav-link" href="#" onclick="toggleSidebar()">Toggle Sidebar</a>
</div>
</div>
</nav>
</div>
<div class="row" style="min-height: 100%">
<!-- Sidebar -->
"""),format.raw/*98.47*/("""
"""),format.raw/*99.17*/("""<div class="col-md-4 col-12" style="min-width: 300px; max-width: 300px; background-color: #e6e6e6; height:100%; min-height:100vh">
"""),format.raw/*100.81*/("""
"""),format.raw/*101.21*/("""<div id="sidebartop" class="row p-2">
<div style="width:auto">
<label class="input-group-btn">
<span class="btn btn-secondary btn-sm">
Select File<input type="file" id="fileselect" style="display: none;" multiple>
</span>
</label>
</div>
<div id="selectedfile" class="w-100">[No File Loaded]</div>
</div>
<div class="sidebarline"></div>
"""),format.raw/*114.85*/("""
"""),format.raw/*115.21*/("""<div id="sidebarmid" class="row p-2">
<div class="w-100"><b>Selected Node:</b></div>
<div id="sidebarmid-content" class="w-100">(None)</div>
</div>
<div class="sidebarline"></div>
<div id="sidebarmid2" class="row p-2">
<div style="width:100%">
<b>Find Node:</b><br>
</div>
<input id="findnodetxt" type="text" oninput="onGraphNodeSearch()">
<div id="findnoderesults">
</div>
</div>
<div class="sidebarline"></div>
"""),format.raw/*134.87*/("""
"""),format.raw/*135.21*/("""<div id="sidebarbottom" class="row p-2">
<br><br>
<strong>Graph Layout:</strong>
<div class="btn-group btn-group-toggle w-100" data-toggle="buttons" style="height: 40px">
<label class="btn btn-secondary active" onclick="setLayout('klay_down')">
<input type="radio" name="options" id="option1" autocomplete="off" checked>Down</label>
<label class="btn btn-secondary" onclick="setLayout('klay_lr')">
<input type="radio" name="options" id="option2" autocomplete="off">Right</label>
<label class="btn btn-secondary" onclick="setLayout('dagre')">
<input type="radio" name="options" id="option3" autocomplete="off">Alt</label>
<label class="btn btn-secondary" onclick="setLayout('cose-bilkent')">
<input type="radio" name="options" id="option3" autocomplete="off">Spread</label>
</div>
<br>
<br>
<br>
</div>
</div>
<!-- Page Content -->
<div id="samediffcontent" class="col-md col-12 main pa-1">
<div id="graphdiv" style="height: 100%;
width: 100%;
display: table"></div>
</div>
</div>
</div>
<!-- Execute once on page load -->
<script>
document.getElementById('fileselect').addEventListener('change', fileSelect, false);
$(document).ready(function () """),format.raw/*167.47*/("""{"""),format.raw/*167.48*/("""
"""),format.raw/*168.21*/("""renderSameDiffGraph();
"""),format.raw/*169.17*/("""}"""),format.raw/*169.18*/(""");
</script>
</body>
</html>
"""))
}
}
}
def render(): play.twirl.api.HtmlFormat.Appendable = apply()
def f:(() => play.twirl.api.HtmlFormat.Appendable) = () => apply()
def ref: this.type = this
}
}
/**/
object SameDiffUI extends SameDiffUI_Scope0.SameDiffUI
/*
-- GENERATED --
DATE: Tue May 07 21:39:40 AEST 2019
SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/samediff/SameDiffUI.scala.html
HASH: d80f43e3f37ebbfb22768edb52560e2d6d6fcb2a
MATRIX: 561->1|657->3|685->5|4621->3938|4658->3948|6437->5729|6483->5747|6664->5959|6715->5981|7384->6685|7435->6707|8217->7526|8268->7548|10114->9365|10144->9366|10195->9388|10264->9428|10294->9429
LINES: 20->1|25->1|26->2|95->71|96->72|122->98|123->99|124->100|125->101|138->114|139->115|158->134|159->135|191->167|191->167|192->168|193->169|193->169
-- GENERATED --
*/

View File

@ -1,175 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.views.html.samediff
import play.twirl.api._
import play.twirl.api.TemplateMagic._
object SameDiffUIBackup_Scope0 {
import models._
import controllers._
import play.api.i18n._
import views.html._
import play.api.templates.PlayMagic._
import play.api.mvc._
import play.api.data._
class SameDiffUIBackup extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template0[play.twirl.api.HtmlFormat.Appendable] {
/**/
def apply/*1.2*/():play.twirl.api.HtmlFormat.Appendable = {
_display_ {
{
Seq[Any](format.raw/*1.4*/("""
"""),format.raw/*2.1*/("""<!DOCTYPE html>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2019 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<html lang="en" style="height: 100%">
<head>
<meta charset="utf-8">
<title>SameDiff Graph Visualization</title>
<!-- start: Mobile Specific -->
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- end: Mobile Specific -->
<link id="bootstrap-style" href="/assets/webjars/bootstrap/2.3.1/css/bootstrap.min.css" rel="stylesheet">
<!-- The HTML5 shim, for IE6-8 support of HTML5 elements -->
<!--[if lt IE 9]>
<script src="http://html5shim.googlecode.com/svn/trunk/html5.js"></script>
<link id="ie-style" href="/assets/css/ie.css" rel="stylesheet"/>
<![endif]-->
<!--[if IE 9]>
<link id="ie9style" href="/assets/css/ie9.css" rel="stylesheet"/>
<![endif]-->
</head>
<body style="height: 100%;
margin: 0;">
<!-- Start JavaScript-->
<script src="/assets/webjars/jquery/2.2.0/jquery.min.js"></script>
<script src="/assets/webjars/jquery-ui/1.10.2/ui/minified/jquery-ui.min.js"></script>
<script src="/assets/webjars/bootstrap/2.3.1/js/bootstrap.min.js"></script>
<script src="/assets/webjars/jquery-cookie/1.4.1-1/jquery.cookie.js"></script>
<script src="/assets/webjars/flatbuffers/1.9.0/js/flatbuffers.js"></script>
<script src="/assets/webjars/cytoscape/3.3.3/dist/cytoscape.min.js"></script>
<script src="/assets/webjars/dagre/0.8.4/dist/dagre.min.js"></script>
<script src="/assets/webjars/cytoscape-dagre/2.1.0/cytoscape-dagre.js"></script>
<script src="/assets/webjars/cytoscape-cose-bilkent/4.0.0/cytoscape-cose-bilkent.js"></script>
<script src="/assets/webjars/webcola/3.1.3/WebCola/cola.js"></script>
<script src="/assets/webjars/cytoscape-cola/2.3.0/cytoscape-cola.js"></script>
<script src="/assets/webjars/cytoscape-euler/1.2.1/cytoscape-euler.js"></script>
<script src="/assets/webjars/klayjs/0.4.1/klay.js"></script>
<script src="/assets/webjars/cytoscape-klay/3.1.2/cytoscape-klay.js"></script>
<script src="/assets/webjars/weaverjs/1.2.0/dist/weaver.js"></script>
<script src="/assets/webjars/cytoscape-spread/3.0.0/cytoscape-spread.js"></script>
<script src="/assets/js/samediff/generated/uigraphevents_generated.js"></script>
<script src="/assets/js/samediff/generated/uigraphstatic_generated.js"></script>
<script src="/assets/js/samediff/generated/array_generated.js"></script>
<script src="/assets/js/samediff/generated/utils_generated.js"></script>
<script src="/assets/js/samediff/generated/variable_generated.js"></script>
<script src="/assets/js/samediff/samediff-ui.js"></script>
<script src="/assets/js/samediff/flatbuffers-utils.js"></script>
<div class="container-fluid-full">
<div class="row-fluid">
<input type="file" id="file" name="file" />
<output id="list"></output>
"""),format.raw/*77.78*/("""
"""),format.raw/*78.17*/("""<div class="row-fluid">
Layout:
<button onclick="setLayout('dagre')">Dagre</button>
<button onclick="setLayout('klay_down')">Klay (Down)</button>
<button onclick="setLayout('klay_lr')">Klay (Right)</button>
"""),format.raw/*83.76*/("""
"""),format.raw/*84.78*/("""
"""),format.raw/*85.74*/("""
"""),format.raw/*86.21*/("""<button onclick="setLayout('cose-bilkent')">CoSE Bilkent</button>
<button onclick="setLayout('breadthfirst')">Breadth First</button>
"""),format.raw/*88.74*/("""
"""),format.raw/*89.17*/("""</div>
</div>
</div>
<div id="graphdiv" style="height: calc(100% - 60px);
width: 100%;
display: table">
</div>
<!-- Execute once on page load -->
<script>
document.getElementById('file').addEventListener('change', fileSelect, false);
$(document).ready(function () """),format.raw/*102.47*/("""{"""),format.raw/*102.48*/("""
"""),format.raw/*103.21*/("""renderSameDiffGraph();
"""),format.raw/*104.17*/("""}"""),format.raw/*104.18*/(""");
</script>
</body>
</html>
"""))
}
}
}
def render(): play.twirl.api.HtmlFormat.Appendable = apply()
def f:(() => play.twirl.api.HtmlFormat.Appendable) = () => apply()
def ref: this.type = this
}
}
/**/
object SameDiffUIBackup extends SameDiffUIBackup_Scope0.SameDiffUIBackup
/*
-- GENERATED --
DATE: Sat Jan 26 18:11:00 AEDT 2019
SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/samediff/SameDiffUIBackup.scala.html
HASH: 5bbca606352004bcd370a7bff06c376141f0d082
MATRIX: 573->1|669->3|697->5|4607->3948|4653->3966|4993->4333|5043->4412|5093->4487|5143->4509|5346->4737|5392->4755|5813->5147|5843->5148|5894->5170|5963->5210|5993->5211
LINES: 20->1|25->1|26->2|101->77|102->78|107->83|108->84|109->85|110->86|112->88|113->89|126->102|126->102|127->103|128->104|128->104
-- GENERATED --
*/

View File

@ -1,334 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.views.html.training
import play.twirl.api._
import play.twirl.api.TemplateMagic._
object TrainingModel_Scope0 {
import models._
import controllers._
import play.api.i18n._
import views.html._
import play.api.templates.PlayMagic._
import play.api.mvc._
import play.api.data._
class TrainingModel extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template1[org.deeplearning4j.ui.api.I18N,play.twirl.api.HtmlFormat.Appendable] {
/**/
def apply/*1.2*/(i18n: org.deeplearning4j.ui.api.I18N):play.twirl.api.HtmlFormat.Appendable = {
_display_ {
{
Seq[Any](format.raw/*1.40*/("""
"""),format.raw/*2.1*/("""<!DOCTYPE html>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<html lang="en">
<head>
<meta charset="utf-8">
<title>"""),_display_(/*24.17*/i18n/*24.21*/.getMessage("train.pagetitle")),format.raw/*24.51*/("""</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="stylesheet" href="/assets/webjars/coreui__coreui/2.1.9/dist/css/coreui.min.css">
<link rel="stylesheet" href="/assets/css/style.css">
<script src="/assets/webjars/jquery/3.4.1/dist/jquery.min.js"></script>
<script src="/assets/webjars/popper.js/1.12.9/dist/umd/popper.min.js"></script>
<script src="/assets/webjars/bootstrap/4.3.1/dist/js/bootstrap.min.js"></script>
<script src="/assets/webjars/coreui__coreui/2.1.9/dist/js/coreui.min.js"></script>
<!-- Icons -->
<link rel="stylesheet" href="/assets/webjars/coreui__icons/0.3.0/css/coreui-icons.min.css"></script>
<link rel="shortcut icon" href="/assets/img/favicon.ico">
</head>
<body class="app sidebar-show aside-menu-show">
<header class="app-header navbar">
<a class="header-text" href="#"><span>"""),_display_(/*46.56*/i18n/*46.60*/.getMessage("train.pagetitle")),format.raw/*46.90*/("""</span></a>
<div id="sessionSelectDiv" style="display:none; float:right;">
<div style="color:white;">"""),_display_(/*48.52*/i18n/*48.56*/.getMessage("train.session.label")),format.raw/*48.90*/("""</div>
<select id="sessionSelect" onchange='selectNewSession()'>
<option>(Session ID)</option>
</select>
</div>
<div id="workerSelectDiv" style="display:none; float:right">
<div style="color:white;">"""),_display_(/*54.52*/i18n/*54.56*/.getMessage("train.session.worker.label")),format.raw/*54.97*/("""</div>
<select id="workerSelect" onchange='selectNewWorker()'>
<option>(Worker ID)</option>
</select>
</div>
</header>
<!-- End Header -->
<div class="app-body">
<div class="sidebar">
<nav class="sidebar-nav">
<ul class="nav">
<li class="nav-item"><a class="nav-link" href="overview"><i class="nav-icon cui-chart"></i>"""),_display_(/*66.117*/i18n/*66.121*/.getMessage("train.nav.overview")),format.raw/*66.154*/("""</a></li>
<li class="nav-item"><a class="nav-link" href="model"><i class="nav-icon cui-graph"></i>"""),_display_(/*67.114*/i18n/*67.118*/.getMessage("train.nav.model")),format.raw/*67.148*/("""</a></li>
<li class="nav-item"><a class="nav-link" href="system"><i class="nav-icon cui-speedometer"></i>"""),_display_(/*68.121*/i18n/*68.125*/.getMessage("train.nav.system")),format.raw/*68.156*/("""</a></li>
<li class="nav-item nav-dropdown">
<a class="nav-link nav-dropdown-toggle" href="#">
<i class="nav-icon cui-globe"></i> """),_display_(/*71.69*/i18n/*71.73*/.getMessage("train.nav.language")),format.raw/*71.106*/("""
"""),format.raw/*72.29*/("""</a>
<ul class="nav-dropdown-items">
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('en', 'overview')"><i class="icon-file-alt"></i>English</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('de', 'overview')"><i class="icon-file-alt"></i>Deutsch</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('ja', 'overview')"><i class="icon-file-alt"></i>日本語</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('zh', 'overview')"><i class="icon-file-alt"></i>中文</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('ko', 'overview')"><i class="icon-file-alt"></i>한글</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('ru', 'overview')"><i class="icon-file-alt"></i>русский</a></li>
</ul>
</li>
</ul>
</nav>
</div>
<style>
/* Graph */
#layers """),format.raw/*89.25*/("""{"""),format.raw/*89.26*/("""
"""),format.raw/*90.21*/("""height: 725px; /* IE8 */
height: 90vh;
width: 100%;
border: 2px solid #eee;
"""),format.raw/*94.17*/("""}"""),format.raw/*94.18*/("""
"""),format.raw/*95.17*/("""</style>
<!-- Start Content -->
<div id="content">
<div class="row">
<div class="col">
<div id="layers"></div>
</div>
<!-- Start Layer Details -->
<div class="col" id="layerDetails" style="width:50pc">
<div class="box">
<div class="chart-header">
<h2><b>"""),_display_(/*109.49*/i18n/*109.53*/.getMessage("train.model.layerInfoTable.title")),format.raw/*109.100*/("""</b></h2>
</div>
<div class="box-content">
<table class="table table-bordered table-striped table-condensed" id="layerInfo"></table>
</div>
</div>
<div class="box">
<div class="chart-header">
<h2><b>"""),_display_(/*118.49*/i18n/*118.53*/.getMessage("train.overview.chart.updateRatioTitle")),format.raw/*118.105*/("""</b></h2><p id="updateRatioTitleLog10"><b>: log<sub>10</sub></b></p>
<ul class="nav tab-menu nav-tabs" style="position:absolute; margin-top: -60px; right: 27px;">
<li id="mmRatioTab"><a href="javascript:void(0);" onclick="setSelectMeanMagChart('ratios')">"""),_display_(/*120.134*/i18n/*120.138*/.getMessage("train.model.meanmag.btn.ratio")),format.raw/*120.182*/("""</a></li>
<li id="mmParamTab"><a href="javascript:void(0);" onclick="setSelectMeanMagChart('paramMM')">"""),_display_(/*121.135*/i18n/*121.139*/.getMessage("train.model.meanmag.btn.param")),format.raw/*121.183*/("""</a></li>
<li id="mmUpdateTab"><a href="javascript:void(0);" onclick="setSelectMeanMagChart('updateMM')">"""),_display_(/*122.137*/i18n/*122.141*/.getMessage("train.model.meanmag.btn.update")),format.raw/*122.186*/("""</a></li>
</ul>
</div>
<div class="box-content">
<div id="meanmag" class="center" style="height: 300px;" ></div>
<p id="hoverdata"><span id="updateRatioTitleSmallLog10"><b>log<sub>
10</sub> """),_display_(/*128.51*/i18n/*128.55*/.getMessage("train.overview.chart.updateRatioTitleShort")),format.raw/*128.112*/("""</b></span> <span id="yMeanMagnitudes">0</span>, <b>"""),_display_(/*128.165*/i18n/*128.169*/.getMessage("train.overview.charts.iteration")),format.raw/*128.215*/(""":</b> <span id="xMeanMagnitudes">0</span></p>
</div>
</div>
<div class="box">
<div class="chart-header">
<h2><b>"""),_display_(/*133.49*/i18n/*133.53*/.getMessage("train.model.activationsChart.title")),format.raw/*133.102*/("""</b></h2>
</div>
<div class="box-content">
<div id="activations" class="center" style="height: 300px;" ></div>
<p id="hoverdata"><b>"""),_display_(/*137.63*/i18n/*137.67*/.getMessage("train.model.activationsChart.titleShort")),format.raw/*137.121*/("""
"""),format.raw/*138.33*/(""":</b> <span id="yActivations">0</span>
, <b>"""),_display_(/*139.47*/i18n/*139.51*/.getMessage("train.overview.charts.iteration")),format.raw/*139.97*/("""
"""),format.raw/*140.33*/(""":</b> <span id="xActivations">0</span></p>
</div>
</div>
<div class="box">
<div class="chart-header">
<h2><b>"""),_display_(/*146.49*/i18n/*146.53*/.getMessage("train.model.paramHistChart.title")),format.raw/*146.100*/("""</b></h2>
<div id="paramhistSelected" style="float: left"></div>
<div id="paramHistButtonsDiv" style="float: right"></div>
</div>
<div class="box-content">
<div id="parametershistogram" class="center" style="height: 300px;"></div>
</div>
</div>
<div class="box">
<div class="chart-header">
<h2><b>"""),_display_(/*157.49*/i18n/*157.53*/.getMessage("train.model.updateHistChart.title")),format.raw/*157.101*/("""</b></h2>
<div id="updatehistSelected" style="float: left"></div>
<div id="updateHistButtonsDiv" style="float: right"></div>
</div>
<div class="box-content">
<div id="updateshistogram" class="center" style="height: 300px;"></div>
</div>
</div>
<div class="box">
<div class="chart-header">
<h2><b>"""),_display_(/*168.49*/i18n/*168.53*/.getMessage("train.model.lrChart.title")),format.raw/*168.93*/("""</b></h2>
</div>
<div class="box-content">
<div id="learningrate" class="center" style="height: 300px;" ></div>
<p id="hoverdata"><b>"""),_display_(/*172.63*/i18n/*172.67*/.getMessage("train.model.lrChart.titleShort")),format.raw/*172.112*/("""
"""),format.raw/*173.33*/(""":</b> <span id="yLearningRate">0</span>
, <b>"""),_display_(/*174.47*/i18n/*174.51*/.getMessage("train.overview.charts.iteration")),format.raw/*174.97*/("""
"""),format.raw/*175.33*/(""":</b> <span id="xLearningRate">0</span></p>
</div>
</div>
</div>
<!-- End Layer Details-->
<!-- Begin Zero State -->
<div class="col" id="zeroState">
<div class="box">
<div class="chart-header">
<h2><b>Getting Started</b></h2>
</div>
<div class="box-content">
<div class="page-header">
<h1>Layer Visualization UI</h1>
</div>
<div class="row-fluid">
<div class="span12">
<h2>Overview</h2>
<p>
The layer visualization UI renders network structure dynamically. Users can inspect node layer parameters by clicking on the various elements of the GUI to see general information as well as overall network information such as performance.
</p>
<h2>Actions</h2>
<p>On the <b>left</b>, you will find an interactive layer visualization.</p>
<p>
<ul>
<li><b>Clicking</b> - Click on a layer to load network performance metrics.</li>
<li><b>Scrolling</b>
- Drag the GUI with your mouse or touchpad to move the model around. </li>
</ul>
</p>
</div>
</div>
</div>
</div>
</div>
<!-- End Zero State-->
</div>
</div>
<!-- End Content -->
</div> <!-- End Container -->
</div> <!-- End Row Fluid-->
<!-- Start JavaScript-->
<script src="/assets/webjars/modernizr/2.8.3/modernizr.min.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.pie.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.stack.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.resize.min.js"></script>
<script src="/assets/webjars/chosen/0.9.8/chosen/chosen.jquery.min.js"></script>
<script src="/assets/webjars/uniform/2.1.2/jquery.uniform.min.js"></script>
<script src="/assets/webjars/noty/2.2.2/jquery.noty.packaged.js"></script>
<script src="/assets/webjars/jquery-raty/2.5.2/jquery.raty.min.js"></script>
<script src="/assets/webjars/imagesloaded/2.1.1/jquery.imagesloaded.min.js"></script>
<script src="/assets/webjars/masonry/3.1.5/masonry.pkgd.min.js"></script>
<script src="/assets/webjars/jquery-knob/1.2.2/jquery.knob.min.js"></script>
<script src="/assets/webjars/jquery.sparkline/2.1.2/jquery.sparkline.min.js"></script>
<script src="/assets/webjars/retinajs/0.0.2/retina.js"></script>
<script src="/assets/webjars/dagre/0.8.4/dist/dagre.min.js"></script>
<script src="/assets/webjars/cytoscape/3.3.3/dist/cytoscape.min.js"></script>
<script src="/assets/webjars/cytoscape-dagre/2.1.0/cytoscape-dagre.js"></script>
<script src="/assets/webjars/github-com-jboesch-Gritter/1.7.4/jquery.gritter.js"></script>
<script src="/assets/js/train/model.js"></script> <!-- Charts and tables are generated here! -->
<script src="/assets/js/train/model-graph.js"></script> <!-- Layer graph generated here! -->
<script src="/assets/js/train/train.js"></script> <!-- Common (lang selection, etc) -->
<script src="/assets/js/counter.js"></script>
<!-- Execute once on page load -->
<script>
$(document).ready(function () """),format.raw/*253.46*/("""{"""),format.raw/*253.47*/("""
"""),format.raw/*254.20*/("""renderModelGraph();
renderModelPage(true);
"""),format.raw/*256.16*/("""}"""),format.raw/*256.17*/(""");
</script>
<!-- Execute periodically (every 2 sec) -->
<script>
setInterval(function () """),format.raw/*261.41*/("""{"""),format.raw/*261.42*/("""
"""),format.raw/*262.21*/("""renderModelPage(false);
"""),format.raw/*263.17*/("""}"""),format.raw/*263.18*/(""", 2000);
</script>
</body>
</html>
"""))
}
}
}
def render(i18n:org.deeplearning4j.ui.api.I18N): play.twirl.api.HtmlFormat.Appendable = apply(i18n)
def f:((org.deeplearning4j.ui.api.I18N) => play.twirl.api.HtmlFormat.Appendable) = (i18n) => apply(i18n)
def ref: this.type = this
}
}
/**/
object TrainingModel extends TrainingModel_Scope0.TrainingModel
/*
-- GENERATED --
DATE: Tue May 07 21:39:41 AEST 2019
SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingModel.scala.html
HASH: e97fc151ff2478d42c14c4d98e2cf87aa2ee049f
MATRIX: 598->1|731->39|759->41|1677->932|1690->936|1741->966|2753->1951|2766->1955|2817->1985|2988->2129|3001->2133|3056->2167|3409->2493|3422->2497|3484->2538|4020->3046|4034->3050|4089->3083|4241->3207|4255->3211|4307->3241|4466->3372|4480->3376|4533->3407|4778->3625|4791->3629|4846->3662|4904->3692|6344->5104|6373->5105|6423->5127|6607->5283|6636->5284|6682->5302|7246->5838|7260->5842|7330->5889|7824->6355|7838->6359|7913->6411|8280->6749|8295->6753|8362->6797|8536->6942|8551->6946|8618->6990|8794->7137|8809->7141|8877->7186|9326->7607|9340->7611|9420->7668|9502->7721|9517->7725|9586->7771|9900->8057|9914->8061|9986->8110|10295->8391|10309->8395|10386->8449|10449->8483|10563->8569|10577->8573|10645->8619|10708->8653|11029->8946|11043->8950|11113->8997|11843->9699|11857->9703|11928->9751|12657->10452|12671->10456|12733->10496|13043->10778|13057->10782|13125->10827|13188->10861|13303->10948|13317->10952|13385->10998|13448->11032|18024->15579|18054->15580|18104->15601|18212->15680|18242->15681|18413->15823|18443->15824|18494->15846|18564->15887|18594->15888
LINES: 20->1|25->1|26->2|48->24|48->24|48->24|70->46|70->46|70->46|72->48|72->48|72->48|78->54|78->54|78->54|90->66|90->66|90->66|91->67|91->67|91->67|92->68|92->68|92->68|95->71|95->71|95->71|96->72|113->89|113->89|114->90|118->94|118->94|119->95|133->109|133->109|133->109|142->118|142->118|142->118|144->120|144->120|144->120|145->121|145->121|145->121|146->122|146->122|146->122|152->128|152->128|152->128|152->128|152->128|152->128|157->133|157->133|157->133|161->137|161->137|161->137|162->138|163->139|163->139|163->139|164->140|170->146|170->146|170->146|181->157|181->157|181->157|192->168|192->168|192->168|196->172|196->172|196->172|197->173|198->174|198->174|198->174|199->175|277->253|277->253|278->254|280->256|280->256|285->261|285->261|286->262|287->263|287->263
-- GENERATED --
*/

View File

@ -1,288 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.views.html.training
import play.twirl.api._
import play.twirl.api.TemplateMagic._
object TrainingOverview_Scope0 {
import models._
import controllers._
import play.api.i18n._
import views.html._
import play.api.templates.PlayMagic._
import play.api.mvc._
import play.api.data._
class TrainingOverview extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template1[org.deeplearning4j.ui.api.I18N,play.twirl.api.HtmlFormat.Appendable] {
/**/
def apply/*1.2*/(i18n: org.deeplearning4j.ui.api.I18N):play.twirl.api.HtmlFormat.Appendable = {
_display_ {
{
Seq[Any](format.raw/*1.40*/("""
"""),format.raw/*2.1*/("""<!DOCTYPE html>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<html lang="en">
<head>
<meta charset="utf-8">
<title>"""),_display_(/*24.17*/i18n/*24.21*/.getMessage("train.pagetitle")),format.raw/*24.51*/("""</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="stylesheet" href="/assets/webjars/coreui__coreui/2.1.9/dist/css/coreui.min.css">
<link rel="stylesheet" href="/assets/css/style.css">
<script src="/assets/webjars/jquery/3.4.1/dist/jquery.min.js"></script>
<script src="/assets/webjars/popper.js/1.12.9/dist/umd/popper.min.js"></script>
<script src="/assets/webjars/bootstrap/4.3.1/dist/js/bootstrap.min.js"></script>
<script src="/assets/webjars/coreui__coreui/2.1.9/dist/js/coreui.min.js"></script>
<!-- Icons -->
<link rel="stylesheet" href="/assets/webjars/coreui__icons/0.3.0/css/coreui-icons.min.css"></script>
<link rel="shortcut icon" href="/assets/img/favicon.ico">
</head>
<body class="app sidebar-show aside-menu-show">
<header class="app-header navbar">
<a class="header-text" href="#"><span>"""),_display_(/*46.56*/i18n/*46.60*/.getMessage("train.pagetitle")),format.raw/*46.90*/("""</span></a>
<div id="sessionSelectDiv" style="display:none; float:right;">
<div style="color:white;">"""),_display_(/*48.52*/i18n/*48.56*/.getMessage("train.session.label")),format.raw/*48.90*/("""</div>
<select id="sessionSelect" onchange='selectNewSession()'>
<option>(Session ID)</option>
</select>
</div>
<div id="workerSelectDiv" style="display:none; float:right">
<div style="color:white;">"""),_display_(/*54.52*/i18n/*54.56*/.getMessage("train.session.worker.label")),format.raw/*54.97*/("""</div>
<select id="workerSelect" onchange='selectNewWorker()'>
<option>(Worker ID)</option>
</select>
</div>
</header>
<div class="app-body">
<div class="sidebar">
<nav class="sidebar-nav">
<ul class="nav">
<li class="nav-item"><a class="nav-link" href="overview"><i class="nav-icon cui-chart"></i>"""),_display_(/*64.117*/i18n/*64.121*/.getMessage("train.nav.overview")),format.raw/*64.154*/("""</a></li>
<li class="nav-item"><a class="nav-link" href="model"><i class="nav-icon cui-graph"></i>"""),_display_(/*65.114*/i18n/*65.118*/.getMessage("train.nav.model")),format.raw/*65.148*/("""</a></li>
<li class="nav-item"><a class="nav-link" href="system"><i class="nav-icon cui-speedometer"></i>"""),_display_(/*66.121*/i18n/*66.125*/.getMessage("train.nav.system")),format.raw/*66.156*/("""</a></li>
<li class="nav-item nav-dropdown">
<a class="nav-link nav-dropdown-toggle" href="#">
<i class="nav-icon cui-globe"></i> """),_display_(/*69.69*/i18n/*69.73*/.getMessage("train.nav.language")),format.raw/*69.106*/("""
"""),format.raw/*70.29*/("""</a>
<ul class="nav-dropdown-items">
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('en', 'overview')"><i class="icon-file-alt"></i>English</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('de', 'overview')"><i class="icon-file-alt"></i>Deutsch</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('ja', 'overview')"><i class="icon-file-alt"></i>日本語</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('zh', 'overview')"><i class="icon-file-alt"></i>中文</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('ko', 'overview')"><i class="icon-file-alt"></i>한글</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('ru', 'overview')"><i class="icon-file-alt"></i>русский</a></li>
</ul>
</li>
</ul>
</nav>
</div>
<main id="content" class="main">
<div class="row">
<div class="col-8 chart-box">
<div class="chart-header">
<h2><b>"""),_display_(/*89.37*/i18n/*89.41*/.getMessage("train.overview.chart.scoreTitle")),format.raw/*89.87*/("""</b></h2>
</div>
<div>
<div id="scoreiterchart" class="center" style="height: 300px;" ></div>
<p id="hoverdata"><b>"""),_display_(/*93.51*/i18n/*93.55*/.getMessage("train.overview.chart.scoreTitleShort")),format.raw/*93.106*/("""
"""),format.raw/*94.33*/(""":</b> <span id="y">0</span>, <b>"""),_display_(/*94.66*/i18n/*94.70*/.getMessage("train.overview.charts.iteration")),format.raw/*94.116*/("""
"""),format.raw/*95.33*/(""":</b> <span id="x">
0</span></p>
</div>
</div>
<!-- End Score Chart-->
<!-- Start Model Table-->
<div class="col-4 chart-box">
<div class="chart-header">
<h2><b>"""),_display_(/*103.37*/i18n/*103.41*/.getMessage("train.overview.perftable.title")),format.raw/*103.86*/("""</b></h2>
</div>
<div>
<table class="table table-bordered table-striped table-condensed">
<tr>
<td>"""),_display_(/*108.42*/i18n/*108.46*/.getMessage("train.overview.modeltable.modeltype")),format.raw/*108.96*/("""</td>
<td id="modelType">Loading...</td>
</tr>
<tr>
<td>"""),_display_(/*112.42*/i18n/*112.46*/.getMessage("train.overview.modeltable.nLayers")),format.raw/*112.94*/("""</td>
<td id="nLayers">Loading...</td>
</tr>
<tr>
<td>"""),_display_(/*116.42*/i18n/*116.46*/.getMessage("train.overview.modeltable.nParams")),format.raw/*116.94*/("""</td>
<td id="nParams">Loading...</td>
</tr>
<tr>
<td>"""),_display_(/*120.42*/i18n/*120.46*/.getMessage("train.overview.perftable.startTime")),format.raw/*120.95*/("""</td>
<td id="startTime">Loading...</td>
</tr>
<tr>
<td>"""),_display_(/*124.42*/i18n/*124.46*/.getMessage("train.overview.perftable.totalRuntime")),format.raw/*124.98*/("""</td>
<td id="totalRuntime">Loading...</td>
</tr>
<tr>
<td>"""),_display_(/*128.42*/i18n/*128.46*/.getMessage("train.overview.perftable.lastUpdate")),format.raw/*128.96*/("""</td>
<td id="lastUpdate">Loading...</td>
</tr>
<tr>
<td>"""),_display_(/*132.42*/i18n/*132.46*/.getMessage("train.overview.perftable.totalParamUpdates")),format.raw/*132.103*/("""</td>
<td id="totalParamUpdates">Loading...</td>
</tr>
<tr>
<td>"""),_display_(/*136.42*/i18n/*136.46*/.getMessage("train.overview.perftable.updatesPerSec")),format.raw/*136.99*/("""</td>
<td id="updatesPerSec">Loading...</td>
</tr>
<tr>
<td>"""),_display_(/*140.42*/i18n/*140.46*/.getMessage("train.overview.perftable.examplesPerSec")),format.raw/*140.100*/("""</td>
<td id="examplesPerSec">Loading...</td>
</tr>
</table>
</div>
</div>
<!--End Model Table -->
</div>
<div class="row">
<!--Start Ratio Table -->
<div class="col chart-box">
<div class="chart-header">
<h2><b>"""),_display_(/*154.37*/i18n/*154.41*/.getMessage("train.overview.chart.updateRatioTitle")),format.raw/*154.93*/(""": log<sub>10</sub></b></h2>
</div>
<div class="box-content">
<div id="updateRatioChart" class="center" style="height: 300px;" ></div>
<p id="hoverdata"><b>"""),_display_(/*158.51*/i18n/*158.55*/.getMessage("train.overview.chart.updateRatioTitleShort")),format.raw/*158.112*/("""
"""),format.raw/*159.33*/(""":</b> <span id="yRatio">0</span>, <b>log<sub>
10</sub> """),_display_(/*160.43*/i18n/*160.47*/.getMessage("train.overview.chart.updateRatioTitleShort")),format.raw/*160.104*/("""
"""),format.raw/*161.33*/(""":</b> <span id="yLogRatio">0</span>
, <b>"""),_display_(/*162.39*/i18n/*162.43*/.getMessage("train.overview.charts.iteration")),format.raw/*162.89*/(""":</b> <span id="xRatio">
0</span></p>
</div>
</div>
<!--End Ratio Table -->
<!--Start Variance Table -->
<div class="col chart-box">
<div class="chart-header">
<h2><b>"""),_display_(/*171.37*/i18n/*171.41*/.getMessage("train.overview.chart.stdevTitle")),format.raw/*171.87*/(""": log<sub>10</sub></b></h2>
<ul class="nav tab-menu nav-tabs" style="position:absolute; margin-top: -30px; right: 22px;">
<li class="active" id="stdevActivations"><a href="javascript:void(0);" onclick="selectStdevChart('stdevActivations')">"""),_display_(/*173.152*/i18n/*173.156*/.getMessage("train.overview.chart.stdevBtn.activations")),format.raw/*173.212*/("""</a></li>
<li id="stdevGradients"><a href="javascript:void(0);" onclick="selectStdevChart('stdevGradients')">"""),_display_(/*174.133*/i18n/*174.137*/.getMessage("train.overview.chart.stdevBtn.gradients")),format.raw/*174.191*/("""</a></li>
<li id="stdevUpdates"><a href="javascript:void(0);" onclick="selectStdevChart('stdevUpdates')">"""),_display_(/*175.129*/i18n/*175.133*/.getMessage("train.overview.chart.stdevBtn.updates")),format.raw/*175.185*/("""</a></li>
</ul>
</div>
<div class="box-content">
<div id="stdevChart" class="center" style="height: 300px;" ></div>
<p id="hoverdata"><b>"""),_display_(/*180.51*/i18n/*180.55*/.getMessage("train.overview.chart.stdevTitleShort")),format.raw/*180.106*/("""
"""),format.raw/*181.33*/(""":</b> <span id="yStdev">0</span>, <b>log<sub>
10</sub> """),_display_(/*182.43*/i18n/*182.47*/.getMessage("train.overview.chart.stdevTitleShort")),format.raw/*182.98*/("""
"""),format.raw/*183.33*/(""":</b> <span id="yLogStdev">0</span>
, <b>"""),_display_(/*184.39*/i18n/*184.43*/.getMessage("train.overview.charts.iteration")),format.raw/*184.89*/(""":</b> <span id="xStdev">
0</span></p>
</div>
</div>
<!-- End Variance Table -->
</div>
</main>
</div>
<script src="/assets/webjars/fullcalendar/1.6.4/fullcalendar.min.js"></script>
<script src="/assets/webjars/excanvas/3/excanvas.js"></script>
<script src="/assets/webjars/retinajs/0.0.2/retina.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.pie.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.stack.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.resize.min.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.selection.js"></script>
<script src="/assets/js/train/overview.js"></script> <!-- Charts and tables are generated here! -->
<script src="/assets/js/train/train.js"></script> <!-- Common (lang selection, etc) -->
<script src="/assets/js/counter.js"></script>
<!-- Execute once on page load -->
<script>
$(document).ready(function () """),format.raw/*208.47*/("""{"""),format.raw/*208.48*/("""
"""),format.raw/*209.21*/("""renderOverviewPage(true);
"""),format.raw/*210.17*/("""}"""),format.raw/*210.18*/(""");
</script>
<!-- Execute periodically (every 2 sec) -->
<script>
setInterval(function () """),format.raw/*215.41*/("""{"""),format.raw/*215.42*/("""
"""),format.raw/*216.21*/("""renderOverviewPage(false);
"""),format.raw/*217.17*/("""}"""),format.raw/*217.18*/(""", 2000);
</script>
</body>
</html>
"""))
}
}
}
def render(i18n:org.deeplearning4j.ui.api.I18N): play.twirl.api.HtmlFormat.Appendable = apply(i18n)
def f:((org.deeplearning4j.ui.api.I18N) => play.twirl.api.HtmlFormat.Appendable) = (i18n) => apply(i18n)
def ref: this.type = this
}
}
/**/
object TrainingOverview extends TrainingOverview_Scope0.TrainingOverview
/*
-- GENERATED --
DATE: Tue May 07 21:39:41 AEST 2019
SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingOverview.scala.html
HASH: 2c9149a6ec8ab4fcc1b316e6ecd5befab2c45af1
MATRIX: 604->1|737->39|765->41|1683->932|1696->936|1747->966|2759->1951|2772->1955|2823->1985|2994->2129|3007->2133|3062->2167|3415->2493|3428->2497|3490->2538|3995->3015|4009->3019|4064->3052|4216->3176|4230->3180|4282->3210|4441->3341|4455->3345|4508->3376|4753->3594|4766->3598|4821->3631|4879->3661|6460->5215|6473->5219|6540->5265|6791->5489|6804->5493|6877->5544|6939->5578|6999->5611|7012->5615|7080->5661|7142->5695|7536->6061|7550->6065|7617->6110|7894->6359|7908->6363|7980->6413|8205->6610|8219->6614|8289->6662|8512->6857|8526->6861|8596->6909|8819->7104|8833->7108|8904->7157|9129->7354|9143->7358|9217->7410|9445->7610|9459->7614|9531->7664|9757->7862|9771->7866|9851->7923|10084->8128|10098->8132|10173->8185|10402->8386|10416->8390|10493->8444|11042->8965|11056->8969|11130->9021|11422->9285|11436->9289|11516->9346|11579->9380|11696->9469|11710->9473|11790->9530|11853->9564|11956->9639|11970->9643|12038->9689|12444->10067|12458->10071|12526->10117|12858->10420|12873->10424|12952->10480|13124->10623|13139->10627|13216->10681|13384->10820|13399->10824|13474->10876|13777->11151|13791->11155|13865->11206|13928->11240|14045->11329|14059->11333|14132->11384|14195->11418|14298->11493|14312->11497|14380->11543|15679->12813|15709->12814|15760->12836|15832->12879|15862->12880|16027->13016|16057->13017|16108->13039|16181->13083|16211->13084
LINES: 20->1|25->1|26->2|48->24|48->24|48->24|70->46|70->46|70->46|72->48|72->48|72->48|78->54|78->54|78->54|88->64|88->64|88->64|89->65|89->65|89->65|90->66|90->66|90->66|93->69|93->69|93->69|94->70|113->89|113->89|113->89|117->93|117->93|117->93|118->94|118->94|118->94|118->94|119->95|127->103|127->103|127->103|132->108|132->108|132->108|136->112|136->112|136->112|140->116|140->116|140->116|144->120|144->120|144->120|148->124|148->124|148->124|152->128|152->128|152->128|156->132|156->132|156->132|160->136|160->136|160->136|164->140|164->140|164->140|178->154|178->154|178->154|182->158|182->158|182->158|183->159|184->160|184->160|184->160|185->161|186->162|186->162|186->162|195->171|195->171|195->171|197->173|197->173|197->173|198->174|198->174|198->174|199->175|199->175|199->175|204->180|204->180|204->180|205->181|206->182|206->182|206->182|207->183|208->184|208->184|208->184|232->208|232->208|233->209|234->210|234->210|239->215|239->215|240->216|241->217|241->217
-- GENERATED --
*/

View File

@ -1,349 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.views.html.training
import play.twirl.api._
import play.twirl.api.TemplateMagic._
object TrainingSystem_Scope0 {
import models._
import controllers._
import play.api.i18n._
import views.html._
import play.api.templates.PlayMagic._
import play.api.mvc._
import play.api.data._
class TrainingSystem extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template1[org.deeplearning4j.ui.api.I18N,play.twirl.api.HtmlFormat.Appendable] {
/**/
def apply/*1.2*/(i18n: org.deeplearning4j.ui.api.I18N):play.twirl.api.HtmlFormat.Appendable = {
_display_ {
{
Seq[Any](format.raw/*1.40*/("""
"""),format.raw/*2.1*/("""<!DOCTYPE html>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<html lang="en">
<head>
<meta charset="utf-8">
<title>"""),_display_(/*24.17*/i18n/*24.21*/.getMessage("train.pagetitle")),format.raw/*24.51*/("""</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="stylesheet" href="/assets/webjars/coreui__coreui/2.1.9/dist/css/coreui.min.css">
<link rel="stylesheet" href="/assets/css/style.css">
<script src="/assets/webjars/jquery/3.4.1/dist/jquery.min.js"></script>
<script src="/assets/webjars/popper.js/1.12.9/dist/umd/popper.min.js"></script>
<script src="/assets/webjars/bootstrap/4.3.1/dist/js/bootstrap.min.js"></script>
<script src="/assets/webjars/coreui__coreui/2.1.9/dist/js/coreui.min.js"></script>
<!-- Icons -->
<link rel="stylesheet" href="/assets/webjars/coreui__icons/0.3.0/css/coreui-icons.min.css"></script>
<link rel="shortcut icon" href="/assets/img/favicon.ico">
</head>
<body class="app sidebar-show aside-menu-show">
<header class="app-header navbar">
<a class="header-text" href="#"><span>"""),_display_(/*46.56*/i18n/*46.60*/.getMessage("train.pagetitle")),format.raw/*46.90*/("""</span></a>
<div id="sessionSelectDiv" style="display:none; float:right;">
<div style="color:white;">"""),_display_(/*48.52*/i18n/*48.56*/.getMessage("train.session.label")),format.raw/*48.90*/("""</div>
<select id="sessionSelect" onchange='selectNewSession()'>
<option>(Session ID)</option>
</select>
</div>
<div id="workerSelectDiv" style="display:none; float:right">
<div style="color:white;">"""),_display_(/*54.52*/i18n/*54.56*/.getMessage("train.session.worker.label")),format.raw/*54.97*/("""</div>
<select id="workerSelect" onchange='selectNewWorker()'>
<option>(Worker ID)</option>
</select>
</div>
</header>
<div class="app-body">
<div class="sidebar">
<nav class="sidebar-nav">
<ul class="nav">
<li class="nav-item"><a class="nav-link" href="overview"><i class="nav-icon cui-chart"></i>"""),_display_(/*64.117*/i18n/*64.121*/.getMessage("train.nav.overview")),format.raw/*64.154*/("""</a></li>
<li class="nav-item"><a class="nav-link" href="model"><i class="nav-icon cui-graph"></i>"""),_display_(/*65.114*/i18n/*65.118*/.getMessage("train.nav.model")),format.raw/*65.148*/("""</a></li>
<li class="nav-item"><a class="nav-link" href="system"><i class="nav-icon cui-speedometer"></i>"""),_display_(/*66.121*/i18n/*66.125*/.getMessage("train.nav.system")),format.raw/*66.156*/("""</a></li>
<li class="nav-item nav-dropdown">
<a class="nav-link nav-dropdown-toggle" href="#">
<i class="nav-icon cui-globe"></i> """),_display_(/*69.69*/i18n/*69.73*/.getMessage("train.nav.language")),format.raw/*69.106*/("""
"""),format.raw/*70.29*/("""</a>
<ul class="nav-dropdown-items">
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('en', 'overview')"><i class="icon-file-alt"></i>English</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('de', 'overview')"><i class="icon-file-alt"></i>Deutsch</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('ja', 'overview')"><i class="icon-file-alt"></i>日本語</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('zh', 'overview')"><i class="icon-file-alt"></i>中文</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('ko', 'overview')"><i class="icon-file-alt"></i>한글</a></li>
<li class="nav-item"><a class="nav-link" href="javascript:void(0);" onclick="languageSelect('ru', 'overview')"><i class="icon-file-alt"></i>русский</a></li>
</ul>
</li>
</ul>
</nav>
</div>
<!-- Start Content -->
<main id="content" class="main">
<div class="row-fluid">
<div class="box span12">
<div class="chart-header">
<h2><b>"""),_display_(/*92.41*/i18n/*92.45*/.getMessage("train.system.title")),format.raw/*92.78*/("""</b></h2>
<div class="btn-group" style="margin-top: -11px; position:absolute; right: 40px;">
<button class="btn dropdown-toggle btn-primary" data-toggle="dropdown">"""),_display_(/*94.105*/i18n/*94.109*/.getMessage("train.system.selectMachine")),format.raw/*94.150*/(""" """),format.raw/*94.151*/("""<span class="caret"></span></button>
<ul class="dropdown-menu" id="systemTab"></ul>
</div>
</div>
<div class="box-content">
<!--Start System Information -->
<div class="tab-content">
<div class="tab-pane active">
<!-- System Memory Utilization Chart -->
<div class="row-fluid">
<div class="box span12" id="systemMemoryChart">
<div class="chart-header">
<h2><b>"""),_display_(/*109.61*/i18n/*109.65*/.getMessage("train.system.chart.systemMemoryTitle")),format.raw/*109.116*/(""" """),format.raw/*109.117*/("""%</b></h2>
</div>
<div class="box-content">
<div id="systemMemoryChartPlot" class="center" style="height: 300px;" ></div>
<p id="hoverdata"><b>"""),_display_(/*113.75*/i18n/*113.79*/.getMessage("train.system.chart.memoryShort")),format.raw/*113.124*/(""":</b> <span id="y">0</span>, <b>
"""),_display_(/*114.58*/i18n/*114.62*/.getMessage("train.overview.charts.iteration")),format.raw/*114.108*/(""":</b> <span id="x">0</span></p>
</div>
</div>
<!-- GPU Memory Utlization Chart -->
<div class="box span6" id="gpuMemoryChart">
<div class="chart-header">
<h2><b>"""),_display_(/*121.61*/i18n/*121.65*/.getMessage("train.system.chart.gpuMemoryTitle")),format.raw/*121.113*/(""" """),format.raw/*121.114*/("""%</b></h2>
</div>
<div class="box-content">
<div id="gpuMemoryChartPlot" class="center" style="height: 300px;" ></div>
<p id="hoverdata"><b>"""),_display_(/*125.75*/i18n/*125.79*/.getMessage("train.system.chart.memoryShort")),format.raw/*125.124*/(""":</b> <span id="y2">0</span>, <b>
"""),_display_(/*126.58*/i18n/*126.62*/.getMessage("train.overview.charts.iteration")),format.raw/*126.108*/(""":</b> <span id="x2">0</span></p>
</div>
</div>
</div>
<!-- Tables -->
<div class="row-fluid">
<!-- Hardware Information -->
<div class="box span12">
<div class="chart-header">
<h2><b>"""),_display_(/*138.61*/i18n/*138.65*/.getMessage("train.system.hwTable.title")),format.raw/*138.106*/("""</b></h2>
</div>
<div class="box-content">
<table class="table table-striped">
<thead>
<tr>
<th>"""),_display_(/*144.70*/i18n/*144.74*/.getMessage("train.system.hwTable.jvmCurrent")),format.raw/*144.120*/("""</th>
<th>"""),_display_(/*145.70*/i18n/*145.74*/.getMessage("train.system.hwTable.jvmMax")),format.raw/*145.116*/("""</th>
<th>"""),_display_(/*146.70*/i18n/*146.74*/.getMessage("train.system.hwTable.offHeapCurrent")),format.raw/*146.124*/("""</th>
<th>"""),_display_(/*147.70*/i18n/*147.74*/.getMessage("train.system.hwTable.offHeapMax")),format.raw/*147.120*/("""</th>
<th>"""),_display_(/*148.70*/i18n/*148.74*/.getMessage("train.system.hwTable.jvmProcs")),format.raw/*148.118*/("""</th>
<th>"""),_display_(/*149.70*/i18n/*149.74*/.getMessage("train.system.hwTable.computeDevices")),format.raw/*149.124*/("""</th>
</tr>
</thead>
<tbody>
<tr>
<td id="currentBytesJVM">Loading...</td>
<td id="maxBytesJVM">Loading...</td>
<td id="currentBytesOffHeap">Loading...</td>
<td id="maxBytesOffHeap">Loading...</td>
<td id="jvmAvailableProcessors">Loading...</td>
<td id="nComputeDevices">Loading...</td>
</tr>
</tbody>
</table>
</div>
</div>
</div>
<div class="row-fluid">
<!-- Software Information -->
<div class="box span12">
<div class="chart-header">
<h2><b>"""),_display_(/*173.61*/i18n/*173.65*/.getMessage("train.system.swTable.title")),format.raw/*173.106*/("""</b></h2>
</div>
<div class="box-content">
<table class="table table-striped">
<thead>
<tr>
<th>"""),_display_(/*179.70*/i18n/*179.74*/.getMessage("train.system.swTable.hostname")),format.raw/*179.118*/("""</th>
<th>"""),_display_(/*180.70*/i18n/*180.74*/.getMessage("train.system.swTable.os")),format.raw/*180.112*/("""</th>
<th>"""),_display_(/*181.70*/i18n/*181.74*/.getMessage("train.system.swTable.osArch")),format.raw/*181.116*/("""</th>
<th>"""),_display_(/*182.70*/i18n/*182.74*/.getMessage("train.system.swTable.jvmName")),format.raw/*182.117*/("""</th>
<th>"""),_display_(/*183.70*/i18n/*183.74*/.getMessage("train.system.swTable.jvmVersion")),format.raw/*183.120*/("""</th>
<th>"""),_display_(/*184.70*/i18n/*184.74*/.getMessage("train.system.swTable.nd4jBackend")),format.raw/*184.121*/("""</th>
<th>"""),_display_(/*185.70*/i18n/*185.74*/.getMessage("train.system.swTable.nd4jDataType")),format.raw/*185.122*/("""</th>
</tr>
</thead>
<tbody>
<tr>
<td id="hostName">Loading...</td>
<td id="OS">Loading...</td>
<td id="OSArchitecture">Loading...</td>
<td id="jvmName">Loading...</td>
<td id="jvmVersion">Loading...</td>
<td id="nd4jBackend">Loading...</td>
<td id="nd4jDataType">Loading...</td>
</tr>
</tbody>
</table>
</div>
</div>
</div>
"""),format.raw/*205.73*/("""
"""),format.raw/*206.82*/("""
"""),format.raw/*207.73*/("""
"""),format.raw/*208.79*/("""
"""),format.raw/*209.88*/("""
"""),format.raw/*210.59*/("""
"""),format.raw/*211.78*/("""
"""),format.raw/*212.92*/("""
"""),format.raw/*213.68*/("""
"""),format.raw/*214.69*/("""
"""),format.raw/*215.79*/("""
"""),format.raw/*216.70*/("""
"""),format.raw/*217.69*/("""
"""),format.raw/*218.68*/("""
"""),format.raw/*219.69*/("""
"""),format.raw/*220.108*/("""
"""),format.raw/*221.70*/("""
"""),format.raw/*222.69*/("""
"""),format.raw/*223.65*/("""
"""),format.raw/*224.59*/("""
"""),format.raw/*225.55*/("""
"""),format.raw/*226.51*/("""
"""),format.raw/*227.37*/("""</div>
<!-- End System Tab -->
</div>
</div>
</div>
<!-- End System Tab -->
</div><!-- End Row Fluid-->
</main><!-- End Content -->
</div><!-- End Container-->
</div><!-- End Row Fluid-->
<!-- Start JavaScript-->
<script src="/assets/webjars/jquery/2.2.0/jquery.min.js"></script>
<script src="/assets/webjars/jquery-ui/1.10.2/ui/minified/jquery-ui.min.js"></script>
<script src="/assets/webjars/jquery-migrate/1.2.1/jquery-migrate.min.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.pie.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.stack.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.resize.min.js"></script>
<script src="/assets/webjars/chosen/0.9.8/chosen/chosen.jquery.min.js"></script>
<script src="/assets/webjars/uniform/2.1.2/jquery.uniform.min.js"></script>
<script src="/assets/webjars/noty/2.2.2/jquery.noty.packaged.js"></script>
<script src="/assets/webjars/jquery-raty/2.5.2/jquery.raty.min.js"></script>
<script src="/assets/webjars/imagesloaded/2.1.1/jquery.imagesloaded.min.js"></script>
<script src="/assets/webjars/masonry/3.1.5/masonry.pkgd.min.js"></script>
<script src="/assets/webjars/jquery-knob/1.2.2/jquery.knob.min.js"></script>
<script src="/assets/webjars/jquery.sparkline/2.1.2/jquery.sparkline.min.js"></script>
<script src="/assets/webjars/retinajs/0.0.2/retina.js"></script>
<script src="/assets/js/train/system.js"></script> <!-- Charts and tables are generated here! -->
<script src="/assets/js/train/train.js"></script> <!-- Common (lang selection, etc) -->
<script src="/assets/js/counter.js"></script>
<!-- Execute once on page load -->
<script>
$(document).ready(function () """),format.raw/*262.47*/("""{"""),format.raw/*262.48*/("""
"""),format.raw/*263.21*/("""renderSystemPage(true);
renderTabs();
selectMachine();
/* Default GPU to hidden */
$("#gpuTable").hide();
$("#gpuMemoryChart").hide();
"""),format.raw/*269.17*/("""}"""),format.raw/*269.18*/(""");
</script>
<!--Execute periodically (every 2 sec) -->
<script>
setInterval(function () """),format.raw/*274.41*/("""{"""),format.raw/*274.42*/("""
"""),format.raw/*275.21*/("""renderSystemPage(false);
"""),format.raw/*276.17*/("""}"""),format.raw/*276.18*/(""", 2000);
</script>
<!--End JavaScript-->
</body>
</html>
"""))
}
}
}
def render(i18n:org.deeplearning4j.ui.api.I18N): play.twirl.api.HtmlFormat.Appendable = apply(i18n)
def f:((org.deeplearning4j.ui.api.I18N) => play.twirl.api.HtmlFormat.Appendable) = (i18n) => apply(i18n)
def ref: this.type = this
}
}
/**/
object TrainingSystem extends TrainingSystem_Scope0.TrainingSystem
/*
-- GENERATED --
DATE: Tue May 07 21:39:41 AEST 2019
SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/training/TrainingSystem.scala.html
HASH: a5fe2fb79cf26f96000f494801fcbb6b9f354dd5
MATRIX: 600->1|733->39|761->41|1679->932|1692->936|1743->966|2755->1951|2768->1955|2819->1985|2990->2129|3003->2133|3058->2167|3411->2493|3424->2497|3486->2538|3991->3015|4005->3019|4060->3052|4212->3176|4226->3180|4278->3210|4437->3341|4451->3345|4504->3376|4749->3594|4762->3598|4817->3631|4875->3661|6525->5284|6538->5288|6592->5321|6851->5552|6865->5556|6928->5597|6958->5598|7821->6433|7835->6437|7909->6488|7940->6489|8316->6837|8330->6841|8398->6886|8517->6977|8531->6981|8600->7027|9078->7477|9092->7481|9163->7529|9194->7530|9567->7875|9581->7879|9649->7924|9769->8016|9783->8020|9852->8066|10487->8673|10501->8677|10565->8718|11024->9149|11038->9153|11107->9199|11211->9275|11225->9279|11290->9321|11394->9397|11408->9401|11481->9451|11585->9527|11599->9531|11668->9577|11772->9653|11786->9657|11853->9701|11957->9777|11971->9781|12044->9831|13693->11452|13707->11456|13771->11497|14230->11928|14244->11932|14311->11976|14415->12052|14429->12056|14490->12094|14594->12170|14608->12174|14673->12216|14777->12292|14791->12296|14857->12339|14961->12415|14975->12419|15044->12465|15148->12541|15162->12545|15232->12592|15336->12668|15350->12672|15421->12720|16821->14119|16892->14202|16967->14276|17046->14356|17129->14445|17208->14505|17287->14584|17370->14677|17457->14746|17548->14816|17643->14896|17734->14967|17821->15037|17908->15106|17999->15176|18095->15285|18186->15356|18273->15426|18356->15492|18435->15552|18510->15608|18581->15660|18648->15698|20855->17876|20885->17877|20936->17899|21222->18156|21252->18157|21420->18296|21450->18297|21501->18319|21572->18361|21602->18362
LINES: 20->1|25->1|26->2|48->24|48->24|48->24|70->46|70->46|70->46|72->48|72->48|72->48|78->54|78->54|78->54|88->64|88->64|88->64|89->65|89->65|89->65|90->66|90->66|90->66|93->69|93->69|93->69|94->70|116->92|116->92|116->92|118->94|118->94|118->94|118->94|133->109|133->109|133->109|133->109|137->113|137->113|137->113|138->114|138->114|138->114|145->121|145->121|145->121|145->121|149->125|149->125|149->125|150->126|150->126|150->126|162->138|162->138|162->138|168->144|168->144|168->144|169->145|169->145|169->145|170->146|170->146|170->146|171->147|171->147|171->147|172->148|172->148|172->148|173->149|173->149|173->149|197->173|197->173|197->173|203->179|203->179|203->179|204->180|204->180|204->180|205->181|205->181|205->181|206->182|206->182|206->182|207->183|207->183|207->183|208->184|208->184|208->184|209->185|209->185|209->185|229->205|230->206|231->207|232->208|233->209|234->210|235->211|236->212|237->213|238->214|239->215|240->216|241->217|242->218|243->219|244->220|245->221|246->222|247->223|248->224|249->225|250->226|251->227|286->262|286->262|287->263|293->269|293->269|298->274|298->274|299->275|300->276|300->276
-- GENERATED --
*/

View File

@ -1,296 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.views.html.tsne
import play.twirl.api._
import play.twirl.api.TemplateMagic._
object Tsne_Scope0 {
import models._
import controllers._
import play.api.i18n._
import views.html._
import play.api.templates.PlayMagic._
import play.api.mvc._
import play.api.data._
class Tsne extends BaseScalaTemplate[play.twirl.api.HtmlFormat.Appendable,Format[play.twirl.api.HtmlFormat.Appendable]](play.twirl.api.HtmlFormat) with play.twirl.api.Template0[play.twirl.api.HtmlFormat.Appendable] {
/**/
def apply():play.twirl.api.HtmlFormat.Appendable = {
_display_ {
{
Seq[Any](format.raw/*1.1*/("""<!DOCTYPE html>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<html lang="en">
<head>
<meta charset="utf-8" />
<title>T-SNE renders</title>
<!-- jQuery -->
<script src="/assets/webjars/jquery/2.2.0/jquery.min.js"></script>
<link href='/assets/legacy/roboto.css' rel='stylesheet' type='text/css'>
<link src="/assets/webjars/bootstrap/2.3.2/css/bootstrap.min.css" rel="stylesheet" ></link>
<!-- Latest compiled and minified JavaScript -->
"""),format.raw/*30.69*/("""
"""),format.raw/*31.9*/("""<script src="/assets/webjars/bootstrap/2.3.2/js/bootstrap.min.js"></script>
<!-- d3 -->
<script src="/assets/webjars/d3js/3.3.5/d3.min.js" charset="utf-8"></script>
<script src="/assets/legacy/jquery-fileupload.js"></script>
<!-- Booststrap Notify plugin-->
<script src="/assets/webjars/bootstrap-notify/3.1.3-1/bootstrap-notify.min.js"></script>
<script src="/assets/legacy/common.js"></script>
<!-- dl4j plot setup -->
<script src="/assets/legacy/renderTsne.js"></script>
<style>
.hd """),format.raw/*50.13*/("""{"""),format.raw/*50.14*/("""
"""),format.raw/*51.13*/("""background-color: #000000;
font-size: 18px;
color: #FFFFFF;
"""),format.raw/*54.9*/("""}"""),format.raw/*54.10*/("""
"""),format.raw/*56.9*/(""".block """),format.raw/*56.16*/("""{"""),format.raw/*56.17*/("""
"""),format.raw/*57.13*/("""width: 250px;
height: 350px;
display: inline-block;
border: 1px solid #DEDEDE;
margin-right: 64px;
"""),format.raw/*62.9*/("""}"""),format.raw/*62.10*/("""
"""),format.raw/*64.9*/(""".hd-small """),format.raw/*64.19*/("""{"""),format.raw/*64.20*/("""
"""),format.raw/*65.13*/("""background-color: #000000;
font-size: 14px;
color: #FFFFFF;
"""),format.raw/*68.9*/("""}"""),format.raw/*68.10*/("""
"""),format.raw/*70.9*/("""body """),format.raw/*70.14*/("""{"""),format.raw/*70.15*/("""
"""),format.raw/*71.13*/("""font-family: 'Roboto', sans-serif;
color: #333;
font-weight: 300;
font-size: 16px;
"""),format.raw/*75.9*/("""}"""),format.raw/*75.10*/("""
"""),format.raw/*77.9*/("""#wrap """),format.raw/*77.15*/("""{"""),format.raw/*77.16*/("""
"""),format.raw/*78.13*/("""width: 800px;
margin-left: auto;
margin-right: auto;
"""),format.raw/*81.9*/("""}"""),format.raw/*81.10*/("""
"""),format.raw/*83.9*/("""#embed """),format.raw/*83.16*/("""{"""),format.raw/*83.17*/("""
"""),format.raw/*84.13*/("""margin-top: 10px;
"""),format.raw/*85.9*/("""}"""),format.raw/*85.10*/("""
"""),format.raw/*87.9*/("""h1 """),format.raw/*87.12*/("""{"""),format.raw/*87.13*/("""
"""),format.raw/*88.13*/("""text-align: center;
font-weight: normal;
"""),format.raw/*90.9*/("""}"""),format.raw/*90.10*/("""
"""),format.raw/*92.9*/(""".tt """),format.raw/*92.13*/("""{"""),format.raw/*92.14*/("""
"""),format.raw/*93.13*/("""margin-top: 10px;
background-color: #EEE;
border-bottom: 1px solid #333;
padding: 5px;
"""),format.raw/*97.9*/("""}"""),format.raw/*97.10*/("""
"""),format.raw/*99.9*/(""".txth """),format.raw/*99.15*/("""{"""),format.raw/*99.16*/("""
"""),format.raw/*100.13*/("""color: #F55;
"""),format.raw/*101.9*/("""}"""),format.raw/*101.10*/("""
"""),format.raw/*103.9*/(""".cit """),format.raw/*103.14*/("""{"""),format.raw/*103.15*/("""
"""),format.raw/*104.13*/("""font-family: courier;
padding-left: 20px;
font-size: 14px;
"""),format.raw/*107.9*/("""}"""),format.raw/*107.10*/("""
"""),format.raw/*108.9*/("""</style>
<script>
$(document).ready(function () """),format.raw/*111.39*/("""{"""),format.raw/*111.40*/("""
"""),format.raw/*112.13*/("""$('#filenamebutton').click(function () """),format.raw/*112.52*/("""{"""),format.raw/*112.53*/("""
"""),format.raw/*113.17*/("""document.getElementById('form').reset();
$('#form').hide();
var filename = $('#filename').val();
$('#filename').val('');
updateFileName(filename);
drawTsne();
"""),format.raw/*119.13*/("""}"""),format.raw/*119.14*/(""");
$('#form').fileUpload("""),format.raw/*121.35*/("""{"""),format.raw/*121.36*/("""
"""),format.raw/*122.17*/("""success: function (data, textStatus, jqXHR) """),format.raw/*122.61*/("""{"""),format.raw/*122.62*/("""
"""),format.raw/*123.21*/("""var fullPath = document.getElementById('form').value;
var filename = data['name'];
if (fullPath) """),format.raw/*125.35*/("""{"""),format.raw/*125.36*/("""
"""),format.raw/*126.25*/("""var startIndex = (fullPath.indexOf('\\') >= 0 ? fullPath.lastIndexOf('\\') : fullPath.lastIndexOf('/'));
var filename = fullPath.substring(startIndex);
if (filename.indexOf('\\') === 0 || filename.indexOf('/') === 0) """),format.raw/*128.90*/("""{"""),format.raw/*128.91*/("""
"""),format.raw/*129.29*/("""filename = filename.substring(1);
"""),format.raw/*130.25*/("""}"""),format.raw/*130.26*/("""
"""),format.raw/*131.21*/("""}"""),format.raw/*131.22*/("""
"""),format.raw/*133.21*/("""document.getElementById('form').reset();
updateFileName(filename);
drawTsne();
"""),format.raw/*137.17*/("""}"""),format.raw/*137.18*/(""", error: function (err) """),format.raw/*137.42*/("""{"""),format.raw/*137.43*/("""
"""),format.raw/*138.21*/("""console.log(err);
drawTsne();
"""),format.raw/*140.17*/("""}"""),format.raw/*140.18*/("""
"""),format.raw/*141.13*/("""}"""),format.raw/*141.14*/(""");
function updateFileName(name) """),format.raw/*144.43*/("""{"""),format.raw/*144.44*/("""
"""),format.raw/*145.17*/("""$.ajax("""),format.raw/*145.24*/("""{"""),format.raw/*145.25*/("""
"""),format.raw/*146.21*/("""url: '/tsne/upload',
type: 'POST',
dataType: 'json',
data: JSON.stringify("""),format.raw/*149.42*/("""{"""),format.raw/*149.43*/(""""url": name"""),format.raw/*149.54*/("""}"""),format.raw/*149.55*/("""),
cache: false,
success: function (data, textStatus, jqXHR) """),format.raw/*151.65*/("""{"""),format.raw/*151.66*/("""
"""),format.raw/*152.25*/("""setSessionId("UploadedFile");
drawTsne();
"""),format.raw/*154.21*/("""}"""),format.raw/*154.22*/(""",
error: function (jqXHR, textStatus, errorThrown) """),format.raw/*155.70*/("""{"""),format.raw/*155.71*/("""
"""),format.raw/*156.25*/("""// Handle errors here
console.log('ERRORS: ' + textStatus);
drawTsne();
"""),format.raw/*159.21*/("""}"""),format.raw/*159.22*/("""
"""),format.raw/*160.17*/("""}"""),format.raw/*160.18*/(""");
"""),format.raw/*161.13*/("""}"""),format.raw/*161.14*/("""
"""),format.raw/*164.9*/("""}"""),format.raw/*164.10*/(""") ;
</script>
</head>
<body>
<table style="width: 100%;
padding: 5px;" class="hd">
<tbody>
<tr>
<td style="width: 48px;"><a href="/"><img src="/assets/legacy/deeplearning4j.img" border="0"/></a></td>
<td>DeepLearning4j UI</td>
<td style="width: 512px;
text-align: right;" class="hd-small">&nbsp; Available sessions:
<select class="selectpicker" id="sessionSelect" onchange="selectNewSession()" style="color: #000000;
display: inline-block;
width: 256px;">
<option value="0" selected="selected">Pick a session to track</option>
</select>
</td>
<td style="width: 256px;">&nbsp; <!-- placeholder for future use --></td>
</tr>
</tbody>
</table>
<br />
<div style="text-align: center">
<div id="embed" style="display: inline-block;
width: 1024px;
height: 700px;
border: 1px solid #DEDEDE;"></div>
</div>
<br/>
<br/>
<div style="text-align: center;
width: 100%;
position: fixed;
bottom: 0px;
left: 0px;
margin-bottom: 15px;">
<div style="display: inline-block;
margin-right: 48px;">
<h5>Upload a file to UI server.</h5>
<form encType="multipart/form-data" action="/tsne/upload" method="POST" id="form">
<div>
<input name="file" type="file" style="width: 300px;
display: inline-block;" />
<input type="submit" value="Upload file" style="display: inline-block;"/>
</div>
</form>
</div>
"""),format.raw/*219.53*/("""
"""),format.raw/*220.96*/("""
"""),format.raw/*221.42*/("""
"""),format.raw/*222.59*/("""
"""),format.raw/*223.68*/("""
"""),format.raw/*224.27*/("""
"""),format.raw/*225.23*/("""
"""),format.raw/*226.9*/("""</div>
</body>
</html>"""))
}
}
}
def render(): play.twirl.api.HtmlFormat.Appendable = apply()
def f:(() => play.twirl.api.HtmlFormat.Appendable) = () => apply()
def ref: this.type = this
}
}
/**/
object Tsne extends Tsne_Scope0.Tsne
/*
-- GENERATED --
DATE: Tue May 07 21:39:41 AEST 2019
SOURCE: c:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/views/org/deeplearning4j/ui/views/tsne/Tsne.scala.html
HASH: 9c47ea63f8745ed1b8d2d67657f185532c1a0c78
MATRIX: 634->0|1937->1335|1974->1345|2608->1951|2637->1952|2679->1966|2801->2061|2830->2062|2869->2074|2904->2081|2933->2082|2975->2096|3162->2256|3191->2257|3230->2269|3268->2279|3297->2280|3339->2294|3461->2389|3490->2390|3529->2402|3562->2407|3591->2408|3633->2422|3791->2553|3820->2554|3859->2566|3893->2572|3922->2573|3964->2587|4079->2675|4108->2676|4147->2688|4182->2695|4211->2696|4253->2710|4307->2737|4336->2738|4375->2750|4406->2753|4435->2754|4477->2768|4567->2831|4596->2832|4635->2844|4667->2848|4696->2849|4738->2863|4900->2998|4929->2999|4968->3011|5002->3017|5031->3018|5074->3032|5124->3054|5154->3055|5194->3067|5228->3072|5258->3073|5301->3087|5423->3181|5453->3182|5491->3192|5588->3260|5618->3261|5661->3275|5729->3314|5759->3315|5806->3333|6092->3590|6122->3591|6191->3631|6221->3632|6268->3650|6341->3694|6371->3695|6422->3717|6590->3856|6620->3857|6675->3883|6971->4150|7001->4151|7060->4181|7148->4240|7178->4241|7229->4263|7259->4264|7312->4288|7481->4428|7511->4429|7564->4453|7594->4454|7645->4476|7742->4544|7772->4545|7815->4559|7845->4560|7924->4610|7954->4611|8001->4629|8037->4636|8067->4637|8118->4659|8284->4796|8314->4797|8354->4808|8384->4809|8516->4912|8546->4913|8601->4939|8718->5027|8748->5028|8849->5100|8879->5101|8934->5127|9106->5270|9136->5271|9183->5289|9213->5290|9258->5306|9288->5307|9330->5321|9360->5322|11465->7438|11512->7535|11559->7578|11610->7638|11661->7707|11708->7735|11751->7759|11789->7769
LINES: 25->1|54->30|55->31|74->50|74->50|75->51|78->54|78->54|80->56|80->56|80->56|81->57|86->62|86->62|88->64|88->64|88->64|89->65|92->68|92->68|94->70|94->70|94->70|95->71|99->75|99->75|101->77|101->77|101->77|102->78|105->81|105->81|107->83|107->83|107->83|108->84|109->85|109->85|111->87|111->87|111->87|112->88|114->90|114->90|116->92|116->92|116->92|117->93|121->97|121->97|123->99|123->99|123->99|124->100|125->101|125->101|127->103|127->103|127->103|128->104|131->107|131->107|132->108|135->111|135->111|136->112|136->112|136->112|137->113|143->119|143->119|145->121|145->121|146->122|146->122|146->122|147->123|149->125|149->125|150->126|152->128|152->128|153->129|154->130|154->130|155->131|155->131|157->133|161->137|161->137|161->137|161->137|162->138|164->140|164->140|165->141|165->141|168->144|168->144|169->145|169->145|169->145|170->146|173->149|173->149|173->149|173->149|175->151|175->151|176->152|178->154|178->154|179->155|179->155|180->156|183->159|183->159|184->160|184->160|185->161|185->161|188->164|188->164|243->219|244->220|245->221|246->222|247->223|248->224|249->225|250->226
-- GENERATED --
*/

View File

@ -1,172 +0,0 @@
@()
<!DOCTYPE html>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2019 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<html lang="en" style="height: 100%">
<head>
<meta charset="utf-8">
<title>SameDiff Graph Visualization</title>
<!-- start: Mobile Specific -->
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- end: Mobile Specific -->
<link id="bootstrap-style" href="/assets/webjars/bootstrap/4.3.1/dist/css/bootstrap.min.css" rel="stylesheet">
<link href="/assets/css/samediff/samediff.css" rel="stylesheet">
<![endif]-->
</head>
<body>
<!-- Start JavaScript-->
<script src="/assets/webjars/jquery/2.2.0/jquery.min.js"></script>
<script src="/assets/webjars/jquery-ui/1.10.2/ui/minified/jquery-ui.min.js"></script>
<script src="/assets/webjars/bootstrap/4.3.1/dist/js/bootstrap.min.js"></script>
<script src="/assets/webjars/jquery-cookie/1.4.1-1/jquery.cookie.js"></script>
<script src="/assets/webjars/flatbuffers/1.9.0/js/flatbuffers.js"></script>
<script src="/assets/webjars/cytoscape/3.3.3/dist/cytoscape.min.js"></script>
<script src="/assets/webjars/dagre/0.8.4/dist/dagre.min.js"></script>
<script src="/assets/webjars/cytoscape-dagre/2.1.0/cytoscape-dagre.js"></script>
<script src="/assets/webjars/cytoscape-cose-bilkent/4.0.0/cytoscape-cose-bilkent.js"></script>
<script src="/assets/webjars/webcola/3.1.3/WebCola/cola.js"></script>
<script src="/assets/webjars/cytoscape-cola/2.3.0/cytoscape-cola.js"></script>
<script src="/assets/webjars/cytoscape-euler/1.2.1/cytoscape-euler.js"></script>
<script src="/assets/webjars/klayjs/0.4.1/klay.js"></script>
<script src="/assets/webjars/cytoscape-klay/3.1.2/cytoscape-klay.js"></script>
<script src="/assets/webjars/weaverjs/1.2.0/dist/weaver.js"></script>
<script src="/assets/webjars/cytoscape-spread/3.0.0/cytoscape-spread.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.pie.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.stack.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.resize.min.js"></script>
<script src="/assets/webjars/flot/0.8.3/jquery.flot.selection.js"></script>
<script src="/assets/js/samediff/generated/uigraphevents_generated.js"></script>
<script src="/assets/js/samediff/generated/uigraphstatic_generated.js"></script>
<script src="/assets/js/samediff/generated/array_generated.js"></script>
<script src="/assets/js/samediff/generated/utils_generated.js"></script>
<script src="/assets/js/samediff/generated/variable_generated.js"></script>
<script src="/assets/js/samediff/samediff-ui.js"></script>
<script src="/assets/js/samediff/samediff-graph.js"></script>
<script src="/assets/js/samediff/samediff-plots.js"></script>
<script src="/assets/js/samediff/flatbuffers-utils.js"></script>
@*<div class="wrapper">*@
<div class="container-fluid" style="min-height: 100%">
<div class="row">
<!-- NavBar - Bootstrap classes -->
<nav class="navbar navbar-expand navbar-dark bg-dark" style="width: 100pc">
<a class="navbar-brand" href="#">SameDiff</a>
<button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbarNavAltMarkup" aria-controls="navbarNavAltMarkup" aria-expanded="false" aria-label="Toggle navigation">
<span class="navbar-toggler-icon"></span>
</button>
<div class="collapse navbar-collapse" id="navbarNavAltMarkup">
<div class="navbar-nav">
<a id="sdnavgraph" class="nav-item nav-link active" href="#" onclick="samediffSetPage('graph')">
Graph</a>
<a id="sdnavplots" class="nav-item nav-link" href="#" onclick="samediffSetPage('plots')">
Plots</a>
<a id="sdnaveval" class="nav-item nav-link" href="#" onclick="samediffSetPage('evaluation')">
Evaluation</a>
<a id="sdnavperf" class="nav-item nav-link" href="#" onclick="samediffSetPage('performance')">
Performance</a>
<a class="nav-item nav-link" href="#" onclick="toggleSidebar()">Toggle Sidebar</a>
</div>
</div>
</nav>
</div>
<div class="row" style="min-height: 100%">
<!-- Sidebar -->
@*<div id="samediffsidebar">*@
<div class="col-md-4 col-12" style="min-width: 300px; max-width: 300px; background-color: #e6e6e6; height:100%; min-height:100vh">
@*<div id="sidebartop" style="position: absolute; top: 0">*@
<div id="sidebartop" class="row p-2">
<div style="width:auto">
<label class="input-group-btn">
<span class="btn btn-secondary btn-sm">
Select File<input type="file" id="fileselect" style="display: none;" multiple>
</span>
</label>
</div>
<div id="selectedfile" class="w-100">[No File Loaded]</div>
</div>
<div class="sidebarline"></div>
@*<div id="sidebarmid" style="position: absolute; top: 300px">*@
<div id="sidebarmid" class="row p-2">
<div class="w-100"><b>Selected Node:</b></div>
<div id="sidebarmid-content" class="w-100">(None)</div>
</div>
<div class="sidebarline"></div>
<div id="sidebarmid2" class="row p-2">
<div style="width:100%">
<b>Find Node:</b><br>
</div>
<input id="findnodetxt" type="text" oninput="onGraphNodeSearch()">
<div id="findnoderesults">
</div>
</div>
<div class="sidebarline"></div>
@*<div id="sidebarbottom" style="position: absolute; bottom: 0">*@
<div id="sidebarbottom" class="row p-2">
<br><br>
<strong>Graph Layout:</strong>
<div class="btn-group btn-group-toggle w-100" data-toggle="buttons" style="height: 40px">
<label class="btn btn-secondary active" onclick="setLayout('klay_down')">
<input type="radio" name="options" id="option1" autocomplete="off" checked>Down</label>
<label class="btn btn-secondary" onclick="setLayout('klay_lr')">
<input type="radio" name="options" id="option2" autocomplete="off">Right</label>
<label class="btn btn-secondary" onclick="setLayout('dagre')">
<input type="radio" name="options" id="option3" autocomplete="off">Alt</label>
<label class="btn btn-secondary" onclick="setLayout('cose-bilkent')">
<input type="radio" name="options" id="option3" autocomplete="off">Spread</label>
</div>
<br>
<br>
<br>
</div>
</div>
<!-- Page Content -->
<div id="samediffcontent" class="col-md col-12 main pa-1">
<div id="graphdiv" style="height: 100%;
width: 100%;
display: table"></div>
</div>
</div>
</div>
<!-- Execute once on page load -->
<script>
document.getElementById('fileselect').addEventListener('change', fileSelect, false);
$(document).ready(function () {
renderSameDiffGraph();
});
</script>
</body>
</html>

View File

@ -1,96 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.play;
import org.deeplearning4j.ui.api.UIServer;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.File;
import java.util.Arrays;
@Ignore
public class TestSameDiffUI {
@Ignore
@Test
public void testSameDiff() throws Exception {
//
// File f = new File("C:/Temp/SameDiffUI/ui_data.bin");
// f.getParentFile().mkdirs();
// f.delete();
//
// SameDiff sd = SameDiff.create();
// SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
// SDVariable w = sd.var("w", DataType.FLOAT, 3,4);
// SDVariable b = sd.var("b", DataType.FLOAT, 1, 4);
//
// SDVariable z = in.mmul(w).add(b);
// SDVariable a = sd.nn().tanh(z);
//
// LogFileWriter lfw = new LogFileWriter(f);
// lfw.writeGraphStructure(sd);
// lfw.writeFinishStaticMarker();
//
// //Append a number of events
// lfw.registerEventName("accuracy");
// lfw.registerEventName("precision");
// long t = System.currentTimeMillis();
// for( int iter=0; iter<50; iter++) {
// double d = Math.cos(0.1*iter);
// d *= d;
// lfw.writeScalarEvent("accuracy", t + iter, iter, 0, d);
//
// double prec = Math.min(0.05 * iter, 1.0);
// lfw.writeScalarEvent("precision", t+iter, iter, 0, prec);
// }
//
// //Add some histograms:
// lfw.registerEventName("histogramDiscrete");
// lfw.registerEventName("histogramEqualSpacing");
// lfw.registerEventName("histogramCustomBins");
// for( int i=0; i<3; i++ ){
// INDArray discreteY = Nd4j.createFromArray(0, 1, 2);
// lfw.writeHistogramEventDiscrete("histogramDiscrete", t+i, i, 0, Arrays.asList("zero", "one", "two"), discreteY);
//
// INDArray eqSpacingY = Nd4j.createFromArray(-0.5 + 0.5 * i, 0.75 * i + i, 1.0 * i + 1.0);
// lfw.writeHistogramEventEqualSpacing("histogramEqualSpacing", t+i, i, 0, 0.0, 1.0, eqSpacingY);
//
// INDArray customBins = Nd4j.createFromArray(new double[][]{
// {0.0, 0.5, 0.9},
// {0.2, 0.55, 1.0}
// });
// System.out.println(Arrays.toString(customBins.data().asFloat()));
// System.out.println(customBins.shapeInfoToString());
// lfw.writeHistogramEventCustomBins("histogramCustomBins", t+i, i, 0, customBins, eqSpacingY);
// }
UIServer uiServer = UIServer.getInstance();
Thread.sleep(1_000_000_000);
}
}

View File

@ -134,7 +134,7 @@
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui_2.11</artifactId> <artifactId>deeplearning4j-ui</artifactId>
<version>${deeplearning4j.version}</version> <version>${deeplearning4j.version}</version>
</dependency> </dependency>
</dependencies> </dependencies>

View File

@ -1,5 +1,6 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc. ~ Copyright (c) 2015-2018 Skymind, Inc.
~ Copyright (c) 2019 Konduit K.K.
~ ~
~ This program and the accompanying materials are made available under the ~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at ~ terms of the Apache License, Version 2.0 which is available at
@ -22,14 +23,14 @@
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<artifactId>deeplearning4j-ui_2.11</artifactId> <artifactId>deeplearning4j-ui</artifactId>
<name>deeplearning4j-ui</name> <name>deeplearning4j-ui</name>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-play_2.11</artifactId> <artifactId>deeplearning4j-vertx</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
@ -75,4 +76,19 @@
</profile> </profile>
</profiles> </profiles>
<build>
<pluginManagement>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</pluginManagement>
</build>
</project> </project>

View File

@ -1,46 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui;
import org.slf4j.Logger;
import java.awt.*;
import java.net.URI;
public class UiUtils {
private UiUtils() {}
public static void tryOpenBrowser(String path, Logger log) {
try {
UiUtils.openBrowser(new URI(path));
} catch (Exception e) {
// log.error("Could not open browser",e);
System.out.println("Browser could not be launched automatically.\nUI path: " + path);
}
}
public static void openBrowser(URI uri) throws Exception {
if (Desktop.isDesktopSupported()) {
Desktop.getDesktop().browse(uri);
} else {
throw new UnsupportedOperationException(
"Cannot open browser on this platform: Desktop.isDesktopSupported() == false");
}
}
}

View File

@ -1,100 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui;
import lombok.NonNull;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.concurrent.LinkedBlockingQueue;
/**
* This is simple wrapper for sending state updates generated by `TrainingListener`s.
* Basic idea is simple: network processing should be handled in background, without slowing caller thread
*
* @author raver119@gmail.com
*/
public class WebReporter {
private static final WebReporter INSTANCE = new WebReporter();
protected LinkedBlockingQueue<Pair<WebTarget, Entity>> queue = new LinkedBlockingQueue<>();
private static final Logger log = LoggerFactory.getLogger(WebReporter.class);
private WebReporter() {
ReporterThread thread = new ReporterThread(queue);
thread.start();
throw new RuntimeException("Not implemented");
}
public static WebReporter getInstance() {
return INSTANCE;
}
/**
* This method queues UI report for sending
*
* @param target
* @param entity
*/
public void queueReport(WebTarget target, Entity entity) {
queue.add(Pair.makePair(target, entity));
}
/**
* This method immediately sends UI report to specified target using POST request
*
* @param target
* @param entity
*/
public void postReport(WebTarget target, Entity entity) {
Response resp = target.request(MediaType.APPLICATION_JSON).accept(MediaType.APPLICATION_JSON).post(entity);
log.debug("{}", resp);
}
private class ReporterThread extends Thread implements Runnable {
private LinkedBlockingQueue<Pair<WebTarget, Entity>> queue;
public ReporterThread(@NonNull LinkedBlockingQueue<Pair<WebTarget, Entity>> queue) {
this.queue = queue;
this.setName("DL4j Ui WebReporter thread");
this.setDaemon(true);
}
@Override
public void run() {
while (true) {
try {
Pair<WebTarget, Entity> pair = queue.take();
postReport(pair.getFirst(), pair.getSecond());
} catch (Exception e) {
log.error("Exception caught but ignored: " + e.getMessage());
e.printStackTrace();
} finally {
try {
Thread.sleep(100);
} catch (Exception e) {
}
}
}
}
}
}

View File

@ -1,38 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.exception;
import org.apache.commons.lang.exception.ExceptionUtils;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import javax.ws.rs.ext.Provider;
@Provider
public class GenericExceptionMapper implements ExceptionMapper<Throwable> {
@Override
public Response toResponse(Throwable ex) {
return Response.status(500)
.entity("Error occurred\n\n" + ex.getMessage() + "\n" + ExceptionUtils.getStackTrace(ex))
.type(MediaType.APPLICATION_JSON).build();
}
}

View File

@ -1,35 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.exception;
import com.fasterxml.jackson.core.JsonProcessingException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import javax.ws.rs.ext.Provider;
/**
* @author Adam Gibson
*/
@Provider
public class JsonExceptionMapper implements ExceptionMapper<JsonProcessingException> {
@Override
public Response toResponse(JsonProcessingException exception) {
return Response.status(500).entity("WTF").type(MediaType.APPLICATION_JSON).build();
}
}

View File

@ -1,46 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.providers;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.serde.jackson.ndarray.NDArrayDeSerializer;
import org.nd4j.shade.serde.jackson.ndarray.NDArraySerializer;
import javax.ws.rs.ext.ContextResolver;
import javax.ws.rs.ext.Provider;
/**
* @author Adam Gibson
*/
@Provider
public class ObjectMapperProvider implements ContextResolver<ObjectMapper> {
@Override
public ObjectMapper getContext(Class<?> type) {
final ObjectMapper result = new ObjectMapper();
result.registerModule(module());
return result;
}
public static SimpleModule module() {
SimpleModule module = new SimpleModule("nd4j");
module.addDeserializer(INDArray.class, new NDArrayDeSerializer());
module.addSerializer(INDArray.class, new NDArraySerializer());
return module;
}
}

View File

@ -26,10 +26,9 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage; import org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage;
import org.deeplearning4j.util.UIDProvider; import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -48,7 +48,6 @@ import org.deeplearning4j.ui.weights.ConvolutionalIterationListener;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -57,7 +56,6 @@ import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;

View File

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc. ~ Copyright (c) 2019 Konduit K.K.
~ ~
~ This program and the accompanying materials are made available under the ~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at ~ terms of the Apache License, Version 2.0 which is available at
@ -14,7 +14,6 @@
~ ~
~ SPDX-License-Identifier: Apache-2.0 ~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" <project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
@ -25,165 +24,43 @@
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<artifactId>deeplearning4j-play_2.11</artifactId> <artifactId>deeplearning4j-vertx</artifactId>
<name>deeplearning4j-play</name>
<properties>
<!-- Default scala versions, may be overwritten by build profiles -->
<scala.version>2.11.12</scala.version>
<scala.binary.version>2.11</scala.binary.version>
</properties>
<profiles>
<!-- To build UI templates: run "mvn compile -P buildUiTemplates" -->
<profile>
<id>buildUiTemplates</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<build>
<plugins>
<plugin>
<groupId>com.google.code.play2-maven-plugin</groupId>
<artifactId>play2-maven-plugin</artifactId>
<version>${maven-play2-plugin.version}</version>
<!-- Generate Scala Page Templates
The Play framework template engine ("twirl") uses templates for HTML pages (or in principle any text-based
data: CSV, XML etc). These templates (*.scala.html files) need to be converted to Scala classes using
code generation. This is done here during the Maven compile phase.
However, the Maven Play framework plugin does not allow proper customization of the output directory. Thus,
we generate these Scala classes in the default location (in the target/twirl/main/ directory) and use the
maven resources plugin to copy them to the actual location we want.
To generate the latest versions of these templates (after modifying or adding a new template), just run
"mvn compile" in either the main project directory, or within the deeplearning4j ui module separately.
-->
<executions>
<execution>
<id>GenerateTemplates</id>
<phase>compile</phase>
<configuration>
</configuration>
<goals>
<goal>template-compile</goal>
</goals>
</execution>
</executions>
</plugin>
<!-- Copy the generated Scala templates to the appropriate directory. See templates comment above. -->
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>${maven-resources-plugin.version}</version>
<executions>
<execution>
<id>CopyTemplates</id>
<phase>compile</phase>
<goals>
<goal>copy-resources</goal>
</goals>
<configuration>
<outputDirectory>${basedir}/src/main/scala/org/deeplearning4j/ui/views/html</outputDirectory>
<resources>
<resource>
<directory>${basedir}/target/twirl/main/org/deeplearning4j/ui/views/html</directory>
<filtering>true</filtering>
</resource>
</resources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
<profile>
<id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-10.1</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
<dependencies> <dependencies>
<dependency>
<groupId>io.vertx</groupId>
<artifactId>vertx-core</artifactId>
<version>${vertx.version}</version>
</dependency>
<dependency>
<groupId>io.vertx</groupId>
<artifactId>vertx-web</artifactId>
<version>${vertx.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui-model</artifactId> <artifactId>deeplearning4j-ui-model</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>javax.ws.rs</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>javax.ws.rs-api</artifactId> <artifactId>logback-classic</artifactId>
<version>${ws.rs.version}</version> <scope>test</scope>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-java_2.11</artifactId>
<version>${playframework.version}</version>
<exclusions>
<exclusion>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
</exclusion>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>jul-to-slf4j</artifactId>
</exclusion>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>jcl-over-slf4j</artifactId>
</exclusion>
<exclusion>
<groupId>net.jodah</groupId>
<artifactId>typetools</artifactId>
</exclusion>
</exclusions>
</dependency> </dependency>
<dependency> <dependency>
<groupId>net.jodah</groupId> <groupId>org.freemarker</groupId>
<artifactId>typetools</artifactId> <artifactId>freemarker</artifactId>
<version>${jodah.typetools.version}</version> <version>2.3.29</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-netty-server_2.11</artifactId>
<version>${playframework.version}</version>
</dependency>
<dependency>
<groupId>org.fusesource.leveldbjni</groupId>
<artifactId>leveldbjni-all</artifactId>
<version>${leveldb.version}</version>
</dependency> </dependency>
<dependency> <dependency>
@ -192,24 +69,8 @@
<version>${jcommander.version}</version> <version>${jcommander.version}</version>
</dependency> </dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId> <!-- Version set by deeplearning4j-parent dependency management -->
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>jul-to-slf4j</artifactId> <!-- Version set by deeplearning4j-parent dependency management -->
<scope>compile</scope>
<version>1.7.21</version>
</dependency>
<!-- Test Scope Dependencies -->
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId> <!-- Version set by deeplearning4j-parent dependency management -->
<scope>test</scope>
</dependency>
<!-- WebJars Dependencies --> <!-- WebJars Dependencies -->
<dependency> <dependency>
@ -306,11 +167,6 @@
<artifactId>cytoscape-cola</artifactId> <artifactId>cytoscape-cola</artifactId>
<version>2.3.0</version> <version>2.3.0</version>
</dependency> </dependency>
<dependency>
<groupId>org.webjars.npm</groupId>
<artifactId>webcola</artifactId>
<version>3.1.3</version>
</dependency>
<dependency> <dependency>
<groupId>org.webjars.npm</groupId> <groupId>org.webjars.npm</groupId>
<artifactId>cytoscape-cose-bilkent</artifactId> <artifactId>cytoscape-cose-bilkent</artifactId>
@ -442,68 +298,8 @@
<artifactId>flatbuffers</artifactId> <artifactId>flatbuffers</artifactId>
<version>1.9.0</version> <version>1.9.0</version>
</dependency> </dependency>
<!-- Workaround for Java 9+ -->
<dependency>
<groupId>javax.xml.bind</groupId>
<artifactId>jaxb-api</artifactId>
<version>${jaxb.version}</version>
<scope>provided</scope>
</dependency>
</dependencies> </dependencies>
<build>
<plugins>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<configuration>
<scalaVersion>${scala.version}</scalaVersion>
</configuration>
</plugin>
<!-- Build helper plugin: used to add the multiple independent source directories (Java, Scala, HTML templates) -->
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>add-source</id>
<phase>compile</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>src/main/scala</source>
<source>src/main/java</source>
<source>src/main/views</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>com.google.code.sbt-compiler-maven-plugin</groupId>
<artifactId>sbt-compiler-maven-plugin</artifactId>
<version>${sbt-compiler-maven-plugin.version}</version>
</plugin>
<!-- Maven compiler plugin last: should ensure Scala code is compiled before Java code -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId> <!-- Version set by deeplearning4j-parent dependency management -->
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
<dependencyManagement> <dependencyManagement>
<dependencies> <dependencies>
<!-- Override webjars dependencies with fixed versions. <!-- Override webjars dependencies with fixed versions.
@ -619,4 +415,19 @@
</dependency> </dependency>
</dependencies> </dependencies>
</dependencyManagement> </dependencyManagement>
</project>
<build>
<pluginManagement>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</pluginManagement>
</build>
</project>

View File

@ -1,5 +1,5 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,243 +14,174 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.ui.play; package org.deeplearning4j.ui;
import com.beust.jcommander.JCommander; import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter; import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException; import io.vertx.core.AbstractVerticle;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpServer;
import io.vertx.core.http.impl.MimeMapping;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.handler.BodyHandler;
import lombok.Data; import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent; import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.api.storage.StatsStorageListener; import org.deeplearning4j.api.storage.StatsStorageListener;
import org.deeplearning4j.api.storage.StatsStorageRouter; import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.config.DL4JSystemProperties; import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.ui.SameDiffModule;
import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule; import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.i18n.DefaultI18N;
import org.deeplearning4j.ui.i18n.I18NProvider; import org.deeplearning4j.ui.i18n.I18NProvider;
import org.deeplearning4j.ui.module.SameDiffModule;
import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule; import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule;
import org.deeplearning4j.ui.module.defaultModule.DefaultModule; import org.deeplearning4j.ui.module.defaultModule.DefaultModule;
import org.deeplearning4j.ui.module.remote.RemoteReceiverModule; import org.deeplearning4j.ui.module.remote.RemoteReceiverModule;
import org.deeplearning4j.ui.module.train.TrainModule; import org.deeplearning4j.ui.module.train.TrainModule;
import org.deeplearning4j.ui.module.tsne.TsneModule; import org.deeplearning4j.ui.module.tsne.TsneModule;
import org.deeplearning4j.ui.play.staticroutes.Assets;
import org.deeplearning4j.ui.storage.FileStatsStorage; import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.storage.impl.QueueStatsStorageListener; import org.deeplearning4j.ui.storage.impl.QueueStatsStorageListener;
import org.deeplearning4j.util.DL4JFileUtils; import org.deeplearning4j.util.DL4JFileUtils;
import org.nd4j.linalg.function.Function; import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import play.BuiltInComponents;
import play.Mode;
import play.routing.Router;
import play.routing.RoutingDsl;
import play.server.Server;
import java.io.File; import java.io.File;
import java.util.*; import java.util.*;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import static play.mvc.Results.ok;
/**
* A UI server based on the Play framework
*
* @author Alex Black
*/
@Slf4j @Slf4j
@Data public class VertxUIServer extends AbstractVerticle implements UIServer {
public class PlayUIServer extends UIServer {
/**
* @deprecated Use {@link DL4JSystemProperties#UI_SERVER_PORT_PROPERTY}
*/
@Deprecated
public static final String UI_SERVER_PORT_PROPERTY = DL4JSystemProperties.UI_SERVER_PORT_PROPERTY;
public static final int DEFAULT_UI_PORT = 9000; public static final int DEFAULT_UI_PORT = 9000;
public static final String ASSETS_ROOT_DIRECTORY = "deeplearning4jUiAssets/"; public static final String ASSETS_ROOT_DIRECTORY = "deeplearning4jUiAssets/";
private Server server; @Getter
private boolean stopped; private static VertxUIServer instance;
@Getter
private static AtomicBoolean multiSession = new AtomicBoolean(false);
@Getter
@Setter
private static Function<String, StatsStorage> statsStorageProvider;
private static Integer instancePort;
public static VertxUIServer getInstance() {
return getInstance(null, multiSession.get(), null);
}
public static VertxUIServer getInstance(Integer port, boolean multiSession, Function<String, StatsStorage> statsStorageProvider){
if (instance == null || instance.isStopped()) {
VertxUIServer.multiSession.set(multiSession);
VertxUIServer.setStatsStorageProvider(statsStorageProvider);
instancePort = port;
Vertx vertx = Vertx.vertx();
//Launch UI server verticle and wait for it to start
CountDownLatch l = new CountDownLatch(1);
vertx.deployVerticle(VertxUIServer.class.getName(), res -> {
l.countDown();
});
try {
l.await(5000, TimeUnit.MILLISECONDS);
} catch (InterruptedException e){ } //Ignore
} else if (!instance.isStopped()) {
if (instance.multiSession.get() && !instance.isMultiSession()) {
throw new RuntimeException("Cannot return multi-session instance." +
" UIServer has already started in single-session mode at " + instance.getAddress() +
" You may stop the UI server instance, and start a new one.");
} else if (!instance.multiSession.get() && instance.isMultiSession()) {
throw new RuntimeException("Cannot return single-session instance." +
" UIServer has already started in multi-session mode at " + instance.getAddress() +
" You may stop the UI server instance, and start a new one.");
}
}
return instance;
}
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
private RemoteReceiverModule remoteReceiverModule;
private StatsStorageLoader statsStorageLoader;
//typeIDModuleMap: Records which modules are registered for which type IDs
private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>();
private HttpServer server;
private AtomicBoolean shutdown = new AtomicBoolean(false);
private long uiProcessingDelay = 500; //500ms. TODO make configurable
private final BlockingQueue<StatsStorageEvent> eventQueue = new LinkedBlockingQueue<>(); private final BlockingQueue<StatsStorageEvent> eventQueue = new LinkedBlockingQueue<>();
private List<Pair<StatsStorage, StatsStorageListener>> listeners = new CopyOnWriteArrayList<>(); private List<Pair<StatsStorage, StatsStorageListener>> listeners = new CopyOnWriteArrayList<>();
private List<StatsStorage> statsStorageInstances = new CopyOnWriteArrayList<>(); private List<StatsStorage> statsStorageInstances = new CopyOnWriteArrayList<>();
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
private RemoteReceiverModule remoteReceiverModule;
//typeIDModuleMap: Records which modules are registered for which type IDs
private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>();
private long uiProcessingDelay = 500; //500ms. TODO make configurable
private final AtomicBoolean shutdown = new AtomicBoolean(false);
private Thread uiEventRoutingThread; private Thread uiEventRoutingThread;
@Parameter(names = {"-r", "--enableRemote"}, description = "Whether to enable remote or not", arity = 1) public VertxUIServer() {
private boolean enableRemote; instance = this;
@Parameter(names = {"-p", "--uiPort"}, description = "Custom HTTP port for UI", arity = 1)
private int port = DEFAULT_UI_PORT;
@Parameter(names = {"-f", "--customStatsFile"}, description = "Path to create custom stats file (remote only)", arity = 1)
private String customStatsFile;
@Parameter(names = {"-m", "--multiSession"}, description = "Whether to enable multiple separate browser sessions or not", arity = 1)
private boolean multiSession;
private StatsStorageLoader statsStorageLoader;
public PlayUIServer() {
this(DEFAULT_UI_PORT);
} }
public PlayUIServer(int port) { public static void stopInstance(){
this(port, false); if(instance == null)
return;
instance.stop();
} }
/** @Override
* Create {@code PlayUIServer} at given port in given mode. Note: to start the server, run {@link #runMain(String[])} public void start() throws Exception {
* @param port port that the server will listen on //Create REST endpoints
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs. File uploadDir = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis());
* <br/>URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId uploadDir.mkdirs();
*/ Router r = Router.router(vertx);
public PlayUIServer(int port, boolean multiSession) { r.route().handler(BodyHandler.create() //NOTE: Setting this is required to receive request body content at all
this.port = port; .setUploadsDirectory(uploadDir.getAbsolutePath()));
this.multiSession = multiSession; r.get("/assets/*").handler(rc -> {
} String path = rc.request().path();
path = path.substring(8); //Remove "/assets/", which is 8 characters
/** String mime;
* Auto-attach StatsStorage if an unknown session ID is passed as URL path parameter in multi-session mode String newPath;
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID if (path.contains("webjars")) {
*/ newPath = "META-INF/resources/" + path.substring(path.indexOf("webjars"));
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
if (statsStorageProvider != null) {
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
}
}
public void runMain(String[] args) {
JCommander jcmdr = new JCommander(this);
try {
jcmdr.parse(args);
} catch (ParameterException e) {
//User provides invalid input -> print the usage info
jcmdr.usage();
try {
Thread.sleep(500);
} catch (Exception e2) {
}
System.exit(1);
}
if(((DefaultI18N)I18NProvider.getInstance()).noI18NData()){
Throwable t = ((DefaultI18N)I18NProvider.getInstance()).getLanguageLoadingException();
if(t != null){
throw new RuntimeException("Exception encountered during loading UI Language (Internationalization) data", t);
}
log.error("Error loading UI Language (Internationalization) data: no language resource data files were" +
"found on the classpath. This usually occurs when running DL4J's UI from an uber-jar, which was " +
"built incorrectly (without language resource files). See https://deeplearning4j.org/docs/latest/deeplearning4j-nn-visualization#issues" +
"for more details");
System.exit(1);
}
String portProperty = System.getProperty(DL4JSystemProperties.UI_SERVER_PORT_PROPERTY);
if (portProperty != null) {
try {
port = Integer.parseInt(portProperty);
} catch (NumberFormatException e) {
log.warn("Could not parse UI port property \"{}\" with value \"{}\"", DL4JSystemProperties.UI_SERVER_PORT_PROPERTY,
portProperty, e);
}
}
//Set play secret key, if required
//http://www.playframework.com/documentation/latest/ApplicationSecret
String crypto = System.getProperty("play.crypto.secret");
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) {
byte[] newCrypto = new byte[1024];
new Random().nextBytes(newCrypto);
String base64 = Base64.getEncoder().encodeToString(newCrypto);
System.setProperty("play.crypto.secret", base64);
}
try {
server = Server.forRouter(Mode.PROD, port, this::createRouter);
} catch (Throwable e){
if(e.getMessage().contains("'play.crypto.provider")){
//Usual cause: user's uber-jar does not include application.conf
log.error("Error starting UI server due to missing play.crypto.provider config: This usually occurs due to missing" +
" application.conf file. DL4J's UI (based on the Play framework) requires this file in order" +
" to run. File can be missing due to incorrect creation of uber-jars that do not include resource" +
" files. See https://deeplearning4j.org/docs/latest/deeplearning4j-nn-visualization#issues for more information", e);
} else { } else {
log.error("Unknown error when starting UI server",e); newPath = ASSETS_ROOT_DIRECTORY + (path.startsWith("/") ? path.substring(1) : path);
} }
throw e; mime = MimeMapping.getMimeTypeForFilename(FilenameUtils.getName(newPath));
}
log.info("DL4J UI Server started at {}", getAddress()); //System.out.println("PATH: " + path + " - mime = " + mime);
rc.response()
.putHeader("content-type", mime)
.sendFile(newPath);
});
uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
uiEventRoutingThread.setDaemon(true);
uiEventRoutingThread.start();
if (enableRemote && customStatsFile != null) {
enableRemoteListener(new FileStatsStorage(new File(customStatsFile)), true);
}
else if(enableRemote) {
try {
File tempStatsFile = DL4JFileUtils.createTempFile("dl4j", "UIstats");
tempStatsFile.delete();
tempStatsFile.deleteOnExit();
enableRemoteListener(new FileStatsStorage(tempStatsFile), true);
} catch(Exception e) {
log.error("Failed to create temporary file for stats storage",e);
System.exit(1);
}
}
setStopped(false); if (isMultiSession()) {
} r.get("/setlang/:sessionId/:to").handler(
rc -> {
protected Router createRouter(BuiltInComponents builtInComponents){ String sid = rc.request().getParam("sessionID");
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); String to = rc.request().getParam("to");
I18NProvider.getInstance(sid).setDefaultLanguage(to);
//Set up index page and assets routing rc.response().end();
//The definitions and FunctionUtil may look a bit weird here... this is used to translate implementation independent });
// definitions (i.e., Java Supplier, Function etc interfaces) to the Play-specific versions
//This way, routing is not directly dependent ot Play API. Furthermore, Play 2.5 switches to using these Java interfaces
// anyway; thus switching 2.5 should be as simple as removing the FunctionUtil calls...
if (multiSession) {
routingDsl.GET("/setlang/:sessionId/:to").routingTo((request, sid, to) -> {
I18NProvider.getInstance(sid.toString()).setDefaultLanguage(to.toString());
return ok();
});
} else { } else {
routingDsl.GET("/setlang/:to").routingTo((request, to) -> { r.get("/setlang/:to").handler(rc -> {
I18NProvider.getInstance().setDefaultLanguage(to.toString()); String to = rc.request().getParam("to");
return ok(); I18NProvider.getInstance().setDefaultLanguage(to);
rc.response().end();
}); });
} }
routingDsl.GET("/assets/*file").routingTo((request, file) -> Assets.assetRequest(ASSETS_ROOT_DIRECTORY, file.toString()));
uiModules.add(new DefaultModule(multiSession)); //For: navigation page "/"
uiModules.add(new TrainModule(multiSession, statsStorageLoader, this::getAddress)); uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
uiModules.add(new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress));
uiModules.add(new ConvolutionalListenerModule()); uiModules.add(new ConvolutionalListenerModule());
uiModules.add(new TsneModule()); uiModules.add(new TsneModule());
uiModules.add(new SameDiffModule()); uiModules.add(new SameDiffModule());
@ -262,27 +193,19 @@ public class PlayUIServer extends UIServer {
for (UIModule m : uiModules) { for (UIModule m : uiModules) {
List<Route> routes = m.getRoutes(); List<Route> routes = m.getRoutes();
for (Route r : routes) { for (Route route : routes) {
RoutingDsl.PathPatternMatcher ppm = routingDsl.match(r.getHttpMethod().name(), r.getRoute()); switch (route.getHttpMethod()) {
switch (r.getFunctionType()) { case GET:
case Supplier: r.get(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc));
ppm.routingTo(request -> r.getSupplier().get());
break; break;
case Function: case PUT:
ppm.routingTo((request, arg) -> r.getFunction().apply(arg.toString())); r.put(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc));
break; break;
case BiFunction: case POST:
ppm.routingTo((request, arg0, arg1) -> r.getFunction2().apply(arg0.toString(), arg1.toString())); r.post(route.getRoute()).handler(rc -> route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), rc), rc));
break; break;
case Request0Function:
ppm.routingTo(request -> r.getRequest0Function().apply(request));
break;
case Request1Function:
ppm.routingTo((request, arg0) -> r.getRequest1Function().apply(request, arg0.toString()));
break;
case Function3:
default: default:
throw new RuntimeException("Not yet implemented"); throw new IllegalStateException("Unknown or not supported HTTP method: " + route.getHttpMethod());
} }
} }
@ -297,20 +220,41 @@ public class PlayUIServer extends UIServer {
list.add(m); list.add(m);
} }
} }
Router router = routingDsl.build();
return router;
}
@Override //Check port property
public String getAddress() { int port = instancePort == null ? DEFAULT_UI_PORT : instancePort;
String addr = server.mainAddress().toString(); String portProp = System.getenv(DL4JSystemProperties.UI_SERVER_PORT_PROPERTY);
if (addr.startsWith("/0:0:0:0:0:0:0:0")) { if(portProp != null && !portProp.isEmpty()){
int last = addr.lastIndexOf(':'); try{
if (last > 0) { port = Integer.parseInt(portProp);
addr = "http://localhost:" + addr.substring(last + 1); } catch (NumberFormatException e){
log.warn("Error parsing port property {}={}", DL4JSystemProperties.UI_SERVER_PORT_PROPERTY, portProp);
} }
} }
return addr;
server = vertx.createHttpServer()
.requestHandler(r)
.listen(port);
uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
uiEventRoutingThread.setDaemon(true);
uiEventRoutingThread.start();
}
private List<String> extractArgsFromRoute(String path, RoutingContext rc) {
if (!path.contains(":")) {
return Collections.emptyList();
}
String[] split = path.split("/");
List<String> out = new ArrayList<>();
for (String s : split) {
if (s.startsWith(":")) {
String s2 = s.substring(1);
out.add(rc.request().getParam(s2));
}
}
return out;
} }
private void modulesViaServiceLoader(List<UIModule> uiModules) { private void modulesViaServiceLoader(List<UIModule> uiModules) {
@ -326,32 +270,49 @@ public class PlayUIServer extends UIServer {
UIModule m = iter.next(); UIModule m = iter.next();
Class<?> c = m.getClass(); Class<?> c = m.getClass();
boolean foundExisting = false; boolean foundExisting = false;
for(UIModule mExisting : uiModules){ for (UIModule mExisting : uiModules) {
if(mExisting.getClass() == c){ if (mExisting.getClass() == c) {
foundExisting = true; foundExisting = true;
break; break;
} }
} }
if(!foundExisting) { if (!foundExisting) {
log.debug("Loaded UI module via service loader: {}", m.getClass()); log.debug("Loaded UI module via service loader: {}", m.getClass());
uiModules.add(m); uiModules.add(m);
} }
} }
} }
@Override
public void stop() {
server.close();
shutdown.set(true);
}
public static void main(String[] args) {
new PlayUIServer().runMain(args); @Override
public boolean isStopped() {
return shutdown.get();
}
@Override
public boolean isMultiSession() {
return multiSession.get();
}
@Override
public String getAddress() {
return "https://localhost:" + server.actualPort();
} }
@Override @Override
public int getPort() { public int getPort() {
return port; return server.actualPort();
} }
@Override @Override
public synchronized void attach(StatsStorage statsStorage) { public void attach(StatsStorage statsStorage) {
if (statsStorage == null) if (statsStorage == null)
throw new IllegalArgumentException("StatsStorage cannot be null"); throw new IllegalArgumentException("StatsStorage cannot be null");
if (statsStorageInstances.contains(statsStorage)) if (statsStorageInstances.contains(statsStorage))
@ -369,13 +330,13 @@ public class PlayUIServer extends UIServer {
} }
@Override @Override
public synchronized void detach(StatsStorage statsStorage) { public void detach(StatsStorage statsStorage) {
if (statsStorage == null) if (statsStorage == null)
throw new IllegalArgumentException("StatsStorage cannot be null"); throw new IllegalArgumentException("StatsStorage cannot be null");
if (!statsStorageInstances.contains(statsStorage)) if (!statsStorageInstances.contains(statsStorage))
return; //No op return; //No op
boolean found = false; boolean found = false;
for (Iterator<Pair<StatsStorage, StatsStorageListener>> iterator = listeners.iterator(); iterator.hasNext();) { for (Iterator<Pair<StatsStorage, StatsStorageListener>> iterator = listeners.iterator(); iterator.hasNext(); ) {
Pair<StatsStorage, StatsStorageListener> p = iterator.next(); Pair<StatsStorage, StatsStorageListener> p = iterator.next();
if (p.getFirst() == statsStorage) { //Same object, not equality if (p.getFirst() == statsStorage) { //Same object, not equality
statsStorage.deregisterStatsStorageListener(p.getSecond()); statsStorage.deregisterStatsStorageListener(p.getSecond());
@ -433,15 +394,6 @@ public class PlayUIServer extends UIServer {
return remoteReceiverModule.isEnabled(); return remoteReceiverModule.isEnabled();
} }
@Override
public void stop() {
if (server != null) {
server.stop();
setStopped(true);
}
}
private class StatsEventRouterRunnable implements Runnable { private class StatsEventRouterRunnable implements Runnable {
@ -455,7 +407,7 @@ public class PlayUIServer extends UIServer {
} }
private void runHelper() throws Exception { private void runHelper() throws Exception {
log.debug("PlayUIServer.StatsEventRouterRunnable started"); log.info("VertxUIServer.StatsEventRouterRunnable started");
//Idea: collect all event stats, and route them to the appropriate modules //Idea: collect all event stats, and route them to the appropriate modules
while (!shutdown.get()) { while (!shutdown.get()) {
@ -523,4 +475,39 @@ public class PlayUIServer extends UIServer {
} }
} }
//==================================================================================================================
// CLI Launcher
@Data
private static class CLIParams {
@Parameter(names = {"-r", "--enableRemote"}, description = "Whether to enable remote or not", arity = 1)
private boolean cliEnableRemote;
@Parameter(names = {"-p", "--uiPort"}, description = "Custom HTTP port for UI", arity = 1)
private int cliPort = DEFAULT_UI_PORT;
@Parameter(names = {"-f", "--customStatsFile"}, description = "Path to create custom stats file (remote only)", arity = 1)
private String cliCustomStatsFile;
@Parameter(names = {"-m", "--multiSession"}, description = "Whether to enable multiple separate browser sessions or not", arity = 1)
private boolean cliMultiSession;
}
public void main(String[] args){
CLIParams d = new CLIParams();
new JCommander(d).parse(args);
instancePort = d.getCliPort();
UIServer.getInstance(d.isCliMultiSession(), null);
if(d.isCliEnableRemote()){
try {
File tempStatsFile = DL4JFileUtils.createTempFile("dl4j", "UIstats");
tempStatsFile.delete();
tempStatsFile.deleteOnExit();
enableRemoteListener(new FileStatsStorage(tempStatsFile), true);
} catch(Exception e) {
log.error("Failed to create temporary file for stats storage",e);
System.exit(1);
}
}
}
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -22,5 +23,5 @@ package org.deeplearning4j.ui.api;
* @author Alex Black * @author Alex Black
*/ */
public enum HttpMethod { public enum HttpMethod {
GET, PUT, POST, DELETE, HEAD, OPTIONS, PATCH GET, PUT, POST
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,8 @@
package org.deeplearning4j.ui.api; package org.deeplearning4j.ui.api;
import java.util.Map;
/** /**
* Interface to handle user interface internationalization. * Interface to handle user interface internationalization.
* Internationalization support is bulit into Play framework, but this doesn't seem to function with a Java + Maven * Internationalization support is bulit into Play framework, but this doesn't seem to function with a Java + Maven
@ -62,4 +65,11 @@ public interface I18N {
*/ */
void setDefaultLanguage(String langCode); void setDefaultLanguage(String langCode);
/**
* Get all internationalization messages for the specified language code
* @param langCode Language code
* @return Messages
*/
Map<String,String> getMessages(String langCode);
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,18 +17,22 @@
package org.deeplearning4j.ui.api; package org.deeplearning4j.ui.api;
import io.vertx.ext.web.RoutingContext;
import lombok.AllArgsConstructor;
import lombok.Data;
import java.util.List;
import java.util.function.*;
/** /**
* Enumeration for the type of function. Mainly used in specifying {@link Route} instances<br> * A Route specifies an endpoint that can be queried in the UI - along with how it should be handled
* Supplier: No args<br>
* Function: 1 arg<br>
* BiFunction: 2 args<br>
* Function3: 3 args<br>
* Request0Function: Supplier + request, no args (as Function<Request,Result>)<br>
* Request1Function: Supplier + request + 1 args (as BiFunction<Request,String, Result>)<br>
* *
* @author Alex Black * @author Alex Black
*/ */
public enum FunctionType { @Data
Supplier, Function, BiFunction, Function3, @AllArgsConstructor
Request0Function, Request1Function public class Route {
private final String route;
private final HttpMethod httpMethod;
private final BiConsumer<List<String>, RoutingContext> consumer;
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -78,6 +79,8 @@ public interface UIModule {
*/ */
void onDetach(StatsStorage statsStorage); void onDetach(StatsStorage statsStorage);
/**
* @return List of internationalization resources
*/
List<I18NResource> getInternationalizationResources(); List<I18NResource> getInternationalizationResources();
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,10 +17,9 @@
package org.deeplearning4j.ui.api; package org.deeplearning4j.ui.api;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter; import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.ui.play.PlayUIServer; import org.deeplearning4j.ui.VertxUIServer;
import org.nd4j.linalg.function.Function; import org.nd4j.linalg.function.Function;
import java.util.List; import java.util.List;
@ -29,10 +29,7 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
@Slf4j public interface UIServer {
public abstract class UIServer {
private static UIServer uiServer;
/** /**
* Get (and, initialize if necessary) the UI server. * Get (and, initialize if necessary) the UI server.
@ -41,99 +38,81 @@ public abstract class UIServer {
* @return UI instance for this JVM * @return UI instance for this JVM
* @throws RuntimeException if the instance has already started in a different mode (multi/single-session) * @throws RuntimeException if the instance has already started in a different mode (multi/single-session)
*/ */
public static synchronized UIServer getInstance() throws RuntimeException { static UIServer getInstance() throws RuntimeException {
return getInstance(false, null); return getInstance(false, null);
} }
/** /**
* Get (and, initialize if necessary) the UI server. * Get (and, initialize if necessary) the UI server.
* Singleton pattern - all calls to getInstance() will return the same UI instance. * Singleton pattern - all calls to getInstance() will return the same UI instance.
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs. *
* <br/>URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId * @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
* <br/>URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID. * @param statsStorageProvider function that returns a StatsStorage containing the given session ID.
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed * <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
* as URL path parameter in multi-session mode, or leave it {@code null}. * as URL path parameter in multi-session mode, or leave it {@code null}.
* @return UI instance for this JVM * @return UI instance for this JVM
* @throws RuntimeException if the instance has already started in a different mode (multi/single-session) * @throws RuntimeException if the instance has already started in a different mode (multi/single-session)
*/ */
public static synchronized UIServer getInstance(boolean multiSession, Function<String, StatsStorage> statsStorageProvider) static UIServer getInstance(boolean multiSession, Function<String, StatsStorage> statsStorageProvider) throws RuntimeException {
throws RuntimeException { return VertxUIServer.getInstance(null, multiSession, statsStorageProvider);
if (uiServer == null || uiServer.isStopped()) {
PlayUIServer playUIServer = new PlayUIServer(PlayUIServer.DEFAULT_UI_PORT, multiSession);
playUIServer.setMultiSession(multiSession);
if (statsStorageProvider != null) {
playUIServer.autoAttachStatsStorageBySessionId(statsStorageProvider);
}
playUIServer.runMain(new String[] {});
uiServer = playUIServer;
} else if (!uiServer.isStopped()) {
if (multiSession && !uiServer.isMultiSession()) {
throw new RuntimeException("Cannot return multi-session instance." +
" UIServer has already started in single-session mode at " + uiServer.getAddress() +
" You may stop the UI server instance, and start a new one.");
} else if (!multiSession && uiServer.isMultiSession()) {
throw new RuntimeException("Cannot return single-session instance." +
" UIServer has already started in multi-session mode at " + uiServer.getAddress() +
" You may stop the UI server instance, and start a new one.");
}
}
return uiServer;
} }
/** /**
* Stop UIServer instance, if already running * Stop UIServer instance, if already running
*/ */
public static void stopInstance() { static void stopInstance() {
if (uiServer != null) { VertxUIServer.stopInstance();
uiServer.stop();
}
} }
public abstract boolean isStopped(); boolean isStopped();
/** /**
* Check if the instance initialized with one of the factory methods * Check if the instance initialized with one of the factory methods
* ({@link #getInstance()} or {@link #getInstance(boolean, Function)}) is in multi-session mode * ({@link #getInstance()} or {@link #getInstance(boolean, Function)}) is in multi-session mode
*
* @return {@code true} if the instance is in multi-session * @return {@code true} if the instance is in multi-session
*/ */
public abstract boolean isMultiSession(); boolean isMultiSession();
/** /**
* Get the address of the UI * Get the address of the UI
* *
* @return Address of the UI * @return Address of the UI
*/ */
public abstract String getAddress(); String getAddress();
/** /**
* Get the current port for the UI * Get the current port for the UI
*/ */
public abstract int getPort(); int getPort();
/** /**
* Attach the given StatsStorage instance to the UI, so the data can be visualized * Attach the given StatsStorage instance to the UI, so the data can be visualized
* @param statsStorage StatsStorage instance to attach to the UI *
* @param statsStorage StatsStorage instance to attach to the UI
*/ */
public abstract void attach(StatsStorage statsStorage); void attach(StatsStorage statsStorage);
/** /**
* Detach the specified StatsStorage instance from the UI * Detach the specified StatsStorage instance from the UI
* @param statsStorage StatsStorage instance to detach. If not attached: no op. *
* @param statsStorage StatsStorage instance to detach. If not attached: no op.
*/ */
public abstract void detach(StatsStorage statsStorage); void detach(StatsStorage statsStorage);
/** /**
* Check whether the specified StatsStorage instance is attached to the UI instance * Check whether the specified StatsStorage instance is attached to the UI instance
* *
* @param statsStorage StatsStorage instance to attach * @param statsStorage StatsStorage instance to attach
* @return True if attached * @return True if attached
*/ */
public abstract boolean isAttached(StatsStorage statsStorage); boolean isAttached(StatsStorage statsStorage);
/** /**
* @return A list of all StatsStorage instances currently attached * @return A list of all StatsStorage instances currently attached
*/ */
public abstract List<StatsStorage> getStatsStorageInstances(); List<StatsStorage> getStatsStorageInstances();
/** /**
* Enable the remote listener functionality, storing all data in memory, and attaching the instance to the UI. * Enable the remote listener functionality, storing all data in memory, and attaching the instance to the UI.
@ -142,31 +121,31 @@ public abstract class UIServer {
* *
* @see #enableRemoteListener(StatsStorageRouter, boolean) * @see #enableRemoteListener(StatsStorageRouter, boolean)
*/ */
public abstract void enableRemoteListener(); void enableRemoteListener();
/** /**
* Enable the remote listener functionality, storing the received results in the specified StatsStorageRouter. * Enable the remote listener functionality, storing the received results in the specified StatsStorageRouter.
* If the StatsStorageRouter is a {@link StatsStorage} instance, it may (optionally) be attached to the UI, * If the StatsStorageRouter is a {@link StatsStorage} instance, it may (optionally) be attached to the UI,
* as if {@link #attach(StatsStorage)} was called on it. * as if {@link #attach(StatsStorage)} was called on it.
* *
* @param statsStorage StatsStorageRouter to post the received results to * @param statsStorage StatsStorageRouter to post the received results to
* @param attach Whether to attach the given StatsStorage instance to the UI server * @param attach Whether to attach the given StatsStorage instance to the UI server
*/ */
public abstract void enableRemoteListener(StatsStorageRouter statsStorage, boolean attach); void enableRemoteListener(StatsStorageRouter statsStorage, boolean attach);
/** /**
* Disable the remote listener functionality (disabled by default) * Disable the remote listener functionality (disabled by default)
*/ */
public abstract void disableRemoteListener(); void disableRemoteListener();
/** /**
* @return Whether the remote listener functionality is currently enabled * @return Whether the remote listener functionality is currently enabled
*/ */
public abstract boolean isRemoteListenerEnabled(); boolean isRemoteListenerEnabled();
/** /**
* Stop/shut down the UI server. * Stop/shut down the UI server.
*/ */
public abstract void stop(); void stop();
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -23,12 +24,8 @@ import org.deeplearning4j.ui.api.UIModule;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.nio.charset.Charset; import java.nio.charset.StandardCharsets;
import java.util.Map; import java.util.*;
import java.util.HashMap;
import java.util.List;
import java.util.Collections;
import java.util.ServiceLoader;
/** /**
* Default internationalization implementation.<br> * Default internationalization implementation.<br>
@ -124,32 +121,12 @@ public class DefaultI18N implements I18N {
} }
} }
public Throwable getLanguageLoadingException(){
return languageLoadingException;
}
public boolean noI18NData(){
if(messagesByLanguage.isEmpty()){
return true;
}
boolean noData = true;
for(Map.Entry<String,Map<String,String>> e : messagesByLanguage.entrySet()){
if(e.getValue() == null || e.getValue().isEmpty()){
continue;
}
noData = false;
break;
}
return noData;
}
private void parseFile(I18NResource r, Map<String,String> results){ private void parseFile(I18NResource r, Map<String,String> results){
List<String> lines; List<String> lines;
try (InputStream is = r.getInputStream()){ try (InputStream is = r.getInputStream()){
lines = IOUtils.readLines(is, Charset.forName("UTF-8")); lines = IOUtils.readLines(is, StandardCharsets.UTF_8);
} catch (IOException e){ } catch (IOException e){
log.debug("Error parsing UI I18N content file; skipping: {}", r.getResource(), e.getMessage()); log.debug("Error parsing UI I18N content file; skipping: {} - {}", r.getResource(), e.getMessage());
return; return;
} }
@ -201,4 +178,15 @@ public class DefaultI18N implements I18N {
this.currentLanguage = langCode; this.currentLanguage = langCode;
log.debug("UI: Set language to {}", langCode); log.debug("UI: Set language to {}", langCode);
} }
@Override
public Map<String, String> getMessages(String langCode) {
//Start with map for default language
//Then overwrite with the actual language - so any missing are reported in default language
Map<String,String> ret = new HashMap<>(messagesByLanguage.get(FALLBACK_LANGUAGE));
if(!langCode.equals(FALLBACK_LANGUAGE)){
ret.putAll(messagesByLanguage.get(langCode));
}
return ret;
}
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -26,10 +27,14 @@ import org.deeplearning4j.ui.api.I18N;
*/ */
public class I18NProvider { public class I18NProvider {
/**
* Current I18N instance
*/
private static I18N i18n = DefaultI18N.getInstance(); private static I18N i18n = DefaultI18N.getInstance();
/** /**
* Get the current/global I18N instance (used in single-session mode) * Get the current/global I18N instance (used in single-session mode)
*
* @return global instance * @return global instance
*/ */
public static I18N getInstance() { public static I18N getInstance() {
@ -39,6 +44,7 @@ public class I18NProvider {
/** /**
* Get instance for session (used in multi-session mode) * Get instance for session (used in multi-session mode)
*
* @param sessionId session * @param sessionId session
* @return instance for session * @return instance for session
*/ */
@ -48,9 +54,9 @@ public class I18NProvider {
/** /**
* Remove I18N instance for session * Remove I18N instance for session
*
* @param sessionId session ID * @param sessionId session ID
* @return the previous value associated with {@code sessionId}, * @return the previous value associated with {@code sessionId} or null if there was no mapping for {@code sessionId}
* or null if there was no mapping for {@code sessionId}
*/ */
public static synchronized I18N removeInstance(String sessionId) { public static synchronized I18N removeInstance(String sessionId) {
return DefaultI18N.removeInstance(sessionId); return DefaultI18N.removeInstance(sessionId);

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -26,12 +27,9 @@ import java.io.InputStream;
@AllArgsConstructor @AllArgsConstructor
@Data @Data
public class I18NResource { public class I18NResource {
private final String resource; private final String resource;
public InputStream getInputStream() throws IOException { public InputStream getInputStream() throws IOException {
return new ClassPathResource(resource).getInputStream(); return new ClassPathResource(resource).getInputStream();
} }
} }

View File

@ -14,28 +14,18 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.ui; package org.deeplearning4j.ui.module;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent; import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.ui.api.FunctionType;
import org.deeplearning4j.ui.api.HttpMethod; import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule; import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.i18n.I18NResource; import org.deeplearning4j.ui.i18n.I18NResource;
import play.libs.Json;
import play.mvc.Http;
import play.mvc.Result;
import play.mvc.Results;
import java.io.File; import java.util.Collection;
import java.io.IOException; import java.util.Collections;
import java.util.*; import java.util.List;
import static play.mvc.Controller.request;
import static play.mvc.Results.badRequest;
import static play.mvc.Results.ok;
/** /**
* Created by Alex on 25/10/2016. * Created by Alex on 25/10/2016.
@ -53,9 +43,10 @@ public class SameDiffModule implements UIModule {
@Override @Override
public List<Route> getRoutes() { public List<Route> getRoutes() {
Route r1 = new Route("/samediff", HttpMethod.GET, FunctionType.Supplier, // Route r1 = new Route("/samediff", HttpMethod.GET, FunctionType.Supplier,
() -> ok(org.deeplearning4j.ui.views.html.samediff.SameDiffUI.apply())); // () -> ok(org.deeplearning4j.ui.views.html.samediff.SameDiffUI.apply()));
return Arrays.asList(r1); Route r1 = new Route("/samediff", HttpMethod.GET, (path,rc) -> rc.response().sendFile("templates/SameDiffUI.html"));
return Collections.singletonList(r1);
} }
@Override @Override

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,31 +17,28 @@
package org.deeplearning4j.ui.module.convolutional; package org.deeplearning4j.ui.module.convolutional;
import io.vertx.core.buffer.Buffer;
import io.vertx.ext.web.RoutingContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.api.storage.Persistable; import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent; import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.api.storage.StatsStorageListener; import org.deeplearning4j.api.storage.StatsStorageListener;
import org.deeplearning4j.ui.api.FunctionType;
import org.deeplearning4j.ui.api.HttpMethod; import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule; import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.i18n.I18NResource; import org.deeplearning4j.ui.i18n.I18NResource;
import org.deeplearning4j.ui.weights.ConvolutionListenerPersistable; import org.deeplearning4j.ui.weights.ConvolutionListenerPersistable;
import play.mvc.Result;
import javax.imageio.ImageIO; import javax.imageio.ImageIO;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static play.mvc.Results.ok;
/** /**
* Used for plotting results from the ConvolutionalIterationListener * Used for plotting results from the ConvolutionalIterationListener
* *
@ -63,9 +61,8 @@ public class ConvolutionalListenerModule implements UIModule {
@Override @Override
public List<Route> getRoutes() { public List<Route> getRoutes() {
Route r = new Route("/activations", HttpMethod.GET, FunctionType.Supplier, Route r = new Route("/activations", HttpMethod.GET, (path, rc) -> rc.response().sendFile("templates/Activations.html"));
() -> ok(org.deeplearning4j.ui.views.html.convolutional.Activations.apply())); Route r2 = new Route("/activations/data", HttpMethod.GET, (path, rc) -> this.getImage(rc));
Route r2 = new Route("/activations/data", HttpMethod.GET, FunctionType.Supplier, this::getImage);
return Arrays.asList(r, r2); return Arrays.asList(r, r2);
} }
@ -87,12 +84,12 @@ public class ConvolutionalListenerModule implements UIModule {
@Override @Override
public void onAttach(StatsStorage statsStorage) { public void onAttach(StatsStorage statsStorage) {
//No-op
} }
@Override @Override
public void onDetach(StatsStorage statsStorage) { public void onDetach(StatsStorage statsStorage) {
//No-op
} }
@Override @Override
@ -100,7 +97,7 @@ public class ConvolutionalListenerModule implements UIModule {
return Collections.emptyList(); return Collections.emptyList();
} }
private Result getImage() { private void getImage(RoutingContext rc) {
if (lastTimeStamp > 0 && lastStorage != null) { if (lastTimeStamp > 0 && lastStorage != null) {
Persistable p = lastStorage.getStaticInfo(lastSessionID, TYPE_ID, lastWorkerID); Persistable p = lastStorage.getStaticInfo(lastSessionID, TYPE_ID, lastWorkerID);
if (p instanceof ConvolutionListenerPersistable) { if (p instanceof ConvolutionListenerPersistable) {
@ -112,12 +109,19 @@ public class ConvolutionalListenerModule implements UIModule {
} catch (IOException e) { } catch (IOException e) {
log.warn("Error displaying image", e); log.warn("Error displaying image", e);
} }
return ok(baos.toByteArray()).as("image/jpg");
rc.response()
.putHeader("content-type", "image/jpg")
.end(Buffer.buffer(baos.toByteArray()));
} else { } else {
return ok(new byte[0]).as("image/jpeg"); rc.response()
.putHeader("content-type", "image/jpg")
.end(Buffer.buffer(new byte[0]));
} }
} else { } else {
return ok(new byte[0]).as("image/jpeg"); rc.response()
.putHeader("content-type", "image/jpg")
.end(Buffer.buffer(new byte[0]));
} }
} }
} }

View File

@ -1,5 +1,5 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2019 Konduit, KK.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,23 +16,18 @@
package org.deeplearning4j.ui.module.defaultModule; package org.deeplearning4j.ui.module.defaultModule;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent; import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.ui.api.FunctionType;
import org.deeplearning4j.ui.api.HttpMethod; import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule; import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.i18n.I18NResource; import org.deeplearning4j.ui.i18n.I18NResource;
import org.nd4j.linalg.function.Supplier;
import java.io.File;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static play.mvc.Results.ok;
import static play.mvc.Results.redirect;
/** /**
* Landing page - i.e., "/" route * Landing page - i.e., "/" route
* @author Alex Black * @author Alex Black
@ -59,27 +54,28 @@ public class DefaultModule implements UIModule {
@Override @Override
public List<Route> getRoutes() { public List<Route> getRoutes() {
//TODO Route r = new Route("/", HttpMethod.GET, (params, rc) -> rc.response()
// Route r = new Route("/", HttpMethod.GET, FunctionType.Supplier, () -> ok(org.deeplearning4j.ui.views.html.defaultPage.DefaultPage.apply())); .putHeader("location", "/train" + (multiSession ? "" : "/overview"))
Route r = multiSession ? new Route("/", HttpMethod.GET, FunctionType.Supplier, () -> redirect("/train")) .setStatusCode(HttpResponseStatus.FOUND.code())
: new Route("/", HttpMethod.GET, FunctionType.Supplier, () -> redirect("/train/overview")); .end()
);
return Collections.singletonList(r); return Collections.singletonList(r);
} }
@Override @Override
public void reportStorageEvents(Collection<StatsStorageEvent> events) { public void reportStorageEvents(Collection<StatsStorageEvent> events) {
//Nop-op
} }
@Override @Override
public void onAttach(StatsStorage statsStorage) { public void onAttach(StatsStorage statsStorage) {
//No-op
} }
@Override @Override
public void onDetach(StatsStorage statsStorage) { public void onDetach(StatsStorage statsStorage) {
//No-op
} }
@Override @Override

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,23 +15,24 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.ui.module.remote; package org.deeplearning4j.ui.module.remote;
import com.fasterxml.jackson.databind.JsonNode; import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.RoutingContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.api.storage.*; import org.deeplearning4j.api.storage.*;
import org.deeplearning4j.ui.api.HttpMethod; import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule; import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.i18n.I18NResource; import org.deeplearning4j.ui.i18n.I18NResource;
import play.mvc.Http;
import play.mvc.Result;
import play.mvc.Results;
import javax.xml.bind.DatatypeConverter; import javax.xml.bind.DatatypeConverter;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
/** /**
@ -70,7 +72,7 @@ public class RemoteReceiverModule implements UIModule {
@Override @Override
public List<Route> getRoutes() { public List<Route> getRoutes() {
Route r = Route.request0Function("/remoteReceive", HttpMethod.POST, this::receiveData); Route r = new Route("/remoteReceive", HttpMethod.POST, (path, rc) -> this.receiveData(rc));
return Collections.singletonList(r); return Collections.singletonList(r);
} }
@ -95,47 +97,48 @@ public class RemoteReceiverModule implements UIModule {
return Collections.emptyList(); return Collections.emptyList();
} }
private Result receiveData(Http.Request request) { private void receiveData(RoutingContext rc) {
if (!enabled.get()) { if (!enabled.get()) {
return Results.forbidden( rc.response().setStatusCode(HttpResponseStatus.FORBIDDEN.code())
"UI server remote listening is currently disabled. Use UIServer.getInstance().enableRemoteListener()"); .end("UI server remote listening is currently disabled. Use UIServer.getInstance().enableRemoteListener()");
return;
} }
if (statsStorage == null) { if (statsStorage == null) {
return Results.internalServerError( rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
"UI Server remote listener: no StatsStorage instance is set/available to store results"); .end("UI Server remote listener: no StatsStorage instance is set/available to store results");
return;
} }
JsonNode jn = request.body().asJson(); JsonObject jo = rc.getBodyAsJson();
JsonNode type = jn.get("type"); Map<String,Object> map = jo.getMap();
JsonNode dataClass = jn.get("class"); String type = (String) map.get("type");
JsonNode data = jn.get("data"); String dataClass = (String) map.get("class");
String data = (String) map.get("data");
if (type == null || dataClass == null || data == null) { if (type == null || dataClass == null || data == null) {
log.warn("Received incorrectly formatted data from remote listener (has type = " + (type != null) log.warn("Received incorrectly formatted data from remote listener (has type = " + (type != null)
+ ", has data class = " + (dataClass != null) + ", has data = " + (data != null) + ")"); + ", has data class = " + (dataClass != null) + ", has data = " + (data != null) + ")");
return Results.badRequest("Received incorrectly formatted data"); rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.end("Received incorrectly formatted data");
return;
} }
String dc = dataClass.asText(); switch (type.toLowerCase()) {
String content = data.asText();
switch (type.asText().toLowerCase()) {
case "metadata": case "metadata":
StorageMetaData meta = getMetaData(dc, content); StorageMetaData meta = getMetaData(dataClass, data);
if (meta != null) { if (meta != null) {
statsStorage.putStorageMetaData(meta); statsStorage.putStorageMetaData(meta);
} }
break; break;
case "staticinfo": case "staticinfo":
Persistable staticInfo = getPersistable(dc, content); Persistable staticInfo = getPersistable(dataClass, data);
if (staticInfo != null) { if (staticInfo != null) {
statsStorage.putStaticInfo(staticInfo); statsStorage.putStaticInfo(staticInfo);
} }
break; break;
case "update": case "update":
Persistable update = getPersistable(dc, content); Persistable update = getPersistable(dataClass, data);
if (update != null) { if (update != null) {
statsStorage.putUpdate(update); statsStorage.putUpdate(update);
} }
@ -144,11 +147,10 @@ public class RemoteReceiverModule implements UIModule {
} }
return Results.ok("Receiver got data: "); rc.response().end();
} }
private StorageMetaData getMetaData(String dataClass, String content) { private StorageMetaData getMetaData(String dataClass, String content) {
StorageMetaData meta; StorageMetaData meta;
try { try {
Class<?> c = Class.forName(dataClass); Class<?> c = Class.forName(dataClass);
@ -201,4 +203,4 @@ public class RemoteReceiverModule implements UIModule {
return p; return p;
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,10 +17,18 @@
package org.deeplearning4j.ui.module.train; package org.deeplearning4j.ui.module.train;
import freemarker.template.Configuration;
import freemarker.template.Template;
import freemarker.template.TemplateExceptionHandler;
import freemarker.template.Version;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.ext.web.RoutingContext;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.api.storage.Persistable; import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent; import org.deeplearning4j.api.storage.StatsStorageEvent;
@ -32,7 +41,11 @@ import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.ui.api.*; import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.I18N;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.i18n.DefaultI18N;
import org.deeplearning4j.ui.i18n.I18NProvider; import org.deeplearning4j.ui.i18n.I18NProvider;
import org.deeplearning4j.ui.i18n.I18NResource; import org.deeplearning4j.ui.i18n.I18NResource;
import org.deeplearning4j.ui.stats.StatsListener; import org.deeplearning4j.ui.stats.StatsListener;
@ -40,38 +53,25 @@ import org.deeplearning4j.ui.stats.api.Histogram;
import org.deeplearning4j.ui.stats.api.StatsInitializationReport; import org.deeplearning4j.ui.stats.api.StatsInitializationReport;
import org.deeplearning4j.ui.stats.api.StatsReport; import org.deeplearning4j.ui.stats.api.StatsReport;
import org.deeplearning4j.ui.stats.api.StatsType; import org.deeplearning4j.ui.stats.api.StatsType;
import org.deeplearning4j.ui.views.html.training.TrainingModel;
import org.deeplearning4j.ui.views.html.training.TrainingOverview;
import org.deeplearning4j.ui.views.html.training.TrainingSystem;
import org.eclipse.collections.impl.list.mutable.primitive.LongArrayList;
import org.nd4j.linalg.function.Function; import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.function.Supplier; import org.nd4j.linalg.function.Supplier;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple; import org.nd4j.linalg.primitives.Triple;
import org.nd4j.resources.Resources;
import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import play.mvc.Result;
import play.mvc.Results;
import java.io.File;
import java.io.StringReader;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import java.text.DateFormat; import java.text.DateFormat;
import java.text.DecimalFormat; import java.text.DecimalFormat;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.util.Date; import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.util.Set;
import java.util.HashSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Arrays;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import static play.mvc.Results.*;
/** /**
* Main DL4J Training UI * Main DL4J Training UI
* *
@ -81,15 +81,9 @@ import static play.mvc.Results.*;
public class TrainModule implements UIModule { public class TrainModule implements UIModule {
public static final double NAN_REPLACEMENT_VALUE = 0.0; //UI front-end chokes on NaN in JSON public static final double NAN_REPLACEMENT_VALUE = 0.0; //UI front-end chokes on NaN in JSON
public static final int DEFAULT_MAX_CHART_POINTS = 512; public static final int DEFAULT_MAX_CHART_POINTS = 512;
/**
* @deprecated Use {@link DL4JSystemProperties#CHART_MAX_POINTS_PROPERTY}
*/
@Deprecated
public static final String CHART_MAX_POINTS_PROPERTY = DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY;
private static final DecimalFormat df2 = new DecimalFormat("#.00"); private static final DecimalFormat df2 = new DecimalFormat("#.00");
private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
private static final ObjectMapper JSON = new ObjectMapper();
private final Supplier<String> addressSupplier; private final Supplier<String> addressSupplier;
private enum ModelType { private enum ModelType {
@ -106,15 +100,19 @@ public class TrainModule implements UIModule {
private final boolean multiSession; private final boolean multiSession;
private final Function<String, Boolean> sessionLoader; private final Function<String, Boolean> sessionLoader;
private final Configuration configuration;
public TrainModule() { public TrainModule() {
this(false, null, null); this(false, null, null);
} }
/** /**
* TrainModule * TrainModule
* @param multiSession multi-session mode *
* @param sessionLoader StatsStorage loader to call if an unknown session ID is passed as URL path parameter * @param multiSession multi-session mode
* in multi-session mode * @param sessionLoader StatsStorage loader to call if an unknown session ID is passed as URL path parameter
* in multi-session mode
* @param addressSupplier supplier for server address (server address in PlayUIServer gets initialized after modules) * @param addressSupplier supplier for server address (server address in PlayUIServer gets initialized after modules)
*/ */
public TrainModule(boolean multiSession, Function<String, Boolean> sessionLoader, Supplier<String> addressSupplier) { public TrainModule(boolean multiSession, Function<String, Boolean> sessionLoader, Supplier<String> addressSupplier) {
@ -135,6 +133,19 @@ public class TrainModule implements UIModule {
} else { } else {
maxChartPoints = DEFAULT_MAX_CHART_POINTS; maxChartPoints = DEFAULT_MAX_CHART_POINTS;
} }
configuration = new Configuration(new Version(2, 3, 29));
configuration.setDefaultEncoding("UTF-8");
configuration.setLocale(Locale.US);
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
configuration.setClassForTemplateLoading(TrainModule.class, "");
try {
File dir = Resources.asFile("templates/TrainingOverview.html.ftl").getParentFile();
configuration.setDirectoryForTemplateLoading(dir);
} catch (Throwable t) {
throw new RuntimeException(t);
}
} }
@Override @Override
@ -144,63 +155,87 @@ public class TrainModule implements UIModule {
@Override @Override
public List<Route> getRoutes() { public List<Route> getRoutes() {
Route r0, r0a, r0b, r2, r2a, r3, r3a, r3b, r4, r4a, r6b, r6d, r7a; List<Route> r = new ArrayList<>();
r0 = new Route("/train/multisession", HttpMethod.GET, FunctionType.Supplier, () -> ok(multiSession ? "true" : "false")); r.add(new Route("/train/multisession", HttpMethod.GET, (path, rc) -> rc.response().end(multiSession ? "true" : "false")));
if (multiSession) { if (multiSession) {
r0a = new Route("/train", HttpMethod.GET, FunctionType.Supplier, this::listSessions); r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
r0b = new Route("/train/:sessionId", HttpMethod.GET, FunctionType.Function, r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> {
(sessionId) -> redirect("/train/" + sessionId + "/overview")); rc.response()
r2 = new Route("/train/:sessionId/overview", HttpMethod.GET, FunctionType.Function, .putHeader("location", path.get(0) + "/overview")
(sessionId) -> knownSessionIDs.containsKey(sessionId) .setStatusCode(HttpResponseStatus.FOUND.code())
? ok(TrainingOverview.apply(getI18N(sessionId))) : sessionNotFound(sessionId, "overview")); .end();
r2a = new Route("/train/:sessionId/overview/data", HttpMethod.GET, FunctionType.Function, }));
this::getOverviewDataForSession); r.add(new Route("/train/:sessionId/overview", HttpMethod.GET, (path, rc) -> {
r3 = new Route("/train/:sessionId/model", HttpMethod.GET, FunctionType.Function, if (knownSessionIDs.containsKey(path.get(0))) {
(sessionId) -> knownSessionIDs.containsKey(sessionId) renderFtl("TrainingOverview.html.ftl", rc);
? ok(TrainingModel.apply(getI18N(sessionId))) : sessionNotFound(sessionId, "model")); } else {
r3a = new Route("/train/:sessionId/model/graph", HttpMethod.GET, FunctionType.Function, sessionNotFound(path.get(0), rc.request().path(), rc);
this::getModelGraphForSession); }
r3b = new Route("/train/:sessionId/model/data/:layerId", HttpMethod.GET, FunctionType.BiFunction, }));
this::getModelDataForSession); r.add(new Route("/train/:sessionId/overview/data", HttpMethod.GET, (path, rc) -> getOverviewDataForSession(path.get(0), rc)));
r4 = new Route("/train/:sessionId/system", HttpMethod.GET, FunctionType.Function, r.add(new Route("/train/:sessionId/model", HttpMethod.GET, (path, rc) -> {
(sessionId) -> knownSessionIDs.containsKey(sessionId) if (knownSessionIDs.containsKey(path.get(0))) {
? ok(TrainingSystem.apply(getI18N(sessionId))) : sessionNotFound(sessionId, "system")); renderFtl("TrainingModel.html.ftl", rc);
r4a = new Route("/train/:sessionId/system/data", HttpMethod.GET, FunctionType.Function, } else {
this::getSystemDataForSession); sessionNotFound(path.get(0), rc.request().path(), rc);
r6b = new Route("/train/:sessionId/info", HttpMethod.GET, FunctionType.Function, this::sessionInfoForSession); }
}));
r.add(new Route("/train/:sessionId/model/graph", HttpMethod.GET, (path, rc) -> this.getModelGraphForSession(path.get(0), rc)));
r.add(new Route("/train/:sessionId/model/data/:layerId", HttpMethod.GET, (path, rc) -> this.getModelDataForSession(path.get(0), path.get(1), rc)));
r.add(new Route("/train/:sessionId/system", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
this.renderFtl("TrainingSystem.html.ftl", rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
} else { } else {
r0a = new Route("/train/sessions/current", HttpMethod.GET, FunctionType.Supplier, r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
() -> ok(currentSessionID == null ? "" : currentSessionID)); r.add(new Route("/train/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)));
r0b = new Route("/train/sessions/set/:to", HttpMethod.GET, FunctionType.Function, this::setSession); r.add(new Route("/train/overview", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingOverview.html.ftl", rc)));
r2 = new Route("/train/overview", HttpMethod.GET, FunctionType.Supplier, r.add(new Route("/train/overview/data", HttpMethod.GET, (path, rc) -> this.getOverviewData(rc)));
() -> ok(TrainingOverview.apply(getI18N()))); r.add(new Route("/train/model", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingModel.html.ftl", rc)));
r2a = new Route("/train/overview/data", HttpMethod.GET, FunctionType.Supplier, this::getOverviewData); r.add(new Route("/train/model/graph", HttpMethod.GET, (path, rc) -> this.getModelGraph(rc)));
r3 = new Route("/train/model", HttpMethod.GET, FunctionType.Supplier, r.add(new Route("/train/model/data/:layerId", HttpMethod.GET, (path, rc) -> this.getModelData(path.get(0), rc)));
() -> ok(TrainingModel.apply(getI18N()))); r.add(new Route("/train/system", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingSystem.html.ftl", rc)));
r3a = new Route("/train/model/graph", HttpMethod.GET, FunctionType.Supplier, this::getModelGraph); r.add(new Route("/train/sessions/info", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc)));
r3b = new Route("/train/model/data/:layerId", HttpMethod.GET, FunctionType.Function, this::getModelData); r.add(new Route("/train/system/data", HttpMethod.GET, (path, rc) -> this.getSystemData(rc)));
r4 = new Route("/train/system", HttpMethod.GET, FunctionType.Supplier,
() -> ok(TrainingSystem.apply(getI18N())));
r6b = new Route("/train/sessions/info", HttpMethod.GET, FunctionType.Supplier, this::sessionInfo);
r4a = new Route("/train/system/data", HttpMethod.GET, FunctionType.Supplier, this::getSystemData);
} }
// common for single- and multi-session mode
r.add(new Route("/train/sessions/lastUpdate/:sessionId", HttpMethod.GET, (path, rc) -> this.getLastUpdateForSession(path.get(0), rc)));
r.add(new Route("/train/workers/setByIdx/:to", HttpMethod.GET, (path, rc) -> this.setWorkerByIdx(path.get(0), rc)));
r6d = new Route("/train/sessions/lastUpdate/:sessionId", HttpMethod.GET, FunctionType.Function, return r;
this::getLastUpdateForSession); // common for single- and multi-session mode
r7a = new Route("/train/workers/setByIdx/:to", HttpMethod.GET, FunctionType.Function,
this::setWorkerByIdx);
return Arrays.asList(r0, r0a, r0b, r2, r2a, r3, r3a, r3b, r4, r4a, r6b, r6d, r7a);
} }
/** /**
* List training sessions * Render a single Freemarker .ftl file from the /templates/ directory
* @return HTML list of training sessions * @param file File to render
* @param rc Routing context
*/ */
private synchronized Result listSessions() { private void renderFtl(String file, RoutingContext rc) {
Map<String, String> input = DefaultI18N.getInstance().getMessages(DefaultI18N.getInstance().getDefaultLanguage());
String html;
try {
String content = FileUtils.readFileToString(Resources.asFile("templates/" + file), StandardCharsets.UTF_8);
Template template = new Template(FilenameUtils.getName(file), new StringReader(content), configuration);
StringWriter stringWriter = new StringWriter();
template.process(input, stringWriter);
html = stringWriter.toString();
} catch (Throwable t) {
log.error("", t);
throw new RuntimeException(t);
}
rc.response().end(html);
}
/**
* List training sessions. Returns a HTML list of training sessions
*/
private synchronized void listSessions(RoutingContext rc) {
StringBuilder sb = new StringBuilder("<!DOCTYPE html>\n" + StringBuilder sb = new StringBuilder("<!DOCTYPE html>\n" +
"<html lang=\"en\">\n" + "<html lang=\"en\">\n" +
"<head>\n" + "<head>\n" +
@ -216,7 +251,7 @@ public class TrainModule implements UIModule {
if (!knownSessionIDs.isEmpty()) { if (!knownSessionIDs.isEmpty()) {
sb.append(" <ul>"); sb.append(" <ul>");
for (String sessionId : knownSessionIDs.keySet()) { for (String sessionId : knownSessionIDs.keySet()) {
sb.append(" <li><a href=\"train/" + sessionId + "\">" + sessionId + "</a></li>\n"); sb.append(" <li><a href=\"train/").append(sessionId).append("\">").append(sessionId).append("</a></li>\n");
} }
sb.append(" </ul>"); sb.append(" </ul>");
} else { } else {
@ -225,25 +260,29 @@ public class TrainModule implements UIModule {
sb.append(" </body>\n" + sb.append(" </body>\n" +
"</html>\n"); "</html>\n");
return ok(sb.toString()).as("text/html; charset=utf-8");
rc.response()
.putHeader("content-type", "text/html; charset=utf-8")
.end(sb.toString());
} }
/** /**
* Load StatsStorage via provider, or return "not found" * Load StatsStorage via provider, or return "not found"
* @param sessionId session ID to look fo with provider *
* @param sessionId session ID to look fo with provider
* @param targetPath one of overview / model / system, or null * @param targetPath one of overview / model / system, or null
* @return temporaryRedirect, ok, or notFound
*/ */
private Result sessionNotFound(String sessionId, String targetPath) { private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
if (sessionLoader != null && sessionLoader.apply(sessionId)) { if (sessionLoader != null && sessionLoader.apply(sessionId)) {
if (targetPath != null) { if (targetPath != null) {
return temporaryRedirect("./" + targetPath); rc.reroute("./" + targetPath);
} else { } else {
return ok(); rc.response().end();
} }
} else { } else {
return notFound("Unknown session ID: " + sessionId); rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code())
.end("Unknown session ID: " + sessionId);
} }
} }
@ -252,8 +291,8 @@ public class TrainModule implements UIModule {
for (StatsStorageEvent sse : events) { for (StatsStorageEvent sse : events) {
if (StatsListener.TYPE_ID.equals(sse.getTypeID())) { if (StatsListener.TYPE_ID.equals(sse.getTypeID())) {
if (sse.getEventType() == StatsStorageListener.EventType.PostStaticInfo if (sse.getEventType() == StatsStorageListener.EventType.PostStaticInfo
&& StatsListener.TYPE_ID.equals(sse.getTypeID()) && StatsListener.TYPE_ID.equals(sse.getTypeID())
&& !knownSessionIDs.containsKey(sse.getSessionID())) { && !knownSessionIDs.containsKey(sse.getSessionID())) {
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage()); knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
if (multiSession) { if (multiSession) {
log.info("Adding training session {}/train/{} of StatsStorage instance {}", log.info("Adding training session {}/train/{} of StatsStorage instance {}",
@ -287,7 +326,7 @@ public class TrainModule implements UIModule {
} }
List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID); List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
for (Persistable update: latestUpdates) { for (Persistable update : latestUpdates) {
long updateTime = update.getTimeStamp(); long updateTime = update.getTimeStamp();
if (lastUpdateForSession.containsKey(sessionID) && lastUpdateForSession.get(sessionID) < updateTime) { if (lastUpdateForSession.containsKey(sessionID) && lastUpdateForSession.get(sessionID) < updateTime) {
lastUpdateForSession.put(sessionID, updateTime); lastUpdateForSession.put(sessionID, updateTime);
@ -311,7 +350,7 @@ public class TrainModule implements UIModule {
currentSessionID = null; currentSessionID = null;
} }
} }
for(String s : toRemove) { for (String s : toRemove) {
knownSessionIDs.remove(s); knownSessionIDs.remove(s);
if (multiSession) { if (multiSession) {
log.info("Removing training session {}/train/{} of StatsStorage instance {}.", log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
@ -349,11 +388,7 @@ public class TrainModule implements UIModule {
if (sessionId == null) if (sessionId == null)
return null; return null;
Map<Integer, String> idxToId = workerIdxToName.get(sessionId); Map<Integer, String> idxToId = workerIdxToName.computeIfAbsent(sessionId, k -> Collections.synchronizedMap(new HashMap<>()));
if (idxToId == null) {
idxToId = Collections.synchronizedMap(new HashMap<>());
workerIdxToName.put(sessionId, idxToId);
}
if (idxToId.containsKey(workerIdx)) { if (idxToId.containsKey(workerIdx)) {
return idxToId.get(workerIdx); return idxToId.get(workerIdx);
@ -389,9 +424,9 @@ public class TrainModule implements UIModule {
/** /**
* Display, for each session: session ID, start time, number of workers, last update * Display, for each session: session ID, start time, number of workers, last update
* @return info for each session as JSON * Returns info for each session as JSON
*/ */
private synchronized Result sessionInfo() { private synchronized void sessionInfo(RoutingContext rc) {
Map<String, Object> dataEachSession = new HashMap<>(); Map<String, Object> dataEachSession = new HashMap<>();
for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) { for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
@ -400,13 +435,16 @@ public class TrainModule implements UIModule {
Map<String, Object> dataThisSession = sessionData(sid, ss); Map<String, Object> dataThisSession = sessionData(sid, ss);
dataEachSession.put(sid, dataThisSession); dataEachSession.put(sid, dataThisSession);
} }
return Results.ok(asJson(dataEachSession)).as("application/json"); rc.response()
.putHeader("content-type", "application/json")
.end(asJson(dataEachSession));
} }
/** /**
* Extract session data from {@link StatsStorage} * Extract session data from {@link StatsStorage}
*
* @param sid session ID * @param sid session ID
* @param ss {@code StatsStorage} instance * @param ss {@code StatsStorage} instance
* @return session data map * @return session data map
*/ */
private static Map<String, Object> sessionData(String sid, StatsStorage ss) { private static Map<String, Object> sessionData(String sid, StatsStorage ss) {
@ -460,11 +498,12 @@ public class TrainModule implements UIModule {
} }
/** /**
* Display, for given session: session ID, start time, number of workers, last update * Display, for given session: session ID, start time, number of workers, last update.
* Returns info for session as JSON
*
* @param sessionId session ID * @param sessionId session ID
* @return info for session as JSON
*/ */
private synchronized Result sessionInfoForSession(String sessionId) { private synchronized void sessionInfoForSession(String sessionId, RoutingContext rc) {
Map<String, Object> dataEachSession = new HashMap<>(); Map<String, Object> dataEachSession = new HashMap<>();
StatsStorage ss = knownSessionIDs.get(sessionId); StatsStorage ss = knownSessionIDs.get(sessionId);
@ -472,33 +511,37 @@ public class TrainModule implements UIModule {
Map<String, Object> dataThisSession = sessionData(sessionId, ss); Map<String, Object> dataThisSession = sessionData(sessionId, ss);
dataEachSession.put(sessionId, dataThisSession); dataEachSession.put(sessionId, dataThisSession);
} }
return Results.ok(asJson(dataEachSession)).as("application/json"); rc.response()
.putHeader("content-type", "application/json")
.end(asJson(dataEachSession));
} }
private synchronized Result setSession(String newSessionID) { private synchronized void setSession(String newSessionID, RoutingContext rc) {
if (knownSessionIDs.containsKey(newSessionID)) { if (knownSessionIDs.containsKey(newSessionID)) {
currentSessionID = newSessionID; currentSessionID = newSessionID;
currentWorkerIdx = 0; currentWorkerIdx = 0;
return ok(); rc.response().end();
} else { } else {
return Results.badRequest("Unknown session ID: " + newSessionID); rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()).end();
} }
} }
private Result getLastUpdateForSession(String sessionID) { private void getLastUpdateForSession(String sessionID, RoutingContext rc) {
Long lastUpdate = lastUpdateForSession.get(sessionID); Long lastUpdate = lastUpdateForSession.get(sessionID);
if (lastUpdate != null) if (lastUpdate != null) {
return ok(String.valueOf(lastUpdate)); rc.response().end(String.valueOf(lastUpdate));
return ok("-1"); return;
}
rc.response().end("-1");
} }
private Result setWorkerByIdx(String newWorkerIdx) { private void setWorkerByIdx(String newWorkerIdx, RoutingContext rc) {
try { try {
currentWorkerIdx = Integer.parseInt(newWorkerIdx); currentWorkerIdx = Integer.parseInt(newWorkerIdx);
} catch (NumberFormatException e) { } catch (NumberFormatException e) {
log.debug("Invalid call to setWorkerByIdx", e); log.debug("Invalid call to setWorkerByIdx", e);
} }
return ok(); rc.response().end();
} }
private static double fixNaN(double d) { private static double fixNaN(double d) {
@ -534,6 +577,7 @@ public class TrainModule implements UIModule {
/** /**
* Get last update time for given session ID, checking for null values * Get last update time for given session ID, checking for null values
*
* @param sessionId session ID * @param sessionId session ID
* @return last update time for session if found, or {@code null} * @return last update time for session if found, or {@code null}
*/ */
@ -547,6 +591,7 @@ public class TrainModule implements UIModule {
/** /**
* Get global {@link I18N} instance if {@link #multiSession} is {@code true}, or instance for session * Get global {@link I18N} instance if {@link #multiSession} is {@code true}, or instance for session
*
* @param sessionId session ID * @param sessionId session ID
* @return {@link I18N} instance * @return {@link I18N} instance
*/ */
@ -554,20 +599,12 @@ public class TrainModule implements UIModule {
return multiSession ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance(); return multiSession ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
} }
/**
* Get global {@link I18N} instance, used if {@link #multiSession} is {@code false} private void getOverviewData(RoutingContext rc) {
* @return {@link I18N} instance getOverviewDataForSession(currentSessionID, rc);
*/
private I18N getI18N() {
return I18NProvider.getInstance();
} }
private synchronized void getOverviewDataForSession(String sessionId, RoutingContext rc) {
private Result getOverviewData() {
return getOverviewDataForSession(currentSessionID);
}
private synchronized Result getOverviewDataForSession(String sessionId) {
Long lastUpdateTime = getLastUpdateTime(sessionId); Long lastUpdateTime = getLastUpdateTime(sessionId);
I18N i18N = getI18N(sessionId); I18N i18N = getI18N(sessionId);
@ -587,19 +624,19 @@ public class TrainModule implements UIModule {
//Get scores info //Get scores info
long[] allTimes = (noData ? null : ss.getAllUpdateTimes(sessionId, StatsListener.TYPE_ID, wid)); long[] allTimes = (noData ? null : ss.getAllUpdateTimes(sessionId, StatsListener.TYPE_ID, wid));
List<Persistable> updates = null; List<Persistable> updates = null;
if(allTimes != null && allTimes.length > maxChartPoints){ if (allTimes != null && allTimes.length > maxChartPoints) {
int subsamplingFrequency = allTimes.length / maxChartPoints; int subsamplingFrequency = allTimes.length / maxChartPoints;
LongArrayList timesToQuery = new LongArrayList(maxChartPoints+2); LongArrayList timesToQuery = new LongArrayList(maxChartPoints + 2);
int i=0; int i = 0;
for(; i<allTimes.length; i+= subsamplingFrequency){ for (; i < allTimes.length; i += subsamplingFrequency) {
timesToQuery.add(allTimes[i]); timesToQuery.add(allTimes[i]);
} }
if((i-subsamplingFrequency) != allTimes.length-1){ if ((i - subsamplingFrequency) != allTimes.length - 1) {
//Also add final point //Also add final point
timesToQuery.add(allTimes[allTimes.length-1]); timesToQuery.add(allTimes[allTimes.length - 1]);
} }
updates = ss.getUpdates(sessionId, StatsListener.TYPE_ID, wid, timesToQuery.toArray()); updates = ss.getUpdates(sessionId, StatsListener.TYPE_ID, wid, timesToQuery.toLongArray());
} else if(allTimes != null) { } else if (allTimes != null) {
//Don't subsample //Don't subsample
updates = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, wid, 0); updates = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, wid, 0);
} }
@ -754,12 +791,12 @@ public class TrainModule implements UIModule {
//----- Performance Info ----- //----- Performance Info -----
String[][] perfInfo = new String[][] {{i18N.getMessage("train.overview.perftable.startTime"), ""}, String[][] perfInfo = new String[][]{{i18N.getMessage("train.overview.perftable.startTime"), ""},
{i18N.getMessage("train.overview.perftable.totalRuntime"), ""}, {i18N.getMessage("train.overview.perftable.totalRuntime"), ""},
{i18N.getMessage("train.overview.perftable.lastUpdate"), ""}, {i18N.getMessage("train.overview.perftable.lastUpdate"), ""},
{i18N.getMessage("train.overview.perftable.totalParamUpdates"), ""}, {i18N.getMessage("train.overview.perftable.totalParamUpdates"), ""},
{i18N.getMessage("train.overview.perftable.updatesPerSec"), ""}, {i18N.getMessage("train.overview.perftable.updatesPerSec"), ""},
{i18N.getMessage("train.overview.perftable.examplesPerSec"), ""}}; {i18N.getMessage("train.overview.perftable.examplesPerSec"), ""}};
if (last != null) { if (last != null) {
perfInfo[2][1] = String.valueOf(dateFormat.format(new Date(last.getTimeStamp()))); perfInfo[2][1] = String.valueOf(dateFormat.format(new Date(last.getTimeStamp())));
@ -772,9 +809,9 @@ public class TrainModule implements UIModule {
// ----- Model Info ----- // ----- Model Info -----
String[][] modelInfo = new String[][] {{i18N.getMessage("train.overview.modeltable.modeltype"), ""}, String[][] modelInfo = new String[][]{{i18N.getMessage("train.overview.modeltable.modeltype"), ""},
{i18N.getMessage("train.overview.modeltable.nLayers"), ""}, {i18N.getMessage("train.overview.modeltable.nLayers"), ""},
{i18N.getMessage("train.overview.modeltable.nParams"), ""}}; {i18N.getMessage("train.overview.modeltable.nParams"), ""}};
if (!noData) { if (!noData) {
Persistable p = ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, wid); Persistable p = ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, wid);
if (p != null) { if (p != null) {
@ -803,28 +840,40 @@ public class TrainModule implements UIModule {
result.put("model", modelInfo); result.put("model", modelInfo);
return Results.ok(asJson(result)).as("application/json"); String json = asJson(result);
rc.response()
.putHeader("content-type", "application/json")
.end(json);
} }
private Result getModelGraph() { private void getModelGraph(RoutingContext rc) {
return getModelGraphForSession(currentSessionID); getModelGraphForSession(currentSessionID, rc);
} }
private Result getModelGraphForSession(String sessionId) { private void getModelGraphForSession(String sessionId, RoutingContext rc) {
boolean noData = (sessionId == null || !knownSessionIDs.containsKey(sessionId)); boolean noData = (sessionId == null || !knownSessionIDs.containsKey(sessionId));
StatsStorage ss = (noData ? null : knownSessionIDs.get(sessionId)); StatsStorage ss = (noData ? null : knownSessionIDs.get(sessionId));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(sessionId, StatsListener.TYPE_ID)); : ss.getAllStaticInfos(sessionId, StatsListener.TYPE_ID));
if (allStatic.isEmpty()) { if (allStatic.isEmpty()) {
return ok(); rc.response().end();
return;
} }
TrainModuleUtils.GraphInfo gi = getGraphInfo(getConfig(sessionId)); TrainModuleUtils.GraphInfo gi = getGraphInfo(getConfig(sessionId));
if (gi == null) if (gi == null) {
return ok(); rc.response().end();
return Results.ok(asJson(gi)).as("application/json"); return;
}
String json = asJson(gi);
rc.response()
.putHeader("content-type", "application/json")
.end(json);
} }
private TrainModuleUtils.GraphInfo getGraphInfo(Triple<MultiLayerConfiguration, private TrainModuleUtils.GraphInfo getGraphInfo(Triple<MultiLayerConfiguration,
@ -848,7 +897,7 @@ public class TrainModule implements UIModule {
boolean noData = (sessionId == null || !knownSessionIDs.containsKey(sessionId)); boolean noData = (sessionId == null || !knownSessionIDs.containsKey(sessionId));
StatsStorage ss = (noData ? null : knownSessionIDs.get(sessionId)); StatsStorage ss = (noData ? null : knownSessionIDs.get(sessionId));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(sessionId, StatsListener.TYPE_ID)); : ss.getAllStaticInfos(sessionId, StatsListener.TYPE_ID));
if (allStatic.isEmpty()) if (allStatic.isEmpty())
return null; return null;
@ -865,7 +914,7 @@ public class TrainModule implements UIModule {
} else { } else {
try { try {
NeuralNetConfiguration layer = NeuralNetConfiguration layer =
NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class); NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class);
return new Triple<>(null, null, layer); return new Triple<>(null, null, layer);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); e.printStackTrace();
@ -875,12 +924,11 @@ public class TrainModule implements UIModule {
} }
private void getModelData(String layerId, RoutingContext rc) {
private Result getModelData(String layerId) { getModelDataForSession(currentSessionID, layerId, rc);
return getModelDataForSession(currentSessionID, layerId);
} }
private Result getModelDataForSession(String sessionId, String layerId) { private void getModelDataForSession(String sessionId, String layerId, RoutingContext rc) {
Long lastUpdateTime = getLastUpdateTime(sessionId); Long lastUpdateTime = getLastUpdateTime(sessionId);
int layerIdx = Integer.parseInt(layerId); //TODO validation int layerIdx = Integer.parseInt(layerId); //TODO validation
@ -898,12 +946,18 @@ public class TrainModule implements UIModule {
Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> conf = getConfig(sessionId); Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> conf = getConfig(sessionId);
if (conf == null) { if (conf == null) {
return Results.ok(asJson(result)).as("application/json"); rc.response()
.putHeader("content-type", "application/json")
.end(asJson(result));
return;
} }
TrainModuleUtils.GraphInfo gi = getGraphInfo(conf); TrainModuleUtils.GraphInfo gi = getGraphInfo(conf);
if (gi == null) { if (gi == null) {
return Results.ok(asJson(result)).as("application/json"); rc.response()
.putHeader("content-type", "application/json")
.end(asJson(result));
return;
} }
@ -918,19 +972,19 @@ public class TrainModule implements UIModule {
List<Persistable> updates = null; List<Persistable> updates = null;
List<Integer> iterationCounts = null; List<Integer> iterationCounts = null;
boolean needToHandleLegacyIterCounts = false; boolean needToHandleLegacyIterCounts = false;
if(allTimes != null && allTimes.length > maxChartPoints){ if (allTimes != null && allTimes.length > maxChartPoints) {
int subsamplingFrequency = allTimes.length / maxChartPoints; int subsamplingFrequency = allTimes.length / maxChartPoints;
LongArrayList timesToQuery = new LongArrayList(maxChartPoints+2); LongArrayList timesToQuery = new LongArrayList(maxChartPoints + 2);
int i=0; int i = 0;
for(; i<allTimes.length; i+= subsamplingFrequency){ for (; i < allTimes.length; i += subsamplingFrequency) {
timesToQuery.add(allTimes[i]); timesToQuery.add(allTimes[i]);
} }
if((i-subsamplingFrequency) != allTimes.length-1){ if ((i - subsamplingFrequency) != allTimes.length - 1) {
//Also add final point //Also add final point
timesToQuery.add(allTimes[allTimes.length-1]); timesToQuery.add(allTimes[allTimes.length - 1]);
} }
updates = ss.getUpdates(sessionId, StatsListener.TYPE_ID, wid, timesToQuery.toArray()); updates = ss.getUpdates(sessionId, StatsListener.TYPE_ID, wid, timesToQuery.toLongArray());
} else if(allTimes != null) { } else if (allTimes != null) {
//Don't subsample //Don't subsample
updates = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, wid, 0); updates = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, wid, 0);
} }
@ -994,14 +1048,16 @@ public class TrainModule implements UIModule {
Map<String, Object> updateHistograms = getHistograms(layerIdx, gi, StatsType.Updates, lastUpdate); Map<String, Object> updateHistograms = getHistograms(layerIdx, gi, StatsType.Updates, lastUpdate);
result.put("updateHist", updateHistograms); result.put("updateHist", updateHistograms);
return Results.ok(asJson(result)).as("application/json"); rc.response()
.putHeader("content-type", "application/json")
.end(asJson(result));
} }
private Result getSystemData() { private void getSystemData(RoutingContext rc) {
return getSystemDataForSession(currentSessionID); getSystemDataForSession(currentSessionID, rc);
} }
private Result getSystemDataForSession(String sessionId) { private void getSystemDataForSession(String sessionId, RoutingContext rc) {
Long lastUpdate = getLastUpdateTime(sessionId); Long lastUpdate = getLastUpdateTime(sessionId);
I18N i18n = getI18N(sessionId); I18N i18n = getI18N(sessionId);
@ -1013,9 +1069,9 @@ public class TrainModule implements UIModule {
boolean noData = (ss == null); boolean noData = (ss == null);
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(sessionId, StatsListener.TYPE_ID)); : ss.getAllStaticInfos(sessionId, StatsListener.TYPE_ID));
List<Persistable> latestUpdates = (noData ? Collections.EMPTY_LIST List<Persistable> latestUpdates = (noData ? Collections.EMPTY_LIST
: ss.getLatestUpdateAllWorkers(sessionId, StatsListener.TYPE_ID)); : ss.getLatestUpdateAllWorkers(sessionId, StatsListener.TYPE_ID));
long lastUpdateTime = -1; long lastUpdateTime = -1;
@ -1029,7 +1085,7 @@ public class TrainModule implements UIModule {
long fromTime = lastUpdateTime - 5 * 60 * 1000; //TODO Make configurable long fromTime = lastUpdateTime - 5 * 60 * 1000; //TODO Make configurable
List<Persistable> lastNMinutes = List<Persistable> lastNMinutes =
(noData ? null : ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, fromTime)); (noData ? null : ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, fromTime));
Map<String, Object> mem = getMemory(allStatic, lastNMinutes, i18n); Map<String, Object> mem = getMemory(allStatic, lastNMinutes, i18n);
Pair<Map<String, Object>, Map<String, Object>> hwSwInfo = getHardwareSoftwareInfo(allStatic, i18n); Pair<Map<String, Object>, Map<String, Object>> hwSwInfo = getHardwareSoftwareInfo(allStatic, i18n);
@ -1040,7 +1096,9 @@ public class TrainModule implements UIModule {
ret.put("hardware", hwSwInfo.getFirst()); ret.put("hardware", hwSwInfo.getFirst());
ret.put("software", hwSwInfo.getSecond()); ret.put("software", hwSwInfo.getSecond());
return Results.ok(asJson(ret)).as("application/json"); rc.response()
.putHeader("content-type", "application/json")
.end(asJson(ret));
} }
private static String getLayerType(Layer layer) { private static String getLayerType(Layer layer) {
@ -1055,11 +1113,11 @@ public class TrainModule implements UIModule {
} }
private static String[][] getLayerInfoTable(String sessionId, int layerIdx, TrainModuleUtils.GraphInfo gi, I18N i18N, boolean noData, private static String[][] getLayerInfoTable(String sessionId, int layerIdx, TrainModuleUtils.GraphInfo gi, I18N i18N, boolean noData,
StatsStorage ss, String wid) { StatsStorage ss, String wid) {
List<String[]> layerInfoRows = new ArrayList<>(); List<String[]> layerInfoRows = new ArrayList<>();
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerName"), layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerName"),
gi.getVertexNames().get(layerIdx)}); gi.getVertexNames().get(layerIdx)});
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerType"), ""}); layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerType"), ""});
if (!noData) { if (!noData) {
Persistable p = ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, wid); Persistable p = ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, wid);
@ -1104,7 +1162,7 @@ public class TrainModule implements UIModule {
layerType = gi.getVertexTypes().get(layerIdx); layerType = gi.getVertexTypes().get(layerIdx);
Map<String, String> map = gi.getVertexInfo().get(layerIdx); Map<String, String> map = gi.getVertexInfo().get(layerIdx);
for (Map.Entry<String, String> entry : map.entrySet()) { for (Map.Entry<String, String> entry : map.entrySet()) {
layerInfoRows.add(new String[] {entry.getKey(), entry.getValue()}); layerInfoRows.add(new String[]{entry.getKey(), entry.getValue()});
} }
} }
@ -1116,30 +1174,30 @@ public class TrainModule implements UIModule {
String activationFn = null; String activationFn = null;
if (layer instanceof FeedForwardLayer) { if (layer instanceof FeedForwardLayer) {
FeedForwardLayer ffl = (FeedForwardLayer) layer; FeedForwardLayer ffl = (FeedForwardLayer) layer;
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerNIn"), layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerNIn"),
String.valueOf(ffl.getNIn())}); String.valueOf(ffl.getNIn())});
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerSize"), layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerSize"),
String.valueOf(ffl.getNOut())}); String.valueOf(ffl.getNOut())});
} }
if (layer instanceof BaseLayer) { if (layer instanceof BaseLayer) {
BaseLayer bl = (BaseLayer) layer; BaseLayer bl = (BaseLayer) layer;
activationFn = bl.getActivationFn().toString(); activationFn = bl.getActivationFn().toString();
val nParams = layer.initializer().numParams(nnc); long nParams = layer.initializer().numParams(nnc);
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerNParams"), layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerNParams"),
String.valueOf(nParams)}); String.valueOf(nParams)});
if (nParams > 0) { if (nParams > 0) {
try { try {
String str = JsonMappers.getMapper().writeValueAsString(bl.getWeightInitFn()); String str = JsonMappers.getMapper().writeValueAsString(bl.getWeightInitFn());
layerInfoRows.add(new String[] { layerInfoRows.add(new String[]{
i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str}); i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str});
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
IUpdater u = bl.getIUpdater(); IUpdater u = bl.getIUpdater();
String us = (u == null ? "" : u.getClass().getSimpleName()); String us = (u == null ? "" : u.getClass().getSimpleName());
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerUpdater"), layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerUpdater"),
us}); us});
//TODO: Maybe L1/L2, dropout, updater-specific values etc //TODO: Maybe L1/L2, dropout, updater-specific values etc
} }
@ -1160,21 +1218,21 @@ public class TrainModule implements UIModule {
stride = ssl.getStride(); stride = ssl.getStride();
padding = ssl.getPadding(); padding = ssl.getPadding();
activationFn = null; activationFn = null;
layerInfoRows.add(new String[] { layerInfoRows.add(new String[]{
i18N.getMessage("train.model.layerinfotable.layerSubsamplingPoolingType"), i18N.getMessage("train.model.layerinfotable.layerSubsamplingPoolingType"),
ssl.getPoolingType().toString()}); ssl.getPoolingType().toString()});
} }
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerCnnKernel"), layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerCnnKernel"),
Arrays.toString(kernel)}); Arrays.toString(kernel)});
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerCnnStride"), layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerCnnStride"),
Arrays.toString(stride)}); Arrays.toString(stride)});
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerCnnPadding"), layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerCnnPadding"),
Arrays.toString(padding)}); Arrays.toString(padding)});
} }
if (activationFn != null) { if (activationFn != null) {
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerActivationFn"), layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerActivationFn"),
activationFn}); activationFn});
} }
} }
layerInfoRows.get(1)[1] = layerType; layerInfoRows.get(1)[1] = layerType;
@ -1187,10 +1245,10 @@ public class TrainModule implements UIModule {
//TODO float precision for smaller transfers? //TODO float precision for smaller transfers?
//First: iteration. Second: ratios, by parameter //First: iteration. Second: ratios, by parameter
private static MeanMagnitudes getLayerMeanMagnitudes(int layerIdx, TrainModuleUtils.GraphInfo gi, private static MeanMagnitudes getLayerMeanMagnitudes(int layerIdx, TrainModuleUtils.GraphInfo gi,
List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) { List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) {
if (gi == null) { if (gi == null) {
return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(), return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(),
Collections.emptyMap()); Collections.emptyMap());
} }
String layerName = gi.getVertexNames().get(layerIdx); String layerName = gi.getVertexNames().get(layerIdx);
@ -1201,7 +1259,7 @@ public class TrainModule implements UIModule {
String layerType = gi.getVertexTypes().get(layerIdx); String layerType = gi.getVertexTypes().get(layerIdx);
if ("input".equalsIgnoreCase(layerType)) { //TODO better checking - other vertices, etc if ("input".equalsIgnoreCase(layerType)) { //TODO better checking - other vertices, etc
return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(), return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(),
Collections.emptyMap()); Collections.emptyMap());
} }
List<Integer> iterCounts = new ArrayList<>(); List<Integer> iterCounts = new ArrayList<>();
@ -1283,7 +1341,7 @@ public class TrainModule implements UIModule {
private static Triple<int[], float[], float[]> EMPTY_TRIPLE = new Triple<>(new int[0], new float[0], new float[0]); private static Triple<int[], float[], float[]> EMPTY_TRIPLE = new Triple<>(new int[0], new float[0], new float[0]);
private static Triple<int[], float[], float[]> getLayerActivations(int index, TrainModuleUtils.GraphInfo gi, private static Triple<int[], float[], float[]> getLayerActivations(int index, TrainModuleUtils.GraphInfo gi,
List<Persistable> updates, List<Integer> iterationCounts) { List<Persistable> updates, List<Integer> iterationCounts) {
if (gi == null) { if (gi == null) {
return EMPTY_TRIPLE; return EMPTY_TRIPLE;
} }
@ -1344,6 +1402,7 @@ public class TrainModule implements UIModule {
} }
private static final Map<String, Object> EMPTY_LR_MAP = new HashMap<>(); private static final Map<String, Object> EMPTY_LR_MAP = new HashMap<>();
static { static {
EMPTY_LR_MAP.put("iterCounts", new int[0]); EMPTY_LR_MAP.put("iterCounts", new int[0]);
EMPTY_LR_MAP.put("paramNames", Collections.EMPTY_LIST); EMPTY_LR_MAP.put("paramNames", Collections.EMPTY_LIST);
@ -1351,7 +1410,7 @@ public class TrainModule implements UIModule {
} }
private static Map<String, Object> getLayerLearningRates(int layerIdx, TrainModuleUtils.GraphInfo gi, private static Map<String, Object> getLayerLearningRates(int layerIdx, TrainModuleUtils.GraphInfo gi,
List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) { List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) {
if (gi == null) { if (gi == null) {
return Collections.emptyMap(); return Collections.emptyMap();
} }
@ -1424,7 +1483,7 @@ public class TrainModule implements UIModule {
private static Map<String, Object> getHistograms(int layerIdx, TrainModuleUtils.GraphInfo gi, StatsType statsType, private static Map<String, Object> getHistograms(int layerIdx, TrainModuleUtils.GraphInfo gi, StatsType statsType,
Persistable p) { Persistable p) {
if (p == null) if (p == null)
return null; return null;
if (!(p instanceof StatsReport)) if (!(p instanceof StatsReport))
@ -1475,7 +1534,7 @@ public class TrainModule implements UIModule {
} }
private static Map<String, Object> getMemory(List<Persistable> staticInfoAllWorkers, private static Map<String, Object> getMemory(List<Persistable> staticInfoAllWorkers,
List<Persistable> updatesLastNMinutes, I18N i18n) { List<Persistable> updatesLastNMinutes, I18N i18n) {
Map<String, Object> ret = new HashMap<>(); Map<String, Object> ret = new HashMap<>();
@ -1612,7 +1671,7 @@ public class TrainModule implements UIModule {
} }
private static Pair<Map<String, Object>, Map<String, Object>> getHardwareSoftwareInfo( private static Pair<Map<String, Object>, Map<String, Object>> getHardwareSoftwareInfo(
List<Persistable> staticInfoAllWorkers, I18N i18n) { List<Persistable> staticInfoAllWorkers, I18N i18n) {
Map<String, Object> retHw = new HashMap<>(); Map<String, Object> retHw = new HashMap<>();
Map<String, Object> retSw = new HashMap<>(); Map<String, Object> retSw = new HashMap<>();
@ -1641,24 +1700,24 @@ public class TrainModule implements UIModule {
String[] deviceDescription = sr.getHwDeviceDescription(); String[] deviceDescription = sr.getHwDeviceDescription();
long[] devTotalMem = sr.getHwDeviceTotalMemory(); long[] devTotalMem = sr.getHwDeviceTotalMemory();
hwInfo.add(new String[] {i18n.getMessage("train.system.hwTable.jvmMax"), hwInfo.add(new String[]{i18n.getMessage("train.system.hwTable.jvmMax"),
String.valueOf(sr.getHwJvmMaxMemory())}); String.valueOf(sr.getHwJvmMaxMemory())});
hwInfo.add(new String[] {i18n.getMessage("train.system.hwTable.offHeapMax"), hwInfo.add(new String[]{i18n.getMessage("train.system.hwTable.offHeapMax"),
String.valueOf(sr.getHwOffHeapMaxMemory())}); String.valueOf(sr.getHwOffHeapMaxMemory())});
hwInfo.add(new String[] {i18n.getMessage("train.system.hwTable.jvmProcs"), hwInfo.add(new String[]{i18n.getMessage("train.system.hwTable.jvmProcs"),
String.valueOf(sr.getHwJvmAvailableProcessors())}); String.valueOf(sr.getHwJvmAvailableProcessors())});
hwInfo.add(new String[] {i18n.getMessage("train.system.hwTable.computeDevices"), hwInfo.add(new String[]{i18n.getMessage("train.system.hwTable.computeDevices"),
String.valueOf(numDevices)}); String.valueOf(numDevices)});
for (int i = 0; i < numDevices; i++) { for (int i = 0; i < numDevices; i++) {
String label = i18n.getMessage("train.system.hwTable.deviceName") + " (" + i + ")"; String label = i18n.getMessage("train.system.hwTable.deviceName") + " (" + i + ")";
String name = (deviceDescription == null || i >= deviceDescription.length ? String.valueOf(i) String name = (deviceDescription == null || i >= deviceDescription.length ? String.valueOf(i)
: deviceDescription[i]); : deviceDescription[i]);
hwInfo.add(new String[] {label, name}); hwInfo.add(new String[]{label, name});
String memLabel = i18n.getMessage("train.system.hwTable.deviceMemory") + " (" + i + ")"; String memLabel = i18n.getMessage("train.system.hwTable.deviceMemory") + " (" + i + ")";
String memBytes = String memBytes =
(devTotalMem == null || i >= devTotalMem.length ? "-" : String.valueOf(devTotalMem[i])); (devTotalMem == null || i >= devTotalMem.length ? "-" : String.valueOf(devTotalMem[i]));
hwInfo.add(new String[] {memLabel, memBytes}); hwInfo.add(new String[]{memLabel, memBytes});
} }
retHw.put(String.valueOf(count), hwInfo); retHw.put(String.valueOf(count), hwInfo);
@ -1690,13 +1749,13 @@ public class TrainModule implements UIModule {
datatype = datatype.toLowerCase(); datatype = datatype.toLowerCase();
List<String[]> swInfo = new ArrayList<>(); List<String[]> swInfo = new ArrayList<>();
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.os"), sr.getSwOsName()}); swInfo.add(new String[]{i18n.getMessage("train.system.swTable.os"), sr.getSwOsName()});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.hostname"), sr.getSwHostName()}); swInfo.add(new String[]{i18n.getMessage("train.system.swTable.hostname"), sr.getSwHostName()});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.osArch"), sr.getSwArch()}); swInfo.add(new String[]{i18n.getMessage("train.system.swTable.osArch"), sr.getSwArch()});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.jvmName"), sr.getSwJvmName()}); swInfo.add(new String[]{i18n.getMessage("train.system.swTable.jvmName"), sr.getSwJvmName()});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.jvmVersion"), sr.getSwJvmVersion()}); swInfo.add(new String[]{i18n.getMessage("train.system.swTable.jvmVersion"), sr.getSwJvmVersion()});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.nd4jBackend"), nd4jBackend}); swInfo.add(new String[]{i18n.getMessage("train.system.swTable.nd4jBackend"), nd4jBackend});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.nd4jDataType"), datatype}); swInfo.add(new String[]{i18n.getMessage("train.system.swTable.nd4jDataType"), datatype});
retSw.put(String.valueOf(count), swInfo); retSw.put(String.valueOf(count), swInfo);
@ -1717,10 +1776,10 @@ public class TrainModule implements UIModule {
} }
private static final String asJson(Object o){ private static final String asJson(Object o) {
try{ try {
return JSON.writeValueAsString(o); return JsonMappers.getMapper().writeValueAsString(o);
} catch (Exception e){ } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
@ -1737,10 +1796,9 @@ public class TrainModule implements UIModule {
return files; return files;
} }
private static void addAll(List<I18NResource> to, String prefix, String... suffixes){ private static void addAll(List<I18NResource> to, String prefix, String... suffixes) {
for(String s : suffixes){ for (String s : suffixes) {
to.add(new I18NResource("dl4j_i18n/" + prefix + "." + s)); to.add(new I18NResource("dl4j_i18n/" + prefix + "." + s));
} }
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -117,7 +118,6 @@ public class TrainModuleUtils {
vertexToIndexMap.put(s, vertexCount++); vertexToIndexMap.put(s, vertexCount++);
} }
int layerCount = 0;
for (Map.Entry<String, GraphVertex> entry : vertices.entrySet()) { for (Map.Entry<String, GraphVertex> entry : vertices.entrySet()) {
GraphVertex gv = entry.getValue(); GraphVertex gv = entry.getValue();
layerNames.add(entry.getKey()); layerNames.add(entry.getKey());
@ -266,7 +266,6 @@ public class TrainModuleUtils {
private static Map<String, String> getLayerInfo(NeuralNetConfiguration c, Layer layer) { private static Map<String, String> getLayerInfo(NeuralNetConfiguration c, Layer layer) {
Map<String, String> map = new LinkedHashMap<>(); Map<String, String> map = new LinkedHashMap<>();
if (layer instanceof FeedForwardLayer) { if (layer instanceof FeedForwardLayer) {
@ -278,8 +277,8 @@ public class TrainModuleUtils {
} }
if (layer instanceof ConvolutionLayer) { if (layer instanceof ConvolutionLayer) {
org.deeplearning4j.nn.conf.layers.ConvolutionLayer layer1 = ConvolutionLayer layer1 =
(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) layer; (ConvolutionLayer) layer;
map.put("Kernel size", Arrays.toString(layer1.getKernelSize())); map.put("Kernel size", Arrays.toString(layer1.getKernelSize()));
map.put("Stride", Arrays.toString(layer1.getStride())); map.put("Stride", Arrays.toString(layer1.getStride()));
map.put("Padding", Arrays.toString(layer1.getPadding())); map.put("Padding", Arrays.toString(layer1.getPadding()));

View File

@ -0,0 +1,141 @@
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.ui.module.tsne;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.ext.web.FileUpload;
import io.vertx.ext.web.RoutingContext;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.i18n.I18NResource;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.*;
public class TsneModule implements UIModule {
private static final String UPLOADED_FILE = "UploadedFile";
private Map<String, List<String>> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>());
private List<String> uploadedFileLines = null;
public TsneModule() {
}
@Override
public List<String> getCallbackTypeIDs() {
return Collections.emptyList();
}
@Override
public List<Route> getRoutes() {
Route r1 = new Route("/tsne", HttpMethod.GET, (path, rc) -> rc.response().sendFile("templates/Tsne.html"));
Route r2 = new Route("/tsne/sessions", HttpMethod.GET, (path, rc) -> this.listSessions(rc));
Route r3 = new Route("/tsne/coords/:sid", HttpMethod.GET, (path, rc) -> this.getCoords(path.get(0), rc));
Route r4 = new Route("/tsne/upload", HttpMethod.POST, (path, rc) -> this.uploadFile(rc));
Route r5 = new Route("/tsne/post/:sid", HttpMethod.POST, (path, rc) -> this.postFile(path.get(0), rc));
return Arrays.asList(r1, r2, r3, r4, r5);
}
@Override
public void reportStorageEvents(Collection<StatsStorageEvent> events) {
//No-op
}
@Override
public void onAttach(StatsStorage statsStorage) {
//No-op
}
@Override
public void onDetach(StatsStorage statsStorage) {
//No-op
}
@Override
public List<I18NResource> getInternationalizationResources() {
return Collections.emptyList();
}
private String asJson(Object o){
try{
return JsonMappers.getMapper().writeValueAsString(o);
} catch (JsonProcessingException e){
throw new RuntimeException("Error converting object to JSON", e);
}
}
private void listSessions(RoutingContext rc) {
List<String> list = new ArrayList<>(knownSessionIDs.keySet());
if (uploadedFileLines != null) {
list.add(UPLOADED_FILE);
}
rc.response()
.putHeader("content-type", "application/json")
.end(asJson(list));
}
private void getCoords(String sessionId, RoutingContext rc) {
if (UPLOADED_FILE.equals(sessionId) && uploadedFileLines != null) {
rc.response()
.putHeader("content-type", "application/json")
.end(asJson(uploadedFileLines));
} else if (knownSessionIDs.containsKey(sessionId)) {
rc.response().putHeader("content-type", "application/json")
.end(asJson(knownSessionIDs.get(sessionId)));
} else {
rc.response().end();
}
}
private void uploadFile(RoutingContext rc) {
postFile(null, rc);
}
private void postFile(String sid, RoutingContext rc) {
Set<FileUpload> files = rc.fileUploads();
if(files == null || files.isEmpty()){
rc.response().end();
return;
}
FileUpload u = files.iterator().next();
File f = new File(u.uploadedFileName());
List<String> lines;
try {
lines = FileUtils.readLines(f, StandardCharsets.UTF_8);
} catch (IOException e) {
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()).end("Could not read from uploaded file");
return;
}
if(sid == null){
uploadedFileLines = lines;
} else {
knownSessionIDs.put(sid, lines);
}
rc.response().end("File uploaded: " + u.fileName() + ", " + u.contentType());
}
}

View File

@ -14,9 +14,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
################################################################################ ################################################################################
org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule #org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule
org.deeplearning4j.ui.module.defaultModule.DefaultModule org.deeplearning4j.ui.module.defaultModule.DefaultModule
org.deeplearning4j.ui.module.remote.RemoteReceiverModule #org.deeplearning4j.ui.module.remote.RemoteReceiverModule
org.deeplearning4j.ui.module.train.TrainModule org.deeplearning4j.ui.module.train.TrainModule
org.deeplearning4j.ui.module.tsne.TsneModule #org.deeplearning4j.ui.module.tsne.TsneModule

Some files were not shown because too many files have changed in this diff Show More