diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java index d3927e8c9..798aa094f 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java @@ -125,8 +125,6 @@ public abstract class Learning return nshape; } - protected abstract IDataManager getDataManager(); - public abstract NN getNeuralNet(); public int incrementStep() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java index a8948b23e..8866f0c40 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java @@ -21,6 +21,7 @@ import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.util.IDataManager; import org.nd4j.linalg.factory.Nd4j; /** @@ -36,6 +37,7 @@ import org.nd4j.linalg.factory.Nd4j; public abstract class AsyncLearning, NN extends NeuralNet> extends Learning { + protected abstract IDataManager getDataManager(); public AsyncLearning(AsyncConfiguration conf) { super(conf); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java index fb69f2fa9..3464410dc 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java @@ -18,67 +18,132 @@ package org.deeplearning4j.rl4j.learning.sync; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent; +import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent; +import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.Constants; import org.deeplearning4j.rl4j.util.IDataManager; +import java.util.ArrayList; +import java.util.List; + /** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16. - * * Mother class and useful factorisations for all training methods that * are not asynchronous. * + * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16. + * @author Alexandre Boulanger */ @Slf4j public abstract class SyncLearning, NN extends NeuralNet> - extends Learning { + extends Learning { - private int lastSave = -Constants.MODEL_SAVE_FREQ; + private List listeners = new ArrayList<>(); public SyncLearning(LConfiguration conf) { super(conf); } - public void train() { - - try { - log.info("training starting."); - - getDataManager().writeInfo(this); - - - while (getStepCounter() < getConfiguration().getMaxStep()) { - preEpoch(); - IDataManager.StatEntry statEntry = trainEpoch(); - postEpoch(); - - incrementEpoch(); - - if (getStepCounter() - lastSave >= Constants.MODEL_SAVE_FREQ) { - getDataManager().save(this); - lastSave = getStepCounter(); - } - - getDataManager().appendStat(statEntry); - getDataManager().writeInfo(this); - - log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward()); - } - } catch (Exception e) { - log.error("Training failed.", e); - e.printStackTrace(); - } - - + /** + * Add a listener at the end of the listener list. + * + * @param listener + */ + public void addListener(SyncTrainingListener listener) { + listeners.add(listener); } + /** + * This method will train the model

+ * The training stop when:
+ * - the number of steps reaches the maximum defined in the configuration (see {@link LConfiguration#getMaxStep() LConfiguration.getMaxStep()})
+ * OR
+ * - a listener explicitly stops it
+ *

+ * Listeners
+ * For a given event, the listeners are called sequentially in same the order as they were added. If one listener + * returns {@link SyncTrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.
+ * Events: + *

    + *
  • {@link SyncTrainingListener#onTrainingStart(SyncTrainingEvent) onTrainingStart()} is called once when the training starts.
  • + *
  • {@link SyncTrainingListener#onEpochStart(SyncTrainingEvent) onEpochStart()} and {@link SyncTrainingListener#onEpochEnd(SyncTrainingEpochEndEvent) onEpochEnd()} are called for every epoch. onEpochEnd will not be called if onEpochStart stops the training
  • + *
  • {@link SyncTrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.
  • + *
+ */ + public void train() { + + log.info("training starting."); + + boolean canContinue = notifyTrainingStarted(); + if (canContinue) { + while (getStepCounter() < getConfiguration().getMaxStep()) { + preEpoch(); + canContinue = notifyEpochStarted(); + if (!canContinue) { + break; + } + + IDataManager.StatEntry statEntry = trainEpoch(); + + postEpoch(); + canContinue = notifyEpochFinished(statEntry); + if (!canContinue) { + break; + } + + log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward()); + + incrementEpoch(); + } + } + + notifyTrainingFinished(); + } + + private boolean notifyTrainingStarted() { + SyncTrainingEvent event = new SyncTrainingEvent(this); + for (SyncTrainingListener listener : listeners) { + if (listener.onTrainingStart(event) == SyncTrainingListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } + + private void notifyTrainingFinished() { + for (SyncTrainingListener listener : listeners) { + listener.onTrainingEnd(); + } + } + + private boolean notifyEpochStarted() { + SyncTrainingEvent event = new SyncTrainingEvent(this); + for (SyncTrainingListener listener : listeners) { + if (listener.onEpochStart(event) == SyncTrainingListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } + + private boolean notifyEpochFinished(IDataManager.StatEntry statEntry) { + SyncTrainingEpochEndEvent event = new SyncTrainingEpochEndEvent(this, statEntry); + for (SyncTrainingListener listener : listeners) { + if (listener.onEpochEnd(event) == SyncTrainingListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } protected abstract void preEpoch(); protected abstract void postEpoch(); - protected abstract IDataManager.StatEntry trainEpoch(); - + protected abstract IDataManager.StatEntry trainEpoch(); // TODO: finish removal of IDataManager from Learning } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEpochEndEvent.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEpochEndEvent.java new file mode 100644 index 000000000..71a357ec8 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEpochEndEvent.java @@ -0,0 +1,22 @@ +package org.deeplearning4j.rl4j.learning.sync.listener; + +import lombok.Getter; +import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.util.IDataManager; + +/** + * A subclass of SyncTrainingEvent that is passed to SyncTrainingListener.onEpochEnd() + */ +public class SyncTrainingEpochEndEvent extends SyncTrainingEvent { + + /** + * The stats of the epoch training + */ + @Getter + private final IDataManager.StatEntry statEntry; + + public SyncTrainingEpochEndEvent(Learning learning, IDataManager.StatEntry statEntry) { + super(learning); + this.statEntry = statEntry; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEvent.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEvent.java new file mode 100644 index 000000000..964040f28 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingEvent.java @@ -0,0 +1,21 @@ +package org.deeplearning4j.rl4j.learning.sync.listener; + +import lombok.Getter; +import lombok.Setter; +import org.deeplearning4j.rl4j.learning.Learning; + +/** + * SyncTrainingEvent are passed as parameters to the events of SyncTrainingListener + */ +public class SyncTrainingEvent { + + /** + * The source of the event + */ + @Getter + private final Learning learning; + + public SyncTrainingEvent(Learning learning) { + this.learning = learning; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingListener.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingListener.java new file mode 100644 index 000000000..852c16036 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/listener/SyncTrainingListener.java @@ -0,0 +1,45 @@ +package org.deeplearning4j.rl4j.learning.sync.listener; + +/** + * A listener interface to use with a descendant of {@link org.deeplearning4j.rl4j.learning.sync.SyncLearning} + */ +public interface SyncTrainingListener { + + public enum ListenerResponse { + /** + * Tell SyncLearning to continue calling the listeners and the training. + */ + CONTINUE, + + /** + * Tell SyncLearning to stop calling the listeners and terminate the training. + */ + STOP, + } + + /** + * Called once when the training starts. + * @param event + * @return A ListenerResponse telling the source of the event if it should go on or cancel the training. + */ + ListenerResponse onTrainingStart(SyncTrainingEvent event); + + /** + * Called once when the training has finished. This method is called even when the training has been aborted. + */ + void onTrainingEnd(); + + /** + * Called before the start of every epoch. + * @param event + * @return A ListenerResponse telling the source of the event if it should continue or stop the training. + */ + ListenerResponse onEpochStart(SyncTrainingEvent event); + + /** + * Called after the end of every epoch. + * @param event + * @return A ListenerResponse telling the source of the event if it should continue or stop the training. + */ + ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index eb2fcc5a5..351226e9d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; import lombok.Getter; import lombok.Setter; -import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.sync.Transition; @@ -29,12 +28,13 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.Constants; +import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import java.util.ArrayList; @@ -53,8 +53,6 @@ public abstract class QLearningDiscrete extends QLearning mdp; @Getter final private IDQN currentDQN; @@ -68,24 +66,31 @@ public abstract class QLearningDiscrete extends QLearning mdp, IDQN dqn, QLConfiguration conf, IDataManager dataManager, int epsilonNbStep) { + this(mdp, dqn, conf, epsilonNbStep); + addListener(DataManagerSyncTrainingListener.builder(dataManager).build()); + } + + public QLearningDiscrete(MDP mdp, IDQN dqn, QLConfiguration conf, + int epsilonNbStep) { super(conf); this.configuration = conf; this.mdp = mdp; - this.dataManager = dataManager; currentDQN = dqn; targetDQN = dqn.clone(); policy = new DQNPolicy(getCurrentDQN()); egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(), - this); + this); mdp.getActionSpace().setSeed(conf.getSeed()); } - public void postEpoch() { if (getHistoryProcessor() != null) @@ -97,14 +102,6 @@ public abstract class QLearningDiscrete extends QLearning= Constants.MONITOR_FREQ && getHistoryProcessor() != null - && getDataManager().isSaveData()) { - lastMonitor = getStepCounter(); - int[] shape = getMdp().getObservationSpace().getShape(); - getHistoryProcessor().startMonitor(getDataManager().getVideoDir() + "/video-" + getEpochCounter() + "-" - + getStepCounter() + ".mp4", shape); - } } /** diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerSyncTrainingListener.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerSyncTrainingListener.java new file mode 100644 index 000000000..c9166f34e --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerSyncTrainingListener.java @@ -0,0 +1,126 @@ +package org.deeplearning4j.rl4j.util; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent; +import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent; +import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener; + +/** + * DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the + * training process can be fed to the DataManager + */ +@Slf4j +public class DataManagerSyncTrainingListener implements SyncTrainingListener { + private final IDataManager dataManager; + private final int saveFrequency; + private final int monitorFrequency; + + private int lastSave; + private int lastMonitor; + + private DataManagerSyncTrainingListener(Builder builder) { + this.dataManager = builder.dataManager; + + this.saveFrequency = builder.saveFrequency; + this.lastSave = -builder.saveFrequency; + + this.monitorFrequency = builder.monitorFrequency; + this.lastMonitor = -builder.monitorFrequency; + } + + @Override + public ListenerResponse onTrainingStart(SyncTrainingEvent event) { + try { + dataManager.writeInfo(event.getLearning()); + } catch (Exception e) { + log.error("Training failed.", e); + return ListenerResponse.STOP; + } + return ListenerResponse.CONTINUE; + } + + @Override + public void onTrainingEnd() { + // Do nothing + } + + @Override + public ListenerResponse onEpochStart(SyncTrainingEvent event) { + int stepCounter = event.getLearning().getStepCounter(); + + if (stepCounter - lastMonitor >= monitorFrequency + && event.getLearning().getHistoryProcessor() != null + && dataManager.isSaveData()) { + lastMonitor = stepCounter; + int[] shape = event.getLearning().getMdp().getObservationSpace().getShape(); + event.getLearning().getHistoryProcessor().startMonitor(dataManager.getVideoDir() + "/video-" + event.getLearning().getEpochCounter() + "-" + + stepCounter + ".mp4", shape); + } + + return ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) { + try { + int stepCounter = event.getLearning().getStepCounter(); + if (stepCounter - lastSave >= saveFrequency) { + dataManager.save(event.getLearning()); + lastSave = stepCounter; + } + + dataManager.appendStat(event.getStatEntry()); + dataManager.writeInfo(event.getLearning()); + } catch (Exception e) { + log.error("Training failed.", e); + return ListenerResponse.STOP; + } + + return ListenerResponse.CONTINUE; + } + + public static Builder builder(IDataManager dataManager) { + return new Builder(dataManager); + } + + public static class Builder { + private final IDataManager dataManager; + private int saveFrequency = Constants.MODEL_SAVE_FREQ; + private int monitorFrequency = Constants.MONITOR_FREQ; + + /** + * Create a Builder with the given DataManager + * @param dataManager + */ + public Builder(IDataManager dataManager) { + this.dataManager = dataManager; + } + + /** + * A number that represent the number of steps since the last call to DataManager.save() before can it be called again. + * @param saveFrequency (Default: 100000) + */ + public Builder saveFrequency(int saveFrequency) { + this.saveFrequency = saveFrequency; + return this; + } + + /** + * A number that represent the number of steps since the last call to HistoryProcessor.startMonitor() before can it be called again. + * @param monitorFrequency (Default: 10000) + */ + public Builder monitorFrequency(int monitorFrequency) { + this.monitorFrequency = monitorFrequency; + return this; + } + + /** + * Creates a DataManagerSyncTrainingListener with the configured parameters + * @return An instance of DataManagerSyncTrainingListener + */ + public DataManagerSyncTrainingListener build() { + return new DataManagerSyncTrainingListener(this); + } + + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java index cace83d71..c311ee53a 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java @@ -1,12 +1,13 @@ package org.deeplearning4j.rl4j.learning.sync; -import lombok.AllArgsConstructor; -import lombok.Value; -import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; +import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; +import org.deeplearning4j.rl4j.learning.sync.support.MockSyncTrainingListener; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.support.MockDataManager; +import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; @@ -15,61 +16,101 @@ import static org.junit.Assert.assertEquals; public class SyncLearningTest { @Test - public void refac_checkDataManagerCallsRemainTheSame() { + public void when_training_expect_listenersToBeCalled() { // Arrange - MockLConfiguration lconfig = new MockLConfiguration(10); - MockDataManager dataManager = new MockDataManager(false); - MockSyncLearning sut = new MockSyncLearning(lconfig, dataManager, 2); + QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + MockSyncTrainingListener listener = new MockSyncTrainingListener(); + MockSyncLearning sut = new MockSyncLearning(lconfig); + sut.addListener(listener); // Act sut.train(); - assertEquals(10, dataManager.statEntries.size()); - for(int i = 0; i < 10; ++i) { - IDataManager.StatEntry entry = dataManager.statEntries.get(i); - assertEquals(2, entry.getEpochCounter()); - assertEquals(i+1, entry.getStepCounter()); - assertEquals(1.0, entry.getReward(), 0.0); + assertEquals(1, listener.onTrainingStartCallCount); + assertEquals(10, listener.onEpochStartCallCount); + assertEquals(10, listener.onEpochEndStartCallCount); + assertEquals(1, listener.onTrainingEndCallCount); + } - } - assertEquals(0, dataManager.isSaveDataCallCount); - assertEquals(0, dataManager.getVideoDirCallCount); - assertEquals(11, dataManager.writeInfoCallCount); - assertEquals(1, dataManager.saveCallCount); + @Test + public void when_trainingStartCanContinueFalse_expect_trainingStopped() { + // Arrange + QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + MockSyncTrainingListener listener = new MockSyncTrainingListener(); + MockSyncLearning sut = new MockSyncLearning(lconfig); + sut.addListener(listener); + listener.trainingStartCanContinue = false; + + // Act + sut.train(); + + assertEquals(1, listener.onTrainingStartCallCount); + assertEquals(0, listener.onEpochStartCallCount); + assertEquals(0, listener.onEpochEndStartCallCount); + assertEquals(1, listener.onTrainingEndCallCount); + } + + @Test + public void when_epochStartCanContinueFalse_expect_trainingStopped() { + // Arrange + QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + MockSyncTrainingListener listener = new MockSyncTrainingListener(); + MockSyncLearning sut = new MockSyncLearning(lconfig); + sut.addListener(listener); + listener.nbStepsEpochStartCanContinue = 3; + + // Act + sut.train(); + + assertEquals(1, listener.onTrainingStartCallCount); + assertEquals(3, listener.onEpochStartCallCount); + assertEquals(2, listener.onEpochEndStartCallCount); + assertEquals(1, listener.onTrainingEndCallCount); + } + + @Test + public void when_epochEndCanContinueFalse_expect_trainingStopped() { + // Arrange + QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + MockSyncTrainingListener listener = new MockSyncTrainingListener(); + MockSyncLearning sut = new MockSyncLearning(lconfig); + sut.addListener(listener); + listener.nbStepsEpochEndCanContinue = 3; + + // Act + sut.train(); + + assertEquals(1, listener.onTrainingStartCallCount); + assertEquals(3, listener.onEpochStartCallCount); + assertEquals(3, listener.onEpochEndStartCallCount); + assertEquals(1, listener.onTrainingEndCallCount); } public static class MockSyncLearning extends SyncLearning { - private final IDataManager dataManager; private LConfiguration conf; - private final int epochSteps; - public MockSyncLearning(LConfiguration conf, IDataManager dataManager, int epochSteps) { + public MockSyncLearning(LConfiguration conf, IDataManager dataManager) { + super(conf); + addListener(DataManagerSyncTrainingListener.builder(dataManager).build()); + this.conf = conf; + } + + public MockSyncLearning(LConfiguration conf) { super(conf); - this.dataManager = dataManager; this.conf = conf; - this.epochSteps = epochSteps; } @Override - protected void preEpoch() { - - } + protected void preEpoch() { } @Override - protected void postEpoch() { - - } + protected void postEpoch() { } @Override protected IDataManager.StatEntry trainEpoch() { setStepCounter(getStepCounter() + 1); - return new MockStatEntry(epochSteps, getStepCounter(), 1.0); - } - - @Override - protected IDataManager getDataManager() { - return dataManager; + return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0); } @Override @@ -92,41 +133,4 @@ public class SyncLearningTest { return null; } } - - public static class MockLConfiguration implements ILearning.LConfiguration { - - private final int maxStep; - - public MockLConfiguration(int maxStep) { - this.maxStep = maxStep; - } - - @Override - public int getSeed() { - return 0; - } - - @Override - public int getMaxEpochStep() { - return 0; - } - - @Override - public int getMaxStep() { - return maxStep; - } - - @Override - public double getGamma() { - return 0; - } - } - - @AllArgsConstructor - @Value - public static class MockStatEntry implements IDataManager.StatEntry { - int epochCounter; - int stepCounter; - double reward; - } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningDiscreteTest.java new file mode 100644 index 000000000..7ac4716d6 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningDiscreteTest.java @@ -0,0 +1,65 @@ +package org.deeplearning4j.rl4j.learning.sync.qlearning; + +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscrete; +import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; +import org.deeplearning4j.rl4j.learning.sync.support.MockMDP; +import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.support.MockDataManager; +import org.deeplearning4j.rl4j.support.MockHistoryProcessor; +import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener; +import org.deeplearning4j.rl4j.util.IDataManager; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class QLearningDiscreteTest { + @Test + public void refac_checkDataManagerCallsRemainTheSame() { + // Arrange + QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder() + .maxStep(10) + .expRepMaxSize(1) + .build(); + MockDataManager dataManager = new MockDataManager(true); + MockQLearningDiscrete sut = new MockQLearningDiscrete(10, lconfig, dataManager, 2, 3); + IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder() + .build(); + sut.setHistoryProcessor(new MockHistoryProcessor(hpConfig)); + + // Act + sut.train(); + + assertEquals(10, dataManager.statEntries.size()); + for(int i = 0; i < 10; ++i) { + IDataManager.StatEntry entry = dataManager.statEntries.get(i); + assertEquals(i, entry.getEpochCounter()); + assertEquals(i+1, entry.getStepCounter()); + assertEquals(1.0, entry.getReward(), 0.0); + + } + assertEquals(4, dataManager.isSaveDataCallCount); + assertEquals(4, dataManager.getVideoDirCallCount); + assertEquals(11, dataManager.writeInfoCallCount); + assertEquals(5, dataManager.saveCallCount); + } + + public static class MockQLearningDiscrete extends QLearningDiscrete { + + public MockQLearningDiscrete(int maxSteps, QLConfiguration conf, + IDataManager dataManager, int saveFrequency, int monitorFrequency) { + super(new MockMDP(maxSteps), new MockDQN(), conf, 2); + addListener(DataManagerSyncTrainingListener.builder(dataManager) + .saveFrequency(saveFrequency) + .monitorFrequency(monitorFrequency) + .build()); + } + + @Override + protected IDataManager.StatEntry trainEpoch() { + setStepCounter(getStepCounter() + 1); + return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0); + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java new file mode 100644 index 000000000..7d088f060 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java @@ -0,0 +1,92 @@ +package org.deeplearning4j.rl4j.learning.sync.support; + +import org.deeplearning4j.nn.api.NeuralNetwork; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.io.IOException; +import java.io.OutputStream; + +public class MockDQN implements IDQN { + @Override + public NeuralNetwork[] getNeuralNetworks() { + return new NeuralNetwork[0]; + } + + @Override + public boolean isRecurrent() { + return false; + } + + @Override + public void reset() { + + } + + @Override + public void fit(INDArray input, INDArray labels) { + + } + + @Override + public void fit(INDArray input, INDArray[] labels) { + + } + + @Override + public INDArray output(INDArray batch) { + return null; + } + + @Override + public INDArray[] outputAll(INDArray batch) { + return new INDArray[0]; + } + + @Override + public IDQN clone() { + return null; + } + + @Override + public void copy(NeuralNet from) { + + } + + @Override + public void copy(IDQN from) { + + } + + @Override + public Gradient[] gradient(INDArray input, INDArray label) { + return new Gradient[0]; + } + + @Override + public Gradient[] gradient(INDArray input, INDArray[] label) { + return new Gradient[0]; + } + + @Override + public void applyGradient(Gradient[] gradient, int batchSize) { + + } + + @Override + public double getLatestScore() { + return 0; + } + + @Override + public void save(OutputStream os) throws IOException { + + } + + @Override + public void save(String filename) throws IOException { + + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockMDP.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockMDP.java new file mode 100644 index 000000000..4352b9bee --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockMDP.java @@ -0,0 +1,79 @@ +package org.deeplearning4j.rl4j.learning.sync.support; + +import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.deeplearning4j.rl4j.space.ObservationSpace; +import org.nd4j.linalg.api.ndarray.INDArray; + +public class MockMDP implements MDP { + + private final int maxSteps; + private final DiscreteSpace actionSpace = new DiscreteSpace(1); + private final MockObservationSpace observationSpace = new MockObservationSpace(); + + private int currentStep = 0; + + public MockMDP(int maxSteps) { + + this.maxSteps = maxSteps; + } + + @Override + public ObservationSpace getObservationSpace() { + return observationSpace; + } + + @Override + public DiscreteSpace getActionSpace() { + return actionSpace; + } + + @Override + public Object reset() { + return null; + } + + @Override + public void close() { + + } + + @Override + public StepReply step(Integer integer) { + return new StepReply(null, 1.0, isDone(), null); + } + + @Override + public boolean isDone() { + return currentStep >= maxSteps; + } + + @Override + public MDP newInstance() { + return null; + } + + private static class MockObservationSpace implements ObservationSpace { + + @Override + public String getName() { + return null; + } + + @Override + public int[] getShape() { + return new int[0]; + } + + @Override + public INDArray getLow() { + return null; + } + + @Override + public INDArray getHigh() { + return null; + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockStatEntry.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockStatEntry.java new file mode 100644 index 000000000..540b869d3 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockStatEntry.java @@ -0,0 +1,13 @@ +package org.deeplearning4j.rl4j.learning.sync.support; + +import lombok.AllArgsConstructor; +import lombok.Value; +import org.deeplearning4j.rl4j.util.IDataManager; + +@AllArgsConstructor +@Value +public class MockStatEntry implements IDataManager.StatEntry { + int epochCounter; + int stepCounter; + double reward; +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockSyncTrainingListener.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockSyncTrainingListener.java new file mode 100644 index 000000000..5f41e280a --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockSyncTrainingListener.java @@ -0,0 +1,46 @@ +package org.deeplearning4j.rl4j.learning.sync.support; + +import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent; +import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent; +import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener; + +public class MockSyncTrainingListener implements SyncTrainingListener { + + public int onTrainingStartCallCount = 0; + public int onTrainingEndCallCount = 0; + public int onEpochStartCallCount = 0; + public int onEpochEndStartCallCount = 0; + + public boolean trainingStartCanContinue = true; + public int nbStepsEpochStartCanContinue = Integer.MAX_VALUE; + public int nbStepsEpochEndCanContinue = Integer.MAX_VALUE; + + @Override + public ListenerResponse onTrainingStart(SyncTrainingEvent event) { + ++onTrainingStartCallCount; + return trainingStartCanContinue ? ListenerResponse.CONTINUE : ListenerResponse.STOP; + } + + @Override + public void onTrainingEnd() { + ++onTrainingEndCallCount; + } + + @Override + public ListenerResponse onEpochStart(SyncTrainingEvent event) { + ++onEpochStartCallCount; + if(onEpochStartCallCount >= nbStepsEpochStartCanContinue) { + return ListenerResponse.STOP; + } + return ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) { + ++onEpochEndStartCallCount; + if(onEpochEndStartCallCount >= nbStepsEpochEndCanContinue) { + return ListenerResponse.STOP; + } + return ListenerResponse.CONTINUE; + } +}