RL4J - AsyncTrainingListener (#8072)

* Code clarity: Extracted parts of run() into private methods

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Added listener pattern to async learning

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Merged all listeners logic

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Added interface and common data to training events

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Fixed missing info log file

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Fixed bad merge; removed useless TrainingEvent

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Removed param from training start/end event

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Removed 'event' classes from the training listener

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Reverted changes to QLearningDiscrete.setTarget()
master
Alexandre Boulanger 2019-09-18 21:28:13 -04:00 committed by Alex Black
parent d58a4b45b1
commit 59f1cbf0c6
47 changed files with 1576 additions and 881 deletions

View File

@ -0,0 +1,31 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning;
import org.deeplearning4j.rl4j.mdp.MDP;
/**
* The common API between Learning and AsyncThread.
*
* @author Alexandre Boulanger
*/
public interface IEpochTrainer {
int getStepCounter();
int getEpochCounter();
IHistoryProcessor getHistoryProcessor();
MDP getMdp();
}

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.rl4j.learning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
*/
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> extends StepCountable {
Policy<O, A> getPolicy();
IPolicy<O, A> getPolicy();
void train();
@ -38,6 +38,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> ex
MDP<O, A, AS> getMdp();
IHistoryProcessor getHistoryProcessor();
interface LConfiguration {

View File

@ -26,7 +26,6 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

View File

@ -17,7 +17,6 @@
package org.deeplearning4j.rl4j.learning.async;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
@ -63,7 +62,6 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
@Getter
private NN target;
@Getter
@Setter
private boolean running = true;
public AsyncGlobal(NN initial, AsyncConfiguration a3cc) {
@ -78,7 +76,9 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
}
public void enqueue(Gradient[] gradient, Integer nstep) {
queue.add(new Pair<>(gradient, nstep));
if(running && !isTrainingComplete()) {
queue.add(new Pair<>(gradient, nstep));
}
}
@Override
@ -105,4 +105,12 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
}
/**
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
*/
public void terminate() {
running = false;
queue.clear();
}
}

View File

@ -16,33 +16,49 @@
package org.deeplearning4j.rl4j.learning.async;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.factory.Nd4j;
/**
* The entry point for async training. This class will start a number ({@link AsyncConfiguration#getNumThread()
* configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals
* (see setProgressEventInterval(int))
*
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16.
*
* Async learning always follow the same pattern in RL4J
* -launch the Global thread
* -launch the "save threads"
* -periodically evaluate the model of the global thread for monitoring purposes
*
* @author Alexandre Boulanger
*/
@Slf4j
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Learning<O, A, AS, NN> {
protected abstract IDataManager getDataManager();
@Getter(AccessLevel.PROTECTED)
private final TrainingListenerList listeners = new TrainingListenerList();
public AsyncLearning(AsyncConfiguration conf) {
super(conf);
}
/**
* Add a {@link TrainingListener} listener at the end of the listener list.
*
* @param listener the listener to be added
*/
public void addListener(TrainingListener listener) {
listeners.add(listener);
}
/**
* Returns the configuration
* @return the configuration (see {@link AsyncConfiguration})
*/
public abstract AsyncConfiguration getConfiguration();
protected abstract AsyncThread newThread(int i, int deviceAffinity);
@ -57,41 +73,80 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
return getAsyncGlobal().isTrainingComplete();
}
public void launchThreads() {
private boolean canContinue = true;
/**
* Number of milliseconds between calls to onTrainingProgress
*/
@Getter @Setter
private int progressMonitorFrequency = 20000;
private void launchThreads() {
startGlobalThread();
for (int i = 0; i < getConfiguration().getNumThread(); i++) {
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
t.start();
}
log.info("Threads launched.");
}
/**
* @return The current step
*/
@Override
public int getStepCounter() {
return getAsyncGlobal().getT().get();
}
/**
* This method will train the model<p>
* The training stop when:<br>
* - A worker thread terminate the AsyncGlobal thread (see {@link AsyncGlobal})<br>
* OR<br>
* - a listener explicitly stops it<br>
* <p>
* Listeners<br>
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
* returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse TrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
* Events:
* <ul>
* <li>{@link TrainingListener#onTrainingStart() onTrainingStart()} is called once when the training starts.</li>
* <li>{@link TrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
* </ul>
*/
public void train() {
try {
log.info("AsyncLearning training starting.");
launchThreads();
log.info("AsyncLearning training starting.");
//this is simply for stat purposes
getDataManager().writeInfo(this);
synchronized (this) {
while (!isTrainingComplete() && getAsyncGlobal().isRunning()) {
getPolicy().play(getMdp(), getHistoryProcessor());
getDataManager().writeInfo(this);
wait(20000);
canContinue = listeners.notifyTrainingStarted();
if (canContinue) {
launchThreads();
monitorTraining();
}
cleanupPostTraining();
listeners.notifyTrainingFinished();
}
protected void monitorTraining() {
try {
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
canContinue = listeners.notifyTrainingProgress(this);
if(!canContinue) {
return;
}
synchronized (this) {
wait(progressMonitorFrequency);
}
}
} catch (Exception e) {
log.error("Training failed.", e);
e.printStackTrace();
} catch (InterruptedException e) {
log.error("Training interrupted.", e);
}
}
protected void cleanupPostTraining() {
// Worker threads stops automatically when the global thread stops
getAsyncGlobal().terminate();
}
}

View File

@ -21,33 +21,31 @@ import lombok.Getter;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.learning.*;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.Constants;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*
* This represent a local thread that explore the environment
* and calculate a gradient to enqueue to the global thread/model
*
* It has its own version of a model that it syncs at the start of every
* sub epoch
*
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
* @author Alexandre Boulanger
*/
@Slf4j
public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Thread implements StepCountable {
extends Thread implements StepCountable, IEpochTrainer {
@Getter
private int threadNumber;
@Getter
protected final int deviceNum;
@ -55,12 +53,16 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
private int stepCounter = 0;
@Getter @Setter
private int epochCounter = 0;
@Getter
private MDP<O, A, AS> mdp;
@Getter @Setter
private IHistoryProcessor historyProcessor;
@Getter
private int lastMonitor = -Constants.MONITOR_FREQ;
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
private final TrainingListenerList listeners;
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, MDP<O, A, AS> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
this.mdp = mdp;
this.listeners = listeners;
this.threadNumber = threadNumber;
this.deviceNum = deviceNum;
}
@ -80,75 +82,106 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
}
protected void preEpoch() {
if (getStepCounter() - lastMonitor >= Constants.MONITOR_FREQ && getHistoryProcessor() != null
&& getDataManager().isSaveData()) {
lastMonitor = getStepCounter();
int[] shape = getMdp().getObservationSpace().getShape();
getHistoryProcessor().startMonitor(getDataManager().getVideoDir() + "/video-" + threadNumber + "-"
+ getEpochCounter() + "-" + getStepCounter() + ".mp4", shape);
}
// Do nothing
}
/**
* This method will start the worker thread<p>
* The thread will stop when:<br>
* - The AsyncGlobal thread terminates or reports that the training is complete
* (see {@link AsyncGlobal#isTrainingComplete()}). In such case, the currently running epoch will still be handled normally and
* events will also be fired normally.<br>
* OR<br>
* - a listener explicitly stops it, in which case, the AsyncGlobal thread will be terminated along with
* all other worker threads <br>
* <p>
* Listeners<br>
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
* returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse
* TrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
* Events:
* <ul>
* <li>{@link TrainingListener#onNewEpoch(IEpochTrainer) onNewEpoch()} is called when a new epoch is started.</li>
* <li>{@link TrainingListener#onEpochTrainingResult(IEpochTrainer, IDataManager.StatEntry) onEpochTrainingResult()} is called at the end of every
* epoch. It will not be called if onNewEpoch() stops the training.</li>
* </ul>
*/
@Override
public void run() {
RunContext<O> context = new RunContext<>();
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
log.info("ThreadNum-" + threadNumber + " Started!");
try {
log.info("ThreadNum-" + threadNumber + " Started!");
getCurrent().reset();
Learning.InitMdp<O> initMdp = Learning.initMdp(getMdp(), historyProcessor);
O obs = initMdp.getLastObs();
double rewards = initMdp.getReward();
int length = initMdp.getSteps();
boolean canContinue = initWork(context);
if (canContinue) {
preEpoch();
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - length);
SubEpochReturn<O> subEpochReturn = trainSubEpoch(obs, maxSteps);
obs = subEpochReturn.getLastObs();
stepCounter += subEpochReturn.getSteps();
length += subEpochReturn.getSteps();
rewards += subEpochReturn.getReward();
double score = subEpochReturn.getScore();
if (length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
postEpoch();
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score);
getDataManager().appendStat(statEntry);
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
getCurrent().reset();
initMdp = Learning.initMdp(getMdp(), historyProcessor);
obs = initMdp.getLastObs();
rewards = initMdp.getReward();
length = initMdp.getSteps();
epochCounter++;
preEpoch();
handleTraining(context);
if (context.length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
canContinue = finishEpoch(context) && startNewEpoch(context);
if (!canContinue) {
break;
}
}
}
} catch (Exception e) {
log.error("Thread crashed: " + e.getCause());
getAsyncGlobal().setRunning(false);
e.printStackTrace();
} finally {
postEpoch();
}
terminateWork();
}
private void initNewEpoch(RunContext context) {
getCurrent().reset();
Learning.InitMdp<O> initMdp = Learning.initMdp(getMdp(), historyProcessor);
context.obs = initMdp.getLastObs();
context.rewards = initMdp.getReward();
context.length = initMdp.getSteps();
}
private void handleTraining(RunContext<O> context) {
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.length);
SubEpochReturn<O> subEpochReturn = trainSubEpoch(context.obs, maxSteps);
context.obs = subEpochReturn.getLastObs();
stepCounter += subEpochReturn.getSteps();
context.length += subEpochReturn.getSteps();
context.rewards += subEpochReturn.getReward();
context.score = subEpochReturn.getScore();
}
private boolean initWork(RunContext context) {
initNewEpoch(context);
preEpoch();
return listeners.notifyNewEpoch(this);
}
private boolean startNewEpoch(RunContext context) {
initNewEpoch(context);
epochCounter++;
preEpoch();
return listeners.notifyNewEpoch(this);
}
private boolean finishEpoch(RunContext context) {
postEpoch();
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.length, context.score);
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
return listeners.notifyEpochTrainingResult(this, statEntry);
}
private void terminateWork() {
postEpoch();
getAsyncGlobal().terminate();
}
protected abstract NN getCurrent();
protected abstract int getThreadNumber();
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
protected abstract MDP<O, A, AS> getMdp();
protected abstract AsyncConfiguration getConf();
protected abstract IDataManager getDataManager();
protected abstract Policy<O, A> getPolicy(NN net);
protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep);
@ -172,4 +205,11 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
double score;
}
private static class RunContext<O extends Encodable> {
private O obs;
private double rewards;
private int length;
private double score;
}
}

