From 59f1cbf0c64e83ab6594d507cdb96f74306ab398 Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Wed, 18 Sep 2019 21:28:13 -0400 Subject: [PATCH] RL4J - AsyncTrainingListener (#8072) * Code clarity: Extracted parts of run() into private methods Signed-off-by: Alexandre Boulanger * Added listener pattern to async learning Signed-off-by: unknown * Merged all listeners logic Signed-off-by: Alexandre Boulanger * Added interface and common data to training events Signed-off-by: Alexandre Boulanger * Fixed missing info log file Signed-off-by: Alexandre Boulanger * Fixed bad merge; removed useless TrainingEvent Signed-off-by: Alexandre Boulanger * Removed param from training start/end event Signed-off-by: Alexandre Boulanger * Removed 'event' classes from the training listener Signed-off-by: Alexandre Boulanger * Reverted changes to QLearningDiscrete.setTarget() --- .../rl4j/learning/IEpochTrainer.java | 31 ++ .../rl4j/learning/ILearning.java | 5 +- .../rl4j/learning/Learning.java | 1 - .../rl4j/learning/async/AsyncGlobal.java | 14 +- .../rl4j/learning/async/AsyncLearning.java | 103 +++-- .../rl4j/learning/async/AsyncThread.java | 160 +++++--- .../learning/async/AsyncThreadDiscrete.java | 6 +- .../rl4j/learning/async/IAsyncGlobal.java | 7 +- .../async/a3c/discrete/A3CDiscrete.java | 10 +- .../async/a3c/discrete/A3CDiscreteConv.java | 25 +- .../async/a3c/discrete/A3CDiscreteDense.java | 28 +- .../async/a3c/discrete/A3CThreadDiscrete.java | 19 +- .../discrete/AsyncNStepQLearningDiscrete.java | 13 +- .../AsyncNStepQLearningDiscreteConv.java | 19 +- .../AsyncNStepQLearningDiscreteDense.java | 21 +- .../AsyncNStepQLearningThreadDiscrete.java | 22 +- .../learning/listener/TrainingListener.java | 72 ++++ .../listener/TrainingListenerList.java | 105 +++++ .../rl4j/learning/sync/SyncLearning.java | 99 ++--- .../listener/SyncTrainingEpochEndEvent.java | 22 -- .../sync/listener/SyncTrainingEvent.java | 21 - .../sync/listener/SyncTrainingListener.java | 45 --- .../learning/sync/qlearning/QLearning.java | 2 +- .../qlearning/discrete/QLearningDiscrete.java | 18 +- .../discrete/QLearningDiscreteConv.java | 21 +- .../discrete/QLearningDiscreteDense.java | 19 +- .../deeplearning4j/rl4j/policy/IPolicy.java | 10 + .../deeplearning4j/rl4j/policy/Policy.java | 3 +- .../deeplearning4j/rl4j/util/DataManager.java | 19 +- .../util/DataManagerSyncTrainingListener.java | 126 ------ .../util/DataManagerTrainingListener.java | 83 ++++ .../rl4j/util/IDataManager.java | 2 +- .../learning/async/AsyncLearningTest.java | 127 ++++++ .../rl4j/learning/async/AsyncThreadTest.java | 367 ++++-------------- .../AsyncTrainingListenerListTest.java | 98 +++++ .../listener/TrainingListenerListTest.java | 83 ++++ .../rl4j/learning/sync/SyncLearningTest.java | 50 +-- .../sync/qlearning/QLearningDiscreteTest.java | 65 ---- .../discrete/QLearningDiscreteTest.java | 13 +- .../support/MockSyncTrainingListener.java | 46 --- .../rl4j/support/MockAsyncConfiguration.java | 65 ++++ .../rl4j/support/MockAsyncGlobal.java | 65 ++++ .../rl4j/support/MockDataManager.java | 2 +- .../rl4j/support/MockNeuralNet.java | 74 ++++ .../rl4j/support/MockPolicy.java | 17 + .../rl4j/support/MockTrainingListener.java | 65 ++++ .../util/DataManagerTrainingListenerTest.java | 169 ++++++++ 47 files changed, 1576 insertions(+), 881 deletions(-) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListener.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerList.java delete mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEpochEndEvent.java delete mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEvent.java delete mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingListener.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java delete mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerSyncTrainingListener.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningDiscreteTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockSyncTrainingListener.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java new file mode 100644 index 000000000..72510dcaa --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java @@ -0,0 +1,31 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning; + +import org.deeplearning4j.rl4j.mdp.MDP; + +/** + * The common API between Learning and AsyncThread. + * + * @author Alexandre Boulanger + */ +public interface IEpochTrainer { + int getStepCounter(); + int getEpochCounter(); + IHistoryProcessor getHistoryProcessor(); + MDP getMdp(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java index dc9f78577..e243bdc5e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java @@ -17,7 +17,7 @@ package org.deeplearning4j.rl4j.learning; import org.deeplearning4j.rl4j.mdp.MDP; -import org.deeplearning4j.rl4j.policy.Policy; +import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; @@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.space.Encodable; */ public interface ILearning> extends StepCountable { - Policy getPolicy(); + IPolicy getPolicy(); void train(); @@ -38,6 +38,7 @@ public interface ILearning> ex MDP getMdp(); + IHistoryProcessor getHistoryProcessor(); interface LConfiguration { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java index 798aa094f..89c7fdb59 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java @@ -26,7 +26,6 @@ import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.IDataManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java index 4ff461206..fb07baf1e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java @@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.learning.async; import lombok.Getter; -import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.network.NeuralNet; @@ -63,7 +62,6 @@ public class AsyncGlobal extends Thread implements IAsyncG @Getter private NN target; @Getter - @Setter private boolean running = true; public AsyncGlobal(NN initial, AsyncConfiguration a3cc) { @@ -78,7 +76,9 @@ public class AsyncGlobal extends Thread implements IAsyncG } public void enqueue(Gradient[] gradient, Integer nstep) { - queue.add(new Pair<>(gradient, nstep)); + if(running && !isTrainingComplete()) { + queue.add(new Pair<>(gradient, nstep)); + } } @Override @@ -105,4 +105,12 @@ public class AsyncGlobal extends Thread implements IAsyncG } + /** + * Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded. + */ + public void terminate() { + running = false; + queue.clear(); + } + } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java index 24e6bfcb9..40d279182 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java @@ -16,33 +16,49 @@ package org.deeplearning4j.rl4j.learning.async; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.listener.*; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.IDataManager; import org.nd4j.linalg.factory.Nd4j; /** + * The entry point for async training. This class will start a number ({@link AsyncConfiguration#getNumThread() + * configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals + * (see setProgressEventInterval(int)) + * * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16. - * - * Async learning always follow the same pattern in RL4J - * -launch the Global thread - * -launch the "save threads" - * -periodically evaluate the model of the global thread for monitoring purposes - * + * @author Alexandre Boulanger */ @Slf4j public abstract class AsyncLearning, NN extends NeuralNet> extends Learning { - protected abstract IDataManager getDataManager(); + @Getter(AccessLevel.PROTECTED) + private final TrainingListenerList listeners = new TrainingListenerList(); public AsyncLearning(AsyncConfiguration conf) { super(conf); } + /** + * Add a {@link TrainingListener} listener at the end of the listener list. + * + * @param listener the listener to be added + */ + public void addListener(TrainingListener listener) { + listeners.add(listener); + } + + /** + * Returns the configuration + * @return the configuration (see {@link AsyncConfiguration}) + */ public abstract AsyncConfiguration getConfiguration(); protected abstract AsyncThread newThread(int i, int deviceAffinity); @@ -57,41 +73,80 @@ public abstract class AsyncLearning + * The training stop when:
+ * - A worker thread terminate the AsyncGlobal thread (see {@link AsyncGlobal})
+ * OR
+ * - a listener explicitly stops it
+ *

+ * Listeners
+ * For a given event, the listeners are called sequentially in same the order as they were added. If one listener + * returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse TrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.
+ * Events: + *

    + *
  • {@link TrainingListener#onTrainingStart() onTrainingStart()} is called once when the training starts.
  • + *
  • {@link TrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.
  • + *
+ */ public void train() { - try { - log.info("AsyncLearning training starting."); - launchThreads(); + log.info("AsyncLearning training starting."); - //this is simply for stat purposes - getDataManager().writeInfo(this); - synchronized (this) { - while (!isTrainingComplete() && getAsyncGlobal().isRunning()) { - getPolicy().play(getMdp(), getHistoryProcessor()); - getDataManager().writeInfo(this); - wait(20000); + canContinue = listeners.notifyTrainingStarted(); + if (canContinue) { + launchThreads(); + monitorTraining(); + } + + cleanupPostTraining(); + listeners.notifyTrainingFinished(); + } + + protected void monitorTraining() { + try { + while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) { + canContinue = listeners.notifyTrainingProgress(this); + if(!canContinue) { + return; + } + + synchronized (this) { + wait(progressMonitorFrequency); } } - } catch (Exception e) { - log.error("Training failed.", e); - e.printStackTrace(); + } catch (InterruptedException e) { + log.error("Training interrupted.", e); } } - + protected void cleanupPostTraining() { + // Worker threads stops automatically when the global thread stops + getAsyncGlobal().terminate(); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java index 2749949b6..1d763be0b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java @@ -21,33 +21,31 @@ import lombok.Getter; import lombok.Setter; import lombok.Value; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.rl4j.learning.HistoryProcessor; -import org.deeplearning4j.rl4j.learning.IHistoryProcessor; -import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.StepCountable; +import org.deeplearning4j.rl4j.learning.*; +import org.deeplearning4j.rl4j.learning.listener.*; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.Constants; import org.deeplearning4j.rl4j.util.IDataManager; import org.nd4j.linalg.factory.Nd4j; /** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. - * * This represent a local thread that explore the environment * and calculate a gradient to enqueue to the global thread/model * * It has its own version of a model that it syncs at the start of every * sub epoch * + * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. + * @author Alexandre Boulanger */ @Slf4j public abstract class AsyncThread, NN extends NeuralNet> - extends Thread implements StepCountable { + extends Thread implements StepCountable, IEpochTrainer { + @Getter private int threadNumber; @Getter protected final int deviceNum; @@ -55,12 +53,16 @@ public abstract class AsyncThread mdp; @Getter @Setter private IHistoryProcessor historyProcessor; - @Getter - private int lastMonitor = -Constants.MONITOR_FREQ; - public AsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, int deviceNum) { + private final TrainingListenerList listeners; + + public AsyncThread(IAsyncGlobal asyncGlobal, MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) { + this.mdp = mdp; + this.listeners = listeners; this.threadNumber = threadNumber; this.deviceNum = deviceNum; } @@ -80,75 +82,106 @@ public abstract class AsyncThread= Constants.MONITOR_FREQ && getHistoryProcessor() != null - && getDataManager().isSaveData()) { - lastMonitor = getStepCounter(); - int[] shape = getMdp().getObservationSpace().getShape(); - getHistoryProcessor().startMonitor(getDataManager().getVideoDir() + "/video-" + threadNumber + "-" - + getEpochCounter() + "-" + getStepCounter() + ".mp4", shape); - } + // Do nothing } + /** + * This method will start the worker thread

+ * The thread will stop when:
+ * - The AsyncGlobal thread terminates or reports that the training is complete + * (see {@link AsyncGlobal#isTrainingComplete()}). In such case, the currently running epoch will still be handled normally and + * events will also be fired normally.
+ * OR
+ * - a listener explicitly stops it, in which case, the AsyncGlobal thread will be terminated along with + * all other worker threads
+ *

+ * Listeners
+ * For a given event, the listeners are called sequentially in same the order as they were added. If one listener + * returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse + * TrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.
+ * Events: + *

    + *
  • {@link TrainingListener#onNewEpoch(IEpochTrainer) onNewEpoch()} is called when a new epoch is started.
  • + *
  • {@link TrainingListener#onEpochTrainingResult(IEpochTrainer, IDataManager.StatEntry) onEpochTrainingResult()} is called at the end of every + * epoch. It will not be called if onNewEpoch() stops the training.
  • + *
+ */ @Override public void run() { + RunContext context = new RunContext<>(); Nd4j.getAffinityManager().unsafeSetDevice(deviceNum); + log.info("ThreadNum-" + threadNumber + " Started!"); - try { - log.info("ThreadNum-" + threadNumber + " Started!"); - getCurrent().reset(); - Learning.InitMdp initMdp = Learning.initMdp(getMdp(), historyProcessor); - O obs = initMdp.getLastObs(); - double rewards = initMdp.getReward(); - int length = initMdp.getSteps(); + boolean canContinue = initWork(context); + if (canContinue) { - preEpoch(); while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) { - int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - length); - SubEpochReturn subEpochReturn = trainSubEpoch(obs, maxSteps); - obs = subEpochReturn.getLastObs(); - stepCounter += subEpochReturn.getSteps(); - length += subEpochReturn.getSteps(); - rewards += subEpochReturn.getReward(); - double score = subEpochReturn.getScore(); - if (length >= getConf().getMaxEpochStep() || getMdp().isDone()) { - postEpoch(); - - IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score); - getDataManager().appendStat(statEntry); - log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward()); - - getCurrent().reset(); - initMdp = Learning.initMdp(getMdp(), historyProcessor); - obs = initMdp.getLastObs(); - rewards = initMdp.getReward(); - length = initMdp.getSteps(); - epochCounter++; - - preEpoch(); + handleTraining(context); + if (context.length >= getConf().getMaxEpochStep() || getMdp().isDone()) { + canContinue = finishEpoch(context) && startNewEpoch(context); + if (!canContinue) { + break; + } } } - } catch (Exception e) { - log.error("Thread crashed: " + e.getCause()); - getAsyncGlobal().setRunning(false); - e.printStackTrace(); - } finally { - postEpoch(); } + terminateWork(); + } + + private void initNewEpoch(RunContext context) { + getCurrent().reset(); + Learning.InitMdp initMdp = Learning.initMdp(getMdp(), historyProcessor); + + context.obs = initMdp.getLastObs(); + context.rewards = initMdp.getReward(); + context.length = initMdp.getSteps(); + } + + private void handleTraining(RunContext context) { + int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.length); + SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); + + context.obs = subEpochReturn.getLastObs(); + stepCounter += subEpochReturn.getSteps(); + context.length += subEpochReturn.getSteps(); + context.rewards += subEpochReturn.getReward(); + context.score = subEpochReturn.getScore(); + } + + private boolean initWork(RunContext context) { + initNewEpoch(context); + preEpoch(); + return listeners.notifyNewEpoch(this); + } + + private boolean startNewEpoch(RunContext context) { + initNewEpoch(context); + epochCounter++; + preEpoch(); + return listeners.notifyNewEpoch(this); + } + + private boolean finishEpoch(RunContext context) { + postEpoch(); + IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.length, context.score); + + log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards); + + return listeners.notifyEpochTrainingResult(this, statEntry); + } + + private void terminateWork() { + postEpoch(); + getAsyncGlobal().terminate(); } protected abstract NN getCurrent(); - protected abstract int getThreadNumber(); - protected abstract IAsyncGlobal getAsyncGlobal(); - protected abstract MDP getMdp(); - protected abstract AsyncConfiguration getConf(); - protected abstract IDataManager getDataManager(); - protected abstract Policy getPolicy(NN net); protected abstract SubEpochReturn trainSubEpoch(O obs, int nstep); @@ -172,4 +205,11 @@ public abstract class AsyncThread { + private O obs; + private double rewards; + private int length; + private double score; + } + } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index da4da9f22..7458c0c06 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -21,7 +21,9 @@ import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; @@ -44,8 +46,8 @@ public abstract class AsyncThreadDiscrete asyncGlobal, int threadNumber, int deviceNum) { - super(asyncGlobal, threadNumber, deviceNum); + public AsyncThreadDiscrete(IAsyncGlobal asyncGlobal, MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) { + super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); synchronized (asyncGlobal) { current = (NN)asyncGlobal.getCurrent().clone(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java index 138bff943..df3d476f9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java @@ -23,9 +23,14 @@ import java.util.concurrent.atomic.AtomicInteger; public interface IAsyncGlobal { boolean isRunning(); - void setRunning(boolean value); boolean isTrainingComplete(); void start(); + + /** + * Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded. + */ + void terminate(); + AtomicInteger getT(); NN getCurrent(); NN getTarget(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java index 7dbec6210..2b1aa4ef4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java @@ -26,7 +26,6 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. @@ -47,24 +46,19 @@ public abstract class A3CDiscrete extends AsyncLearning policy; - @Getter - final private IDataManager dataManager; - public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CConfiguration conf, - IDataManager dataManager) { + public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CConfiguration conf) { super(conf); this.iActorCritic = iActorCritic; this.mdp = mdp; this.configuration = conf; - this.dataManager = dataManager; policy = new ACPolicy<>(iActorCritic, getRandom()); asyncGlobal = new AsyncGlobal<>(iActorCritic, conf); mdp.getActionSpace().setSeed(conf.getSeed()); } - @Override protected AsyncThread newThread(int i, int deviceNum) { - return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), i, dataManager, deviceNum); + return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), deviceNum, getListeners(), i); } public IActorCritic getNeuralNet() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java index 196f3a6c3..0bb835c90 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java @@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; /** @@ -43,24 +44,38 @@ public class A3CDiscreteConv extends A3CDiscrete { final private HistoryProcessor.Configuration hpconf; - public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, + @Deprecated + public A3CDiscreteConv(MDP mdp, IActorCritic actorCritic, HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { - super(mdp, IActorCritic, conf, dataManager); + this(mdp, actorCritic, hpconf, conf); + addListener(new DataManagerTrainingListener(dataManager)); + } + public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, + HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { + super(mdp, IActorCritic, conf); this.hpconf = hpconf; setHistoryProcessor(hpconf); } - + @Deprecated public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { - this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, - dataManager); + this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); + } + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, + HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { + this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } + @Deprecated public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager); } + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, + HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { + this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf); + } @Override public AsyncThread newThread(int i, int deviceNum) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java index c67659589..16b8151df 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java @@ -20,6 +20,7 @@ import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.*; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; /** @@ -33,33 +34,58 @@ import org.deeplearning4j.rl4j.util.IDataManager; */ public class A3CDiscreteDense extends A3CDiscrete { + @Deprecated public A3CDiscreteDense(MDP mdp, IActorCritic IActorCritic, A3CConfiguration conf, IDataManager dataManager) { - super(mdp, IActorCritic, conf, dataManager); + this(mdp, IActorCritic, conf); + addListener(new DataManagerTrainingListener(dataManager)); + } + public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CConfiguration conf) { + super(mdp, actorCritic, conf); } + @Deprecated public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, + A3CConfiguration conf) { + this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } + @Deprecated public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager); } + public A3CDiscreteDense(MDP mdp, + ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) { + this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf); + } + @Deprecated public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + A3CConfiguration conf) { + this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } + @Deprecated public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager); } + public A3CDiscreteDense(MDP mdp, + ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf) { + this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index 3a481b09c..0999e5d42 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -22,13 +22,13 @@ import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.IDataManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -46,24 +46,19 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< @Getter final protected A3CDiscrete.A3CConfiguration conf; @Getter - final protected MDP mdp; - @Getter final protected AsyncGlobal asyncGlobal; @Getter final protected int threadNumber; - @Getter - final protected IDataManager dataManager; final private Random random; public A3CThreadDiscrete(MDP mdp, AsyncGlobal asyncGlobal, - A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager, int deviceNum) { - super(asyncGlobal, threadNumber, deviceNum); + A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners, + int threadNumber) { + super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); this.conf = a3cc; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; - this.mdp = mdp; - this.dataManager = dataManager; mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber); random = new Random(conf.getSeed() + threadNumber); } @@ -85,15 +80,15 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< //if recurrent then train as a time serie with a batch size of 1 boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent(); - int[] shape = getHistoryProcessor() == null ? mdp.getObservationSpace().getShape() + int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); int[] nshape = recurrent ? Learning.makeShape(1, shape, size) : Learning.makeShape(size, shape); INDArray input = Nd4j.create(nshape); INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1); - INDArray logSoftmax = recurrent ? Nd4j.zeros(1, mdp.getActionSpace().getSize(), size) - : Nd4j.zeros(size, mdp.getActionSpace().getSize()); + INDArray logSoftmax = recurrent ? Nd4j.zeros(1, getMdp().getActionSpace().getSize(), size) + : Nd4j.zeros(size, getMdp().getActionSpace().getSize()); double r = minTrans.getReward(); for (int i = size - 1; i >= 0; i--) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index bab60fec4..cef53543a 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -24,10 +24,9 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.policy.DQNPolicy; -import org.deeplearning4j.rl4j.policy.Policy; +import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. @@ -40,16 +39,12 @@ public abstract class AsyncNStepQLearningDiscrete @Getter final private MDP mdp; @Getter - final private IDataManager dataManager; - @Getter final private AsyncGlobal asyncGlobal; - public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf, - IDataManager dataManager) { + public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { super(conf); this.mdp = mdp; - this.dataManager = dataManager; this.configuration = conf; this.asyncGlobal = new AsyncGlobal<>(dqn, conf); mdp.getActionSpace().setSeed(conf.getSeed()); @@ -57,14 +52,14 @@ public abstract class AsyncNStepQLearningDiscrete @Override public AsyncThread newThread(int i, int deviceNum) { - return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager, deviceNum); + return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, getListeners(), i, deviceNum); } public IDQN getNeuralNet() { return asyncGlobal.getCurrent(); } - public Policy getPolicy() { + public IPolicy getPolicy() { return new DQNPolicy(getNeuralNet()); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java index 257e5fb5d..83274b7f6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java @@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; /** @@ -35,22 +36,38 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN final private HistoryProcessor.Configuration hpconf; + @Deprecated public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { - super(mdp, dqn, conf, dataManager); + this(mdp, dqn, hpconf, conf); + addListener(new DataManagerTrainingListener(dataManager)); + } + public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, + HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { + super(mdp, dqn, conf); this.hpconf = hpconf; setHistoryProcessor(hpconf); } + @Deprecated public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } + public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, + HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { + this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); + } + @Deprecated public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); } + public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, + HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { + this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); + } @Override public AsyncThread newThread(int i, int deviceNum) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java index 837681981..b58e15902 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java @@ -22,6 +22,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; /** @@ -29,19 +30,37 @@ import org.deeplearning4j.rl4j.util.IDataManager; */ public class AsyncNStepQLearningDiscreteDense extends AsyncNStepQLearningDiscrete { + @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf, IDataManager dataManager) { - super(mdp, dqn, conf, dataManager); + super(mdp, dqn, conf); + addListener(new DataManagerTrainingListener(dataManager)); } + public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, + AsyncNStepQLConfiguration conf) { + super(mdp, dqn, conf); + } + + @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, + AsyncNStepQLConfiguration conf) { + this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } + @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); } + public AsyncNStepQLearningDiscreteDense(MDP mdp, + DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) { + this(mdp, new DQNFactoryStdDense(netConf), conf); + } + } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index 23d6f79ca..6bd1c8b6d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -19,9 +19,10 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import lombok.Getter; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; +import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.policy.DQNPolicy; @@ -29,7 +30,6 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.IDataManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -44,31 +44,25 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn @Getter final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf; @Getter - final protected MDP mdp; - @Getter final protected IAsyncGlobal asyncGlobal; @Getter final protected int threadNumber; - @Getter - final protected IDataManager dataManager; final private Random random; public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, - AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber, - IDataManager dataManager, int deviceNum) { - super(asyncGlobal, threadNumber, deviceNum); + AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, + TrainingListenerList listeners, int threadNumber, int deviceNum) { + super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); this.conf = conf; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; - this.mdp = mdp; - this.dataManager = dataManager; mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber); random = new Random(conf.getSeed() + threadNumber); } public Policy getPolicy(IDQN nn) { - return new EpsGreedy(new DQNPolicy(nn), mdp, conf.getUpdateStart(), conf.getEpsilonNbStep(), + return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(), random, conf.getMinEpsilon(), this); } @@ -81,11 +75,11 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn int size = rewards.size(); - int[] shape = getHistoryProcessor() == null ? mdp.getObservationSpace().getShape() + int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); int[] nshape = Learning.makeShape(size, shape); INDArray input = Nd4j.create(nshape); - INDArray targets = Nd4j.create(size, mdp.getActionSpace().getSize()); + INDArray targets = Nd4j.create(size, getMdp().getActionSpace().getSize()); double r = minTrans.getReward(); for (int i = size - 1; i >= 0; i--) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListener.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListener.java new file mode 100644 index 000000000..7eab385e1 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListener.java @@ -0,0 +1,72 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.learning.listener; + +import org.deeplearning4j.rl4j.learning.IEpochTrainer; +import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.util.IDataManager; + +/** + * The base definition of all training event listeners + * + * @author Alexandre Boulanger + */ +public interface TrainingListener { + enum ListenerResponse { + /** + * Tell the learning process to continue calling the listeners and the training. + */ + CONTINUE, + + /** + * Tell the learning process to stop calling the listeners and terminate the training. + */ + STOP, + } + + /** + * Called once when the training starts. + * @return A ListenerResponse telling the source of the event if it should go on or cancel the training. + */ + ListenerResponse onTrainingStart(); + + /** + * Called once when the training has finished. This method is called even when the training has been aborted. + */ + void onTrainingEnd(); + + /** + * Called before the start of every epoch. + * @param trainer A {@link IEpochTrainer} + * @return A ListenerResponse telling the source of the event if it should continue or stop the training. + */ + ListenerResponse onNewEpoch(IEpochTrainer trainer); + + /** + * Called when an epoch has been completed + * @param trainer A {@link IEpochTrainer} + * @param statEntry A {@link org.deeplearning4j.rl4j.util.IDataManager.StatEntry} + * @return A ListenerResponse telling the source of the event if it should continue or stop the training. + */ + ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry); + + /** + * Called regularly to monitor the training progress. + * @param learning A {@link ILearning} + * @return A ListenerResponse telling the source of the event if it should continue or stop the training. + */ + ListenerResponse onTrainingProgress(ILearning learning); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerList.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerList.java new file mode 100644 index 000000000..a1c6451d0 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerList.java @@ -0,0 +1,105 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.listener; + +import org.deeplearning4j.rl4j.learning.IEpochTrainer; +import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.util.IDataManager; + +import java.util.ArrayList; +import java.util.List; + +/** + * The base logic to notify training listeners with the different training events. + * + * @author Alexandre Boulanger + */ +public class TrainingListenerList { + protected final List listeners = new ArrayList<>(); + + /** + * Add a listener at the end of the list + * @param listener The listener to be added + */ + public void add(TrainingListener listener) { + listeners.add(listener); + } + + /** + * Notify the listeners that the training has started. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP} + * @return whether or not the source training should be stopped + */ + public boolean notifyTrainingStarted() { + for (TrainingListener listener : listeners) { + if (listener.onTrainingStart() == TrainingListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } + + /** + * Notify the listeners that the training has finished. + */ + public void notifyTrainingFinished() { + for (TrainingListener listener : listeners) { + listener.onTrainingEnd(); + } + } + + /** + * Notify the listeners that a new epoch has started. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP} + * @return whether or not the source training should be stopped + */ + public boolean notifyNewEpoch(IEpochTrainer trainer) { + for (TrainingListener listener : listeners) { + if (listener.onNewEpoch(trainer) == TrainingListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } + + /** + * Notify the listeners that an epoch has been completed and the training results are available. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP} + * @return whether or not the source training should be stopped + */ + public boolean notifyEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) { + for (TrainingListener listener : listeners) { + if (listener.onEpochTrainingResult(trainer, statEntry) == TrainingListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } + + /** + * Notify the listeners that they update the progress ot the trainning. + */ + public boolean notifyTrainingProgress(ILearning learning) { + for (TrainingListener listener : listeners) { + if (listener.onTrainingProgress(learning) == TrainingListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java index 3464410dc..6e9e0ccf1 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java @@ -16,19 +16,18 @@ package org.deeplearning4j.rl4j.learning.sync; +import lombok.Getter; +import lombok.Setter; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.rl4j.learning.IEpochTrainer; +import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent; -import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent; -import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener; +import org.deeplearning4j.rl4j.learning.listener.*; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager; -import java.util.ArrayList; -import java.util.List; - /** * Mother class and useful factorisations for all training methods that * are not asynchronous. @@ -38,9 +37,9 @@ import java.util.List; */ @Slf4j public abstract class SyncLearning, NN extends NeuralNet> - extends Learning { + extends Learning implements IEpochTrainer { - private List listeners = new ArrayList<>(); + private final TrainingListenerList listeners = new TrainingListenerList(); public SyncLearning(LConfiguration conf) { super(conf); @@ -49,12 +48,24 @@ public abstract class SyncLearning * The training stop when:
@@ -64,81 +75,49 @@ public abstract class SyncLearning * Listeners
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener - * returns {@link SyncTrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.
+ * returns {@link TrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.
* Events: *
    - *
  • {@link SyncTrainingListener#onTrainingStart(SyncTrainingEvent) onTrainingStart()} is called once when the training starts.
  • - *
  • {@link SyncTrainingListener#onEpochStart(SyncTrainingEvent) onEpochStart()} and {@link SyncTrainingListener#onEpochEnd(SyncTrainingEpochEndEvent) onEpochEnd()} are called for every epoch. onEpochEnd will not be called if onEpochStart stops the training
  • - *
  • {@link SyncTrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.
  • + *
  • {@link TrainingListener#onTrainingStart() onTrainingStart()} is called once when the training starts.
  • + *
  • {@link TrainingListener#onNewEpoch(IEpochTrainer) onNewEpoch()} and {@link TrainingListener#onEpochTrainingResult(IEpochTrainer, IDataManager.StatEntry) onEpochTrainingResult()} are called for every epoch. onEpochTrainingResult will not be called if onNewEpoch stops the training
  • + *
  • {@link TrainingListener#onTrainingProgress(ILearning) onTrainingProgress()} is called after onEpochTrainingResult()
  • + *
  • {@link TrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.
  • *
*/ public void train() { log.info("training starting."); - boolean canContinue = notifyTrainingStarted(); + boolean canContinue = listeners.notifyTrainingStarted(); if (canContinue) { while (getStepCounter() < getConfiguration().getMaxStep()) { preEpoch(); - canContinue = notifyEpochStarted(); + canContinue = listeners.notifyNewEpoch(this); if (!canContinue) { break; } IDataManager.StatEntry statEntry = trainEpoch(); - - postEpoch(); - canContinue = notifyEpochFinished(statEntry); + canContinue = listeners.notifyEpochTrainingResult(this, statEntry); if (!canContinue) { break; } - log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward()); + postEpoch(); + if(getEpochCounter() % progressMonitorFrequency == 0) { + canContinue = listeners.notifyTrainingProgress(this); + if (!canContinue) { + break; + } + } + + log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward()); incrementEpoch(); } } - notifyTrainingFinished(); - } - - private boolean notifyTrainingStarted() { - SyncTrainingEvent event = new SyncTrainingEvent(this); - for (SyncTrainingListener listener : listeners) { - if (listener.onTrainingStart(event) == SyncTrainingListener.ListenerResponse.STOP) { - return false; - } - } - - return true; - } - - private void notifyTrainingFinished() { - for (SyncTrainingListener listener : listeners) { - listener.onTrainingEnd(); - } - } - - private boolean notifyEpochStarted() { - SyncTrainingEvent event = new SyncTrainingEvent(this); - for (SyncTrainingListener listener : listeners) { - if (listener.onEpochStart(event) == SyncTrainingListener.ListenerResponse.STOP) { - return false; - } - } - - return true; - } - - private boolean notifyEpochFinished(IDataManager.StatEntry statEntry) { - SyncTrainingEpochEndEvent event = new SyncTrainingEpochEndEvent(this, statEntry); - for (SyncTrainingListener listener : listeners) { - if (listener.onEpochEnd(event) == SyncTrainingListener.ListenerResponse.STOP) { - return false; - } - } - - return true; + listeners.notifyTrainingFinished(); } protected abstract void preEpoch(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEpochEndEvent.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEpochEndEvent.java deleted file mode 100644 index 71a357ec8..000000000 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEpochEndEvent.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.deeplearning4j.rl4j.learning.sync.listener; - -import lombok.Getter; -import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.util.IDataManager; - -/** - * A subclass of SyncTrainingEvent that is passed to SyncTrainingListener.onEpochEnd() - */ -public class SyncTrainingEpochEndEvent extends SyncTrainingEvent { - - /** - * The stats of the epoch training - */ - @Getter - private final IDataManager.StatEntry statEntry; - - public SyncTrainingEpochEndEvent(Learning learning, IDataManager.StatEntry statEntry) { - super(learning); - this.statEntry = statEntry; - } -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEvent.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEvent.java deleted file mode 100644 index 964040f28..000000000 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEvent.java +++ /dev/null @@ -1,21 +0,0 @@ -package org.deeplearning4j.rl4j.learning.sync.listener; - -import lombok.Getter; -import lombok.Setter; -import org.deeplearning4j.rl4j.learning.Learning; - -/** - * SyncTrainingEvent are passed as parameters to the events of SyncTrainingListener - */ -public class SyncTrainingEvent { - - /** - * The source of the event - */ - @Getter - private final Learning learning; - - public SyncTrainingEvent(Learning learning) { - this.learning = learning; - } -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingListener.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingListener.java deleted file mode 100644 index 852c16036..000000000 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingListener.java +++ /dev/null @@ -1,45 +0,0 @@ -package org.deeplearning4j.rl4j.learning.sync.listener; - -/** - * A listener interface to use with a descendant of {@link org.deeplearning4j.rl4j.learning.sync.SyncLearning} - */ -public interface SyncTrainingListener { - - public enum ListenerResponse { - /** - * Tell SyncLearning to continue calling the listeners and the training. - */ - CONTINUE, - - /** - * Tell SyncLearning to stop calling the listeners and terminate the training. - */ - STOP, - } - - /** - * Called once when the training starts. - * @param event - * @return A ListenerResponse telling the source of the event if it should go on or cancel the training. - */ - ListenerResponse onTrainingStart(SyncTrainingEvent event); - - /** - * Called once when the training has finished. This method is called even when the training has been aborted. - */ - void onTrainingEnd(); - - /** - * Called before the start of every epoch. - * @param event - * @return A ListenerResponse telling the source of the event if it should continue or stop the training. - */ - ListenerResponse onEpochStart(SyncTrainingEvent event); - - /** - * Called after the end of every epoch. - * @param event - * @return A ListenerResponse telling the source of the event if it should continue or stop the training. - */ - ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event); -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index 525995455..564f654fc 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -49,7 +49,7 @@ public abstract class QLearning expReplay; @Getter - @Setter(AccessLevel.PACKAGE) + @Setter(AccessLevel.PROTECTED) protected IExpReplay expReplay; public QLearning(QLConfiguration conf) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index 351226e9d..7ea47eba8 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -28,8 +28,6 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener; -import org.deeplearning4j.rl4j.util.IDataManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; @@ -64,20 +62,9 @@ public abstract class QLearningDiscrete extends QLearning mdp, IDQN dqn, QLConfiguration conf, - IDataManager dataManager, int epsilonNbStep) { - this(mdp, dqn, conf, epsilonNbStep); - addListener(DataManagerSyncTrainingListener.builder(dataManager).build()); - } - public QLearningDiscrete(MDP mdp, IDQN dqn, QLConfiguration conf, int epsilonNbStep) { super(conf); @@ -186,7 +173,6 @@ public abstract class QLearningDiscrete extends QLearning setTarget(ArrayList> transitions) { if (transitions.size() == 0) throw new IllegalArgumentException("too few transitions"); @@ -194,7 +180,7 @@ public abstract class QLearningDiscrete extends QLearning extends QLearningDiscrete { - + @Deprecated public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { - super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); + this(mdp, dqn, hpconf, conf); + addListener(new DataManagerTrainingListener(dataManager)); + } + public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, + QLConfiguration conf) { + super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); setHistoryProcessor(hpconf); } + @Deprecated public QLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } + public QLearningDiscreteConv(MDP mdp, DQNFactory factory, + HistoryProcessor.Configuration hpconf, QLConfiguration conf) { + this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); + } + @Deprecated public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); } + public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, + HistoryProcessor.Configuration hpconf, QLConfiguration conf) { + this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); + } + } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java index 341031aec..ef69ea6fb 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java @@ -23,6 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; /** @@ -31,21 +32,35 @@ import org.deeplearning4j.rl4j.util.IDataManager; public class QLearningDiscreteDense extends QLearningDiscrete { - + @Deprecated public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf, IDataManager dataManager) { - super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep()); + this(mdp, dqn, conf); + addListener(new DataManagerTrainingListener(dataManager)); + } + public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf) { + super(mdp, dqn, conf, conf.getEpsilonNbStep()); } + @Deprecated public QLearningDiscreteDense(MDP mdp, DQNFactory factory, QLearning.QLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } + public QLearningDiscreteDense(MDP mdp, DQNFactory factory, + QLearning.QLConfiguration conf) { + this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } + @Deprecated public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, QLearning.QLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); } + public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, + QLearning.QLConfiguration conf) { + this(mdp, new DQNFactoryStdDense(netConf), conf); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java new file mode 100644 index 000000000..5c9d54d45 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java @@ -0,0 +1,10 @@ +package org.deeplearning4j.rl4j.policy; + +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Encodable; + +public interface IPolicy { + > double play(MDP mdp, IHistoryProcessor hp); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index a5d5d261c..1be123f1d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -35,7 +35,7 @@ import org.nd4j.linalg.util.ArrayUtil; * * A Policy responsability is to choose the next action given a state */ -public abstract class Policy { +public abstract class Policy implements IPolicy { public abstract NeuralNet getNeuralNet(); @@ -49,6 +49,7 @@ public abstract class Policy { return play(mdp, new HistoryProcessor(conf)); } + @Override public > double play(MDP mdp, IHistoryProcessor hp) { getNeuralNet().reset(); Learning.InitMdp initMdp = Learning.initMdp(mdp, hp); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java index e9c243eea..4f7d6244e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java @@ -22,6 +22,7 @@ import lombok.Builder; import lombok.Getter; import lombok.Value; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.rl4j.learning.NeuralNetFetchable; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.learning.Learning; @@ -72,13 +73,13 @@ public class DataManager implements IDataManager { } } - public static void save(String path, Learning learning) throws IOException { + public static void save(String path, ILearning learning) throws IOException { try (BufferedOutputStream os = new BufferedOutputStream(new FileOutputStream(path))) { save(os, learning); } } - public static void save(OutputStream os, Learning learning) throws IOException { + public static void save(OutputStream os, ILearning learning) throws IOException { try (ZipOutputStream zipfile = new ZipOutputStream(os)) { @@ -91,7 +92,9 @@ public class DataManager implements IDataManager { zipfile.putNextEntry(dqn); ByteArrayOutputStream bos = new ByteArrayOutputStream(); - learning.getNeuralNet().save(bos); + if(learning instanceof NeuralNetFetchable) { + ((NeuralNetFetchable)learning).getNeuralNet().save(bos); + } bos.flush(); bos.close(); @@ -104,7 +107,9 @@ public class DataManager implements IDataManager { zipfile.putNextEntry(hpconf); ByteArrayOutputStream bos2 = new ByteArrayOutputStream(); - learning.getNeuralNet().save(bos2); + if(learning instanceof NeuralNetFetchable) { + ((NeuralNetFetchable)learning).getNeuralNet().save(bos2); + } bos2.flush(); bos2.close(); @@ -256,13 +261,15 @@ public class DataManager implements IDataManager { return exists; } - public void save(Learning learning) throws IOException { + public void save(ILearning learning) throws IOException { if (!saveData) return; save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning); - learning.getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model"); + if(learning instanceof NeuralNetFetchable) { + ((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model"); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerSyncTrainingListener.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerSyncTrainingListener.java deleted file mode 100644 index c9166f34e..000000000 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerSyncTrainingListener.java +++ /dev/null @@ -1,126 +0,0 @@ -package org.deeplearning4j.rl4j.util; - -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent; -import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent; -import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener; - -/** - * DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the - * training process can be fed to the DataManager - */ -@Slf4j -public class DataManagerSyncTrainingListener implements SyncTrainingListener { - private final IDataManager dataManager; - private final int saveFrequency; - private final int monitorFrequency; - - private int lastSave; - private int lastMonitor; - - private DataManagerSyncTrainingListener(Builder builder) { - this.dataManager = builder.dataManager; - - this.saveFrequency = builder.saveFrequency; - this.lastSave = -builder.saveFrequency; - - this.monitorFrequency = builder.monitorFrequency; - this.lastMonitor = -builder.monitorFrequency; - } - - @Override - public ListenerResponse onTrainingStart(SyncTrainingEvent event) { - try { - dataManager.writeInfo(event.getLearning()); - } catch (Exception e) { - log.error("Training failed.", e); - return ListenerResponse.STOP; - } - return ListenerResponse.CONTINUE; - } - - @Override - public void onTrainingEnd() { - // Do nothing - } - - @Override - public ListenerResponse onEpochStart(SyncTrainingEvent event) { - int stepCounter = event.getLearning().getStepCounter(); - - if (stepCounter - lastMonitor >= monitorFrequency - && event.getLearning().getHistoryProcessor() != null - && dataManager.isSaveData()) { - lastMonitor = stepCounter; - int[] shape = event.getLearning().getMdp().getObservationSpace().getShape(); - event.getLearning().getHistoryProcessor().startMonitor(dataManager.getVideoDir() + "/video-" + event.getLearning().getEpochCounter() + "-" - + stepCounter + ".mp4", shape); - } - - return ListenerResponse.CONTINUE; - } - - @Override - public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) { - try { - int stepCounter = event.getLearning().getStepCounter(); - if (stepCounter - lastSave >= saveFrequency) { - dataManager.save(event.getLearning()); - lastSave = stepCounter; - } - - dataManager.appendStat(event.getStatEntry()); - dataManager.writeInfo(event.getLearning()); - } catch (Exception e) { - log.error("Training failed.", e); - return ListenerResponse.STOP; - } - - return ListenerResponse.CONTINUE; - } - - public static Builder builder(IDataManager dataManager) { - return new Builder(dataManager); - } - - public static class Builder { - private final IDataManager dataManager; - private int saveFrequency = Constants.MODEL_SAVE_FREQ; - private int monitorFrequency = Constants.MONITOR_FREQ; - - /** - * Create a Builder with the given DataManager - * @param dataManager - */ - public Builder(IDataManager dataManager) { - this.dataManager = dataManager; - } - - /** - * A number that represent the number of steps since the last call to DataManager.save() before can it be called again. - * @param saveFrequency (Default: 100000) - */ - public Builder saveFrequency(int saveFrequency) { - this.saveFrequency = saveFrequency; - return this; - } - - /** - * A number that represent the number of steps since the last call to HistoryProcessor.startMonitor() before can it be called again. - * @param monitorFrequency (Default: 10000) - */ - public Builder monitorFrequency(int monitorFrequency) { - this.monitorFrequency = monitorFrequency; - return this; - } - - /** - * Creates a DataManagerSyncTrainingListener with the configured parameters - * @return An instance of DataManagerSyncTrainingListener - */ - public DataManagerSyncTrainingListener build() { - return new DataManagerSyncTrainingListener(this); - } - - } -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java new file mode 100644 index 000000000..83b8d71da --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java @@ -0,0 +1,83 @@ +package org.deeplearning4j.rl4j.util; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.rl4j.learning.IEpochTrainer; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.learning.async.AsyncThread; +import org.deeplearning4j.rl4j.learning.listener.TrainingListener; + +/** + * DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the + * training process can be fed to the DataManager + */ +@Slf4j +public class DataManagerTrainingListener implements TrainingListener { + private final IDataManager dataManager; + + private int lastSave = -Constants.MODEL_SAVE_FREQ; + + public DataManagerTrainingListener(IDataManager dataManager) { + this.dataManager = dataManager; + } + + @Override + public ListenerResponse onTrainingStart() { + return ListenerResponse.CONTINUE; + } + + @Override + public void onTrainingEnd() { + + } + + @Override + public ListenerResponse onNewEpoch(IEpochTrainer trainer) { + IHistoryProcessor hp = trainer.getHistoryProcessor(); + if(hp != null) { + int[] shape = trainer.getMdp().getObservationSpace().getShape(); + String filename = dataManager.getVideoDir() + "/video-"; + if (trainer instanceof AsyncThread) { + filename += ((AsyncThread) trainer).getThreadNumber() + "-"; + } + filename += trainer.getEpochCounter() + "-" + trainer.getStepCounter() + ".mp4"; + hp.startMonitor(filename, shape); + } + + return ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) { + IHistoryProcessor hp = trainer.getHistoryProcessor(); + if(hp != null) { + hp.stopMonitor(); + } + try { + dataManager.appendStat(statEntry); + } catch (Exception e) { + log.error("Training failed.", e); + return ListenerResponse.STOP; + } + + return ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse onTrainingProgress(ILearning learning) { + try { + int stepCounter = learning.getStepCounter(); + if (stepCounter - lastSave >= Constants.MODEL_SAVE_FREQ) { + dataManager.save(learning); + lastSave = stepCounter; + } + + dataManager.writeInfo(learning); + } catch (Exception e) { + log.error("Training failed.", e); + return ListenerResponse.STOP; + } + + return ListenerResponse.CONTINUE; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java index d265bca5a..c9ad940ab 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java @@ -27,7 +27,7 @@ public interface IDataManager { String getVideoDir(); void appendStat(StatEntry statEntry) throws IOException; void writeInfo(ILearning iLearning) throws IOException; - void save(Learning learning) throws IOException; + void save(ILearning learning) throws IOException; //In order for jackson to serialize StatEntry //please use Lombok @Value (see QLStatEntry) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java new file mode 100644 index 000000000..ec0bca94f --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java @@ -0,0 +1,127 @@ +package org.deeplearning4j.rl4j.learning.async; + +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.policy.IPolicy; +import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.deeplearning4j.rl4j.support.*; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class AsyncLearningTest { + + @Test + public void when_training_expect_AsyncGlobalStarted() { + // Arrange + TestContext context = new TestContext(); + context.asyncGlobal.setMaxLoops(1); + + // Act + context.sut.train(); + + // Assert + assertTrue(context.asyncGlobal.hasBeenStarted); + assertTrue(context.asyncGlobal.hasBeenTerminated); + } + + @Test + public void when_trainStartReturnsStop_expect_noTraining() { + // Arrange + TestContext context = new TestContext(); + context.listener.setRemainingTrainingStartCallCount(0); + // Act + context.sut.train(); + + // Assert + assertEquals(1, context.listener.onTrainingStartCallCount); + assertEquals(1, context.listener.onTrainingEndCallCount); + assertEquals(0, context.policy.playCallCount); + assertTrue(context.asyncGlobal.hasBeenTerminated); + } + + @Test + public void when_trainingIsComplete_expect_trainingStop() { + // Arrange + TestContext context = new TestContext(); + + // Act + context.sut.train(); + + // Assert + assertEquals(1, context.listener.onTrainingStartCallCount); + assertEquals(1, context.listener.onTrainingEndCallCount); + assertTrue(context.asyncGlobal.hasBeenTerminated); + } + + @Test + public void when_training_expect_onTrainingProgressCalled() { + // Arrange + TestContext context = new TestContext(); + + // Act + context.sut.train(); + + // Assert + assertEquals(1, context.listener.onTrainingProgressCallCount); + } + + + public static class TestContext { + public final MockAsyncConfiguration conf = new MockAsyncConfiguration(1, 1); + public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(); + public final MockPolicy policy = new MockPolicy(); + public final TestAsyncLearning sut = new TestAsyncLearning(conf, asyncGlobal, policy); + public final MockTrainingListener listener = new MockTrainingListener(); + + public TestContext() { + sut.addListener(listener); + asyncGlobal.setMaxLoops(1); + sut.setProgressMonitorFrequency(1); + } + } + + public static class TestAsyncLearning extends AsyncLearning { + private final AsyncConfiguration conf; + private final IAsyncGlobal asyncGlobal; + private final IPolicy policy; + + public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy policy) { + super(conf); + this.conf = conf; + this.asyncGlobal = asyncGlobal; + this.policy = policy; + } + + @Override + public IPolicy getPolicy() { + return policy; + } + + @Override + public AsyncConfiguration getConfiguration() { + return conf; + } + + @Override + protected AsyncThread newThread(int i, int deviceAffinity) { + return null; + } + + @Override + public MDP getMdp() { + return null; + } + + @Override + protected IAsyncGlobal getAsyncGlobal() { + return asyncGlobal; + } + + @Override + public MockNeuralNet getNeuralNet() { + return null; + } + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index 23be44f01..4d9e70b56 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -1,206 +1,135 @@ package org.deeplearning4j.rl4j.learning.async; -import org.deeplearning4j.nn.api.NeuralNetwork; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.support.MockDataManager; -import org.deeplearning4j.rl4j.support.MockHistoryProcessor; -import org.deeplearning4j.rl4j.support.MockMDP; -import org.deeplearning4j.rl4j.support.MockObservationSpace; +import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; public class AsyncThreadTest { @Test - public void refac_withoutHistoryProcessor_checkDataManagerCallsRemainTheSame() { + public void when_newEpochStarted_expect_neuralNetworkReset() { // Arrange - MockDataManager dataManager = new MockDataManager(false); - MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10); - MockNeuralNet neuralNet = new MockNeuralNet(); - MockObservationSpace observationSpace = new MockObservationSpace(); - MockMDP mdp = new MockMDP(observationSpace); - MockAsyncConfiguration config = new MockAsyncConfiguration(10, 2); - MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, dataManager); + TestContext context = new TestContext(); + context.listener.setRemainingOnNewEpochCallCount(5); // Act - sut.run(); + context.sut.run(); // Assert - assertEquals(4, dataManager.statEntries.size()); - - IDataManager.StatEntry entry = dataManager.statEntries.get(0); - assertEquals(2, entry.getStepCounter()); - assertEquals(0, entry.getEpochCounter()); - assertEquals(2.0, entry.getReward(), 0.0); - - entry = dataManager.statEntries.get(1); - assertEquals(4, entry.getStepCounter()); - assertEquals(1, entry.getEpochCounter()); - assertEquals(2.0, entry.getReward(), 0.0); - - entry = dataManager.statEntries.get(2); - assertEquals(6, entry.getStepCounter()); - assertEquals(2, entry.getEpochCounter()); - assertEquals(2.0, entry.getReward(), 0.0); - - entry = dataManager.statEntries.get(3); - assertEquals(8, entry.getStepCounter()); - assertEquals(3, entry.getEpochCounter()); - assertEquals(2.0, entry.getReward(), 0.0); - - assertEquals(0, dataManager.isSaveDataCallCount); - assertEquals(0, dataManager.getVideoDirCallCount); + assertEquals(6, context.neuralNet.resetCallCount); } @Test - public void refac_withHistoryProcessor_isSaveFalse_checkDataManagerCallsRemainTheSame() { + public void when_onNewEpochReturnsStop_expect_threadStopped() { // Arrange - MockDataManager dataManager = new MockDataManager(false); - MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10); - MockNeuralNet neuralNet = new MockNeuralNet(); - MockObservationSpace observationSpace = new MockObservationSpace(); - MockMDP mdp = new MockMDP(observationSpace); - MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2); - - IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder() - .build(); - MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig); - - - MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager); - sut.setHistoryProcessor(hp); + TestContext context = new TestContext(); + context.listener.setRemainingOnNewEpochCallCount(1); // Act - sut.run(); + context.sut.run(); // Assert - assertEquals(9, dataManager.statEntries.size()); - - for(int i = 0; i < 9; ++i) { - IDataManager.StatEntry entry = dataManager.statEntries.get(i); - assertEquals(i + 1, entry.getStepCounter()); - assertEquals(i, entry.getEpochCounter()); - assertEquals(79.0, entry.getReward(), 0.0); - } - - assertEquals(10, dataManager.isSaveDataCallCount); - assertEquals(0, dataManager.getVideoDirCallCount); + assertEquals(2, context.listener.onNewEpochCallCount); + assertEquals(1, context.listener.onEpochTrainingResultCallCount); } @Test - public void refac_withHistoryProcessor_isSaveTrue_checkDataManagerCallsRemainTheSame() { + public void when_epochTrainingResultReturnsStop_expect_threadStopped() { // Arrange - MockDataManager dataManager = new MockDataManager(true); - MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10); - MockNeuralNet neuralNet = new MockNeuralNet(); - MockObservationSpace observationSpace = new MockObservationSpace(); - MockMDP mdp = new MockMDP(observationSpace); - MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2); - - IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder() - .build(); - MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig); - - - MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager); - sut.setHistoryProcessor(hp); + TestContext context = new TestContext(); + context.listener.setRemainingOnEpochTrainingResult(1); // Act - sut.run(); + context.sut.run(); // Assert - assertEquals(9, dataManager.statEntries.size()); - - for(int i = 0; i < 9; ++i) { - IDataManager.StatEntry entry = dataManager.statEntries.get(i); - assertEquals(i + 1, entry.getStepCounter()); - assertEquals(i, entry.getEpochCounter()); - assertEquals(79.0, entry.getReward(), 0.0); - } - - assertEquals(1, dataManager.isSaveDataCallCount); - assertEquals(1, dataManager.getVideoDirCallCount); + assertEquals(2, context.listener.onNewEpochCallCount); + assertEquals(2, context.listener.onEpochTrainingResultCallCount); } - public static class MockAsyncGlobal implements IAsyncGlobal { + @Test + public void when_run_expect_preAndPostEpochCalled() { + // Arrange + TestContext context = new TestContext(); - private final int maxLoops; - private int currentLoop = 0; + // Act + context.sut.run(); - public MockAsyncGlobal(int maxLoops) { + // Assert + assertEquals(6, context.sut.preEpochCallCount); + assertEquals(6, context.sut.postEpochCallCount); + } - this.maxLoops = maxLoops; + @Test + public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() { + // Arrange + TestContext context = new TestContext(); + + // Act + context.sut.run(); + + // Assert + assertEquals(5, context.listener.statEntries.size()); + int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 }; + for(int i = 0; i < 5; ++i) { + IDataManager.StatEntry statEntry = context.listener.statEntries.get(i); + assertEquals(expectedStepCounter[i], statEntry.getStepCounter()); + assertEquals(i, statEntry.getEpochCounter()); + assertEquals(2.0, statEntry.getReward(), 0.0001); } + } - @Override - public boolean isRunning() { - return true; - } - - @Override - public void setRunning(boolean value) { - - } - - @Override - public boolean isTrainingComplete() { - return ++currentLoop >= maxLoops; - } - - @Override - public void start() { - - } - - @Override - public AtomicInteger getT() { - return null; - } - - @Override - public NeuralNet getCurrent() { - return null; - } - - @Override - public NeuralNet getTarget() { - return null; - } - - @Override - public void enqueue(Gradient[] gradient, Integer nstep) { + private static class TestContext { + public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(); + public final MockNeuralNet neuralNet = new MockNeuralNet(); + public final MockObservationSpace observationSpace = new MockObservationSpace(); + public final MockMDP mdp = new MockMDP(observationSpace); + public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 2); + public final TrainingListenerList listeners = new TrainingListenerList(); + public final MockTrainingListener listener = new MockTrainingListener(); + public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners); + public TestContext() { + asyncGlobal.setMaxLoops(10); + listeners.add(listener); } } public static class MockAsyncThread extends AsyncThread { - IAsyncGlobal asyncGlobal; - private final MockNeuralNet neuralNet; - private final MDP mdp; - private final AsyncConfiguration conf; - private final IDataManager dataManager; + public int preEpochCallCount = 0; + public int postEpochCallCount = 0; - public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) { - super(asyncGlobal, threadNumber, 0); + + private final IAsyncGlobal asyncGlobal; + private final MockNeuralNet neuralNet; + private final AsyncConfiguration conf; + + public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) { + super(asyncGlobal, mdp, listeners, threadNumber, 0); this.asyncGlobal = asyncGlobal; this.neuralNet = neuralNet; - this.mdp = mdp; this.conf = conf; - this.dataManager = dataManager; + } + + @Override + protected void preEpoch() { + ++preEpochCallCount; + super.preEpoch(); + } + + @Override + protected void postEpoch() { + ++postEpochCallCount; + super.postEpoch(); } @Override @@ -208,31 +137,16 @@ public class AsyncThreadTest { return neuralNet; } - @Override - protected int getThreadNumber() { - return 0; - } - @Override protected IAsyncGlobal getAsyncGlobal() { return asyncGlobal; } - @Override - protected MDP getMdp() { - return mdp; - } - @Override protected AsyncConfiguration getConf() { return conf; } - @Override - protected IDataManager getDataManager() { - return dataManager; - } - @Override protected Policy getPolicy(NeuralNet net) { return null; @@ -244,129 +158,6 @@ public class AsyncThreadTest { } } - public static class MockNeuralNet implements NeuralNet { - @Override - public NeuralNetwork[] getNeuralNetworks() { - return new NeuralNetwork[0]; - } - - @Override - public boolean isRecurrent() { - return false; - } - - @Override - public void reset() { - - } - - @Override - public INDArray[] outputAll(INDArray batch) { - return new INDArray[0]; - } - - @Override - public NeuralNet clone() { - return null; - } - - @Override - public void copy(NeuralNet from) { - - } - - @Override - public Gradient[] gradient(INDArray input, INDArray[] labels) { - return new Gradient[0]; - } - - @Override - public void fit(INDArray input, INDArray[] labels) { - - } - - @Override - public void applyGradient(Gradient[] gradients, int batchSize) { - - } - - @Override - public double getLatestScore() { - return 0; - } - - @Override - public void save(OutputStream os) throws IOException { - - } - - @Override - public void save(String filename) throws IOException { - - } - } - - public static class MockAsyncConfiguration implements AsyncConfiguration { - - private final int nStep; - private final int maxEpochStep; - - public MockAsyncConfiguration(int nStep, int maxEpochStep) { - this.nStep = nStep; - - this.maxEpochStep = maxEpochStep; - } - - @Override - public int getSeed() { - return 0; - } - - @Override - public int getMaxEpochStep() { - return maxEpochStep; - } - - @Override - public int getMaxStep() { - return 0; - } - - @Override - public int getNumThread() { - return 0; - } - - @Override - public int getNstep() { - return nStep; - } - - @Override - public int getTargetDqnUpdateFreq() { - return 0; - } - - @Override - public int getUpdateStart() { - return 0; - } - - @Override - public double getRewardFactor() { - return 0; - } - - @Override - public double getGamma() { - return 0; - } - - @Override - public double getErrorClamp() { - return 0; - } - } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java new file mode 100644 index 000000000..56b8494a0 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java @@ -0,0 +1,98 @@ +package org.deeplearning4j.rl4j.learning.async.listener; + +import org.deeplearning4j.rl4j.learning.IEpochTrainer; +import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.learning.listener.*; +import org.deeplearning4j.rl4j.util.IDataManager; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class AsyncTrainingListenerListTest { + @Test + public void when_listIsEmpty_expect_notifyTrainingStartedReturnTrue() { + // Arrange + TrainingListenerList sut = new TrainingListenerList(); + + // Act + boolean resultTrainingStarted = sut.notifyTrainingStarted(); + boolean resultNewEpoch = sut.notifyNewEpoch(null); + boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null); + + // Assert + assertTrue(resultTrainingStarted); + assertTrue(resultNewEpoch); + assertTrue(resultEpochTrainingResult); + } + + @Test + public void when_firstListerStops_expect_othersListnersNotCalled() { + // Arrange + MockTrainingListener listener1 = new MockTrainingListener(); + listener1.onTrainingResultResponse = TrainingListener.ListenerResponse.STOP; + MockTrainingListener listener2 = new MockTrainingListener(); + TrainingListenerList sut = new TrainingListenerList(); + sut.add(listener1); + sut.add(listener2); + + // Act + sut.notifyEpochTrainingResult(null, null); + + // Assert + assertEquals(1, listener1.onEpochTrainingResultCallCount); + assertEquals(0, listener2.onEpochTrainingResultCallCount); + } + + @Test + public void when_allListenersContinue_expect_listReturnsTrue() { + // Arrange + MockTrainingListener listener1 = new MockTrainingListener(); + MockTrainingListener listener2 = new MockTrainingListener(); + TrainingListenerList sut = new TrainingListenerList(); + sut.add(listener1); + sut.add(listener2); + + // Act + boolean resultTrainingProgress = sut.notifyEpochTrainingResult(null, null); + + // Assert + assertTrue(resultTrainingProgress); + } + + private static class MockTrainingListener implements TrainingListener { + + public int onEpochTrainingResultCallCount = 0; + public ListenerResponse onTrainingResultResponse = ListenerResponse.CONTINUE; + public int onTrainingProgressCallCount = 0; + public ListenerResponse onTrainingProgressResponse = ListenerResponse.CONTINUE; + + @Override + public ListenerResponse onTrainingStart() { + return ListenerResponse.CONTINUE; + } + + @Override + public void onTrainingEnd() { + + } + + @Override + public ListenerResponse onNewEpoch(IEpochTrainer trainer) { + return ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) { + ++onEpochTrainingResultCallCount; + return onTrainingResultResponse; + } + + @Override + public ListenerResponse onTrainingProgress(ILearning learning) { + ++onTrainingProgressCallCount; + return onTrainingProgressResponse; + } + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java new file mode 100644 index 000000000..a926864e4 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java @@ -0,0 +1,83 @@ +package org.deeplearning4j.rl4j.learning.listener; + +import org.deeplearning4j.rl4j.support.MockTrainingListener; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class TrainingListenerListTest { + @Test + public void when_listIsEmpty_expect_notifyReturnTrue() { + // Arrange + TrainingListenerList sut = new TrainingListenerList(); + + // Act + boolean resultTrainingStarted = sut.notifyTrainingStarted(); + boolean resultNewEpoch = sut.notifyNewEpoch(null); + boolean resultEpochFinished = sut.notifyEpochTrainingResult(null, null); + + // Assert + assertTrue(resultTrainingStarted); + assertTrue(resultNewEpoch); + assertTrue(resultEpochFinished); + } + + @Test + public void when_firstListerStops_expect_othersListnersNotCalled() { + // Arrange + MockTrainingListener listener1 = new MockTrainingListener(); + listener1.setRemainingTrainingStartCallCount(0); + listener1.setRemainingOnNewEpochCallCount(0); + listener1.setRemainingonTrainingProgressCallCount(0); + listener1.setRemainingOnEpochTrainingResult(0); + MockTrainingListener listener2 = new MockTrainingListener(); + TrainingListenerList sut = new TrainingListenerList(); + sut.add(listener1); + sut.add(listener2); + + // Act + sut.notifyTrainingStarted(); + sut.notifyNewEpoch(null); + sut.notifyEpochTrainingResult(null, null); + sut.notifyTrainingProgress(null); + sut.notifyTrainingFinished(); + + // Assert + assertEquals(1, listener1.onTrainingStartCallCount); + assertEquals(0, listener2.onTrainingStartCallCount); + + assertEquals(1, listener1.onNewEpochCallCount); + assertEquals(0, listener2.onNewEpochCallCount); + + assertEquals(1, listener1.onEpochTrainingResultCallCount); + assertEquals(0, listener2.onEpochTrainingResultCallCount); + + assertEquals(1, listener1.onTrainingProgressCallCount); + assertEquals(0, listener2.onTrainingProgressCallCount); + + assertEquals(1, listener1.onTrainingEndCallCount); + assertEquals(1, listener2.onTrainingEndCallCount); + } + + @Test + public void when_allListenersContinue_expect_listReturnsTrue() { + // Arrange + MockTrainingListener listener1 = new MockTrainingListener(); + MockTrainingListener listener2 = new MockTrainingListener(); + TrainingListenerList sut = new TrainingListenerList(); + sut.add(listener1); + sut.add(listener2); + + // Act + boolean resultTrainingStarted = sut.notifyTrainingStarted(); + boolean resultNewEpoch = sut.notifyNewEpoch(null); + boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null); + boolean resultProgress = sut.notifyTrainingProgress(null); + + // Assert + assertTrue(resultTrainingStarted); + assertTrue(resultNewEpoch); + assertTrue(resultEpochTrainingResult); + assertTrue(resultProgress); + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java index c311ee53a..7e7c3eb01 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java @@ -2,12 +2,10 @@ package org.deeplearning4j.rl4j.learning.sync; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; -import org.deeplearning4j.rl4j.learning.sync.support.MockSyncTrainingListener; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; -import org.deeplearning4j.rl4j.policy.Policy; -import org.deeplearning4j.rl4j.support.MockDataManager; -import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener; +import org.deeplearning4j.rl4j.policy.IPolicy; +import org.deeplearning4j.rl4j.support.MockTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; @@ -19,7 +17,7 @@ public class SyncLearningTest { public void when_training_expect_listenersToBeCalled() { // Arrange QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); - MockSyncTrainingListener listener = new MockSyncTrainingListener(); + MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); @@ -27,8 +25,8 @@ public class SyncLearningTest { sut.train(); assertEquals(1, listener.onTrainingStartCallCount); - assertEquals(10, listener.onEpochStartCallCount); - assertEquals(10, listener.onEpochEndStartCallCount); + assertEquals(10, listener.onNewEpochCallCount); + assertEquals(10, listener.onEpochTrainingResultCallCount); assertEquals(1, listener.onTrainingEndCallCount); } @@ -36,65 +34,59 @@ public class SyncLearningTest { public void when_trainingStartCanContinueFalse_expect_trainingStopped() { // Arrange QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); - MockSyncTrainingListener listener = new MockSyncTrainingListener(); + MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); - listener.trainingStartCanContinue = false; + listener.setRemainingTrainingStartCallCount(0); // Act sut.train(); assertEquals(1, listener.onTrainingStartCallCount); - assertEquals(0, listener.onEpochStartCallCount); - assertEquals(0, listener.onEpochEndStartCallCount); + assertEquals(0, listener.onNewEpochCallCount); + assertEquals(0, listener.onEpochTrainingResultCallCount); assertEquals(1, listener.onTrainingEndCallCount); } @Test - public void when_epochStartCanContinueFalse_expect_trainingStopped() { + public void when_newEpochCanContinueFalse_expect_trainingStopped() { // Arrange QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); - MockSyncTrainingListener listener = new MockSyncTrainingListener(); + MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); - listener.nbStepsEpochStartCanContinue = 3; + listener.setRemainingOnNewEpochCallCount(2); // Act sut.train(); assertEquals(1, listener.onTrainingStartCallCount); - assertEquals(3, listener.onEpochStartCallCount); - assertEquals(2, listener.onEpochEndStartCallCount); + assertEquals(3, listener.onNewEpochCallCount); + assertEquals(2, listener.onEpochTrainingResultCallCount); assertEquals(1, listener.onTrainingEndCallCount); } @Test - public void when_epochEndCanContinueFalse_expect_trainingStopped() { + public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() { // Arrange QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); - MockSyncTrainingListener listener = new MockSyncTrainingListener(); + MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); - listener.nbStepsEpochEndCanContinue = 3; + listener.setRemainingOnEpochTrainingResult(2); // Act sut.train(); assertEquals(1, listener.onTrainingStartCallCount); - assertEquals(3, listener.onEpochStartCallCount); - assertEquals(3, listener.onEpochEndStartCallCount); + assertEquals(3, listener.onNewEpochCallCount); + assertEquals(3, listener.onEpochTrainingResultCallCount); assertEquals(1, listener.onTrainingEndCallCount); } public static class MockSyncLearning extends SyncLearning { - private LConfiguration conf; - - public MockSyncLearning(LConfiguration conf, IDataManager dataManager) { - super(conf); - addListener(DataManagerSyncTrainingListener.builder(dataManager).build()); - this.conf = conf; - } + private final LConfiguration conf; public MockSyncLearning(LConfiguration conf) { super(conf); @@ -119,7 +111,7 @@ public class SyncLearningTest { } @Override - public Policy getPolicy() { + public IPolicy getPolicy() { return null; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningDiscreteTest.java deleted file mode 100644 index 7ac4716d6..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningDiscreteTest.java +++ /dev/null @@ -1,65 +0,0 @@ -package org.deeplearning4j.rl4j.learning.sync.qlearning; - -import org.deeplearning4j.rl4j.learning.IHistoryProcessor; -import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscrete; -import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; -import org.deeplearning4j.rl4j.learning.sync.support.MockMDP; -import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; -import org.deeplearning4j.rl4j.mdp.MDP; -import org.deeplearning4j.rl4j.support.MockDataManager; -import org.deeplearning4j.rl4j.support.MockHistoryProcessor; -import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener; -import org.deeplearning4j.rl4j.util.IDataManager; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -public class QLearningDiscreteTest { - @Test - public void refac_checkDataManagerCallsRemainTheSame() { - // Arrange - QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder() - .maxStep(10) - .expRepMaxSize(1) - .build(); - MockDataManager dataManager = new MockDataManager(true); - MockQLearningDiscrete sut = new MockQLearningDiscrete(10, lconfig, dataManager, 2, 3); - IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder() - .build(); - sut.setHistoryProcessor(new MockHistoryProcessor(hpConfig)); - - // Act - sut.train(); - - assertEquals(10, dataManager.statEntries.size()); - for(int i = 0; i < 10; ++i) { - IDataManager.StatEntry entry = dataManager.statEntries.get(i); - assertEquals(i, entry.getEpochCounter()); - assertEquals(i+1, entry.getStepCounter()); - assertEquals(1.0, entry.getReward(), 0.0); - - } - assertEquals(4, dataManager.isSaveDataCallCount); - assertEquals(4, dataManager.getVideoDirCallCount); - assertEquals(11, dataManager.writeInfoCallCount); - assertEquals(5, dataManager.saveCallCount); - } - - public static class MockQLearningDiscrete extends QLearningDiscrete { - - public MockQLearningDiscrete(int maxSteps, QLConfiguration conf, - IDataManager dataManager, int saveFrequency, int monitorFrequency) { - super(new MockMDP(maxSteps), new MockDQN(), conf, 2); - addListener(DataManagerSyncTrainingListener.builder(dataManager) - .saveFrequency(saveFrequency) - .monitorFrequency(monitorFrequency) - .build()); - } - - @Override - protected IDataManager.StatEntry trainEpoch() { - setStepCounter(getStepCounter() + 1); - return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0); - } - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 5762875aa..1a02d6e50 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -8,6 +8,7 @@ import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.support.*; +import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; @@ -29,12 +30,11 @@ public class QLearningDiscreteTest { QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, 0, 1.0, 0, 0, 0, 0, true); MockDataManager dataManager = new MockDataManager(false); - TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, 10); + MockExpReplay expReplay = new MockExpReplay(); + TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); sut.setHistoryProcessor(hp); - MockExpReplay expReplay = new MockExpReplay(); - sut.setExpReplay(expReplay); MockEncodable obs = new MockEncodable(1); List> results = new ArrayList<>(); @@ -131,8 +131,11 @@ public class QLearningDiscreteTest { public static class TestQLearningDiscrete extends QLearningDiscrete { public TestQLearningDiscrete(MDP mdp,IDQN dqn, - QLConfiguration conf, IDataManager dataManager, int epsilonNbStep) { - super(mdp, dqn, conf, dataManager, epsilonNbStep); + QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay, + int epsilonNbStep) { + super(mdp, dqn, conf, epsilonNbStep); + addListener(new DataManagerTrainingListener(dataManager)); + setExpReplay(expReplay); } @Override diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockSyncTrainingListener.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockSyncTrainingListener.java deleted file mode 100644 index 5f41e280a..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockSyncTrainingListener.java +++ /dev/null @@ -1,46 +0,0 @@ -package org.deeplearning4j.rl4j.learning.sync.support; - -import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent; -import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent; -import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener; - -public class MockSyncTrainingListener implements SyncTrainingListener { - - public int onTrainingStartCallCount = 0; - public int onTrainingEndCallCount = 0; - public int onEpochStartCallCount = 0; - public int onEpochEndStartCallCount = 0; - - public boolean trainingStartCanContinue = true; - public int nbStepsEpochStartCanContinue = Integer.MAX_VALUE; - public int nbStepsEpochEndCanContinue = Integer.MAX_VALUE; - - @Override - public ListenerResponse onTrainingStart(SyncTrainingEvent event) { - ++onTrainingStartCallCount; - return trainingStartCanContinue ? ListenerResponse.CONTINUE : ListenerResponse.STOP; - } - - @Override - public void onTrainingEnd() { - ++onTrainingEndCallCount; - } - - @Override - public ListenerResponse onEpochStart(SyncTrainingEvent event) { - ++onEpochStartCallCount; - if(onEpochStartCallCount >= nbStepsEpochStartCanContinue) { - return ListenerResponse.STOP; - } - return ListenerResponse.CONTINUE; - } - - @Override - public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) { - ++onEpochEndStartCallCount; - if(onEpochEndStartCallCount >= nbStepsEpochEndCanContinue) { - return ListenerResponse.STOP; - } - return ListenerResponse.CONTINUE; - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java new file mode 100644 index 000000000..a40de0e91 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java @@ -0,0 +1,65 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; + +public class MockAsyncConfiguration implements AsyncConfiguration { + + private final int nStep; + private final int maxEpochStep; + + public MockAsyncConfiguration(int nStep, int maxEpochStep) { + this.nStep = nStep; + + this.maxEpochStep = maxEpochStep; + } + + @Override + public int getSeed() { + return 0; + } + + @Override + public int getMaxEpochStep() { + return maxEpochStep; + } + + @Override + public int getMaxStep() { + return 0; + } + + @Override + public int getNumThread() { + return 0; + } + + @Override + public int getNstep() { + return nStep; + } + + @Override + public int getTargetDqnUpdateFreq() { + return 0; + } + + @Override + public int getUpdateStart() { + return 0; + } + + @Override + public double getRewardFactor() { + return 0; + } + + @Override + public double getGamma() { + return 0; + } + + @Override + public double getErrorClamp() { + return 0; + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java new file mode 100644 index 000000000..0bc34d239 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java @@ -0,0 +1,65 @@ +package org.deeplearning4j.rl4j.support; + +import lombok.Setter; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; +import org.deeplearning4j.rl4j.network.NeuralNet; + +import java.util.concurrent.atomic.AtomicInteger; + +public class MockAsyncGlobal implements IAsyncGlobal { + + public boolean hasBeenStarted = false; + public boolean hasBeenTerminated = false; + + @Setter + private int maxLoops; + @Setter + private int numLoopsStopRunning; + private int currentLoop = 0; + + public MockAsyncGlobal() { + maxLoops = Integer.MAX_VALUE; + numLoopsStopRunning = Integer.MAX_VALUE; + } + + @Override + public boolean isRunning() { + return currentLoop < numLoopsStopRunning; + } + + @Override + public void terminate() { + hasBeenTerminated = true; + } + + @Override + public boolean isTrainingComplete() { + return ++currentLoop > maxLoops; + } + + @Override + public void start() { + hasBeenStarted = true; + } + + @Override + public AtomicInteger getT() { + return null; + } + + @Override + public NeuralNet getCurrent() { + return null; + } + + @Override + public NeuralNet getTarget() { + return null; + } + + @Override + public void enqueue(Gradient[] gradient, Integer nstep) { + + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java index 659198f79..d2d74bb89 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java @@ -44,7 +44,7 @@ public class MockDataManager implements IDataManager { } @Override - public void save(Learning learning) throws IOException { + public void save(ILearning learning) throws IOException { ++saveCallCount; } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java new file mode 100644 index 000000000..bdffa59a8 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java @@ -0,0 +1,74 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.nn.api.NeuralNetwork; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.network.NeuralNet; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.io.IOException; +import java.io.OutputStream; + +public class MockNeuralNet implements NeuralNet { + + public int resetCallCount = 0; + + @Override + public NeuralNetwork[] getNeuralNetworks() { + return new NeuralNetwork[0]; + } + + @Override + public boolean isRecurrent() { + return false; + } + + @Override + public void reset() { + ++resetCallCount; + } + + @Override + public INDArray[] outputAll(INDArray batch) { + return new INDArray[0]; + } + + @Override + public NeuralNet clone() { + return null; + } + + @Override + public void copy(NeuralNet from) { + + } + + @Override + public Gradient[] gradient(INDArray input, INDArray[] labels) { + return new Gradient[0]; + } + + @Override + public void fit(INDArray input, INDArray[] labels) { + + } + + @Override + public void applyGradient(Gradient[] gradients, int batchSize) { + + } + + @Override + public double getLatestScore() { + return 0; + } + + @Override + public void save(OutputStream os) throws IOException { + + } + + @Override + public void save(String filename) throws IOException { + + } +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java new file mode 100644 index 000000000..28f812f33 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java @@ -0,0 +1,17 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.policy.IPolicy; +import org.deeplearning4j.rl4j.space.ActionSpace; + +public class MockPolicy implements IPolicy { + + public int playCallCount = 0; + + @Override + public > double play(MDP mdp, IHistoryProcessor hp) { + ++playCallCount; + return 0; + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java new file mode 100644 index 000000000..97bf5cc28 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java @@ -0,0 +1,65 @@ +package org.deeplearning4j.rl4j.support; + +import lombok.Setter; +import org.deeplearning4j.rl4j.learning.IEpochTrainer; +import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.learning.listener.*; +import org.deeplearning4j.rl4j.util.IDataManager; + +import java.util.ArrayList; +import java.util.List; + +public class MockTrainingListener implements TrainingListener { + + public int onTrainingStartCallCount = 0; + public int onTrainingEndCallCount = 0; + public int onNewEpochCallCount = 0; + public int onEpochTrainingResultCallCount = 0; + public int onTrainingProgressCallCount = 0; + + @Setter + private int remainingTrainingStartCallCount = Integer.MAX_VALUE; + @Setter + private int remainingOnNewEpochCallCount = Integer.MAX_VALUE; + @Setter + private int remainingOnEpochTrainingResult = Integer.MAX_VALUE; + @Setter + private int remainingonTrainingProgressCallCount = Integer.MAX_VALUE; + + public final List statEntries = new ArrayList<>(); + + + @Override + public ListenerResponse onTrainingStart() { + ++onTrainingStartCallCount; + --remainingTrainingStartCallCount; + return remainingTrainingStartCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse onNewEpoch(IEpochTrainer trainer) { + ++onNewEpochCallCount; + --remainingOnNewEpochCallCount; + return remainingOnNewEpochCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) { + ++onEpochTrainingResultCallCount; + --remainingOnEpochTrainingResult; + statEntries.add(statEntry); + return remainingOnEpochTrainingResult < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse onTrainingProgress(ILearning learning) { + ++onTrainingProgressCallCount; + --remainingonTrainingProgressCallCount; + return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE; + } + + @Override + public void onTrainingEnd() { + ++onTrainingEndCallCount; + } +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java new file mode 100644 index 000000000..f6da6d378 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java @@ -0,0 +1,169 @@ +package org.deeplearning4j.rl4j.util; + +import lombok.Getter; +import lombok.Setter; +import org.deeplearning4j.rl4j.learning.IEpochTrainer; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.learning.listener.TrainingListener; +import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.policy.IPolicy; +import org.deeplearning4j.rl4j.support.MockDataManager; +import org.deeplearning4j.rl4j.support.MockHistoryProcessor; +import org.deeplearning4j.rl4j.support.MockMDP; +import org.deeplearning4j.rl4j.support.MockObservationSpace; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; + +public class DataManagerTrainingListenerTest { + + @Test + public void when_callingOnNewEpochWithoutHistoryProcessor_expect_noException() { + // Arrange + TestTrainer trainer = new TestTrainer(); + DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false)); + + // Act + TrainingListener.ListenerResponse response = sut.onNewEpoch(trainer); + + // Assert + assertEquals(TrainingListener.ListenerResponse.CONTINUE, response); + } + + @Test + public void when_callingOnNewEpochWithHistoryProcessor_expect_startMonitorNotCalled() { + // Arrange + TestTrainer trainer = new TestTrainer(); + IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); + trainer.setHistoryProcessor(hp); + DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false)); + + // Act + TrainingListener.ListenerResponse response = sut.onNewEpoch(trainer); + + // Assert + assertEquals(1, hp.startMonitorCallCount); + assertEquals(TrainingListener.ListenerResponse.CONTINUE, response); + } + + @Test + public void when_callingOnEpochTrainingResultWithoutHistoryProcessor_expect_noException() { + // Arrange + TestTrainer trainer = new TestTrainer(); + DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false)); + + // Act + TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, null); + + // Assert + assertEquals(TrainingListener.ListenerResponse.CONTINUE, response); + } + + @Test + public void when_callingOnNewEpochWithHistoryProcessor_expect_stopMonitorNotCalled() { + // Arrange + TestTrainer trainer = new TestTrainer(); + IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); + trainer.setHistoryProcessor(hp); + DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false)); + + // Act + TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, null); + + // Assert + assertEquals(1, hp.stopMonitorCallCount); + assertEquals(TrainingListener.ListenerResponse.CONTINUE, response); + } + + @Test + public void when_callingOnEpochTrainingResult_expect_callToDataManagerAppendStat() { + // Arrange + TestTrainer trainer = new TestTrainer(); + MockDataManager dm = new MockDataManager(false); + DataManagerTrainingListener sut = new DataManagerTrainingListener(dm); + MockStatEntry statEntry = new MockStatEntry(0, 0, 0.0); + + // Act + TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, statEntry); + + // Assert + assertEquals(TrainingListener.ListenerResponse.CONTINUE, response); + assertEquals(1, dm.statEntries.size()); + assertSame(statEntry, dm.statEntries.get(0)); + } + + @Test + public void when_callingOnTrainingProgress_expect_callToDataManagerSaveAndWriteInfo() { + // Arrange + TestTrainer learning = new TestTrainer(); + MockDataManager dm = new MockDataManager(false); + DataManagerTrainingListener sut = new DataManagerTrainingListener(dm); + + // Act + TrainingListener.ListenerResponse response = sut.onTrainingProgress(learning); + + // Assert + assertEquals(TrainingListener.ListenerResponse.CONTINUE, response); + assertEquals(1, dm.writeInfoCallCount); + assertEquals(1, dm.saveCallCount); + } + + @Test + public void when_stepCounterCloseToLastSave_expect_dataManagerSaveNotCalled() { + // Arrange + TestTrainer learning = new TestTrainer(); + MockDataManager dm = new MockDataManager(false); + DataManagerTrainingListener sut = new DataManagerTrainingListener(dm); + + // Act + TrainingListener.ListenerResponse response = sut.onTrainingProgress(learning); + TrainingListener.ListenerResponse response2 = sut.onTrainingProgress(learning); + + // Assert + assertEquals(TrainingListener.ListenerResponse.CONTINUE, response); + assertEquals(TrainingListener.ListenerResponse.CONTINUE, response2); + assertEquals(1, dm.saveCallCount); + } + + private static class TestTrainer implements IEpochTrainer, ILearning + { + @Override + public int getStepCounter() { + return 0; + } + + @Override + public int getEpochCounter() { + return 0; + } + + @Getter + @Setter + private IHistoryProcessor historyProcessor; + + @Override + public IPolicy getPolicy() { + return null; + } + + @Override + public void train() { + + } + + @Override + public LConfiguration getConfiguration() { + return null; + } + + @Override + public MDP getMdp() { + return new MockMDP(new MockObservationSpace()); + } + } +}