RL4J Added listener pattern to SyncLearning (#8050)
* Added listener pattern to SyncLearning Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Did requested changes Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
0527ab8d98
commit
b2145ca780
|
@ -125,8 +125,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
|||
return nshape;
|
||||
}
|
||||
|
||||
protected abstract IDataManager getDataManager();
|
||||
|
||||
public abstract NN getNeuralNet();
|
||||
|
||||
public int incrementStep() {
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
|||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
|
@ -36,6 +37,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||
extends Learning<O, A, AS, NN> {
|
||||
|
||||
protected abstract IDataManager getDataManager();
|
||||
|
||||
public AsyncLearning(AsyncConfiguration conf) {
|
||||
super(conf);
|
||||
|
|
|
@ -18,67 +18,132 @@ package org.deeplearning4j.rl4j.learning.sync;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.Constants;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16.
|
||||
*
|
||||
* Mother class and useful factorisations for all training methods that
|
||||
* are not asynchronous.
|
||||
*
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16.
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||
extends Learning<O, A, AS, NN> {
|
||||
extends Learning<O, A, AS, NN> {
|
||||
|
||||
private int lastSave = -Constants.MODEL_SAVE_FREQ;
|
||||
private List<SyncTrainingListener> listeners = new ArrayList<>();
|
||||
|
||||
public SyncLearning(LConfiguration conf) {
|
||||
super(conf);
|
||||
}
|
||||
|
||||
public void train() {
|
||||
|
||||
try {
|
||||
log.info("training starting.");
|
||||
|
||||
getDataManager().writeInfo(this);
|
||||
|
||||
|
||||
while (getStepCounter() < getConfiguration().getMaxStep()) {
|
||||
preEpoch();
|
||||
IDataManager.StatEntry statEntry = trainEpoch();
|
||||
postEpoch();
|
||||
|
||||
incrementEpoch();
|
||||
|
||||
if (getStepCounter() - lastSave >= Constants.MODEL_SAVE_FREQ) {
|
||||
getDataManager().save(this);
|
||||
lastSave = getStepCounter();
|
||||
}
|
||||
|
||||
getDataManager().appendStat(statEntry);
|
||||
getDataManager().writeInfo(this);
|
||||
|
||||
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("Training failed.", e);
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Add a listener at the end of the listener list.
|
||||
*
|
||||
* @param listener
|
||||
*/
|
||||
public void addListener(SyncTrainingListener listener) {
|
||||
listeners.add(listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method will train the model<p>
|
||||
* The training stop when:<br>
|
||||
* - the number of steps reaches the maximum defined in the configuration (see {@link LConfiguration#getMaxStep() LConfiguration.getMaxStep()})<br>
|
||||
* OR<br>
|
||||
* - a listener explicitly stops it<br>
|
||||
* <p>
|
||||
* Listeners<br>
|
||||
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
|
||||
* returns {@link SyncTrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
|
||||
* Events:
|
||||
* <ul>
|
||||
* <li>{@link SyncTrainingListener#onTrainingStart(SyncTrainingEvent) onTrainingStart()} is called once when the training starts.</li>
|
||||
* <li>{@link SyncTrainingListener#onEpochStart(SyncTrainingEvent) onEpochStart()} and {@link SyncTrainingListener#onEpochEnd(SyncTrainingEpochEndEvent) onEpochEnd()} are called for every epoch. onEpochEnd will not be called if onEpochStart stops the training</li>
|
||||
* <li>{@link SyncTrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
|
||||
* </ul>
|
||||
*/
|
||||
public void train() {
|
||||
|
||||
log.info("training starting.");
|
||||
|
||||
boolean canContinue = notifyTrainingStarted();
|
||||
if (canContinue) {
|
||||
while (getStepCounter() < getConfiguration().getMaxStep()) {
|
||||
preEpoch();
|
||||
canContinue = notifyEpochStarted();
|
||||
if (!canContinue) {
|
||||
break;
|
||||
}
|
||||
|
||||
IDataManager.StatEntry statEntry = trainEpoch();
|
||||
|
||||
postEpoch();
|
||||
canContinue = notifyEpochFinished(statEntry);
|
||||
if (!canContinue) {
|
||||
break;
|
||||
}
|
||||
|
||||
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
||||
|
||||
incrementEpoch();
|
||||
}
|
||||
}
|
||||
|
||||
notifyTrainingFinished();
|
||||
}
|
||||
|
||||
private boolean notifyTrainingStarted() {
|
||||
SyncTrainingEvent event = new SyncTrainingEvent(this);
|
||||
for (SyncTrainingListener listener : listeners) {
|
||||
if (listener.onTrainingStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private void notifyTrainingFinished() {
|
||||
for (SyncTrainingListener listener : listeners) {
|
||||
listener.onTrainingEnd();
|
||||
}
|
||||
}
|
||||
|
||||
private boolean notifyEpochStarted() {
|
||||
SyncTrainingEvent event = new SyncTrainingEvent(this);
|
||||
for (SyncTrainingListener listener : listeners) {
|
||||
if (listener.onEpochStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean notifyEpochFinished(IDataManager.StatEntry statEntry) {
|
||||
SyncTrainingEpochEndEvent event = new SyncTrainingEpochEndEvent(this, statEntry);
|
||||
for (SyncTrainingListener listener : listeners) {
|
||||
if (listener.onEpochEnd(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected abstract void preEpoch();
|
||||
|
||||
protected abstract void postEpoch();
|
||||
|
||||
protected abstract IDataManager.StatEntry trainEpoch();
|
||||
|
||||
protected abstract IDataManager.StatEntry trainEpoch(); // TODO: finish removal of IDataManager from Learning
|
||||
}
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.listener;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* A subclass of SyncTrainingEvent that is passed to SyncTrainingListener.onEpochEnd()
|
||||
*/
|
||||
public class SyncTrainingEpochEndEvent extends SyncTrainingEvent {
|
||||
|
||||
/**
|
||||
* The stats of the epoch training
|
||||
*/
|
||||
@Getter
|
||||
private final IDataManager.StatEntry statEntry;
|
||||
|
||||
public SyncTrainingEpochEndEvent(Learning learning, IDataManager.StatEntry statEntry) {
|
||||
super(learning);
|
||||
this.statEntry = statEntry;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.listener;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
|
||||
/**
|
||||
* SyncTrainingEvent are passed as parameters to the events of SyncTrainingListener
|
||||
*/
|
||||
public class SyncTrainingEvent {
|
||||
|
||||
/**
|
||||
* The source of the event
|
||||
*/
|
||||
@Getter
|
||||
private final Learning learning;
|
||||
|
||||
public SyncTrainingEvent(Learning learning) {
|
||||
this.learning = learning;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.listener;
|
||||
|
||||
/**
|
||||
* A listener interface to use with a descendant of {@link org.deeplearning4j.rl4j.learning.sync.SyncLearning}
|
||||
*/
|
||||
public interface SyncTrainingListener {
|
||||
|
||||
public enum ListenerResponse {
|
||||
/**
|
||||
* Tell SyncLearning to continue calling the listeners and the training.
|
||||
*/
|
||||
CONTINUE,
|
||||
|
||||
/**
|
||||
* Tell SyncLearning to stop calling the listeners and terminate the training.
|
||||
*/
|
||||
STOP,
|
||||
}
|
||||
|
||||
/**
|
||||
* Called once when the training starts.
|
||||
* @param event
|
||||
* @return A ListenerResponse telling the source of the event if it should go on or cancel the training.
|
||||
*/
|
||||
ListenerResponse onTrainingStart(SyncTrainingEvent event);
|
||||
|
||||
/**
|
||||
* Called once when the training has finished. This method is called even when the training has been aborted.
|
||||
*/
|
||||
void onTrainingEnd();
|
||||
|
||||
/**
|
||||
* Called before the start of every epoch.
|
||||
* @param event
|
||||
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||
*/
|
||||
ListenerResponse onEpochStart(SyncTrainingEvent event);
|
||||
|
||||
/**
|
||||
* Called after the end of every epoch.
|
||||
* @param event
|
||||
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||
*/
|
||||
ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event);
|
||||
}
|
|
@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
|||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
|
@ -29,12 +28,13 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
|||
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.Constants;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -53,8 +53,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
@Getter
|
||||
final private QLConfiguration configuration;
|
||||
@Getter
|
||||
final private IDataManager dataManager;
|
||||
@Getter
|
||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
||||
@Getter
|
||||
final private IDQN currentDQN;
|
||||
|
@ -68,24 +66,31 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
private int lastAction;
|
||||
private INDArray history[] = null;
|
||||
private double accuReward = 0;
|
||||
private int lastMonitor = -Constants.MONITOR_FREQ;
|
||||
|
||||
|
||||
/**
|
||||
* @deprecated
|
||||
* Use QLearningDiscrete(MDP, IDQN, QLConfiguration, int) and add the required listeners with addListener() instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||
IDataManager dataManager, int epsilonNbStep) {
|
||||
this(mdp, dqn, conf, epsilonNbStep);
|
||||
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
|
||||
}
|
||||
|
||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||
int epsilonNbStep) {
|
||||
super(conf);
|
||||
this.configuration = conf;
|
||||
this.mdp = mdp;
|
||||
this.dataManager = dataManager;
|
||||
currentDQN = dqn;
|
||||
targetDQN = dqn.clone();
|
||||
policy = new DQNPolicy(getCurrentDQN());
|
||||
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(),
|
||||
this);
|
||||
this);
|
||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||
}
|
||||
|
||||
|
||||
public void postEpoch() {
|
||||
|
||||
if (getHistoryProcessor() != null)
|
||||
|
@ -97,14 +102,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
history = null;
|
||||
lastAction = 0;
|
||||
accuReward = 0;
|
||||
|
||||
if (getStepCounter() - lastMonitor >= Constants.MONITOR_FREQ && getHistoryProcessor() != null
|
||||
&& getDataManager().isSaveData()) {
|
||||
lastMonitor = getStepCounter();
|
||||
int[] shape = getMdp().getObservationSpace().getShape();
|
||||
getHistoryProcessor().startMonitor(getDataManager().getVideoDir() + "/video-" + getEpochCounter() + "-"
|
||||
+ getStepCounter() + ".mp4", shape);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
package org.deeplearning4j.rl4j.util;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
||||
|
||||
/**
|
||||
* DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the
|
||||
* training process can be fed to the DataManager
|
||||
*/
|
||||
@Slf4j
|
||||
public class DataManagerSyncTrainingListener implements SyncTrainingListener {
|
||||
private final IDataManager dataManager;
|
||||
private final int saveFrequency;
|
||||
private final int monitorFrequency;
|
||||
|
||||
private int lastSave;
|
||||
private int lastMonitor;
|
||||
|
||||
private DataManagerSyncTrainingListener(Builder builder) {
|
||||
this.dataManager = builder.dataManager;
|
||||
|
||||
this.saveFrequency = builder.saveFrequency;
|
||||
this.lastSave = -builder.saveFrequency;
|
||||
|
||||
this.monitorFrequency = builder.monitorFrequency;
|
||||
this.lastMonitor = -builder.monitorFrequency;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
|
||||
try {
|
||||
dataManager.writeInfo(event.getLearning());
|
||||
} catch (Exception e) {
|
||||
log.error("Training failed.", e);
|
||||
return ListenerResponse.STOP;
|
||||
}
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrainingEnd() {
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
|
||||
int stepCounter = event.getLearning().getStepCounter();
|
||||
|
||||
if (stepCounter - lastMonitor >= monitorFrequency
|
||||
&& event.getLearning().getHistoryProcessor() != null
|
||||
&& dataManager.isSaveData()) {
|
||||
lastMonitor = stepCounter;
|
||||
int[] shape = event.getLearning().getMdp().getObservationSpace().getShape();
|
||||
event.getLearning().getHistoryProcessor().startMonitor(dataManager.getVideoDir() + "/video-" + event.getLearning().getEpochCounter() + "-"
|
||||
+ stepCounter + ".mp4", shape);
|
||||
}
|
||||
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
|
||||
try {
|
||||
int stepCounter = event.getLearning().getStepCounter();
|
||||
if (stepCounter - lastSave >= saveFrequency) {
|
||||
dataManager.save(event.getLearning());
|
||||
lastSave = stepCounter;
|
||||
}
|
||||
|
||||
dataManager.appendStat(event.getStatEntry());
|
||||
dataManager.writeInfo(event.getLearning());
|
||||
} catch (Exception e) {
|
||||
log.error("Training failed.", e);
|
||||
return ListenerResponse.STOP;
|
||||
}
|
||||
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
public static Builder builder(IDataManager dataManager) {
|
||||
return new Builder(dataManager);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private final IDataManager dataManager;
|
||||
private int saveFrequency = Constants.MODEL_SAVE_FREQ;
|
||||
private int monitorFrequency = Constants.MONITOR_FREQ;
|
||||
|
||||
/**
|
||||
* Create a Builder with the given DataManager
|
||||
* @param dataManager
|
||||
*/
|
||||
public Builder(IDataManager dataManager) {
|
||||
this.dataManager = dataManager;
|
||||
}
|
||||
|
||||
/**
|
||||
* A number that represent the number of steps since the last call to DataManager.save() before can it be called again.
|
||||
* @param saveFrequency (Default: 100000)
|
||||
*/
|
||||
public Builder saveFrequency(int saveFrequency) {
|
||||
this.saveFrequency = saveFrequency;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* A number that represent the number of steps since the last call to HistoryProcessor.startMonitor() before can it be called again.
|
||||
* @param monitorFrequency (Default: 10000)
|
||||
*/
|
||||
public Builder monitorFrequency(int monitorFrequency) {
|
||||
this.monitorFrequency = monitorFrequency;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a DataManagerSyncTrainingListener with the configured parameters
|
||||
* @return An instance of DataManagerSyncTrainingListener
|
||||
*/
|
||||
public DataManagerSyncTrainingListener build() {
|
||||
return new DataManagerSyncTrainingListener(this);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -1,12 +1,13 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Value;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockSyncTrainingListener;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -15,61 +16,101 @@ import static org.junit.Assert.assertEquals;
|
|||
public class SyncLearningTest {
|
||||
|
||||
@Test
|
||||
public void refac_checkDataManagerCallsRemainTheSame() {
|
||||
public void when_training_expect_listenersToBeCalled() {
|
||||
// Arrange
|
||||
MockLConfiguration lconfig = new MockLConfiguration(10);
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
MockSyncLearning sut = new MockSyncLearning(lconfig, dataManager, 2);
|
||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||
sut.addListener(listener);
|
||||
|
||||
// Act
|
||||
sut.train();
|
||||
|
||||
assertEquals(10, dataManager.statEntries.size());
|
||||
for(int i = 0; i < 10; ++i) {
|
||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||
assertEquals(2, entry.getEpochCounter());
|
||||
assertEquals(i+1, entry.getStepCounter());
|
||||
assertEquals(1.0, entry.getReward(), 0.0);
|
||||
assertEquals(1, listener.onTrainingStartCallCount);
|
||||
assertEquals(10, listener.onEpochStartCallCount);
|
||||
assertEquals(10, listener.onEpochEndStartCallCount);
|
||||
assertEquals(1, listener.onTrainingEndCallCount);
|
||||
}
|
||||
|
||||
}
|
||||
assertEquals(0, dataManager.isSaveDataCallCount);
|
||||
assertEquals(0, dataManager.getVideoDirCallCount);
|
||||
assertEquals(11, dataManager.writeInfoCallCount);
|
||||
assertEquals(1, dataManager.saveCallCount);
|
||||
@Test
|
||||
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
|
||||
// Arrange
|
||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||
sut.addListener(listener);
|
||||
listener.trainingStartCanContinue = false;
|
||||
|
||||
// Act
|
||||
sut.train();
|
||||
|
||||
assertEquals(1, listener.onTrainingStartCallCount);
|
||||
assertEquals(0, listener.onEpochStartCallCount);
|
||||
assertEquals(0, listener.onEpochEndStartCallCount);
|
||||
assertEquals(1, listener.onTrainingEndCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_epochStartCanContinueFalse_expect_trainingStopped() {
|
||||
// Arrange
|
||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||
sut.addListener(listener);
|
||||
listener.nbStepsEpochStartCanContinue = 3;
|
||||
|
||||
// Act
|
||||
sut.train();
|
||||
|
||||
assertEquals(1, listener.onTrainingStartCallCount);
|
||||
assertEquals(3, listener.onEpochStartCallCount);
|
||||
assertEquals(2, listener.onEpochEndStartCallCount);
|
||||
assertEquals(1, listener.onTrainingEndCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_epochEndCanContinueFalse_expect_trainingStopped() {
|
||||
// Arrange
|
||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||
sut.addListener(listener);
|
||||
listener.nbStepsEpochEndCanContinue = 3;
|
||||
|
||||
// Act
|
||||
sut.train();
|
||||
|
||||
assertEquals(1, listener.onTrainingStartCallCount);
|
||||
assertEquals(3, listener.onEpochStartCallCount);
|
||||
assertEquals(3, listener.onEpochEndStartCallCount);
|
||||
assertEquals(1, listener.onTrainingEndCallCount);
|
||||
}
|
||||
|
||||
public static class MockSyncLearning extends SyncLearning {
|
||||
|
||||
private final IDataManager dataManager;
|
||||
private LConfiguration conf;
|
||||
private final int epochSteps;
|
||||
|
||||
public MockSyncLearning(LConfiguration conf, IDataManager dataManager, int epochSteps) {
|
||||
public MockSyncLearning(LConfiguration conf, IDataManager dataManager) {
|
||||
super(conf);
|
||||
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
|
||||
this.conf = conf;
|
||||
}
|
||||
|
||||
public MockSyncLearning(LConfiguration conf) {
|
||||
super(conf);
|
||||
this.dataManager = dataManager;
|
||||
this.conf = conf;
|
||||
this.epochSteps = epochSteps;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void preEpoch() {
|
||||
|
||||
}
|
||||
protected void preEpoch() { }
|
||||
|
||||
@Override
|
||||
protected void postEpoch() {
|
||||
|
||||
}
|
||||
protected void postEpoch() { }
|
||||
|
||||
@Override
|
||||
protected IDataManager.StatEntry trainEpoch() {
|
||||
setStepCounter(getStepCounter() + 1);
|
||||
return new MockStatEntry(epochSteps, getStepCounter(), 1.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IDataManager getDataManager() {
|
||||
return dataManager;
|
||||
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -92,41 +133,4 @@ public class SyncLearningTest {
|
|||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockLConfiguration implements ILearning.LConfiguration {
|
||||
|
||||
private final int maxStep;
|
||||
|
||||
public MockLConfiguration(int maxStep) {
|
||||
this.maxStep = maxStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSeed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxEpochStep() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxStep() {
|
||||
return maxStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getGamma() {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
@Value
|
||||
public static class MockStatEntry implements IDataManager.StatEntry {
|
||||
int epochCounter;
|
||||
int stepCounter;
|
||||
double reward;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscrete;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockMDP;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class QLearningDiscreteTest {
|
||||
@Test
|
||||
public void refac_checkDataManagerCallsRemainTheSame() {
|
||||
// Arrange
|
||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder()
|
||||
.maxStep(10)
|
||||
.expRepMaxSize(1)
|
||||
.build();
|
||||
MockDataManager dataManager = new MockDataManager(true);
|
||||
MockQLearningDiscrete sut = new MockQLearningDiscrete(10, lconfig, dataManager, 2, 3);
|
||||
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
||||
.build();
|
||||
sut.setHistoryProcessor(new MockHistoryProcessor(hpConfig));
|
||||
|
||||
// Act
|
||||
sut.train();
|
||||
|
||||
assertEquals(10, dataManager.statEntries.size());
|
||||
for(int i = 0; i < 10; ++i) {
|
||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||
assertEquals(i, entry.getEpochCounter());
|
||||
assertEquals(i+1, entry.getStepCounter());
|
||||
assertEquals(1.0, entry.getReward(), 0.0);
|
||||
|
||||
}
|
||||
assertEquals(4, dataManager.isSaveDataCallCount);
|
||||
assertEquals(4, dataManager.getVideoDirCallCount);
|
||||
assertEquals(11, dataManager.writeInfoCallCount);
|
||||
assertEquals(5, dataManager.saveCallCount);
|
||||
}
|
||||
|
||||
public static class MockQLearningDiscrete extends QLearningDiscrete {
|
||||
|
||||
public MockQLearningDiscrete(int maxSteps, QLConfiguration conf,
|
||||
IDataManager dataManager, int saveFrequency, int monitorFrequency) {
|
||||
super(new MockMDP(maxSteps), new MockDQN(), conf, 2);
|
||||
addListener(DataManagerSyncTrainingListener.builder(dataManager)
|
||||
.saveFrequency(saveFrequency)
|
||||
.monitorFrequency(monitorFrequency)
|
||||
.build());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IDataManager.StatEntry trainEpoch() {
|
||||
setStepCounter(getStepCounter() + 1);
|
||||
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
|
||||
public class MockDQN implements IDQN {
|
||||
@Override
|
||||
public NeuralNetwork[] getNeuralNetworks() {
|
||||
return new NeuralNetwork[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRecurrent() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(INDArray input, INDArray labels) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(INDArray input, INDArray[] labels) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray output(INDArray batch) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] outputAll(INDArray batch) {
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public IDQN clone() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NeuralNet from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(IDQN from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray label) {
|
||||
return new Gradient[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray[] label) {
|
||||
return new Gradient[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void applyGradient(Gradient[] gradient, int batchSize) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getLatestScore() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(OutputStream os) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(String filename) throws IOException {
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
public class MockMDP implements MDP<Object, Integer, DiscreteSpace> {
|
||||
|
||||
private final int maxSteps;
|
||||
private final DiscreteSpace actionSpace = new DiscreteSpace(1);
|
||||
private final MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
|
||||
private int currentStep = 0;
|
||||
|
||||
public MockMDP(int maxSteps) {
|
||||
|
||||
this.maxSteps = maxSteps;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObservationSpace<Object> getObservationSpace() {
|
||||
return observationSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DiscreteSpace getActionSpace() {
|
||||
return actionSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object reset() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public StepReply<Object> step(Integer integer) {
|
||||
return new StepReply<Object>(null, 1.0, isDone(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isDone() {
|
||||
return currentStep >= maxSteps;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MDP<Object, Integer, DiscreteSpace> newInstance() {
|
||||
return null;
|
||||
}
|
||||
|
||||
private static class MockObservationSpace implements ObservationSpace {
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int[] getShape() {
|
||||
return new int[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getLow() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getHigh() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Value;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
@AllArgsConstructor
|
||||
@Value
|
||||
public class MockStatEntry implements IDataManager.StatEntry {
|
||||
int epochCounter;
|
||||
int stepCounter;
|
||||
double reward;
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
||||
|
||||
public class MockSyncTrainingListener implements SyncTrainingListener {
|
||||
|
||||
public int onTrainingStartCallCount = 0;
|
||||
public int onTrainingEndCallCount = 0;
|
||||
public int onEpochStartCallCount = 0;
|
||||
public int onEpochEndStartCallCount = 0;
|
||||
|
||||
public boolean trainingStartCanContinue = true;
|
||||
public int nbStepsEpochStartCanContinue = Integer.MAX_VALUE;
|
||||
public int nbStepsEpochEndCanContinue = Integer.MAX_VALUE;
|
||||
|
||||
@Override
|
||||
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
|
||||
++onTrainingStartCallCount;
|
||||
return trainingStartCanContinue ? ListenerResponse.CONTINUE : ListenerResponse.STOP;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrainingEnd() {
|
||||
++onTrainingEndCallCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
|
||||
++onEpochStartCallCount;
|
||||
if(onEpochStartCallCount >= nbStepsEpochStartCanContinue) {
|
||||
return ListenerResponse.STOP;
|
||||
}
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
|
||||
++onEpochEndStartCallCount;
|
||||
if(onEpochEndStartCallCount >= nbStepsEpochEndCanContinue) {
|
||||
return ListenerResponse.STOP;
|
||||
}
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue