RL4J Added listener pattern to SyncLearning (#8050)

* Added listener pattern to SyncLearning

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Did requested changes

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2019-08-01 22:43:45 -04:00 committed by Alex Black
parent 0527ab8d98
commit b2145ca780
14 changed files with 704 additions and 129 deletions

View File

@ -125,8 +125,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
return nshape;
}
protected abstract IDataManager getDataManager();
public abstract NN getNeuralNet();
public int incrementStep() {

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.factory.Nd4j;
/**
@ -36,6 +37,7 @@ import org.nd4j.linalg.factory.Nd4j;
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Learning<O, A, AS, NN> {
protected abstract IDataManager getDataManager();
public AsyncLearning(AsyncConfiguration conf) {
super(conf);

View File

@ -18,67 +18,132 @@ package org.deeplearning4j.rl4j.learning.sync;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.Constants;
import org.deeplearning4j.rl4j.util.IDataManager;
import java.util.ArrayList;
import java.util.List;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16.
*
* Mother class and useful factorisations for all training methods that
* are not asynchronous.
*
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16.
* @author Alexandre Boulanger
*/
@Slf4j
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Learning<O, A, AS, NN> {
extends Learning<O, A, AS, NN> {
private int lastSave = -Constants.MODEL_SAVE_FREQ;
private List<SyncTrainingListener> listeners = new ArrayList<>();
public SyncLearning(LConfiguration conf) {
super(conf);
}
public void train() {
try {
log.info("training starting.");
getDataManager().writeInfo(this);
while (getStepCounter() < getConfiguration().getMaxStep()) {
preEpoch();
IDataManager.StatEntry statEntry = trainEpoch();
postEpoch();
incrementEpoch();
if (getStepCounter() - lastSave >= Constants.MODEL_SAVE_FREQ) {
getDataManager().save(this);
lastSave = getStepCounter();
}
getDataManager().appendStat(statEntry);
getDataManager().writeInfo(this);
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
}
} catch (Exception e) {
log.error("Training failed.", e);
e.printStackTrace();
}
/**
* Add a listener at the end of the listener list.
*
* @param listener
*/
public void addListener(SyncTrainingListener listener) {
listeners.add(listener);
}
/**
* This method will train the model<p>
* The training stop when:<br>
* - the number of steps reaches the maximum defined in the configuration (see {@link LConfiguration#getMaxStep() LConfiguration.getMaxStep()})<br>
* OR<br>
* - a listener explicitly stops it<br>
* <p>
* Listeners<br>
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
* returns {@link SyncTrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
* Events:
* <ul>
* <li>{@link SyncTrainingListener#onTrainingStart(SyncTrainingEvent) onTrainingStart()} is called once when the training starts.</li>
* <li>{@link SyncTrainingListener#onEpochStart(SyncTrainingEvent) onEpochStart()} and {@link SyncTrainingListener#onEpochEnd(SyncTrainingEpochEndEvent) onEpochEnd()} are called for every epoch. onEpochEnd will not be called if onEpochStart stops the training</li>
* <li>{@link SyncTrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
* </ul>
*/
public void train() {
log.info("training starting.");
boolean canContinue = notifyTrainingStarted();
if (canContinue) {
while (getStepCounter() < getConfiguration().getMaxStep()) {
preEpoch();
canContinue = notifyEpochStarted();
if (!canContinue) {
break;
}
IDataManager.StatEntry statEntry = trainEpoch();
postEpoch();
canContinue = notifyEpochFinished(statEntry);
if (!canContinue) {
break;
}
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
incrementEpoch();
}
}
notifyTrainingFinished();
}
private boolean notifyTrainingStarted() {
SyncTrainingEvent event = new SyncTrainingEvent(this);
for (SyncTrainingListener listener : listeners) {
if (listener.onTrainingStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
private void notifyTrainingFinished() {
for (SyncTrainingListener listener : listeners) {
listener.onTrainingEnd();
}
}
private boolean notifyEpochStarted() {
SyncTrainingEvent event = new SyncTrainingEvent(this);
for (SyncTrainingListener listener : listeners) {
if (listener.onEpochStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
private boolean notifyEpochFinished(IDataManager.StatEntry statEntry) {
SyncTrainingEpochEndEvent event = new SyncTrainingEpochEndEvent(this, statEntry);
for (SyncTrainingListener listener : listeners) {
if (listener.onEpochEnd(event) == SyncTrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
protected abstract void preEpoch();
protected abstract void postEpoch();
protected abstract IDataManager.StatEntry trainEpoch();
protected abstract IDataManager.StatEntry trainEpoch(); // TODO: finish removal of IDataManager from Learning
}

View File

@ -0,0 +1,22 @@
package org.deeplearning4j.rl4j.learning.sync.listener;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
* A subclass of SyncTrainingEvent that is passed to SyncTrainingListener.onEpochEnd()
*/
public class SyncTrainingEpochEndEvent extends SyncTrainingEvent {
/**
* The stats of the epoch training
*/
@Getter
private final IDataManager.StatEntry statEntry;
public SyncTrainingEpochEndEvent(Learning learning, IDataManager.StatEntry statEntry) {
super(learning);
this.statEntry = statEntry;
}
}

View File

@ -0,0 +1,21 @@
package org.deeplearning4j.rl4j.learning.sync.listener;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.rl4j.learning.Learning;
/**
* SyncTrainingEvent are passed as parameters to the events of SyncTrainingListener
*/
public class SyncTrainingEvent {
/**
* The source of the event
*/
@Getter
private final Learning learning;
public SyncTrainingEvent(Learning learning) {
this.learning = learning;
}
}

View File

@ -0,0 +1,45 @@
package org.deeplearning4j.rl4j.learning.sync.listener;
/**
* A listener interface to use with a descendant of {@link org.deeplearning4j.rl4j.learning.sync.SyncLearning}
*/
public interface SyncTrainingListener {
public enum ListenerResponse {
/**
* Tell SyncLearning to continue calling the listeners and the training.
*/
CONTINUE,
/**
* Tell SyncLearning to stop calling the listeners and terminate the training.
*/
STOP,
}
/**
* Called once when the training starts.
* @param event
* @return A ListenerResponse telling the source of the event if it should go on or cancel the training.
*/
ListenerResponse onTrainingStart(SyncTrainingEvent event);
/**
* Called once when the training has finished. This method is called even when the training has been aborted.
*/
void onTrainingEnd();
/**
* Called before the start of every epoch.
* @param event
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onEpochStart(SyncTrainingEvent event);
/**
* Called after the end of every epoch.
* @param event
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event);
}

View File

@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import lombok.Getter;
import lombok.Setter;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.Transition;
@ -29,12 +28,13 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.Constants;
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.ArrayList;
@ -53,8 +53,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Getter
final private QLConfiguration configuration;
@Getter
final private IDataManager dataManager;
@Getter
final private MDP<O, Integer, DiscreteSpace> mdp;
@Getter
final private IDQN currentDQN;
@ -68,24 +66,31 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
private int lastAction;
private INDArray history[] = null;
private double accuReward = 0;
private int lastMonitor = -Constants.MONITOR_FREQ;
/**
* @deprecated
* Use QLearningDiscrete(MDP, IDQN, QLConfiguration, int) and add the required listeners with addListener() instead.
*/
@Deprecated
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
IDataManager dataManager, int epsilonNbStep) {
this(mdp, dqn, conf, epsilonNbStep);
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
}
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
int epsilonNbStep) {
super(conf);
this.configuration = conf;
this.mdp = mdp;
this.dataManager = dataManager;
currentDQN = dqn;
targetDQN = dqn.clone();
policy = new DQNPolicy(getCurrentDQN());
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(),
this);
this);
mdp.getActionSpace().setSeed(conf.getSeed());
}
public void postEpoch() {
if (getHistoryProcessor() != null)
@ -97,14 +102,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
history = null;
lastAction = 0;
accuReward = 0;
if (getStepCounter() - lastMonitor >= Constants.MONITOR_FREQ && getHistoryProcessor() != null
&& getDataManager().isSaveData()) {
lastMonitor = getStepCounter();
int[] shape = getMdp().getObservationSpace().getShape();
getHistoryProcessor().startMonitor(getDataManager().getVideoDir() + "/video-" + getEpochCounter() + "-"
+ getStepCounter() + ".mp4", shape);
}
}
/**

View File

@ -0,0 +1,126 @@
package org.deeplearning4j.rl4j.util;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
/**
* DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the
* training process can be fed to the DataManager
*/
@Slf4j
public class DataManagerSyncTrainingListener implements SyncTrainingListener {
private final IDataManager dataManager;
private final int saveFrequency;
private final int monitorFrequency;
private int lastSave;
private int lastMonitor;
private DataManagerSyncTrainingListener(Builder builder) {
this.dataManager = builder.dataManager;
this.saveFrequency = builder.saveFrequency;
this.lastSave = -builder.saveFrequency;
this.monitorFrequency = builder.monitorFrequency;
this.lastMonitor = -builder.monitorFrequency;
}
@Override
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
try {
dataManager.writeInfo(event.getLearning());
} catch (Exception e) {
log.error("Training failed.", e);
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
@Override
public void onTrainingEnd() {
// Do nothing
}
@Override
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
int stepCounter = event.getLearning().getStepCounter();
if (stepCounter - lastMonitor >= monitorFrequency
&& event.getLearning().getHistoryProcessor() != null
&& dataManager.isSaveData()) {
lastMonitor = stepCounter;
int[] shape = event.getLearning().getMdp().getObservationSpace().getShape();
event.getLearning().getHistoryProcessor().startMonitor(dataManager.getVideoDir() + "/video-" + event.getLearning().getEpochCounter() + "-"
+ stepCounter + ".mp4", shape);
}
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
try {
int stepCounter = event.getLearning().getStepCounter();
if (stepCounter - lastSave >= saveFrequency) {
dataManager.save(event.getLearning());
lastSave = stepCounter;
}
dataManager.appendStat(event.getStatEntry());
dataManager.writeInfo(event.getLearning());
} catch (Exception e) {
log.error("Training failed.", e);
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
public static Builder builder(IDataManager dataManager) {
return new Builder(dataManager);
}
public static class Builder {
private final IDataManager dataManager;
private int saveFrequency = Constants.MODEL_SAVE_FREQ;
private int monitorFrequency = Constants.MONITOR_FREQ;
/**
* Create a Builder with the given DataManager
* @param dataManager
*/
public Builder(IDataManager dataManager) {
this.dataManager = dataManager;
}
/**
* A number that represent the number of steps since the last call to DataManager.save() before can it be called again.
* @param saveFrequency (Default: 100000)
*/
public Builder saveFrequency(int saveFrequency) {
this.saveFrequency = saveFrequency;
return this;
}
/**
* A number that represent the number of steps since the last call to HistoryProcessor.startMonitor() before can it be called again.
* @param monitorFrequency (Default: 10000)
*/
public Builder monitorFrequency(int monitorFrequency) {
this.monitorFrequency = monitorFrequency;
return this;
}
/**
* Creates a DataManagerSyncTrainingListener with the configured parameters
* @return An instance of DataManagerSyncTrainingListener
*/
public DataManagerSyncTrainingListener build() {
return new DataManagerSyncTrainingListener(this);
}
}
}

View File

@ -1,12 +1,13 @@
package org.deeplearning4j.rl4j.learning.sync;
import lombok.AllArgsConstructor;
import lombok.Value;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.learning.sync.support.MockSyncTrainingListener;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.support.MockDataManager;
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
@ -15,61 +16,101 @@ import static org.junit.Assert.assertEquals;
public class SyncLearningTest {
@Test
public void refac_checkDataManagerCallsRemainTheSame() {
public void when_training_expect_listenersToBeCalled() {
// Arrange
MockLConfiguration lconfig = new MockLConfiguration(10);
MockDataManager dataManager = new MockDataManager(false);
MockSyncLearning sut = new MockSyncLearning(lconfig, dataManager, 2);
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
// Act
sut.train();
assertEquals(10, dataManager.statEntries.size());
for(int i = 0; i < 10; ++i) {
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(2, entry.getEpochCounter());
assertEquals(i+1, entry.getStepCounter());
assertEquals(1.0, entry.getReward(), 0.0);
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(10, listener.onEpochStartCallCount);
assertEquals(10, listener.onEpochEndStartCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
}
}
assertEquals(0, dataManager.isSaveDataCallCount);
assertEquals(0, dataManager.getVideoDirCallCount);
assertEquals(11, dataManager.writeInfoCallCount);
assertEquals(1, dataManager.saveCallCount);
@Test
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
listener.trainingStartCanContinue = false;
// Act
sut.train();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(0, listener.onEpochStartCallCount);
assertEquals(0, listener.onEpochEndStartCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
}
@Test
public void when_epochStartCanContinueFalse_expect_trainingStopped() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
listener.nbStepsEpochStartCanContinue = 3;
// Act
sut.train();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(3, listener.onEpochStartCallCount);
assertEquals(2, listener.onEpochEndStartCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
}
@Test
public void when_epochEndCanContinueFalse_expect_trainingStopped() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
listener.nbStepsEpochEndCanContinue = 3;
// Act
sut.train();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(3, listener.onEpochStartCallCount);
assertEquals(3, listener.onEpochEndStartCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
}
public static class MockSyncLearning extends SyncLearning {
private final IDataManager dataManager;
private LConfiguration conf;
private final int epochSteps;
public MockSyncLearning(LConfiguration conf, IDataManager dataManager, int epochSteps) {
public MockSyncLearning(LConfiguration conf, IDataManager dataManager) {
super(conf);
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
this.conf = conf;
}
public MockSyncLearning(LConfiguration conf) {
super(conf);
this.dataManager = dataManager;
this.conf = conf;
this.epochSteps = epochSteps;
}
@Override
protected void preEpoch() {
}
protected void preEpoch() { }
@Override
protected void postEpoch() {
}
protected void postEpoch() { }
@Override
protected IDataManager.StatEntry trainEpoch() {
setStepCounter(getStepCounter() + 1);
return new MockStatEntry(epochSteps, getStepCounter(), 1.0);
}
@Override
protected IDataManager getDataManager() {
return dataManager;
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
}
@Override
@ -92,41 +133,4 @@ public class SyncLearningTest {
return null;
}
}
public static class MockLConfiguration implements ILearning.LConfiguration {
private final int maxStep;
public MockLConfiguration(int maxStep) {
this.maxStep = maxStep;
}
@Override
public int getSeed() {
return 0;
}
@Override
public int getMaxEpochStep() {
return 0;
}
@Override
public int getMaxStep() {
return maxStep;
}
@Override
public double getGamma() {
return 0;
}
}
@AllArgsConstructor
@Value
public static class MockStatEntry implements IDataManager.StatEntry {
int epochCounter;
int stepCounter;
double reward;
}
}

View File

@ -0,0 +1,65 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscrete;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
import org.deeplearning4j.rl4j.learning.sync.support.MockMDP;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.support.MockDataManager;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
public class QLearningDiscreteTest {
@Test
public void refac_checkDataManagerCallsRemainTheSame() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder()
.maxStep(10)
.expRepMaxSize(1)
.build();
MockDataManager dataManager = new MockDataManager(true);
MockQLearningDiscrete sut = new MockQLearningDiscrete(10, lconfig, dataManager, 2, 3);
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
.build();
sut.setHistoryProcessor(new MockHistoryProcessor(hpConfig));
// Act
sut.train();
assertEquals(10, dataManager.statEntries.size());
for(int i = 0; i < 10; ++i) {
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(i, entry.getEpochCounter());
assertEquals(i+1, entry.getStepCounter());
assertEquals(1.0, entry.getReward(), 0.0);
}
assertEquals(4, dataManager.isSaveDataCallCount);
assertEquals(4, dataManager.getVideoDirCallCount);
assertEquals(11, dataManager.writeInfoCallCount);
assertEquals(5, dataManager.saveCallCount);
}
public static class MockQLearningDiscrete extends QLearningDiscrete {
public MockQLearningDiscrete(int maxSteps, QLConfiguration conf,
IDataManager dataManager, int saveFrequency, int monitorFrequency) {
super(new MockMDP(maxSteps), new MockDQN(), conf, 2);
addListener(DataManagerSyncTrainingListener.builder(dataManager)
.saveFrequency(saveFrequency)
.monitorFrequency(monitorFrequency)
.build());
}
@Override
protected IDataManager.StatEntry trainEpoch() {
setStepCounter(getStepCounter() + 1);
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
}
}
}

View File

@ -0,0 +1,92 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.IOException;
import java.io.OutputStream;
public class MockDQN implements IDQN {
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}
@Override
public boolean isRecurrent() {
return false;
}
@Override
public void reset() {
}
@Override
public void fit(INDArray input, INDArray labels) {
}
@Override
public void fit(INDArray input, INDArray[] labels) {
}
@Override
public INDArray output(INDArray batch) {
return null;
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];
}
@Override
public IDQN clone() {
return null;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public void copy(IDQN from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray label) {
return new Gradient[0];
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] label) {
return new Gradient[0];
}
@Override
public void applyGradient(Gradient[] gradient, int batchSize) {
}
@Override
public double getLatestScore() {
return 0;
}
@Override
public void save(OutputStream os) throws IOException {
}
@Override
public void save(String filename) throws IOException {
}
}

View File

@ -0,0 +1,79 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray;
public class MockMDP implements MDP<Object, Integer, DiscreteSpace> {
private final int maxSteps;
private final DiscreteSpace actionSpace = new DiscreteSpace(1);
private final MockObservationSpace observationSpace = new MockObservationSpace();
private int currentStep = 0;
public MockMDP(int maxSteps) {
this.maxSteps = maxSteps;
}
@Override
public ObservationSpace<Object> getObservationSpace() {
return observationSpace;
}
@Override
public DiscreteSpace getActionSpace() {
return actionSpace;
}
@Override
public Object reset() {
return null;
}
@Override
public void close() {
}
@Override
public StepReply<Object> step(Integer integer) {
return new StepReply<Object>(null, 1.0, isDone(), null);
}
@Override
public boolean isDone() {
return currentStep >= maxSteps;
}
@Override
public MDP<Object, Integer, DiscreteSpace> newInstance() {
return null;
}
private static class MockObservationSpace implements ObservationSpace {
@Override
public String getName() {
return null;
}
@Override
public int[] getShape() {
return new int[0];
}
@Override
public INDArray getLow() {
return null;
}
@Override
public INDArray getHigh() {
return null;
}
}
}

View File

@ -0,0 +1,13 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import lombok.AllArgsConstructor;
import lombok.Value;
import org.deeplearning4j.rl4j.util.IDataManager;
@AllArgsConstructor
@Value
public class MockStatEntry implements IDataManager.StatEntry {
int epochCounter;
int stepCounter;
double reward;
}

View File

@ -0,0 +1,46 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
public class MockSyncTrainingListener implements SyncTrainingListener {
public int onTrainingStartCallCount = 0;
public int onTrainingEndCallCount = 0;
public int onEpochStartCallCount = 0;
public int onEpochEndStartCallCount = 0;
public boolean trainingStartCanContinue = true;
public int nbStepsEpochStartCanContinue = Integer.MAX_VALUE;
public int nbStepsEpochEndCanContinue = Integer.MAX_VALUE;
@Override
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
++onTrainingStartCallCount;
return trainingStartCanContinue ? ListenerResponse.CONTINUE : ListenerResponse.STOP;
}
@Override
public void onTrainingEnd() {
++onTrainingEndCallCount;
}
@Override
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
++onEpochStartCallCount;
if(onEpochStartCallCount >= nbStepsEpochStartCanContinue) {
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
++onEpochEndStartCallCount;
if(onEpochEndStartCallCount >= nbStepsEpochEndCanContinue) {
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
}