View File

@ -21,7 +21,9 @@ import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
@ -44,8 +46,8 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
@Getter
private NN current;
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
super(asyncGlobal, threadNumber, deviceNum);
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, MDP<O, Integer, DiscreteSpace> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
synchronized (asyncGlobal) {
current = (NN)asyncGlobal.getCurrent().clone();
}

View File

@ -23,9 +23,14 @@ import java.util.concurrent.atomic.AtomicInteger;
public interface IAsyncGlobal<NN extends NeuralNet> {
boolean isRunning();
void setRunning(boolean value);
boolean isTrainingComplete();
void start();
/**
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
*/
void terminate();
AtomicInteger getT();
NN getCurrent();
NN getTarget();

View File

@ -26,7 +26,6 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
@ -47,24 +46,19 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
final private AsyncGlobal asyncGlobal;
@Getter
final private ACPolicy<O> policy;
@Getter
final private IDataManager dataManager;
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf,
IDataManager dataManager) {
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) {
super(conf);
this.iActorCritic = iActorCritic;
this.mdp = mdp;
this.configuration = conf;
this.dataManager = dataManager;
policy = new ACPolicy<>(iActorCritic, getRandom());
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
mdp.getActionSpace().setSeed(conf.getSeed());
}
@Override
protected AsyncThread newThread(int i, int deviceNum) {
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), i, dataManager, deviceNum);
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), deviceNum, getListeners(), i);
}
public IActorCritic getNeuralNet() {

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
@ -43,24 +44,38 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
final private HistoryProcessor.Configuration hpconf;
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
super(mdp, IActorCritic, conf, dataManager);
this(mdp, actorCritic, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
super(mdp, IActorCritic, conf);
this.hpconf = hpconf;
setHistoryProcessor(hpconf);
}
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf,
dataManager);
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
}
@Override
public AsyncThread newThread(int i, int deviceNum) {

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.*;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
@ -33,33 +34,58 @@ import org.deeplearning4j.rl4j.util.IDataManager;
*/
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
IDataManager dataManager) {
super(mdp, IActorCritic, conf, dataManager);
this(mdp, IActorCritic, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
super(mdp, actorCritic, conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
IDataManager dataManager) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf);
}
}

View File

@ -22,13 +22,13 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -46,24 +46,19 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
@Getter
final protected A3CDiscrete.A3CConfiguration conf;
@Getter
final protected MDP<O, Integer, DiscreteSpace> mdp;
@Getter
final protected AsyncGlobal<IActorCritic> asyncGlobal;
@Getter
final protected int threadNumber;
@Getter
final protected IDataManager dataManager;
final private Random random;
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager, int deviceNum) {
super(asyncGlobal, threadNumber, deviceNum);
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
int threadNumber) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = a3cc;
this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber;
this.mdp = mdp;
this.dataManager = dataManager;
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
random = new Random(conf.getSeed() + threadNumber);
}
@ -85,15 +80,15 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
//if recurrent then train as a time serie with a batch size of 1
boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent();
int[] shape = getHistoryProcessor() == null ? mdp.getObservationSpace().getShape()
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
: getHistoryProcessor().getConf().getShape();
int[] nshape = recurrent ? Learning.makeShape(1, shape, size)
: Learning.makeShape(size, shape);
INDArray input = Nd4j.create(nshape);
INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1);
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, mdp.getActionSpace().getSize(), size)
: Nd4j.zeros(size, mdp.getActionSpace().getSize());
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, getMdp().getActionSpace().getSize(), size)
: Nd4j.zeros(size, getMdp().getActionSpace().getSize());
double r = minTrans.getReward();
for (int i = size - 1; i >= 0; i--) {

View File

@ -24,10 +24,9 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -40,16 +39,12 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
@Getter
final private MDP<O, Integer, DiscreteSpace> mdp;
@Getter
final private IDataManager dataManager;
@Getter
final private AsyncGlobal<IDQN> asyncGlobal;
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf,
IDataManager dataManager) {
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
super(conf);
this.mdp = mdp;
this.dataManager = dataManager;
this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
mdp.getActionSpace().setSeed(conf.getSeed());
@ -57,14 +52,14 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
@Override
public AsyncThread newThread(int i, int deviceNum) {
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager, deviceNum);
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, getListeners(), i, deviceNum);
}
public IDQN getNeuralNet() {
return asyncGlobal.getCurrent();
}
public Policy<O, Integer> getPolicy() {
public IPolicy<O, Integer> getPolicy() {
return new DQNPolicy<O>(getNeuralNet());
}

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
@ -35,22 +36,38 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
final private HistoryProcessor.Configuration hpconf;
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf, dataManager);
this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
super(mdp, dqn, conf);
this.hpconf = hpconf;
setHistoryProcessor(hpconf);
}
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
}
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
}
@Override
public AsyncThread newThread(int i, int deviceNum) {

View File

@ -22,6 +22,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
@ -29,19 +30,37 @@ import org.deeplearning4j.rl4j.util.IDataManager;
*/
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf, dataManager);
super(mdp, dqn, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf) {
super(mdp, dqn, conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
}
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf);
}
}

View File

