RL4J Added listener pattern to SyncLearning (#8050)
* Added listener pattern to SyncLearning Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Did requested changes Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
0527ab8d98
commit
b2145ca780
|
@ -125,8 +125,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
||||||
return nshape;
|
return nshape;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract IDataManager getDataManager();
|
|
||||||
|
|
||||||
public abstract NN getNeuralNet();
|
public abstract NN getNeuralNet();
|
||||||
|
|
||||||
public int incrementStep() {
|
public int incrementStep() {
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -36,6 +37,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
extends Learning<O, A, AS, NN> {
|
extends Learning<O, A, AS, NN> {
|
||||||
|
|
||||||
|
protected abstract IDataManager getDataManager();
|
||||||
|
|
||||||
public AsyncLearning(AsyncConfiguration conf) {
|
public AsyncLearning(AsyncConfiguration conf) {
|
||||||
super(conf);
|
super(conf);
|
||||||
|
|
|
@ -18,67 +18,132 @@ package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.Constants;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16.
|
|
||||||
*
|
|
||||||
* Mother class and useful factorisations for all training methods that
|
* Mother class and useful factorisations for all training methods that
|
||||||
* are not asynchronous.
|
* are not asynchronous.
|
||||||
*
|
*
|
||||||
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16.
|
||||||
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
extends Learning<O, A, AS, NN> {
|
extends Learning<O, A, AS, NN> {
|
||||||
|
|
||||||
private int lastSave = -Constants.MODEL_SAVE_FREQ;
|
private List<SyncTrainingListener> listeners = new ArrayList<>();
|
||||||
|
|
||||||
public SyncLearning(LConfiguration conf) {
|
public SyncLearning(LConfiguration conf) {
|
||||||
super(conf);
|
super(conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a listener at the end of the listener list.
|
||||||
|
*
|
||||||
|
* @param listener
|
||||||
|
*/
|
||||||
|
public void addListener(SyncTrainingListener listener) {
|
||||||
|
listeners.add(listener);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will train the model<p>
|
||||||
|
* The training stop when:<br>
|
||||||
|
* - the number of steps reaches the maximum defined in the configuration (see {@link LConfiguration#getMaxStep() LConfiguration.getMaxStep()})<br>
|
||||||
|
* OR<br>
|
||||||
|
* - a listener explicitly stops it<br>
|
||||||
|
* <p>
|
||||||
|
* Listeners<br>
|
||||||
|
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
|
||||||
|
* returns {@link SyncTrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
|
||||||
|
* Events:
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link SyncTrainingListener#onTrainingStart(SyncTrainingEvent) onTrainingStart()} is called once when the training starts.</li>
|
||||||
|
* <li>{@link SyncTrainingListener#onEpochStart(SyncTrainingEvent) onEpochStart()} and {@link SyncTrainingListener#onEpochEnd(SyncTrainingEpochEndEvent) onEpochEnd()} are called for every epoch. onEpochEnd will not be called if onEpochStart stops the training</li>
|
||||||
|
* <li>{@link SyncTrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
public void train() {
|
public void train() {
|
||||||
|
|
||||||
try {
|
|
||||||
log.info("training starting.");
|
log.info("training starting.");
|
||||||
|
|
||||||
getDataManager().writeInfo(this);
|
boolean canContinue = notifyTrainingStarted();
|
||||||
|
if (canContinue) {
|
||||||
|
|
||||||
while (getStepCounter() < getConfiguration().getMaxStep()) {
|
while (getStepCounter() < getConfiguration().getMaxStep()) {
|
||||||
preEpoch();
|
preEpoch();
|
||||||
IDataManager.StatEntry statEntry = trainEpoch();
|
canContinue = notifyEpochStarted();
|
||||||
postEpoch();
|
if (!canContinue) {
|
||||||
|
break;
|
||||||
incrementEpoch();
|
|
||||||
|
|
||||||
if (getStepCounter() - lastSave >= Constants.MODEL_SAVE_FREQ) {
|
|
||||||
getDataManager().save(this);
|
|
||||||
lastSave = getStepCounter();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
getDataManager().appendStat(statEntry);
|
IDataManager.StatEntry statEntry = trainEpoch();
|
||||||
getDataManager().writeInfo(this);
|
|
||||||
|
postEpoch();
|
||||||
|
canContinue = notifyEpochFinished(statEntry);
|
||||||
|
if (!canContinue) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
||||||
|
|
||||||
|
incrementEpoch();
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Training failed.", e);
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
notifyTrainingFinished();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean notifyTrainingStarted() {
|
||||||
|
SyncTrainingEvent event = new SyncTrainingEvent(this);
|
||||||
|
for (SyncTrainingListener listener : listeners) {
|
||||||
|
if (listener.onTrainingStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void notifyTrainingFinished() {
|
||||||
|
for (SyncTrainingListener listener : listeners) {
|
||||||
|
listener.onTrainingEnd();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean notifyEpochStarted() {
|
||||||
|
SyncTrainingEvent event = new SyncTrainingEvent(this);
|
||||||
|
for (SyncTrainingListener listener : listeners) {
|
||||||
|
if (listener.onEpochStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean notifyEpochFinished(IDataManager.StatEntry statEntry) {
|
||||||
|
SyncTrainingEpochEndEvent event = new SyncTrainingEpochEndEvent(this, statEntry);
|
||||||
|
for (SyncTrainingListener listener : listeners) {
|
||||||
|
if (listener.onEpochEnd(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
protected abstract void preEpoch();
|
protected abstract void preEpoch();
|
||||||
|
|
||||||
protected abstract void postEpoch();
|
protected abstract void postEpoch();
|
||||||
|
|
||||||
protected abstract IDataManager.StatEntry trainEpoch();
|
protected abstract IDataManager.StatEntry trainEpoch(); // TODO: finish removal of IDataManager from Learning
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.listener;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A subclass of SyncTrainingEvent that is passed to SyncTrainingListener.onEpochEnd()
|
||||||
|
*/
|
||||||
|
public class SyncTrainingEpochEndEvent extends SyncTrainingEvent {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The stats of the epoch training
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
private final IDataManager.StatEntry statEntry;
|
||||||
|
|
||||||
|
public SyncTrainingEpochEndEvent(Learning learning, IDataManager.StatEntry statEntry) {
|
||||||
|
super(learning);
|
||||||
|
this.statEntry = statEntry;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.listener;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SyncTrainingEvent are passed as parameters to the events of SyncTrainingListener
|
||||||
|
*/
|
||||||
|
public class SyncTrainingEvent {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The source of the event
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
private final Learning learning;
|
||||||
|
|
||||||
|
public SyncTrainingEvent(Learning learning) {
|
||||||
|
this.learning = learning;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.listener;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A listener interface to use with a descendant of {@link org.deeplearning4j.rl4j.learning.sync.SyncLearning}
|
||||||
|
*/
|
||||||
|
public interface SyncTrainingListener {
|
||||||
|
|
||||||
|
public enum ListenerResponse {
|
||||||
|
/**
|
||||||
|
* Tell SyncLearning to continue calling the listeners and the training.
|
||||||
|
*/
|
||||||
|
CONTINUE,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tell SyncLearning to stop calling the listeners and terminate the training.
|
||||||
|
*/
|
||||||
|
STOP,
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called once when the training starts.
|
||||||
|
* @param event
|
||||||
|
* @return A ListenerResponse telling the source of the event if it should go on or cancel the training.
|
||||||
|
*/
|
||||||
|
ListenerResponse onTrainingStart(SyncTrainingEvent event);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called once when the training has finished. This method is called even when the training has been aborted.
|
||||||
|
*/
|
||||||
|
void onTrainingEnd();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called before the start of every epoch.
|
||||||
|
* @param event
|
||||||
|
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||||
|
*/
|
||||||
|
ListenerResponse onEpochStart(SyncTrainingEvent event);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called after the end of every epoch.
|
||||||
|
* @param event
|
||||||
|
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||||
|
*/
|
||||||
|
ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event);
|
||||||
|
}
|
|
@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
@ -29,12 +28,13 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.Constants;
|
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -53,8 +53,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
@Getter
|
@Getter
|
||||||
final private QLConfiguration configuration;
|
final private QLConfiguration configuration;
|
||||||
@Getter
|
@Getter
|
||||||
final private IDataManager dataManager;
|
|
||||||
@Getter
|
|
||||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
final private MDP<O, Integer, DiscreteSpace> mdp;
|
||||||
@Getter
|
@Getter
|
||||||
final private IDQN currentDQN;
|
final private IDQN currentDQN;
|
||||||
|
@ -68,15 +66,23 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
private int lastAction;
|
private int lastAction;
|
||||||
private INDArray history[] = null;
|
private INDArray history[] = null;
|
||||||
private double accuReward = 0;
|
private double accuReward = 0;
|
||||||
private int lastMonitor = -Constants.MONITOR_FREQ;
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated
|
||||||
|
* Use QLearningDiscrete(MDP, IDQN, QLConfiguration, int) and add the required listeners with addListener() instead.
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||||
IDataManager dataManager, int epsilonNbStep) {
|
IDataManager dataManager, int epsilonNbStep) {
|
||||||
|
this(mdp, dqn, conf, epsilonNbStep);
|
||||||
|
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||||
|
int epsilonNbStep) {
|
||||||
super(conf);
|
super(conf);
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.dataManager = dataManager;
|
|
||||||
currentDQN = dqn;
|
currentDQN = dqn;
|
||||||
targetDQN = dqn.clone();
|
targetDQN = dqn.clone();
|
||||||
policy = new DQNPolicy(getCurrentDQN());
|
policy = new DQNPolicy(getCurrentDQN());
|
||||||
|
@ -85,7 +91,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public void postEpoch() {
|
public void postEpoch() {
|
||||||
|
|
||||||
if (getHistoryProcessor() != null)
|
if (getHistoryProcessor() != null)
|
||||||
|
@ -97,14 +102,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
history = null;
|
history = null;
|
||||||
lastAction = 0;
|
lastAction = 0;
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
|
|
||||||
if (getStepCounter() - lastMonitor >= Constants.MONITOR_FREQ && getHistoryProcessor() != null
|
|
||||||
&& getDataManager().isSaveData()) {
|
|
||||||
lastMonitor = getStepCounter();
|
|
||||||
int[] shape = getMdp().getObservationSpace().getShape();
|
|
||||||
getHistoryProcessor().startMonitor(getDataManager().getVideoDir() + "/video-" + getEpochCounter() + "-"
|
|
||||||
+ getStepCounter() + ".mp4", shape);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -0,0 +1,126 @@
|
||||||
|
package org.deeplearning4j.rl4j.util;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the
|
||||||
|
* training process can be fed to the DataManager
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class DataManagerSyncTrainingListener implements SyncTrainingListener {
|
||||||
|
private final IDataManager dataManager;
|
||||||
|
private final int saveFrequency;
|
||||||
|
private final int monitorFrequency;
|
||||||
|
|
||||||
|
private int lastSave;
|
||||||
|
private int lastMonitor;
|
||||||
|
|
||||||
|
private DataManagerSyncTrainingListener(Builder builder) {
|
||||||
|
this.dataManager = builder.dataManager;
|
||||||
|
|
||||||
|
this.saveFrequency = builder.saveFrequency;
|
||||||
|
this.lastSave = -builder.saveFrequency;
|
||||||
|
|
||||||
|
this.monitorFrequency = builder.monitorFrequency;
|
||||||
|
this.lastMonitor = -builder.monitorFrequency;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
|
||||||
|
try {
|
||||||
|
dataManager.writeInfo(event.getLearning());
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Training failed.", e);
|
||||||
|
return ListenerResponse.STOP;
|
||||||
|
}
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onTrainingEnd() {
|
||||||
|
// Do nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
|
||||||
|
int stepCounter = event.getLearning().getStepCounter();
|
||||||
|
|
||||||
|
if (stepCounter - lastMonitor >= monitorFrequency
|
||||||
|
&& event.getLearning().getHistoryProcessor() != null
|
||||||
|
&& dataManager.isSaveData()) {
|
||||||
|
lastMonitor = stepCounter;
|
||||||
|
int[] shape = event.getLearning().getMdp().getObservationSpace().getShape();
|
||||||
|
event.getLearning().getHistoryProcessor().startMonitor(dataManager.getVideoDir() + "/video-" + event.getLearning().getEpochCounter() + "-"
|
||||||
|
+ stepCounter + ".mp4", shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
|
||||||
|
try {
|
||||||
|
int stepCounter = event.getLearning().getStepCounter();
|
||||||
|
if (stepCounter - lastSave >= saveFrequency) {
|
||||||
|
dataManager.save(event.getLearning());
|
||||||
|
lastSave = stepCounter;
|
||||||
|
}
|
||||||
|
|
||||||
|
dataManager.appendStat(event.getStatEntry());
|
||||||
|
dataManager.writeInfo(event.getLearning());
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Training failed.", e);
|
||||||
|
return ListenerResponse.STOP;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Builder builder(IDataManager dataManager) {
|
||||||
|
return new Builder(dataManager);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private final IDataManager dataManager;
|
||||||
|
private int saveFrequency = Constants.MODEL_SAVE_FREQ;
|
||||||
|
private int monitorFrequency = Constants.MONITOR_FREQ;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a Builder with the given DataManager
|
||||||
|
* @param dataManager
|
||||||
|
*/
|
||||||
|
public Builder(IDataManager dataManager) {
|
||||||
|
this.dataManager = dataManager;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A number that represent the number of steps since the last call to DataManager.save() before can it be called again.
|
||||||
|
* @param saveFrequency (Default: 100000)
|
||||||
|
*/
|
||||||
|
public Builder saveFrequency(int saveFrequency) {
|
||||||
|
this.saveFrequency = saveFrequency;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A number that represent the number of steps since the last call to HistoryProcessor.startMonitor() before can it be called again.
|
||||||
|
* @param monitorFrequency (Default: 10000)
|
||||||
|
*/
|
||||||
|
public Builder monitorFrequency(int monitorFrequency) {
|
||||||
|
this.monitorFrequency = monitorFrequency;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a DataManagerSyncTrainingListener with the configured parameters
|
||||||
|
* @return An instance of DataManagerSyncTrainingListener
|
||||||
|
*/
|
||||||
|
public DataManagerSyncTrainingListener build() {
|
||||||
|
return new DataManagerSyncTrainingListener(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,12 +1,13 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync;
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
import lombok.Value;
|
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
import org.deeplearning4j.rl4j.learning.sync.support.MockSyncTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||||
|
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
@ -15,61 +16,101 @@ import static org.junit.Assert.assertEquals;
|
||||||
public class SyncLearningTest {
|
public class SyncLearningTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void refac_checkDataManagerCallsRemainTheSame() {
|
public void when_training_expect_listenersToBeCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
MockLConfiguration lconfig = new MockLConfiguration(10);
|
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||||
MockDataManager dataManager = new MockDataManager(false);
|
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig, dataManager, 2);
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
|
sut.addListener(listener);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.train();
|
sut.train();
|
||||||
|
|
||||||
assertEquals(10, dataManager.statEntries.size());
|
assertEquals(1, listener.onTrainingStartCallCount);
|
||||||
for(int i = 0; i < 10; ++i) {
|
assertEquals(10, listener.onEpochStartCallCount);
|
||||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
assertEquals(10, listener.onEpochEndStartCallCount);
|
||||||
assertEquals(2, entry.getEpochCounter());
|
assertEquals(1, listener.onTrainingEndCallCount);
|
||||||
assertEquals(i+1, entry.getStepCounter());
|
|
||||||
assertEquals(1.0, entry.getReward(), 0.0);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
assertEquals(0, dataManager.isSaveDataCallCount);
|
|
||||||
assertEquals(0, dataManager.getVideoDirCallCount);
|
@Test
|
||||||
assertEquals(11, dataManager.writeInfoCallCount);
|
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
|
||||||
assertEquals(1, dataManager.saveCallCount);
|
// Arrange
|
||||||
|
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||||
|
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||||
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
|
sut.addListener(listener);
|
||||||
|
listener.trainingStartCanContinue = false;
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.train();
|
||||||
|
|
||||||
|
assertEquals(1, listener.onTrainingStartCallCount);
|
||||||
|
assertEquals(0, listener.onEpochStartCallCount);
|
||||||
|
assertEquals(0, listener.onEpochEndStartCallCount);
|
||||||
|
assertEquals(1, listener.onTrainingEndCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_epochStartCanContinueFalse_expect_trainingStopped() {
|
||||||
|
// Arrange
|
||||||
|
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||||
|
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||||
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
|
sut.addListener(listener);
|
||||||
|
listener.nbStepsEpochStartCanContinue = 3;
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.train();
|
||||||
|
|
||||||
|
assertEquals(1, listener.onTrainingStartCallCount);
|
||||||
|
assertEquals(3, listener.onEpochStartCallCount);
|
||||||
|
assertEquals(2, listener.onEpochEndStartCallCount);
|
||||||
|
assertEquals(1, listener.onTrainingEndCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_epochEndCanContinueFalse_expect_trainingStopped() {
|
||||||
|
// Arrange
|
||||||
|
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||||
|
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||||
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
|
sut.addListener(listener);
|
||||||
|
listener.nbStepsEpochEndCanContinue = 3;
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.train();
|
||||||
|
|
||||||
|
assertEquals(1, listener.onTrainingStartCallCount);
|
||||||
|
assertEquals(3, listener.onEpochStartCallCount);
|
||||||
|
assertEquals(3, listener.onEpochEndStartCallCount);
|
||||||
|
assertEquals(1, listener.onTrainingEndCallCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class MockSyncLearning extends SyncLearning {
|
public static class MockSyncLearning extends SyncLearning {
|
||||||
|
|
||||||
private final IDataManager dataManager;
|
|
||||||
private LConfiguration conf;
|
private LConfiguration conf;
|
||||||
private final int epochSteps;
|
|
||||||
|
|
||||||
public MockSyncLearning(LConfiguration conf, IDataManager dataManager, int epochSteps) {
|
public MockSyncLearning(LConfiguration conf, IDataManager dataManager) {
|
||||||
|
super(conf);
|
||||||
|
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
|
||||||
|
this.conf = conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MockSyncLearning(LConfiguration conf) {
|
||||||
super(conf);
|
super(conf);
|
||||||
this.dataManager = dataManager;
|
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
this.epochSteps = epochSteps;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void preEpoch() {
|
protected void preEpoch() { }
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void postEpoch() {
|
protected void postEpoch() { }
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected IDataManager.StatEntry trainEpoch() {
|
protected IDataManager.StatEntry trainEpoch() {
|
||||||
setStepCounter(getStepCounter() + 1);
|
setStepCounter(getStepCounter() + 1);
|
||||||
return new MockStatEntry(epochSteps, getStepCounter(), 1.0);
|
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected IDataManager getDataManager() {
|
|
||||||
return dataManager;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -92,41 +133,4 @@ public class SyncLearningTest {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class MockLConfiguration implements ILearning.LConfiguration {
|
|
||||||
|
|
||||||
private final int maxStep;
|
|
||||||
|
|
||||||
public MockLConfiguration(int maxStep) {
|
|
||||||
this.maxStep = maxStep;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getSeed() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getMaxEpochStep() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getMaxStep() {
|
|
||||||
return maxStep;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getGamma() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Value
|
|
||||||
public static class MockStatEntry implements IDataManager.StatEntry {
|
|
||||||
int epochCounter;
|
|
||||||
int stepCounter;
|
|
||||||
double reward;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscrete;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.support.MockMDP;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class QLearningDiscreteTest {
|
||||||
|
@Test
|
||||||
|
public void refac_checkDataManagerCallsRemainTheSame() {
|
||||||
|
// Arrange
|
||||||
|
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder()
|
||||||
|
.maxStep(10)
|
||||||
|
.expRepMaxSize(1)
|
||||||
|
.build();
|
||||||
|
MockDataManager dataManager = new MockDataManager(true);
|
||||||
|
MockQLearningDiscrete sut = new MockQLearningDiscrete(10, lconfig, dataManager, 2, 3);
|
||||||
|
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
||||||
|
.build();
|
||||||
|
sut.setHistoryProcessor(new MockHistoryProcessor(hpConfig));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.train();
|
||||||
|
|
||||||
|
assertEquals(10, dataManager.statEntries.size());
|
||||||
|
for(int i = 0; i < 10; ++i) {
|
||||||
|
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||||
|
assertEquals(i, entry.getEpochCounter());
|
||||||
|
assertEquals(i+1, entry.getStepCounter());
|
||||||
|
assertEquals(1.0, entry.getReward(), 0.0);
|
||||||
|
|
||||||
|
}
|
||||||
|
assertEquals(4, dataManager.isSaveDataCallCount);
|
||||||
|
assertEquals(4, dataManager.getVideoDirCallCount);
|
||||||
|
assertEquals(11, dataManager.writeInfoCallCount);
|
||||||
|
assertEquals(5, dataManager.saveCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MockQLearningDiscrete extends QLearningDiscrete {
|
||||||
|
|
||||||
|
public MockQLearningDiscrete(int maxSteps, QLConfiguration conf,
|
||||||
|
IDataManager dataManager, int saveFrequency, int monitorFrequency) {
|
||||||
|
super(new MockMDP(maxSteps), new MockDQN(), conf, 2);
|
||||||
|
addListener(DataManagerSyncTrainingListener.builder(dataManager)
|
||||||
|
.saveFrequency(saveFrequency)
|
||||||
|
.monitorFrequency(monitorFrequency)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected IDataManager.StatEntry trainEpoch() {
|
||||||
|
setStepCounter(getStepCounter() + 1);
|
||||||
|
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
|
||||||
|
public class MockDQN implements IDQN {
|
||||||
|
@Override
|
||||||
|
public NeuralNetwork[] getNeuralNetworks() {
|
||||||
|
return new NeuralNetwork[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isRecurrent() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(INDArray input, INDArray labels) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(INDArray input, INDArray[] labels) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(INDArray batch) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
|
return new INDArray[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IDQN clone() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(NeuralNet from) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(IDQN from) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] gradient(INDArray input, INDArray label) {
|
||||||
|
return new Gradient[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] gradient(INDArray input, INDArray[] label) {
|
||||||
|
return new Gradient[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void applyGradient(Gradient[] gradient, int batchSize) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getLatestScore() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(OutputStream os) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(String filename) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,79 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.gym.StepReply;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
public class MockMDP implements MDP<Object, Integer, DiscreteSpace> {
|
||||||
|
|
||||||
|
private final int maxSteps;
|
||||||
|
private final DiscreteSpace actionSpace = new DiscreteSpace(1);
|
||||||
|
private final MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
|
|
||||||
|
private int currentStep = 0;
|
||||||
|
|
||||||
|
public MockMDP(int maxSteps) {
|
||||||
|
|
||||||
|
this.maxSteps = maxSteps;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ObservationSpace<Object> getObservationSpace() {
|
||||||
|
return observationSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DiscreteSpace getActionSpace() {
|
||||||
|
return actionSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object reset() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public StepReply<Object> step(Integer integer) {
|
||||||
|
return new StepReply<Object>(null, 1.0, isDone(), null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isDone() {
|
||||||
|
return currentStep >= maxSteps;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MDP<Object, Integer, DiscreteSpace> newInstance() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class MockObservationSpace implements ObservationSpace {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int[] getShape() {
|
||||||
|
return new int[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getLow() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getHigh() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Value;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Value
|
||||||
|
public class MockStatEntry implements IDataManager.StatEntry {
|
||||||
|
int epochCounter;
|
||||||
|
int stepCounter;
|
||||||
|
double reward;
|
||||||
|
}
|
|
@ -0,0 +1,46 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
||||||
|
|
||||||
|
public class MockSyncTrainingListener implements SyncTrainingListener {
|
||||||
|
|
||||||
|
public int onTrainingStartCallCount = 0;
|
||||||
|
public int onTrainingEndCallCount = 0;
|
||||||
|
public int onEpochStartCallCount = 0;
|
||||||
|
public int onEpochEndStartCallCount = 0;
|
||||||
|
|
||||||
|
public boolean trainingStartCanContinue = true;
|
||||||
|
public int nbStepsEpochStartCanContinue = Integer.MAX_VALUE;
|
||||||
|
public int nbStepsEpochEndCanContinue = Integer.MAX_VALUE;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
|
||||||
|
++onTrainingStartCallCount;
|
||||||
|
return trainingStartCanContinue ? ListenerResponse.CONTINUE : ListenerResponse.STOP;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onTrainingEnd() {
|
||||||
|
++onTrainingEndCallCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
|
||||||
|
++onEpochStartCallCount;
|
||||||
|
if(onEpochStartCallCount >= nbStepsEpochStartCanContinue) {
|
||||||
|
return ListenerResponse.STOP;
|
||||||
|
}
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
|
||||||
|
++onEpochEndStartCallCount;
|
||||||
|
if(onEpochEndStartCallCount >= nbStepsEpochEndCanContinue) {
|
||||||
|
return ListenerResponse.STOP;
|
||||||
|
}
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue