simplify TrainModule to use getters of VertxUIServer instance for multi-session mode, session loader and server address
Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>
This commit is contained in:
		
							parent
							
								
									ef7c21c204
								
							
						
					
					
						commit
						f1ebced7a1
					
				@ -77,8 +77,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
 | 
			
		||||
    private static Integer instancePort;
 | 
			
		||||
    private static Thread autoStopThread;
 | 
			
		||||
 | 
			
		||||
    private TrainModule trainModule;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
 | 
			
		||||
     * @param port TCP socket port for {@link HttpServer} to listen
 | 
			
		||||
@ -208,7 +206,11 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
 | 
			
		||||
 | 
			
		||||
    private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
 | 
			
		||||
    private RemoteReceiverModule remoteReceiverModule;
 | 
			
		||||
    private StatsStorageLoader statsStorageLoader;
 | 
			
		||||
    /**
 | 
			
		||||
     * Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
 | 
			
		||||
     */
 | 
			
		||||
    @Getter
 | 
			
		||||
    private Function<String, Boolean> statsStorageLoader;
 | 
			
		||||
 | 
			
		||||
    //typeIDModuleMap: Records which modules are registered for which type IDs
 | 
			
		||||
    private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>();
 | 
			
		||||
@ -248,10 +250,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
 | 
			
		||||
     */
 | 
			
		||||
    public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
 | 
			
		||||
        if (statsStorageProvider != null) {
 | 
			
		||||
            this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
 | 
			
		||||
            if (trainModule != null) {
 | 
			
		||||
                this.trainModule.setSessionLoader(this.statsStorageLoader);
 | 
			
		||||
            }
 | 
			
		||||
            this.statsStorageLoader = (sessionId) -> {
 | 
			
		||||
                log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
 | 
			
		||||
                StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
 | 
			
		||||
                if (statsStorage != null) {
 | 
			
		||||
                    if (statsStorage.sessionExists(sessionId)) {
 | 
			
		||||
                        attach(statsStorage);
 | 
			
		||||
                        return true;
 | 
			
		||||
                    }
 | 
			
		||||
                    log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
 | 
			
		||||
                            "Session ID (" + sessionId + ") does not exist in StatsStorage.");
 | 
			
		||||
                    return false;
 | 
			
		||||
                } else {
 | 
			
		||||
                    log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
 | 
			
		||||
                            "StatsStorageProvider returned null.");
 | 
			
		||||
                    return false;
 | 
			
		||||
                }
 | 
			
		||||
            };
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -303,8 +318,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
 | 
			
		||||
        trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
 | 
			
		||||
        uiModules.add(trainModule);
 | 
			
		||||
        uiModules.add(new TrainModule());
 | 
			
		||||
        uiModules.add(new ConvolutionalListenerModule());
 | 
			
		||||
        uiModules.add(new TsneModule());
 | 
			
		||||
        uiModules.add(new SameDiffModule());
 | 
			
		||||