@ -19,9 +19,10 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import lombok.Getter;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
@ -29,7 +30,6 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -44,31 +44,25 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
@Getter
final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf;
@Getter
final protected MDP<O, Integer, DiscreteSpace> mdp;
@Getter
final protected IAsyncGlobal<IDQN> asyncGlobal;
@Getter
final protected int threadNumber;
@Getter
final protected IDataManager dataManager;
final private Random random;
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber,
IDataManager dataManager, int deviceNum) {
super(asyncGlobal, threadNumber, deviceNum);
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf,
TrainingListenerList listeners, int threadNumber, int deviceNum) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = conf;
this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber;
this.mdp = mdp;
this.dataManager = dataManager;
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
random = new Random(conf.getSeed() + threadNumber);
}
public Policy<O, Integer> getPolicy(IDQN nn) {
return new EpsGreedy(new DQNPolicy(nn), mdp, conf.getUpdateStart(), conf.getEpsilonNbStep(),
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
random, conf.getMinEpsilon(), this);
}
@ -81,11 +75,11 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
int size = rewards.size();
int[] shape = getHistoryProcessor() == null ? mdp.getObservationSpace().getShape()
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
: getHistoryProcessor().getConf().getShape();
int[] nshape = Learning.makeShape(size, shape);
INDArray input = Nd4j.create(nshape);
INDArray targets = Nd4j.create(size, mdp.getActionSpace().getSize());
INDArray targets = Nd4j.create(size, getMdp().getActionSpace().getSize());
double r = minTrans.getReward();
for (int i = size - 1; i >= 0; i--) {

View File

@ -0,0 +1,72 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.listener;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
* The base definition of all training event listeners
*
* @author Alexandre Boulanger
*/
public interface TrainingListener {
enum ListenerResponse {
/**
* Tell the learning process to continue calling the listeners and the training.
*/
CONTINUE,
/**
* Tell the learning process to stop calling the listeners and terminate the training.
*/
STOP,
}
/**
* Called once when the training starts.
* @return A ListenerResponse telling the source of the event if it should go on or cancel the training.
*/
ListenerResponse onTrainingStart();
/**
* Called once when the training has finished. This method is called even when the training has been aborted.
*/
void onTrainingEnd();
/**
* Called before the start of every epoch.
* @param trainer A {@link IEpochTrainer}
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onNewEpoch(IEpochTrainer trainer);
/**
* Called when an epoch has been completed
* @param trainer A {@link IEpochTrainer}
* @param statEntry A {@link org.deeplearning4j.rl4j.util.IDataManager.StatEntry}
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry);
/**
* Called regularly to monitor the training progress.
* @param learning A {@link ILearning}
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onTrainingProgress(ILearning learning);
}

View File

@ -0,0 +1,105 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.listener;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.util.IDataManager;
import java.util.ArrayList;
import java.util.List;
/**
* The base logic to notify training listeners with the different training events.
*
* @author Alexandre Boulanger
*/
public class TrainingListenerList {
protected final List<TrainingListener> listeners = new ArrayList<>();
/**
* Add a listener at the end of the list
* @param listener The listener to be added
*/
public void add(TrainingListener listener) {
listeners.add(listener);
}
/**
* Notify the listeners that the training has started. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
* @return whether or not the source training should be stopped
*/
public boolean notifyTrainingStarted() {
for (TrainingListener listener : listeners) {
if (listener.onTrainingStart() == TrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
/**
* Notify the listeners that the training has finished.
*/
public void notifyTrainingFinished() {
for (TrainingListener listener : listeners) {
listener.onTrainingEnd();
}
}
/**
* Notify the listeners that a new epoch has started. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
* @return whether or not the source training should be stopped
*/
public boolean notifyNewEpoch(IEpochTrainer trainer) {
for (TrainingListener listener : listeners) {
if (listener.onNewEpoch(trainer) == TrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
/**
* Notify the listeners that an epoch has been completed and the training results are available. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
* @return whether or not the source training should be stopped
*/
public boolean notifyEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
for (TrainingListener listener : listeners) {
if (listener.onEpochTrainingResult(trainer, statEntry) == TrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
/**
* Notify the listeners that they update the progress ot the trainning.
*/
public boolean notifyTrainingProgress(ILearning learning) {
for (TrainingListener listener : listeners) {
if (listener.onTrainingProgress(learning) == TrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
}

View File

@ -16,19 +16,18 @@
package org.deeplearning4j.rl4j.learning.sync;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import java.util.ArrayList;
import java.util.List;
/**
* Mother class and useful factorisations for all training methods that
* are not asynchronous.
@ -38,9 +37,9 @@ import java.util.List;
*/
@Slf4j
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Learning<O, A, AS, NN> {
extends Learning<O, A, AS, NN> implements IEpochTrainer {
private List<SyncTrainingListener> listeners = new ArrayList<>();
private final TrainingListenerList listeners = new TrainingListenerList();
public SyncLearning(LConfiguration conf) {
super(conf);
@ -49,12 +48,24 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
/**
* Add a listener at the end of the listener list.
*
* @param listener
* @param listener The listener to add
*/
public void addListener(SyncTrainingListener listener) {
public void addListener(TrainingListener listener) {
listeners.add(listener);
}
/**
* Number of epochs between calls to onTrainingProgress. Default is 5
*/
@Getter
private int progressMonitorFrequency = 5;
public void setProgressMonitorFrequency(int value) {
if(value == 0) throw new IllegalArgumentException("The progressMonitorFrequency cannot be 0");
progressMonitorFrequency = value;
}
/**
* This method will train the model<p>
* The training stop when:<br>
@ -64,81 +75,49 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
* <p>
* Listeners<br>
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
* returns {@link SyncTrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
* returns {@link TrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
* Events:
* <ul>
* <li>{@link SyncTrainingListener#onTrainingStart(SyncTrainingEvent) onTrainingStart()} is called once when the training starts.</li>
* <li>{@link SyncTrainingListener#onEpochStart(SyncTrainingEvent) onEpochStart()} and {@link SyncTrainingListener#onEpochEnd(SyncTrainingEpochEndEvent) onEpochEnd()} are called for every epoch. onEpochEnd will not be called if onEpochStart stops the training</li>
* <li>{@link SyncTrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
* <li>{@link TrainingListener#onTrainingStart() onTrainingStart()} is called once when the training starts.</li>
* <li>{@link TrainingListener#onNewEpoch(IEpochTrainer) onNewEpoch()} and {@link TrainingListener#onEpochTrainingResult(IEpochTrainer, IDataManager.StatEntry) onEpochTrainingResult()} are called for every epoch. onEpochTrainingResult will not be called if onNewEpoch stops the training</li>
* <li>{@link TrainingListener#onTrainingProgress(ILearning) onTrainingProgress()} is called after onEpochTrainingResult()</li>
* <li>{@link TrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
* </ul>
*/
public void train() {
log.info("training starting.");
boolean canContinue = notifyTrainingStarted();
boolean canContinue = listeners.notifyTrainingStarted();
if (canContinue) {
while (getStepCounter() < getConfiguration().getMaxStep()) {
preEpoch();
canContinue = notifyEpochStarted();
canContinue = listeners.notifyNewEpoch(this);
if (!canContinue) {
break;
}
IDataManager.StatEntry statEntry = trainEpoch();
postEpoch();
canContinue = notifyEpochFinished(statEntry);
canContinue = listeners.notifyEpochTrainingResult(this, statEntry);
if (!canContinue) {
break;
}
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
postEpoch();
if(getEpochCounter() % progressMonitorFrequency == 0) {
canContinue = listeners.notifyTrainingProgress(this);
if (!canContinue) {
break;
}
}
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
incrementEpoch();
}
}
notifyTrainingFinished();
}
private boolean notifyTrainingStarted() {
SyncTrainingEvent event = new SyncTrainingEvent(this);
for (SyncTrainingListener listener : listeners) {
if (listener.onTrainingStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
private void notifyTrainingFinished() {
for (SyncTrainingListener listener : listeners) {
listener.onTrainingEnd();
}
}
private boolean notifyEpochStarted() {
SyncTrainingEvent event = new SyncTrainingEvent(this);
for (SyncTrainingListener listener : listeners) {
if (listener.onEpochStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
private boolean notifyEpochFinished(IDataManager.StatEntry statEntry) {
SyncTrainingEpochEndEvent event = new SyncTrainingEpochEndEvent(this, statEntry);
for (SyncTrainingListener listener : listeners) {
if (listener.onEpochEnd(event) == SyncTrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
listeners.notifyTrainingFinished();
}
protected abstract void preEpoch();

View File

@ -1,22 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.listener;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
* A subclass of SyncTrainingEvent that is passed to SyncTrainingListener.onEpochEnd()
*/
public class SyncTrainingEpochEndEvent extends SyncTrainingEvent {
/**
* The stats of the epoch training
*/
@Getter
private final IDataManager.StatEntry statEntry;
public SyncTrainingEpochEndEvent(Learning learning, IDataManager.StatEntry statEntry) {
super(learning);
this.statEntry = statEntry;
}
}

View File

@ -1,21 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.listener;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.rl4j.learning.Learning;
/**
* SyncTrainingEvent are passed as parameters to the events of SyncTrainingListener
*/
public class SyncTrainingEvent {
/**
* The source of the event
*/
@Getter
private final Learning learning;
public SyncTrainingEvent(Learning learning) {
this.learning = learning;
}
}

View File

@ -1,45 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.listener;
/**
* A listener interface to use with a descendant of {@link org.deeplearning4j.rl4j.learning.sync.SyncLearning}
*/
public interface SyncTrainingListener {
public enum ListenerResponse {
/**
* Tell SyncLearning to continue calling the listeners and the training.
*/
CONTINUE,
/**
* Tell SyncLearning to stop calling the listeners and terminate the training.
*/
STOP,
}
/**
* Called once when the training starts.
* @param event
* @return A ListenerResponse telling the source of the event if it should go on or cancel the training.
*/
ListenerResponse onTrainingStart(SyncTrainingEvent event);
/**
* Called once when the training has finished. This method is called even when the training has been aborted.
*/
void onTrainingEnd();
/**
* Called before the start of every epoch.
* @param event
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onEpochStart(SyncTrainingEvent event);
/**
* Called after the end of every epoch.
* @param event
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event);
}

View File

@ -49,7 +49,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
// @Getter
// final private IExpReplay<A> expReplay;
@Getter
@Setter(AccessLevel.PACKAGE)
@Setter(AccessLevel.PROTECTED)
protected IExpReplay<A> expReplay;
public QLearning(QLConfiguration conf) {

View File

@ -28,8 +28,6 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
@ -64,20 +62,9 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Setter
private IDQN targetDQN;
private int lastAction;
private INDArray history[] = null;
private INDArray[] history = null;
private double accuReward = 0;
/**
* @deprecated
* Use QLearningDiscrete(MDP, IDQN, QLConfiguration, int) and add the required listeners with addListener() instead.
*/
@Deprecated
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
IDataManager dataManager, int epsilonNbStep) {
this(mdp, dqn, conf, epsilonNbStep);
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
}
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
int epsilonNbStep) {
super(conf);
@ -186,7 +173,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
}
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
if (transitions.size() == 0)
throw new IllegalArgumentException("too few transitions");
@ -194,7 +180,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
int size = transitions.size();
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
: getHistoryProcessor().getConf().getShape();
: getHistoryProcessor().getConf().getShape();
int[] nshape = makeShape(size, shape);
INDArray obs = Nd4j.create(nshape);
INDArray nextObs = Nd4j.create(nshape);

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
@ -33,20 +34,36 @@ import org.deeplearning4j.rl4j.util.IDataManager;
public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscrete<O> {
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf);
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
}
}

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
@ -31,21 +32,35 @@ import org.deeplearning4j.rl4j.util.IDataManager;
public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscrete<O> {
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
IDataManager dataManager) {
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep());
this(mdp, dqn, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep());
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearning.QLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf);
}
}

View File

@ -0,0 +1,10 @@
package org.deeplearning4j.rl4j.policy;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
public interface IPolicy<O extends Encodable, A> {
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
}

View File

@ -35,7 +35,7 @@ import org.nd4j.linalg.util.ArrayUtil;
*
* A Policy responsability is to choose the next action given a state
*/
public abstract class Policy<O extends Encodable, A> {
public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
public abstract NeuralNet getNeuralNet();
@ -49,6 +49,7 @@ public abstract class Policy<O extends Encodable, A> {
return play(mdp, new HistoryProcessor(conf));
}
@Override
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
getNeuralNet().reset();
Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);

View File

@ -22,6 +22,7 @@ import lombok.Builder;
import lombok.Getter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning;
@ -72,13 +73,13 @@ public class DataManager implements IDataManager {
}
}
public static void save(String path, Learning learning) throws IOException {
public static void save(String path, ILearning learning) throws IOException {
try (BufferedOutputStream os = new BufferedOutputStream(new FileOutputStream(path))) {
save(os, learning);
}
}
public static void save(OutputStream os, Learning learning) throws IOException {
public static void save(OutputStream os, ILearning learning) throws IOException {
try (ZipOutputStream zipfile = new ZipOutputStream(os)) {
@ -91,7 +92,9 @@ public class DataManager implements IDataManager {
zipfile.putNextEntry(dqn);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
learning.getNeuralNet().save(bos);
if(learning instanceof NeuralNetFetchable) {
((NeuralNetFetchable)learning).getNeuralNet().save(bos);
}
bos.flush();
bos.close();
@ -104,7 +107,9 @@ public class DataManager implements IDataManager {
zipfile.putNextEntry(hpconf);
ByteArrayOutputStream bos2 = new ByteArrayOutputStream();
learning.getNeuralNet().save(bos2);
if(learning instanceof NeuralNetFetchable) {
((NeuralNetFetchable)learning).getNeuralNet().save(bos2);
}
bos2.flush();
bos2.close();
@ -256,13 +261,15 @@ public class DataManager implements IDataManager {
return exists;
}
public void save(Learning learning) throws IOException {
public void save(ILearning learning) throws IOException {
if (!saveData)
return;
save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning);
learning.getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
if(learning instanceof NeuralNetFetchable) {
((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
}
}

View File

@ -1,126 +0,0 @@
package org.deeplearning4j.rl4j.util;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
/**
* DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the
* training process can be fed to the DataManager
*/
@Slf4j
public class DataManagerSyncTrainingListener implements SyncTrainingListener {
private final IDataManager dataManager;
private final int saveFrequency;
private final int monitorFrequency;
private int lastSave;
private int lastMonitor;
private DataManagerSyncTrainingListener(Builder builder) {
this.dataManager = builder.dataManager;
this.saveFrequency = builder.saveFrequency;
this.lastSave = -builder.saveFrequency;
this.monitorFrequency = builder.monitorFrequency;
this.lastMonitor = -builder.monitorFrequency;
}
@Override
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
try {
dataManager.writeInfo(event.getLearning());
} catch (Exception e) {
log.error("Training failed.", e);
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
@Override
public void onTrainingEnd() {
// Do nothing
}
@Override
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
int stepCounter = event.getLearning().getStepCounter();
if (stepCounter - lastMonitor >= monitorFrequency
&& event.getLearning().getHistoryProcessor() != null
&& dataManager.isSaveData()) {
lastMonitor = stepCounter;
int[] shape = event.getLearning().getMdp().getObservationSpace().getShape();
event.getLearning().getHistoryProcessor().startMonitor(dataManager.getVideoDir() + "/video-" + event.getLearning().getEpochCounter() + "-"
+ stepCounter + ".mp4", shape);
}
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
try {
int stepCounter = event.getLearning().getStepCounter();
if (stepCounter - lastSave >= saveFrequency) {
dataManager.save(event.getLearning());
lastSave = stepCounter;
}
dataManager.appendStat(event.getStatEntry());
dataManager.writeInfo(event.getLearning());
} catch (Exception e) {
log.error("Training failed.", e);
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
public static Builder builder(IDataManager dataManager) {
return new Builder(dataManager);
}
public static class Builder {
private final IDataManager dataManager;
private int saveFrequency = Constants.MODEL_SAVE_FREQ;
private int monitorFrequency = Constants.MONITOR_FREQ;
/**
* Create a Builder with the given DataManager
* @param dataManager
*/
public Builder(IDataManager dataManager) {
this.dataManager = dataManager;
}
/**
* A number that represent the number of steps since the last call to DataManager.save() before can it be called again.
* @param saveFrequency (Default: 100000)
*/
public Builder saveFrequency(int saveFrequency) {
this.saveFrequency = saveFrequency;
return this;
}
/**
* A number that represent the number of steps since the last call to HistoryProcessor.startMonitor() before can it be called again.
* @param monitorFrequency (Default: 10000)
*/
public Builder monitorFrequency(int monitorFrequency) {
this.monitorFrequency = monitorFrequency;
return this;
}
/**
* Creates a DataManagerSyncTrainingListener with the configured parameters
* @return An instance of DataManagerSyncTrainingListener
*/
public DataManagerSyncTrainingListener build() {
return new DataManagerSyncTrainingListener(this);
}
}
}

View File

@ -0,0 +1,83 @@
package org.deeplearning4j.rl4j.util;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
/**
* DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the
* training process can be fed to the DataManager
*/
@Slf4j
public class DataManagerTrainingListener implements TrainingListener {
private final IDataManager dataManager;
private int lastSave = -Constants.MODEL_SAVE_FREQ;
public DataManagerTrainingListener(IDataManager dataManager) {
this.dataManager = dataManager;
}
@Override
public ListenerResponse onTrainingStart() {
return ListenerResponse.CONTINUE;
}
@Override
public void onTrainingEnd() {
}
@Override
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
IHistoryProcessor hp = trainer.getHistoryProcessor();
if(hp != null) {
int[] shape = trainer.getMdp().getObservationSpace().getShape();
String filename = dataManager.getVideoDir() + "/video-";
if (trainer instanceof AsyncThread) {
filename += ((AsyncThread) trainer).getThreadNumber() + "-";
}
filename += trainer.getEpochCounter() + "-" + trainer.getStepCounter() + ".mp4";
hp.startMonitor(filename, shape);
}
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
IHistoryProcessor hp = trainer.getHistoryProcessor();
if(hp != null) {
hp.stopMonitor();
}
try {
dataManager.appendStat(statEntry);
} catch (Exception e) {
log.error("Training failed.", e);
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onTrainingProgress(ILearning learning) {
try {
int stepCounter = learning.getStepCounter();
if (stepCounter - lastSave >= Constants.MODEL_SAVE_FREQ) {
dataManager.save(learning);
lastSave = stepCounter;
}
dataManager.writeInfo(learning);
} catch (Exception e) {
log.error("Training failed.", e);
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
}

View File

@ -27,7 +27,7 @@ public interface IDataManager {
String getVideoDir();
void appendStat(StatEntry statEntry) throws IOException;
void writeInfo(ILearning iLearning) throws IOException;
void save(Learning learning) throws IOException;
void save(ILearning learning) throws IOException;
//In order for jackson to serialize StatEntry
//please use Lombok @Value (see QLStatEntry)

View File

@ -0,0 +1,127 @@
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class AsyncLearningTest {
@Test
public void when_training_expect_AsyncGlobalStarted() {
// Arrange
TestContext context = new TestContext();
context.asyncGlobal.setMaxLoops(1);
// Act
context.sut.train();
// Assert
assertTrue(context.asyncGlobal.hasBeenStarted);
assertTrue(context.asyncGlobal.hasBeenTerminated);
}
@Test
public void when_trainStartReturnsStop_expect_noTraining() {
// Arrange
TestContext context = new TestContext();
context.listener.setRemainingTrainingStartCallCount(0);
// Act
context.sut.train();
// Assert
assertEquals(1, context.listener.onTrainingStartCallCount);
assertEquals(1, context.listener.onTrainingEndCallCount);
assertEquals(0, context.policy.playCallCount);
assertTrue(context.asyncGlobal.hasBeenTerminated);
}
@Test
public void when_trainingIsComplete_expect_trainingStop() {
// Arrange
TestContext context = new TestContext();
// Act
context.sut.train();
// Assert
assertEquals(1, context.listener.onTrainingStartCallCount);
assertEquals(1, context.listener.onTrainingEndCallCount);
assertTrue(context.asyncGlobal.hasBeenTerminated);
}
@Test
public void when_training_expect_onTrainingProgressCalled() {
// Arrange
TestContext context = new TestContext();
// Act
context.sut.train();
// Assert
assertEquals(1, context.listener.onTrainingProgressCallCount);
}
public static class TestContext {
public final MockAsyncConfiguration conf = new MockAsyncConfiguration(1, 1);
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockPolicy policy = new MockPolicy();
public final TestAsyncLearning sut = new TestAsyncLearning(conf, asyncGlobal, policy);
public final MockTrainingListener listener = new MockTrainingListener();
public TestContext() {
sut.addListener(listener);
asyncGlobal.setMaxLoops(1);
sut.setProgressMonitorFrequency(1);
}
}
public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
private final AsyncConfiguration conf;
private final IAsyncGlobal asyncGlobal;
private final IPolicy<MockEncodable, Integer> policy;
public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
super(conf);
this.conf = conf;
this.asyncGlobal = asyncGlobal;
this.policy = policy;
}
@Override
public IPolicy getPolicy() {
return policy;
}
@Override
public AsyncConfiguration getConfiguration() {
return conf;
}
@Override
protected AsyncThread newThread(int i, int deviceAffinity) {
return null;
}
@Override
public MDP getMdp() {
return null;
}
@Override
protected IAsyncGlobal getAsyncGlobal() {
return asyncGlobal;
}
@Override
public MockNeuralNet getNeuralNet() {
return null;
}
}
}

View File

@ -1,206 +1,135 @@
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.support.MockDataManager;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
public class AsyncThreadTest {
@Test
public void refac_withoutHistoryProcessor_checkDataManagerCallsRemainTheSame() {
public void when_newEpochStarted_expect_neuralNetworkReset() {
// Arrange
MockDataManager dataManager = new MockDataManager(false);
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
MockNeuralNet neuralNet = new MockNeuralNet();
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockAsyncConfiguration config = new MockAsyncConfiguration(10, 2);
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, dataManager);
TestContext context = new TestContext();
context.listener.setRemainingOnNewEpochCallCount(5);
// Act
sut.run();
context.sut.run();
// Assert
assertEquals(4, dataManager.statEntries.size());
IDataManager.StatEntry entry = dataManager.statEntries.get(0);
assertEquals(2, entry.getStepCounter());
assertEquals(0, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
entry = dataManager.statEntries.get(1);
assertEquals(4, entry.getStepCounter());
assertEquals(1, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
entry = dataManager.statEntries.get(2);
assertEquals(6, entry.getStepCounter());
assertEquals(2, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
entry = dataManager.statEntries.get(3);
assertEquals(8, entry.getStepCounter());
assertEquals(3, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
assertEquals(0, dataManager.isSaveDataCallCount);
assertEquals(0, dataManager.getVideoDirCallCount);
assertEquals(6, context.neuralNet.resetCallCount);
}
@Test
public void refac_withHistoryProcessor_isSaveFalse_checkDataManagerCallsRemainTheSame() {
public void when_onNewEpochReturnsStop_expect_threadStopped() {
// Arrange
MockDataManager dataManager = new MockDataManager(false);
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
MockNeuralNet neuralNet = new MockNeuralNet();
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
.build();
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
sut.setHistoryProcessor(hp);
TestContext context = new TestContext();
context.listener.setRemainingOnNewEpochCallCount(1);
// Act
sut.run();
context.sut.run();
// Assert
assertEquals(9, dataManager.statEntries.size());
for(int i = 0; i < 9; ++i) {
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(i + 1, entry.getStepCounter());
assertEquals(i, entry.getEpochCounter());
assertEquals(79.0, entry.getReward(), 0.0);
}
assertEquals(10, dataManager.isSaveDataCallCount);
assertEquals(0, dataManager.getVideoDirCallCount);
assertEquals(2, context.listener.onNewEpochCallCount);
assertEquals(1, context.listener.onEpochTrainingResultCallCount);
}
@Test
public void refac_withHistoryProcessor_isSaveTrue_checkDataManagerCallsRemainTheSame() {
public void when_epochTrainingResultReturnsStop_expect_threadStopped() {
// Arrange
MockDataManager dataManager = new MockDataManager(true);
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
MockNeuralNet neuralNet = new MockNeuralNet();
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
.build();
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
sut.setHistoryProcessor(hp);
TestContext context = new TestContext();
context.listener.setRemainingOnEpochTrainingResult(1);
// Act
sut.run();
context.sut.run();
// Assert
assertEquals(9, dataManager.statEntries.size());
for(int i = 0; i < 9; ++i) {
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(i + 1, entry.getStepCounter());
assertEquals(i, entry.getEpochCounter());
assertEquals(79.0, entry.getReward(), 0.0);
}
assertEquals(1, dataManager.isSaveDataCallCount);
assertEquals(1, dataManager.getVideoDirCallCount);
assertEquals(2, context.listener.onNewEpochCallCount);
assertEquals(2, context.listener.onEpochTrainingResultCallCount);
}
public static class MockAsyncGlobal implements IAsyncGlobal {
@Test
public void when_run_expect_preAndPostEpochCalled() {
// Arrange
TestContext context = new TestContext();
private final int maxLoops;
private int currentLoop = 0;
// Act
context.sut.run();
public MockAsyncGlobal(int maxLoops) {
// Assert
assertEquals(6, context.sut.preEpochCallCount);
assertEquals(6, context.sut.postEpochCallCount);
}
this.maxLoops = maxLoops;
@Test
public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() {
// Arrange
TestContext context = new TestContext();
// Act
context.sut.run();
// Assert
assertEquals(5, context.listener.statEntries.size());
int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 };
for(int i = 0; i < 5; ++i) {
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
assertEquals(i, statEntry.getEpochCounter());
assertEquals(2.0, statEntry.getReward(), 0.0001);
}
}
@Override
public boolean isRunning() {
return true;
}
@Override
public void setRunning(boolean value) {
}
@Override
public boolean isTrainingComplete() {
return ++currentLoop >= maxLoops;
}
@Override
public void start() {
}
@Override
public AtomicInteger getT() {
return null;
}
@Override
public NeuralNet getCurrent() {
return null;
}
@Override
public NeuralNet getTarget() {
return null;
}
@Override
public void enqueue(Gradient[] gradient, Integer nstep) {
private static class TestContext {
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockNeuralNet neuralNet = new MockNeuralNet();
public final MockObservationSpace observationSpace = new MockObservationSpace();
public final MockMDP mdp = new MockMDP(observationSpace);
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 2);
public final TrainingListenerList listeners = new TrainingListenerList();
public final MockTrainingListener listener = new MockTrainingListener();
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
public TestContext() {
asyncGlobal.setMaxLoops(10);
listeners.add(listener);
}
}
public static class MockAsyncThread extends AsyncThread {
IAsyncGlobal asyncGlobal;
private final MockNeuralNet neuralNet;
private final MDP mdp;
private final AsyncConfiguration conf;
private final IDataManager dataManager;
public int preEpochCallCount = 0;
public int postEpochCallCount = 0;
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) {
super(asyncGlobal, threadNumber, 0);
private final IAsyncGlobal asyncGlobal;
private final MockNeuralNet neuralNet;
private final AsyncConfiguration conf;
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
super(asyncGlobal, mdp, listeners, threadNumber, 0);
this.asyncGlobal = asyncGlobal;
this.neuralNet = neuralNet;
this.mdp = mdp;
this.conf = conf;
this.dataManager = dataManager;
}
@Override
protected void preEpoch() {
++preEpochCallCount;
super.preEpoch();
}
@Override
protected void postEpoch() {
++postEpochCallCount;
super.postEpoch();
}
@Override
@ -208,31 +137,16 @@ public class AsyncThreadTest {
return neuralNet;
}
@Override
protected int getThreadNumber() {
return 0;
}
@Override
protected IAsyncGlobal getAsyncGlobal() {
return asyncGlobal;
}
@Override
protected MDP getMdp() {
return mdp;
}
@Override
protected AsyncConfiguration getConf() {
return conf;
}
@Override
protected IDataManager getDataManager() {
return dataManager;
}
@Override
protected Policy getPolicy(NeuralNet net) {
return null;
@ -244,129 +158,6 @@ public class AsyncThreadTest {
}
}
public static class MockNeuralNet implements NeuralNet {
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}
@Override
public boolean isRecurrent() {
return false;
}
@Override
public void reset() {
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];
}
@Override
public NeuralNet clone() {
return null;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] labels) {
return new Gradient[0];
}
@Override
public void fit(INDArray input, INDArray[] labels) {
}
@Override
public void applyGradient(Gradient[] gradients, int batchSize) {
}
@Override
public double getLatestScore() {
return 0;
}
@Override
public void save(OutputStream os) throws IOException {
}
@Override
public void save(String filename) throws IOException {
}
}
public static class MockAsyncConfiguration implements AsyncConfiguration {
private final int nStep;
private final int maxEpochStep;
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
this.nStep = nStep;
this.maxEpochStep = maxEpochStep;
}
@Override
public int getSeed() {
return 0;
}
@Override
public int getMaxEpochStep() {
return maxEpochStep;
}
@Override
public int getMaxStep() {
return 0;
}
@Override
public int getNumThread() {
return 0;
}
@Override
public int getNstep() {
return nStep;
}
@Override
public int getTargetDqnUpdateFreq() {
return 0;
}
@Override
public int getUpdateStart() {
return 0;
}
@Override
public double getRewardFactor() {
return 0;
}
@Override
public double getGamma() {
return 0;
}
@Override
public double getErrorClamp() {
return 0;
}
}
}

View File

@ -0,0 +1,98 @@
package org.deeplearning4j.rl4j.learning.async.listener;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class AsyncTrainingListenerListTest {
@Test
public void when_listIsEmpty_expect_notifyTrainingStartedReturnTrue() {
// Arrange
TrainingListenerList sut = new TrainingListenerList();
// Act
boolean resultTrainingStarted = sut.notifyTrainingStarted();
boolean resultNewEpoch = sut.notifyNewEpoch(null);
boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null);
// Assert
assertTrue(resultTrainingStarted);
assertTrue(resultNewEpoch);
assertTrue(resultEpochTrainingResult);
}
@Test
public void when_firstListerStops_expect_othersListnersNotCalled() {
// Arrange
MockTrainingListener listener1 = new MockTrainingListener();
listener1.onTrainingResultResponse = TrainingListener.ListenerResponse.STOP;
MockTrainingListener listener2 = new MockTrainingListener();
TrainingListenerList sut = new TrainingListenerList();
sut.add(listener1);
sut.add(listener2);
// Act
sut.notifyEpochTrainingResult(null, null);
// Assert
assertEquals(1, listener1.onEpochTrainingResultCallCount);
assertEquals(0, listener2.onEpochTrainingResultCallCount);
}
@Test
public void when_allListenersContinue_expect_listReturnsTrue() {
// Arrange
MockTrainingListener listener1 = new MockTrainingListener();
MockTrainingListener listener2 = new MockTrainingListener();
TrainingListenerList sut = new TrainingListenerList();
sut.add(listener1);
sut.add(listener2);
// Act
boolean resultTrainingProgress = sut.notifyEpochTrainingResult(null, null);
// Assert
assertTrue(resultTrainingProgress);
}
private static class MockTrainingListener implements TrainingListener {
public int onEpochTrainingResultCallCount = 0;
public ListenerResponse onTrainingResultResponse = ListenerResponse.CONTINUE;
public int onTrainingProgressCallCount = 0;
public ListenerResponse onTrainingProgressResponse = ListenerResponse.CONTINUE;
@Override
public ListenerResponse onTrainingStart() {
return ListenerResponse.CONTINUE;
}
@Override
public void onTrainingEnd() {
}
@Override
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
++onEpochTrainingResultCallCount;
return onTrainingResultResponse;
}
@Override
public ListenerResponse onTrainingProgress(ILearning learning) {
++onTrainingProgressCallCount;
return onTrainingProgressResponse;
}
}
}

View File

@ -0,0 +1,83 @@
package org.deeplearning4j.rl4j.learning.listener;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.junit.Test;
import static org.junit.Assert.*;
public class TrainingListenerListTest {
@Test
public void when_listIsEmpty_expect_notifyReturnTrue() {
// Arrange
TrainingListenerList sut = new TrainingListenerList();
// Act
boolean resultTrainingStarted = sut.notifyTrainingStarted();
boolean resultNewEpoch = sut.notifyNewEpoch(null);
boolean resultEpochFinished = sut.notifyEpochTrainingResult(null, null);
// Assert
assertTrue(resultTrainingStarted);
assertTrue(resultNewEpoch);
assertTrue(resultEpochFinished);
}
@Test
public void when_firstListerStops_expect_othersListnersNotCalled() {
// Arrange
MockTrainingListener listener1 = new MockTrainingListener();
listener1.setRemainingTrainingStartCallCount(0);
listener1.setRemainingOnNewEpochCallCount(0);
listener1.setRemainingonTrainingProgressCallCount(0);
listener1.setRemainingOnEpochTrainingResult(0);
MockTrainingListener listener2 = new MockTrainingListener();
TrainingListenerList sut = new TrainingListenerList();
sut.add(listener1);
sut.add(listener2);
// Act
sut.notifyTrainingStarted();
sut.notifyNewEpoch(null);
sut.notifyEpochTrainingResult(null, null);
sut.notifyTrainingProgress(null);
sut.notifyTrainingFinished();
// Assert
assertEquals(1, listener1.onTrainingStartCallCount);
assertEquals(0, listener2.onTrainingStartCallCount);
assertEquals(1, listener1.onNewEpochCallCount);
assertEquals(0, listener2.onNewEpochCallCount);
assertEquals(1, listener1.onEpochTrainingResultCallCount);
assertEquals(0, listener2.onEpochTrainingResultCallCount);
assertEquals(1, listener1.onTrainingProgressCallCount);
assertEquals(0, listener2.onTrainingProgressCallCount);
assertEquals(1, listener1.onTrainingEndCallCount);
assertEquals(1, listener2.onTrainingEndCallCount);
}
@Test
public void when_allListenersContinue_expect_listReturnsTrue() {
// Arrange
MockTrainingListener listener1 = new MockTrainingListener();
MockTrainingListener listener2 = new MockTrainingListener();
TrainingListenerList sut = new TrainingListenerList();
sut.add(listener1);
sut.add(listener2);
// Act
boolean resultTrainingStarted = sut.notifyTrainingStarted();
boolean resultNewEpoch = sut.notifyNewEpoch(null);
boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null);
boolean resultProgress = sut.notifyTrainingProgress(null);
// Assert
assertTrue(resultTrainingStarted);
assertTrue(resultNewEpoch);
assertTrue(resultEpochTrainingResult);
assertTrue(resultProgress);
}
}

View File

@ -2,12 +2,10 @@ package org.deeplearning4j.rl4j.learning.sync;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.learning.sync.support.MockSyncTrainingListener;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.support.MockDataManager;
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
@ -19,7 +17,7 @@ public class SyncLearningTest {
public void when_training_expect_listenersToBeCalled() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
@ -27,8 +25,8 @@ public class SyncLearningTest {
sut.train();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(10, listener.onEpochStartCallCount);
assertEquals(10, listener.onEpochEndStartCallCount);
assertEquals(10, listener.onNewEpochCallCount);
assertEquals(10, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
}
@ -36,65 +34,59 @@ public class SyncLearningTest {
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
listener.trainingStartCanContinue = false;
listener.setRemainingTrainingStartCallCount(0);
// Act
sut.train();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(0, listener.onEpochStartCallCount);
assertEquals(0, listener.onEpochEndStartCallCount);
assertEquals(0, listener.onNewEpochCallCount);
assertEquals(0, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
}
@Test
public void when_epochStartCanContinueFalse_expect_trainingStopped() {
public void when_newEpochCanContinueFalse_expect_trainingStopped() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
listener.nbStepsEpochStartCanContinue = 3;
listener.setRemainingOnNewEpochCallCount(2);
// Act
sut.train();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(3, listener.onEpochStartCallCount);
assertEquals(2, listener.onEpochEndStartCallCount);
assertEquals(3, listener.onNewEpochCallCount);
assertEquals(2, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
}
@Test
public void when_epochEndCanContinueFalse_expect_trainingStopped() {
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
listener.nbStepsEpochEndCanContinue = 3;
listener.setRemainingOnEpochTrainingResult(2);
// Act
sut.train();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(3, listener.onEpochStartCallCount);
assertEquals(3, listener.onEpochEndStartCallCount);
assertEquals(3, listener.onNewEpochCallCount);
assertEquals(3, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
}
public static class MockSyncLearning extends SyncLearning {
private LConfiguration conf;
public MockSyncLearning(LConfiguration conf, IDataManager dataManager) {
super(conf);
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
this.conf = conf;
}
private final LConfiguration conf;
public MockSyncLearning(LConfiguration conf) {
super(conf);
@ -119,7 +111,7 @@ public class SyncLearningTest {
}
@Override
public Policy getPolicy() {
public IPolicy getPolicy() {
return null;
}

View File

@ -1,65 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscrete;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
import org.deeplearning4j.rl4j.learning.sync.support.MockMDP;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.support.MockDataManager;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
public class QLearningDiscreteTest {
@Test
public void refac_checkDataManagerCallsRemainTheSame() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder()
.maxStep(10)
.expRepMaxSize(1)
.build();
MockDataManager dataManager = new MockDataManager(true);
MockQLearningDiscrete sut = new MockQLearningDiscrete(10, lconfig, dataManager, 2, 3);
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
.build();
sut.setHistoryProcessor(new MockHistoryProcessor(hpConfig));
// Act
sut.train();
assertEquals(10, dataManager.statEntries.size());
for(int i = 0; i < 10; ++i) {
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(i, entry.getEpochCounter());
assertEquals(i+1, entry.getStepCounter());
assertEquals(1.0, entry.getReward(), 0.0);
}
assertEquals(4, dataManager.isSaveDataCallCount);
assertEquals(4, dataManager.getVideoDirCallCount);
assertEquals(11, dataManager.writeInfoCallCount);
assertEquals(5, dataManager.saveCallCount);
}
public static class MockQLearningDiscrete extends QLearningDiscrete {
public MockQLearningDiscrete(int maxSteps, QLConfiguration conf,
IDataManager dataManager, int saveFrequency, int monitorFrequency) {
super(new MockMDP(maxSteps), new MockDQN(), conf, 2);
addListener(DataManagerSyncTrainingListener.builder(dataManager)
.saveFrequency(saveFrequency)
.monitorFrequency(monitorFrequency)
.build());
}
@Override
protected IDataManager.StatEntry trainEpoch() {
setStepCounter(getStepCounter() + 1);
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
}
}
}

View File

@ -8,6 +8,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -29,12 +30,11 @@ public class QLearningDiscreteTest {
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
0, 1.0, 0, 0, 0, 0, true);
MockDataManager dataManager = new MockDataManager(false);
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, 10);
MockExpReplay expReplay = new MockExpReplay();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10);
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hp);
MockExpReplay expReplay = new MockExpReplay();
sut.setExpReplay(expReplay);
MockEncodable obs = new MockEncodable(1);
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
@ -131,8 +131,11 @@ public class QLearningDiscreteTest {
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn,
QLConfiguration conf, IDataManager dataManager, int epsilonNbStep) {
super(mdp, dqn, conf, dataManager, epsilonNbStep);
QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
int epsilonNbStep) {
super(mdp, dqn, conf, epsilonNbStep);
addListener(new DataManagerTrainingListener(dataManager));
setExpReplay(expReplay);
}
@Override

View File

@ -1,46 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
public class MockSyncTrainingListener implements SyncTrainingListener {
public int onTrainingStartCallCount = 0;
public int onTrainingEndCallCount = 0;
public int onEpochStartCallCount = 0;
public int onEpochEndStartCallCount = 0;
public boolean trainingStartCanContinue = true;
public int nbStepsEpochStartCanContinue = Integer.MAX_VALUE;
public int nbStepsEpochEndCanContinue = Integer.MAX_VALUE;
@Override
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
++onTrainingStartCallCount;
return trainingStartCanContinue ? ListenerResponse.CONTINUE : ListenerResponse.STOP;
}
@Override
public void onTrainingEnd() {
++onTrainingEndCallCount;
}
@Override
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
++onEpochStartCallCount;
if(onEpochStartCallCount >= nbStepsEpochStartCanContinue) {
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
++onEpochEndStartCallCount;
if(onEpochEndStartCallCount >= nbStepsEpochEndCanContinue) {
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
}

View File

@ -0,0 +1,65 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
public class MockAsyncConfiguration implements AsyncConfiguration {
private final int nStep;
private final int maxEpochStep;
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
this.nStep = nStep;
this.maxEpochStep = maxEpochStep;
}
@Override
public int getSeed() {
return 0;
}
@Override
public int getMaxEpochStep() {
return maxEpochStep;
}
@Override
public int getMaxStep() {
return 0;
}
@Override
public int getNumThread() {
return 0;
}
@Override
public int getNstep() {
return nStep;
}
@Override
public int getTargetDqnUpdateFreq() {
return 0;
}
@Override
public int getUpdateStart() {
return 0;
}
@Override
public double getRewardFactor() {
return 0;
}
@Override
public double getGamma() {
return 0;
}
@Override
public double getErrorClamp() {
return 0;
}
}

View File

@ -0,0 +1,65 @@
package org.deeplearning4j.rl4j.support;
import lombok.Setter;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.network.NeuralNet;
import java.util.concurrent.atomic.AtomicInteger;
public class MockAsyncGlobal implements IAsyncGlobal {
public boolean hasBeenStarted = false;
public boolean hasBeenTerminated = false;
@Setter
private int maxLoops;
@Setter
private int numLoopsStopRunning;
private int currentLoop = 0;
public MockAsyncGlobal() {
maxLoops = Integer.MAX_VALUE;
numLoopsStopRunning = Integer.MAX_VALUE;
}
@Override
public boolean isRunning() {
return currentLoop < numLoopsStopRunning;
}
@Override
public void terminate() {
hasBeenTerminated = true;
}
@Override
public boolean isTrainingComplete() {
return ++currentLoop > maxLoops;
}
@Override
public void start() {
hasBeenStarted = true;
}
@Override
public AtomicInteger getT() {
return null;
}
@Override
public NeuralNet getCurrent() {
return null;
}
@Override
public NeuralNet getTarget() {
return null;
}
@Override
public void enqueue(Gradient[] gradient, Integer nstep) {
}
}

View File

@ -44,7 +44,7 @@ public class MockDataManager implements IDataManager {
}
@Override
public void save(Learning learning) throws IOException {
public void save(ILearning learning) throws IOException {
++saveCallCount;
}
}

View File

@ -0,0 +1,74 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.IOException;
import java.io.OutputStream;
public class MockNeuralNet implements NeuralNet {
public int resetCallCount = 0;
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}
@Override
public boolean isRecurrent() {
return false;
}
@Override
public void reset() {
++resetCallCount;
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];
}
@Override
public NeuralNet clone() {
return null;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] labels) {
return new Gradient[0];
}
@Override
public void fit(INDArray input, INDArray[] labels) {
}
@Override
public void applyGradient(Gradient[] gradients, int batchSize) {
}
@Override
public double getLatestScore() {
return 0;
}
@Override
public void save(OutputStream os) throws IOException {
}
@Override
public void save(String filename) throws IOException {
}
}

View File

@ -0,0 +1,17 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
public class MockPolicy implements IPolicy<MockEncodable, Integer> {
public int playCallCount = 0;
@Override
public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) {
++playCallCount;
return 0;
}
}

View File

@ -0,0 +1,65 @@
package org.deeplearning4j.rl4j.support;
import lombok.Setter;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.util.IDataManager;
import java.util.ArrayList;
import java.util.List;
public class MockTrainingListener implements TrainingListener {
public int onTrainingStartCallCount = 0;
public int onTrainingEndCallCount = 0;
public int onNewEpochCallCount = 0;
public int onEpochTrainingResultCallCount = 0;
public int onTrainingProgressCallCount = 0;
@Setter
private int remainingTrainingStartCallCount = Integer.MAX_VALUE;
@Setter
private int remainingOnNewEpochCallCount = Integer.MAX_VALUE;
@Setter
private int remainingOnEpochTrainingResult = Integer.MAX_VALUE;
@Setter
private int remainingonTrainingProgressCallCount = Integer.MAX_VALUE;
public final List<IDataManager.StatEntry> statEntries = new ArrayList<>();
@Override
public ListenerResponse onTrainingStart() {
++onTrainingStartCallCount;
--remainingTrainingStartCallCount;
return remainingTrainingStartCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
++onNewEpochCallCount;
--remainingOnNewEpochCallCount;
return remainingOnNewEpochCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
++onEpochTrainingResultCallCount;
--remainingOnEpochTrainingResult;
statEntries.add(statEntry);
return remainingOnEpochTrainingResult < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onTrainingProgress(ILearning learning) {
++onTrainingProgressCallCount;
--remainingonTrainingProgressCallCount;
return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public void onTrainingEnd() {
++onTrainingEndCallCount;
}
}

View File

@ -0,0 +1,169 @@
package org.deeplearning4j.rl4j.util;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.support.MockDataManager;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
public class DataManagerTrainingListenerTest {
@Test
public void when_callingOnNewEpochWithoutHistoryProcessor_expect_noException() {
// Arrange
TestTrainer trainer = new TestTrainer();
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
// Act
TrainingListener.ListenerResponse response = sut.onNewEpoch(trainer);
// Assert
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
}
@Test
public void when_callingOnNewEpochWithHistoryProcessor_expect_startMonitorNotCalled() {
// Arrange
TestTrainer trainer = new TestTrainer();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
trainer.setHistoryProcessor(hp);
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
// Act
TrainingListener.ListenerResponse response = sut.onNewEpoch(trainer);
// Assert
assertEquals(1, hp.startMonitorCallCount);
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
}
@Test
public void when_callingOnEpochTrainingResultWithoutHistoryProcessor_expect_noException() {
// Arrange
TestTrainer trainer = new TestTrainer();
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
// Act
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, null);
// Assert
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
}
@Test
public void when_callingOnNewEpochWithHistoryProcessor_expect_stopMonitorNotCalled() {
// Arrange
TestTrainer trainer = new TestTrainer();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
trainer.setHistoryProcessor(hp);
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
// Act
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, null);
// Assert
assertEquals(1, hp.stopMonitorCallCount);
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
}
@Test
public void when_callingOnEpochTrainingResult_expect_callToDataManagerAppendStat() {
// Arrange
TestTrainer trainer = new TestTrainer();
MockDataManager dm = new MockDataManager(false);
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
MockStatEntry statEntry = new MockStatEntry(0, 0, 0.0);
// Act
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, statEntry);
// Assert
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
assertEquals(1, dm.statEntries.size());
assertSame(statEntry, dm.statEntries.get(0));
}
@Test
public void when_callingOnTrainingProgress_expect_callToDataManagerSaveAndWriteInfo() {
// Arrange
TestTrainer learning = new TestTrainer();
MockDataManager dm = new MockDataManager(false);
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
// Act
TrainingListener.ListenerResponse response = sut.onTrainingProgress(learning);
// Assert
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
assertEquals(1, dm.writeInfoCallCount);
assertEquals(1, dm.saveCallCount);
}
@Test
public void when_stepCounterCloseToLastSave_expect_dataManagerSaveNotCalled() {
// Arrange
TestTrainer learning = new TestTrainer();
MockDataManager dm = new MockDataManager(false);
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
// Act
TrainingListener.ListenerResponse response = sut.onTrainingProgress(learning);
TrainingListener.ListenerResponse response2 = sut.onTrainingProgress(learning);
// Assert
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response2);
assertEquals(1, dm.saveCallCount);
}
private static class TestTrainer implements IEpochTrainer, ILearning
{
@Override
public int getStepCounter() {
return 0;
}
@Override
public int getEpochCounter() {
return 0;
}
@Getter
@Setter
private IHistoryProcessor historyProcessor;
@Override
public IPolicy getPolicy() {
return null;
}
@Override
public void train() {
}
@Override
public LConfiguration getConfiguration() {
return null;
}
@Override
public MDP getMdp() {
return new MockMDP(new MockObservationSpace());
}
}
}