@ -597,37 +611,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
 | 
			
		||||
     */
 | 
			
		||||
    private class StatsStorageLoader implements Function<String, Boolean> {
 | 
			
		||||
 | 
			
		||||
        Function<String, StatsStorage> statsStorageProvider;
 | 
			
		||||
 | 
			
		||||
        StatsStorageLoader(Function<String, StatsStorage> statsStorageProvider) {
 | 
			
		||||
            this.statsStorageProvider = statsStorageProvider;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public Boolean apply(String sessionId) {
 | 
			
		||||
            log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
 | 
			
		||||
            StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
 | 
			
		||||
            if (statsStorage != null) {
 | 
			
		||||
                if (statsStorage.sessionExists(sessionId)) {
 | 
			
		||||
                    attach(statsStorage);
 | 
			
		||||
                    return true;
 | 
			
		||||
                }
 | 
			
		||||
                log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
 | 
			
		||||
                        "Session ID (" + sessionId + ") does not exist in StatsStorage.");
 | 
			
		||||
                return false;
 | 
			
		||||
            } else {
 | 
			
		||||
                log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
 | 
			
		||||
                        "StatsStorageProvider returned null.");
 | 
			
		||||
                return false;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    //==================================================================================================================
 | 
			
		||||
    // CLI Launcher
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -26,8 +26,6 @@ import io.vertx.ext.web.RoutingContext;
 | 
			
		||||
import it.unimi.dsi.fastutil.longs.LongArrayList;
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import lombok.Setter;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import org.apache.commons.io.FileUtils;
 | 
			
		||||
import org.apache.commons.io.FilenameUtils;
 | 
			
		||||
@ -43,6 +41,7 @@ import org.deeplearning4j.nn.conf.graph.GraphVertex;
 | 
			
		||||
import org.deeplearning4j.nn.conf.graph.LayerVertex;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.*;
 | 
			
		||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
 | 
			
		||||
import org.deeplearning4j.ui.VertxUIServer;
 | 
			
		||||
import org.deeplearning4j.ui.api.HttpMethod;
 | 
			
		||||
import org.deeplearning4j.ui.api.I18N;
 | 
			
		||||
import org.deeplearning4j.ui.api.Route;
 | 
			
		||||
@ -56,7 +55,6 @@ import org.deeplearning4j.ui.model.stats.api.StatsInitializationReport;
 | 
			
		||||
import org.deeplearning4j.ui.model.stats.api.StatsReport;
 | 
			
		||||
import org.deeplearning4j.ui.model.stats.api.StatsType;
 | 
			
		||||
import org.nd4j.common.function.Function;
 | 
			
		||||
import org.nd4j.common.function.Supplier;
 | 
			
		||||
import org.nd4j.linalg.learning.config.IUpdater;
 | 
			
		||||
import org.nd4j.common.primitives.Pair;
 | 
			
		||||
import org.nd4j.common.primitives.Triple;
 | 
			
		||||
@ -86,8 +84,6 @@ public class TrainModule implements UIModule {
 | 
			
		||||
    private static final DecimalFormat df2 = new DecimalFormat("#.00");
 | 
			
		||||
    private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
 | 
			
		||||
 | 
			
		||||
    private final Supplier<String> addressSupplier;
 | 
			
		||||
 | 
			
		||||
    private enum ModelType {
 | 
			
		||||
        MLN, CG, Layer
 | 
			
		||||
    }
 | 
			
		||||
@ -99,29 +95,14 @@ public class TrainModule implements UIModule {
 | 
			
		||||
    private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID
 | 
			
		||||
    private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
 | 
			
		||||
    private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>();
 | 
			
		||||
    private final boolean multiSession;
 | 
			
		||||
    @Getter @Setter
 | 
			
		||||
    private Function<String, Boolean> sessionLoader;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    private final Configuration configuration;
 | 
			
		||||
 | 
			
		||||
    public TrainModule() {
 | 
			
		||||
        this(false, null, null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * TrainModule
 | 
			
		||||
     *
 | 
			
		||||
     * @param multiSession    multi-session mode
 | 
			
		||||
     * @param sessionLoader   StatsStorage loader to call if an unknown session ID is passed as URL path parameter
 | 
			
		||||
     *                        in multi-session mode
 | 
			
		||||
     * @param addressSupplier supplier for server address (server address in PlayUIServer gets initialized after modules)
 | 
			
		||||
     */
 | 
			
		||||
    public TrainModule(boolean multiSession, Function<String, Boolean> sessionLoader, Supplier<String> addressSupplier) {
 | 
			
		||||
        this.multiSession = multiSession;
 | 
			
		||||
        this.sessionLoader = sessionLoader;
 | 
			
		||||
        this.addressSupplier = addressSupplier;
 | 
			
		||||
    public TrainModule() {
 | 
			
		||||
        String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY);
 | 
			
		||||
        int value = DEFAULT_MAX_CHART_POINTS;
 | 
			
		||||
        if (maxChartPointsProp != null) {
 | 
			
		||||
@ -159,8 +140,9 @@ public class TrainModule implements UIModule {
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<Route> getRoutes() {
 | 
			
		||||
        List<Route> r = new ArrayList<>();
 | 
			
		||||
        r.add(new Route("/train/multisession", HttpMethod.GET, (path, rc) -> rc.response().end(multiSession ? "true" : "false")));
 | 
			
		||||
        if (multiSession) {
 | 
			
		||||
        r.add(new Route("/train/multisession", HttpMethod.GET,
 | 
			
		||||
                (path, rc) -> rc.response().end(VertxUIServer.getInstance().isMultiSession() ? "true" : "false")));
 | 
			
		||||
        if (VertxUIServer.getInstance().isMultiSession()) {
 | 
			
		||||
            r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
 | 
			
		||||
            r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> {
 | 
			
		||||
                rc.response()
 | 
			
		||||
@ -264,7 +246,9 @@ public class TrainModule implements UIModule {
 | 
			
		||||
        if (!knownSessionIDs.isEmpty()) {
 | 
			
		||||
            sb.append("        <ul>");
 | 
			
		||||
            for (String sessionId : knownSessionIDs.keySet()) {
 | 
			
		||||
                sb.append("            <li><a href=\"train/").append(sessionId).append("\">").append(sessionId).append("</a></li>\n");
 | 
			
		||||
                sb.append("            <li><a href=\"train/")
 | 
			
		||||
                        .append(sessionId).append("\">")
 | 
			
		||||
                        .append(sessionId).append("</a></li>\n");
 | 
			
		||||
            }
 | 
			
		||||
            sb.append("        </ul>");
 | 
			
		||||
        } else {
 | 
			
		||||
@ -284,9 +268,11 @@ public class TrainModule implements UIModule {
 | 
			
		||||
     *
 | 
			
		||||
     * @param sessionId  session ID to look fo with provider
 | 
			
		||||
     * @param targetPath one of overview / model / system, or null
 | 
			
		||||
     * @param rc routing context
 | 
			
		||||
     */
 | 
			
		||||
    private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
 | 
			
		||||
        if (sessionLoader != null && sessionLoader.apply(sessionId)) {
 | 
			
		||||
        Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader();
 | 
			
		||||
        if (loader != null && loader.apply(sessionId)) {
 | 
			
		||||
            if (targetPath != null) {
 | 
			
		||||
                rc.reroute(targetPath);
 | 
			
		||||
            } else {
 | 
			
		||||
@ -306,9 +292,9 @@ public class TrainModule implements UIModule {
 | 
			
		||||
                        && StatsListener.TYPE_ID.equals(sse.getTypeID())
 | 
			
		||||
                        && !knownSessionIDs.containsKey(sse.getSessionID())) {
 | 
			
		||||
                    knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
 | 
			
		||||
                    if (multiSession) {
 | 
			
		||||
                    if (VertxUIServer.getInstance().isMultiSession()) {
 | 
			
		||||
                        log.info("Adding training session {}/train/{} of StatsStorage instance {}",
 | 
			
		||||
                                addressSupplier.get(), sse.getSessionID(), sse.getStatsStorage());
 | 
			
		||||
                                VertxUIServer.getInstance().getAddress(), sse.getSessionID(), sse.getStatsStorage());
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
@ -332,9 +318,9 @@ public class TrainModule implements UIModule {
 | 
			
		||||
                if (!StatsListener.TYPE_ID.equals(typeID))
 | 
			
		||||
                    continue;
 | 
			
		||||
                knownSessionIDs.put(sessionID, statsStorage);
 | 
			
		||||
                if (multiSession) {
 | 
			
		||||
                if (VertxUIServer.getInstance().isMultiSession()) {
 | 
			
		||||
                    log.info("Adding training session {}/train/{} of StatsStorage instance {}",
 | 
			
		||||
                            addressSupplier.get(), sessionID, statsStorage);
 | 
			
		||||
                            VertxUIServer.getInstance().getAddress(), sessionID, statsStorage);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
 | 
			
		||||
@ -364,9 +350,9 @@ public class TrainModule implements UIModule {
 | 
			
		||||
        }
 | 
			
		||||
        for (String s : toRemove) {
 | 
			
		||||
            knownSessionIDs.remove(s);
 | 
			
		||||
            if (multiSession) {
 | 
			
		||||
            if (VertxUIServer.getInstance().isMultiSession()) {
 | 
			
		||||
                log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
 | 
			
		||||
                        addressSupplier.get(), s, statsStorage);
 | 
			
		||||
                        VertxUIServer.getInstance().getAddress(), s, statsStorage);
 | 
			
		||||
            }
 | 
			
		||||
            lastUpdateForSession.remove(s);
 | 
			
		||||
        }
 | 
			
		||||
@ -602,13 +588,13 @@ public class TrainModule implements UIModule {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get global {@link I18N} instance if {@link #multiSession} is {@code true}, or instance for session
 | 
			
		||||
     * Get global {@link I18N} instance if {@link VertxUIServer#isMultiSession()} is {@code true}, or instance for session
 | 
			
		||||
     *
 | 
			
		||||
     * @param sessionId session ID
 | 
			
		||||
     * @return {@link I18N} instance
 | 
			
		||||
     */
 | 
			
		||||
    private I18N getI18N(String sessionId) {
 | 
			
		||||
        return multiSession ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
 | 
			
		||||
        return VertxUIServer.getInstance().isMultiSession() ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user