RL4J - AsyncTrainingListener (#8072)

* Code clarity: Extracted parts of run() into private methods

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Added listener pattern to async learning

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Merged all listeners logic

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Added interface and common data to training events

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Fixed missing info log file

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Fixed bad merge; removed useless TrainingEvent

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Removed param from training start/end event

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Removed 'event' classes from the training listener

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Reverted changes to QLearningDiscrete.setTarget()
master
Alexandre Boulanger 2019-09-18 21:28:13 -04:00 committed by Alex Black
parent d58a4b45b1
commit 59f1cbf0c6
47 changed files with 1576 additions and 881 deletions

View File

@ -0,0 +1,31 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning;
import org.deeplearning4j.rl4j.mdp.MDP;
/**
* The common API between Learning and AsyncThread.
*
* @author Alexandre Boulanger
*/
public interface IEpochTrainer {
int getStepCounter();
int getEpochCounter();
IHistoryProcessor getHistoryProcessor();
MDP getMdp();
}

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.rl4j.learning; package org.deeplearning4j.rl4j.learning;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
*/ */
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> extends StepCountable { public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> extends StepCountable {
Policy<O, A> getPolicy(); IPolicy<O, A> getPolicy();
void train(); void train();
@ -38,6 +38,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> ex
MDP<O, A, AS> getMdp(); MDP<O, A, AS> getMdp();
IHistoryProcessor getHistoryProcessor();
interface LConfiguration { interface LConfiguration {

View File

@ -26,7 +26,6 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;

View File

@ -17,7 +17,6 @@
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.async;
import lombok.Getter; import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
@ -63,7 +62,6 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
@Getter @Getter
private NN target; private NN target;
@Getter @Getter
@Setter
private boolean running = true; private boolean running = true;
public AsyncGlobal(NN initial, AsyncConfiguration a3cc) { public AsyncGlobal(NN initial, AsyncConfiguration a3cc) {
@ -78,8 +76,10 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
} }
public void enqueue(Gradient[] gradient, Integer nstep) { public void enqueue(Gradient[] gradient, Integer nstep) {
if(running && !isTrainingComplete()) {
queue.add(new Pair<>(gradient, nstep)); queue.add(new Pair<>(gradient, nstep));
} }
}
@Override @Override
public void run() { public void run() {
@ -105,4 +105,12 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
} }
/**
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
*/
public void terminate() {
running = false;
queue.clear();
}
} }

View File

@ -16,33 +16,49 @@
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.async;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
/** /**
* The entry point for async training. This class will start a number ({@link AsyncConfiguration#getNumThread()
* configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals
* (see setProgressEventInterval(int))
*
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16.
* * @author Alexandre Boulanger
* Async learning always follow the same pattern in RL4J
* -launch the Global thread
* -launch the "save threads"
* -periodically evaluate the model of the global thread for monitoring purposes
*
*/ */
@Slf4j @Slf4j
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Learning<O, A, AS, NN> { extends Learning<O, A, AS, NN> {
protected abstract IDataManager getDataManager(); @Getter(AccessLevel.PROTECTED)
private final TrainingListenerList listeners = new TrainingListenerList();
public AsyncLearning(AsyncConfiguration conf) { public AsyncLearning(AsyncConfiguration conf) {
super(conf); super(conf);
} }
/**
* Add a {@link TrainingListener} listener at the end of the listener list.
*
* @param listener the listener to be added
*/
public void addListener(TrainingListener listener) {
listeners.add(listener);
}
/**
* Returns the configuration
* @return the configuration (see {@link AsyncConfiguration})
*/
public abstract AsyncConfiguration getConfiguration(); public abstract AsyncConfiguration getConfiguration();
protected abstract AsyncThread newThread(int i, int deviceAffinity); protected abstract AsyncThread newThread(int i, int deviceAffinity);
@ -57,41 +73,80 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
return getAsyncGlobal().isTrainingComplete(); return getAsyncGlobal().isTrainingComplete();
} }
public void launchThreads() { private boolean canContinue = true;
/**
* Number of milliseconds between calls to onTrainingProgress
*/
@Getter @Setter
private int progressMonitorFrequency = 20000;
private void launchThreads() {
startGlobalThread(); startGlobalThread();
for (int i = 0; i < getConfiguration().getNumThread(); i++) { for (int i = 0; i < getConfiguration().getNumThread(); i++) {
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices()); Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
t.start(); t.start();
} }
log.info("Threads launched."); log.info("Threads launched.");
} }
/**
* @return The current step
*/
@Override @Override
public int getStepCounter() { public int getStepCounter() {
return getAsyncGlobal().getT().get(); return getAsyncGlobal().getT().get();
} }
/**
* This method will train the model<p>
* The training stop when:<br>
* - A worker thread terminate the AsyncGlobal thread (see {@link AsyncGlobal})<br>
* OR<br>
* - a listener explicitly stops it<br>
* <p>
* Listeners<br>
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
* returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse TrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
* Events:
* <ul>
* <li>{@link TrainingListener#onTrainingStart() onTrainingStart()} is called once when the training starts.</li>
* <li>{@link TrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
* </ul>
*/
public void train() { public void train() {
try {
log.info("AsyncLearning training starting."); log.info("AsyncLearning training starting.");
canContinue = listeners.notifyTrainingStarted();
if (canContinue) {
launchThreads(); launchThreads();
monitorTraining();
}
cleanupPostTraining();
listeners.notifyTrainingFinished();
}
protected void monitorTraining() {
try {
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
canContinue = listeners.notifyTrainingProgress(this);
if(!canContinue) {
return;
}
//this is simply for stat purposes
getDataManager().writeInfo(this);
synchronized (this) { synchronized (this) {
while (!isTrainingComplete() && getAsyncGlobal().isRunning()) { wait(progressMonitorFrequency);
getPolicy().play(getMdp(), getHistoryProcessor());
getDataManager().writeInfo(this);
wait(20000);
} }
} }
} catch (Exception e) { } catch (InterruptedException e) {
log.error("Training failed.", e); log.error("Training interrupted.", e);
e.printStackTrace();
} }
} }
protected void cleanupPostTraining() {
// Worker threads stops automatically when the global thread stops
getAsyncGlobal().terminate();
}
} }

View File

@ -21,33 +21,31 @@ import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.Value; import lombok.Value;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.HistoryProcessor; import org.deeplearning4j.rl4j.learning.*;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.Constants;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*
* This represent a local thread that explore the environment * This represent a local thread that explore the environment
* and calculate a gradient to enqueue to the global thread/model * and calculate a gradient to enqueue to the global thread/model
* *
* It has its own version of a model that it syncs at the start of every * It has its own version of a model that it syncs at the start of every
* sub epoch * sub epoch
* *
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
* @author Alexandre Boulanger
*/ */
@Slf4j @Slf4j
public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Thread implements StepCountable { extends Thread implements StepCountable, IEpochTrainer {
@Getter
private int threadNumber; private int threadNumber;
@Getter @Getter
protected final int deviceNum; protected final int deviceNum;
@ -55,12 +53,16 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
private int stepCounter = 0; private int stepCounter = 0;
@Getter @Setter @Getter @Setter
private int epochCounter = 0; private int epochCounter = 0;
@Getter
private MDP<O, A, AS> mdp;
@Getter @Setter @Getter @Setter
private IHistoryProcessor historyProcessor; private IHistoryProcessor historyProcessor;
@Getter
private int lastMonitor = -Constants.MONITOR_FREQ;
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) { private final TrainingListenerList listeners;
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, MDP<O, A, AS> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
this.mdp = mdp;
this.listeners = listeners;
this.threadNumber = threadNumber; this.threadNumber = threadNumber;
this.deviceNum = deviceNum; this.deviceNum = deviceNum;
} }
@ -80,75 +82,106 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
} }
protected void preEpoch() { protected void preEpoch() {
if (getStepCounter() - lastMonitor >= Constants.MONITOR_FREQ && getHistoryProcessor() != null // Do nothing
&& getDataManager().isSaveData()) {
lastMonitor = getStepCounter();
int[] shape = getMdp().getObservationSpace().getShape();
getHistoryProcessor().startMonitor(getDataManager().getVideoDir() + "/video-" + threadNumber + "-"
+ getEpochCounter() + "-" + getStepCounter() + ".mp4", shape);
}
} }
/**
* This method will start the worker thread<p>
* The thread will stop when:<br>
* - The AsyncGlobal thread terminates or reports that the training is complete
* (see {@link AsyncGlobal#isTrainingComplete()}). In such case, the currently running epoch will still be handled normally and
* events will also be fired normally.<br>
* OR<br>
* - a listener explicitly stops it, in which case, the AsyncGlobal thread will be terminated along with
* all other worker threads <br>
* <p>
* Listeners<br>
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
* returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse
* TrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
* Events:
* <ul>
* <li>{@link TrainingListener#onNewEpoch(IEpochTrainer) onNewEpoch()} is called when a new epoch is started.</li>
* <li>{@link TrainingListener#onEpochTrainingResult(IEpochTrainer, IDataManager.StatEntry) onEpochTrainingResult()} is called at the end of every
* epoch. It will not be called if onNewEpoch() stops the training.</li>
* </ul>
*/
@Override @Override
public void run() { public void run() {
RunContext<O> context = new RunContext<>();
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum); Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
try {
log.info("ThreadNum-" + threadNumber + " Started!"); log.info("ThreadNum-" + threadNumber + " Started!");
boolean canContinue = initWork(context);
if (canContinue) {
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
handleTraining(context);
if (context.length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
canContinue = finishEpoch(context) && startNewEpoch(context);
if (!canContinue) {
break;
}
}
}
}
terminateWork();
}
private void initNewEpoch(RunContext context) {
getCurrent().reset(); getCurrent().reset();
Learning.InitMdp<O> initMdp = Learning.initMdp(getMdp(), historyProcessor); Learning.InitMdp<O> initMdp = Learning.initMdp(getMdp(), historyProcessor);
O obs = initMdp.getLastObs();
double rewards = initMdp.getReward();
int length = initMdp.getSteps();
preEpoch(); context.obs = initMdp.getLastObs();
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) { context.rewards = initMdp.getReward();
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - length); context.length = initMdp.getSteps();
SubEpochReturn<O> subEpochReturn = trainSubEpoch(obs, maxSteps); }
obs = subEpochReturn.getLastObs();
private void handleTraining(RunContext<O> context) {
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.length);
SubEpochReturn<O> subEpochReturn = trainSubEpoch(context.obs, maxSteps);
context.obs = subEpochReturn.getLastObs();
stepCounter += subEpochReturn.getSteps(); stepCounter += subEpochReturn.getSteps();
length += subEpochReturn.getSteps(); context.length += subEpochReturn.getSteps();
rewards += subEpochReturn.getReward(); context.rewards += subEpochReturn.getReward();
double score = subEpochReturn.getScore(); context.score = subEpochReturn.getScore();
if (length >= getConf().getMaxEpochStep() || getMdp().isDone()) { }
postEpoch();
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score);
getDataManager().appendStat(statEntry);
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
getCurrent().reset();
initMdp = Learning.initMdp(getMdp(), historyProcessor);
obs = initMdp.getLastObs();
rewards = initMdp.getReward();
length = initMdp.getSteps();
epochCounter++;
private boolean initWork(RunContext context) {
initNewEpoch(context);
preEpoch(); preEpoch();
return listeners.notifyNewEpoch(this);
} }
private boolean startNewEpoch(RunContext context) {
initNewEpoch(context);
epochCounter++;
preEpoch();
return listeners.notifyNewEpoch(this);
} }
} catch (Exception e) {
log.error("Thread crashed: " + e.getCause()); private boolean finishEpoch(RunContext context) {
getAsyncGlobal().setRunning(false);
e.printStackTrace();
} finally {
postEpoch(); postEpoch();
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.length, context.score);
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
return listeners.notifyEpochTrainingResult(this, statEntry);
} }
private void terminateWork() {
postEpoch();
getAsyncGlobal().terminate();
} }
protected abstract NN getCurrent(); protected abstract NN getCurrent();
protected abstract int getThreadNumber();
protected abstract IAsyncGlobal<NN> getAsyncGlobal(); protected abstract IAsyncGlobal<NN> getAsyncGlobal();
protected abstract MDP<O, A, AS> getMdp();
protected abstract AsyncConfiguration getConf(); protected abstract AsyncConfiguration getConf();
protected abstract IDataManager getDataManager();
protected abstract Policy<O, A> getPolicy(NN net); protected abstract Policy<O, A> getPolicy(NN net);
protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep); protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep);
@ -172,4 +205,11 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
double score; double score;
} }
private static class RunContext<O extends Encodable> {
private O obs;
private double rewards;
private int length;
private double score;
}
} }

View File

@ -21,7 +21,9 @@ import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
@ -44,8 +46,8 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
@Getter @Getter
private NN current; private NN current;
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) { public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, MDP<O, Integer, DiscreteSpace> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
super(asyncGlobal, threadNumber, deviceNum); super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
synchronized (asyncGlobal) { synchronized (asyncGlobal) {
current = (NN)asyncGlobal.getCurrent().clone(); current = (NN)asyncGlobal.getCurrent().clone();
} }

View File

@ -23,9 +23,14 @@ import java.util.concurrent.atomic.AtomicInteger;
public interface IAsyncGlobal<NN extends NeuralNet> { public interface IAsyncGlobal<NN extends NeuralNet> {
boolean isRunning(); boolean isRunning();
void setRunning(boolean value);
boolean isTrainingComplete(); boolean isTrainingComplete();
void start(); void start();
/**
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
*/
void terminate();
AtomicInteger getT(); AtomicInteger getT();
NN getCurrent(); NN getCurrent();
NN getTarget(); NN getTarget();

View File

@ -26,7 +26,6 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
@ -47,24 +46,19 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
final private AsyncGlobal asyncGlobal; final private AsyncGlobal asyncGlobal;
@Getter @Getter
final private ACPolicy<O> policy; final private ACPolicy<O> policy;
@Getter
final private IDataManager dataManager;
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf, public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) {
IDataManager dataManager) {
super(conf); super(conf);
this.iActorCritic = iActorCritic; this.iActorCritic = iActorCritic;
this.mdp = mdp; this.mdp = mdp;
this.configuration = conf; this.configuration = conf;
this.dataManager = dataManager;
policy = new ACPolicy<>(iActorCritic, getRandom()); policy = new ACPolicy<>(iActorCritic, getRandom());
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf); asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
mdp.getActionSpace().setSeed(conf.getSeed()); mdp.getActionSpace().setSeed(conf.getSeed());
} }
@Override
protected AsyncThread newThread(int i, int deviceNum) { protected AsyncThread newThread(int i, int deviceNum) {
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), i, dataManager, deviceNum); return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), deviceNum, getListeners(), i);
} }
public IActorCritic getNeuralNet() { public IActorCritic getNeuralNet() {

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
@ -43,24 +44,38 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
final private HistoryProcessor.Configuration hpconf; final private HistoryProcessor.Configuration hpconf;
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, @Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
super(mdp, IActorCritic, conf, dataManager); this(mdp, actorCritic, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
super(mdp, IActorCritic, conf);
this.hpconf = hpconf; this.hpconf = hpconf;
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
} }
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
dataManager); }
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
} }
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager); this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
} }
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
}
@Override @Override
public AsyncThread newThread(int i, int deviceNum) { public AsyncThread newThread(int i, int deviceNum) {

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.*; import org.deeplearning4j.rl4j.network.ac.*;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
@ -33,33 +34,58 @@ import org.deeplearning4j.rl4j.util.IDataManager;
*/ */
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> { public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
IDataManager dataManager) { IDataManager dataManager) {
super(mdp, IActorCritic, conf, dataManager); this(mdp, IActorCritic, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
super(mdp, actorCritic, conf);
} }
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf, IDataManager dataManager) { A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
IDataManager dataManager) { IDataManager dataManager) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager); this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf, IDataManager dataManager) { A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf, ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
IDataManager dataManager) { IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager); this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf);
}
} }

View File

@ -22,13 +22,13 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.MiniTrans; import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -46,24 +46,19 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
@Getter @Getter
final protected A3CDiscrete.A3CConfiguration conf; final protected A3CDiscrete.A3CConfiguration conf;
@Getter @Getter
final protected MDP<O, Integer, DiscreteSpace> mdp;
@Getter
final protected AsyncGlobal<IActorCritic> asyncGlobal; final protected AsyncGlobal<IActorCritic> asyncGlobal;
@Getter @Getter
final protected int threadNumber; final protected int threadNumber;
@Getter
final protected IDataManager dataManager;
final private Random random; final private Random random;
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal, public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager, int deviceNum) { A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
super(asyncGlobal, threadNumber, deviceNum); int threadNumber) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = a3cc; this.conf = a3cc;
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber; this.threadNumber = threadNumber;
this.mdp = mdp;
this.dataManager = dataManager;
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber); mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
random = new Random(conf.getSeed() + threadNumber); random = new Random(conf.getSeed() + threadNumber);
} }
@ -85,15 +80,15 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
//if recurrent then train as a time serie with a batch size of 1 //if recurrent then train as a time serie with a batch size of 1
boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent(); boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent();
int[] shape = getHistoryProcessor() == null ? mdp.getObservationSpace().getShape() int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
: getHistoryProcessor().getConf().getShape(); : getHistoryProcessor().getConf().getShape();
int[] nshape = recurrent ? Learning.makeShape(1, shape, size) int[] nshape = recurrent ? Learning.makeShape(1, shape, size)
: Learning.makeShape(size, shape); : Learning.makeShape(size, shape);
INDArray input = Nd4j.create(nshape); INDArray input = Nd4j.create(nshape);
INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1); INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1);
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, mdp.getActionSpace().getSize(), size) INDArray logSoftmax = recurrent ? Nd4j.zeros(1, getMdp().getActionSpace().getSize(), size)
: Nd4j.zeros(size, mdp.getActionSpace().getSize()); : Nd4j.zeros(size, getMdp().getActionSpace().getSize());
double r = minTrans.getReward(); double r = minTrans.getReward();
for (int i = size - 1; i >= 0; i--) { for (int i = size - 1; i >= 0; i--) {

View File

@ -24,10 +24,9 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -40,16 +39,12 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
@Getter @Getter
final private MDP<O, Integer, DiscreteSpace> mdp; final private MDP<O, Integer, DiscreteSpace> mdp;
@Getter @Getter
final private IDataManager dataManager;
@Getter
final private AsyncGlobal<IDQN> asyncGlobal; final private AsyncGlobal<IDQN> asyncGlobal;
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf, public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
IDataManager dataManager) {
super(conf); super(conf);
this.mdp = mdp; this.mdp = mdp;
this.dataManager = dataManager;
this.configuration = conf; this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf); this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
mdp.getActionSpace().setSeed(conf.getSeed()); mdp.getActionSpace().setSeed(conf.getSeed());
@ -57,14 +52,14 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
@Override @Override
public AsyncThread newThread(int i, int deviceNum) { public AsyncThread newThread(int i, int deviceNum) {
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager, deviceNum); return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, getListeners(), i, deviceNum);
} }
public IDQN getNeuralNet() { public IDQN getNeuralNet() {
return asyncGlobal.getCurrent(); return asyncGlobal.getCurrent();
} }
public Policy<O, Integer> getPolicy() { public IPolicy<O, Integer> getPolicy() {
return new DQNPolicy<O>(getNeuralNet()); return new DQNPolicy<O>(getNeuralNet());
} }

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
@ -35,22 +36,38 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
final private HistoryProcessor.Configuration hpconf; final private HistoryProcessor.Configuration hpconf;
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf, dataManager); this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
super(mdp, dqn, conf);
this.hpconf = hpconf; this.hpconf = hpconf;
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
} }
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
} }
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
} }
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
}
@Override @Override
public AsyncThread newThread(int i, int deviceNum) { public AsyncThread newThread(int i, int deviceNum) {

View File

@ -22,6 +22,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
@ -29,19 +30,37 @@ import org.deeplearning4j.rl4j.util.IDataManager;
*/ */
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> { public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf, IDataManager dataManager) { AsyncNStepQLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf, dataManager); super(mdp, dqn, conf);
addListener(new DataManagerTrainingListener(dataManager));
} }
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf) {
super(mdp, dqn, conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf, IDataManager dataManager) { AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
} }
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf);
}
} }

View File

@ -19,9 +19,10 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans; import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.DQNPolicy;
@ -29,7 +30,6 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -44,31 +44,25 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
@Getter @Getter
final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf; final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf;
@Getter @Getter
final protected MDP<O, Integer, DiscreteSpace> mdp;
@Getter
final protected IAsyncGlobal<IDQN> asyncGlobal; final protected IAsyncGlobal<IDQN> asyncGlobal;
@Getter @Getter
final protected int threadNumber; final protected int threadNumber;
@Getter
final protected IDataManager dataManager;
final private Random random; final private Random random;
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal, public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber, AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf,
IDataManager dataManager, int deviceNum) { TrainingListenerList listeners, int threadNumber, int deviceNum) {
super(asyncGlobal, threadNumber, deviceNum); super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = conf; this.conf = conf;
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber; this.threadNumber = threadNumber;
this.mdp = mdp;
this.dataManager = dataManager;
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber); mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
random = new Random(conf.getSeed() + threadNumber); random = new Random(conf.getSeed() + threadNumber);
} }
public Policy<O, Integer> getPolicy(IDQN nn) { public Policy<O, Integer> getPolicy(IDQN nn) {
return new EpsGreedy(new DQNPolicy(nn), mdp, conf.getUpdateStart(), conf.getEpsilonNbStep(), return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
random, conf.getMinEpsilon(), this); random, conf.getMinEpsilon(), this);
} }
@ -81,11 +75,11 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
int size = rewards.size(); int size = rewards.size();
int[] shape = getHistoryProcessor() == null ? mdp.getObservationSpace().getShape() int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
: getHistoryProcessor().getConf().getShape(); : getHistoryProcessor().getConf().getShape();
int[] nshape = Learning.makeShape(size, shape); int[] nshape = Learning.makeShape(size, shape);
INDArray input = Nd4j.create(nshape); INDArray input = Nd4j.create(nshape);
INDArray targets = Nd4j.create(size, mdp.getActionSpace().getSize()); INDArray targets = Nd4j.create(size, getMdp().getActionSpace().getSize());
double r = minTrans.getReward(); double r = minTrans.getReward();
for (int i = size - 1; i >= 0; i--) { for (int i = size - 1; i >= 0; i--) {

View File

@ -0,0 +1,72 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.listener;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
* The base definition of all training event listeners
*
* @author Alexandre Boulanger
*/
public interface TrainingListener {
enum ListenerResponse {
/**
* Tell the learning process to continue calling the listeners and the training.
*/
CONTINUE,
/**
* Tell the learning process to stop calling the listeners and terminate the training.
*/
STOP,
}
/**
* Called once when the training starts.
* @return A ListenerResponse telling the source of the event if it should go on or cancel the training.
*/
ListenerResponse onTrainingStart();
/**
* Called once when the training has finished. This method is called even when the training has been aborted.
*/
void onTrainingEnd();
/**
* Called before the start of every epoch.
* @param trainer A {@link IEpochTrainer}
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onNewEpoch(IEpochTrainer trainer);
/**
* Called when an epoch has been completed
* @param trainer A {@link IEpochTrainer}
* @param statEntry A {@link org.deeplearning4j.rl4j.util.IDataManager.StatEntry}
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry);
/**
* Called regularly to monitor the training progress.
* @param learning A {@link ILearning}
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onTrainingProgress(ILearning learning);
}

View File

@ -0,0 +1,105 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.listener;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.util.IDataManager;
import java.util.ArrayList;
import java.util.List;
/**
* The base logic to notify training listeners with the different training events.
*
* @author Alexandre Boulanger
*/
public class TrainingListenerList {
protected final List<TrainingListener> listeners = new ArrayList<>();
/**
* Add a listener at the end of the list
* @param listener The listener to be added
*/
public void add(TrainingListener listener) {
listeners.add(listener);
}
/**
* Notify the listeners that the training has started. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
* @return whether or not the source training should be stopped
*/
public boolean notifyTrainingStarted() {
for (TrainingListener listener : listeners) {
if (listener.onTrainingStart() == TrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
/**
* Notify the listeners that the training has finished.
*/
public void notifyTrainingFinished() {
for (TrainingListener listener : listeners) {
listener.onTrainingEnd();
}
}
/**
* Notify the listeners that a new epoch has started. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
* @return whether or not the source training should be stopped
*/
public boolean notifyNewEpoch(IEpochTrainer trainer) {
for (TrainingListener listener : listeners) {
if (listener.onNewEpoch(trainer) == TrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
/**
* Notify the listeners that an epoch has been completed and the training results are available. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
* @return whether or not the source training should be stopped
*/
public boolean notifyEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
for (TrainingListener listener : listeners) {
if (listener.onEpochTrainingResult(trainer, statEntry) == TrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
/**
* Notify the listeners that they update the progress ot the trainning.
*/
public boolean notifyTrainingProgress(ILearning learning) {
for (TrainingListener listener : listeners) {
if (listener.onTrainingProgress(learning) == TrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
}

View File

@ -16,19 +16,18 @@
package org.deeplearning4j.rl4j.learning.sync; package org.deeplearning4j.rl4j.learning.sync;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent; import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import java.util.ArrayList;
import java.util.List;
/** /**
* Mother class and useful factorisations for all training methods that * Mother class and useful factorisations for all training methods that
* are not asynchronous. * are not asynchronous.
@ -38,9 +37,9 @@ import java.util.List;
*/ */
@Slf4j @Slf4j
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Learning<O, A, AS, NN> { extends Learning<O, A, AS, NN> implements IEpochTrainer {
private List<SyncTrainingListener> listeners = new ArrayList<>(); private final TrainingListenerList listeners = new TrainingListenerList();
public SyncLearning(LConfiguration conf) { public SyncLearning(LConfiguration conf) {
super(conf); super(conf);
@ -49,12 +48,24 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
/** /**
* Add a listener at the end of the listener list. * Add a listener at the end of the listener list.
* *
* @param listener * @param listener The listener to add
*/ */
public void addListener(SyncTrainingListener listener) { public void addListener(TrainingListener listener) {
listeners.add(listener); listeners.add(listener);
} }
/**
* Number of epochs between calls to onTrainingProgress. Default is 5
*/
@Getter
private int progressMonitorFrequency = 5;
public void setProgressMonitorFrequency(int value) {
if(value == 0) throw new IllegalArgumentException("The progressMonitorFrequency cannot be 0");
progressMonitorFrequency = value;
}
/** /**
* This method will train the model<p> * This method will train the model<p>
* The training stop when:<br> * The training stop when:<br>
@ -64,81 +75,49 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
* <p> * <p>
* Listeners<br> * Listeners<br>
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener * For a given event, the listeners are called sequentially in same the order as they were added. If one listener
* returns {@link SyncTrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br> * returns {@link TrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
* Events: * Events:
* <ul> * <ul>
* <li>{@link SyncTrainingListener#onTrainingStart(SyncTrainingEvent) onTrainingStart()} is called once when the training starts.</li> * <li>{@link TrainingListener#onTrainingStart() onTrainingStart()} is called once when the training starts.</li>
* <li>{@link SyncTrainingListener#onEpochStart(SyncTrainingEvent) onEpochStart()} and {@link SyncTrainingListener#onEpochEnd(SyncTrainingEpochEndEvent) onEpochEnd()} are called for every epoch. onEpochEnd will not be called if onEpochStart stops the training</li> * <li>{@link TrainingListener#onNewEpoch(IEpochTrainer) onNewEpoch()} and {@link TrainingListener#onEpochTrainingResult(IEpochTrainer, IDataManager.StatEntry) onEpochTrainingResult()} are called for every epoch. onEpochTrainingResult will not be called if onNewEpoch stops the training</li>
* <li>{@link SyncTrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li> * <li>{@link TrainingListener#onTrainingProgress(ILearning) onTrainingProgress()} is called after onEpochTrainingResult()</li>
* <li>{@link TrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
* </ul> * </ul>
*/ */
public void train() { public void train() {
log.info("training starting."); log.info("training starting.");
boolean canContinue = notifyTrainingStarted(); boolean canContinue = listeners.notifyTrainingStarted();
if (canContinue) { if (canContinue) {
while (getStepCounter() < getConfiguration().getMaxStep()) { while (getStepCounter() < getConfiguration().getMaxStep()) {
preEpoch(); preEpoch();
canContinue = notifyEpochStarted(); canContinue = listeners.notifyNewEpoch(this);
if (!canContinue) { if (!canContinue) {
break; break;
} }
IDataManager.StatEntry statEntry = trainEpoch(); IDataManager.StatEntry statEntry = trainEpoch();
canContinue = listeners.notifyEpochTrainingResult(this, statEntry);
postEpoch();
canContinue = notifyEpochFinished(statEntry);
if (!canContinue) { if (!canContinue) {
break; break;
} }
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward()); postEpoch();
if(getEpochCounter() % progressMonitorFrequency == 0) {
canContinue = listeners.notifyTrainingProgress(this);
if (!canContinue) {
break;
}
}
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
incrementEpoch(); incrementEpoch();
} }
} }
notifyTrainingFinished(); listeners.notifyTrainingFinished();
}
private boolean notifyTrainingStarted() {
SyncTrainingEvent event = new SyncTrainingEvent(this);
for (SyncTrainingListener listener : listeners) {
if (listener.onTrainingStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
private void notifyTrainingFinished() {
for (SyncTrainingListener listener : listeners) {
listener.onTrainingEnd();
}
}
private boolean notifyEpochStarted() {
SyncTrainingEvent event = new SyncTrainingEvent(this);
for (SyncTrainingListener listener : listeners) {
if (listener.onEpochStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
private boolean notifyEpochFinished(IDataManager.StatEntry statEntry) {
SyncTrainingEpochEndEvent event = new SyncTrainingEpochEndEvent(this, statEntry);
for (SyncTrainingListener listener : listeners) {
if (listener.onEpochEnd(event) == SyncTrainingListener.ListenerResponse.STOP) {
return false;
}
}
return true;
} }
protected abstract void preEpoch(); protected abstract void preEpoch();

View File

@ -1,22 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.listener;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
* A subclass of SyncTrainingEvent that is passed to SyncTrainingListener.onEpochEnd()
*/
public class SyncTrainingEpochEndEvent extends SyncTrainingEvent {
/**
* The stats of the epoch training
*/
@Getter
private final IDataManager.StatEntry statEntry;
public SyncTrainingEpochEndEvent(Learning learning, IDataManager.StatEntry statEntry) {
super(learning);
this.statEntry = statEntry;
}
}

View File

@ -1,21 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.listener;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.rl4j.learning.Learning;
/**
* SyncTrainingEvent are passed as parameters to the events of SyncTrainingListener
*/
public class SyncTrainingEvent {
/**
* The source of the event
*/
@Getter
private final Learning learning;
public SyncTrainingEvent(Learning learning) {
this.learning = learning;
}
}

View File

@ -1,45 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.listener;
/**
* A listener interface to use with a descendant of {@link org.deeplearning4j.rl4j.learning.sync.SyncLearning}
*/
public interface SyncTrainingListener {
public enum ListenerResponse {
/**
* Tell SyncLearning to continue calling the listeners and the training.
*/
CONTINUE,
/**
* Tell SyncLearning to stop calling the listeners and terminate the training.
*/
STOP,
}
/**
* Called once when the training starts.
* @param event
* @return A ListenerResponse telling the source of the event if it should go on or cancel the training.
*/
ListenerResponse onTrainingStart(SyncTrainingEvent event);
/**
* Called once when the training has finished. This method is called even when the training has been aborted.
*/
void onTrainingEnd();
/**
* Called before the start of every epoch.
* @param event
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onEpochStart(SyncTrainingEvent event);
/**
* Called after the end of every epoch.
* @param event
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
*/
ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event);
}

View File

@ -49,7 +49,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
// @Getter // @Getter
// final private IExpReplay<A> expReplay; // final private IExpReplay<A> expReplay;
@Getter @Getter
@Setter(AccessLevel.PACKAGE) @Setter(AccessLevel.PROTECTED)
protected IExpReplay<A> expReplay; protected IExpReplay<A> expReplay;
public QLearning(QLConfiguration conf) { public QLearning(QLConfiguration conf) {

View File

@ -28,8 +28,6 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.INDArrayIndex;
@ -64,20 +62,9 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Setter @Setter
private IDQN targetDQN; private IDQN targetDQN;
private int lastAction; private int lastAction;
private INDArray history[] = null; private INDArray[] history = null;
private double accuReward = 0; private double accuReward = 0;
/**
* @deprecated
* Use QLearningDiscrete(MDP, IDQN, QLConfiguration, int) and add the required listeners with addListener() instead.
*/
@Deprecated
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
IDataManager dataManager, int epsilonNbStep) {
this(mdp, dqn, conf, epsilonNbStep);
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
}
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf, public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
int epsilonNbStep) { int epsilonNbStep) {
super(conf); super(conf);
@ -186,7 +173,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
} }
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) { protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
if (transitions.size() == 0) if (transitions.size() == 0)
throw new IllegalArgumentException("too few transitions"); throw new IllegalArgumentException("too few transitions");

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
@ -33,20 +34,36 @@ import org.deeplearning4j.rl4j.util.IDataManager;
public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscrete<O> { public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscrete<O> {
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf, IDataManager dataManager) { QLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
} }
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
} }
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
} }
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
}
} }

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
@ -31,21 +32,35 @@ import org.deeplearning4j.rl4j.util.IDataManager;
public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscrete<O> { public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscrete<O> {
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf, public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
IDataManager dataManager) { IDataManager dataManager) {
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep()); this(mdp, dqn, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep());
} }
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearning.QLConfiguration conf, IDataManager dataManager) { QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearning.QLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf, public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf, IDataManager dataManager) { QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
} }
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf);
}
} }

View File

@ -0,0 +1,10 @@
package org.deeplearning4j.rl4j.policy;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
public interface IPolicy<O extends Encodable, A> {
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
}

View File

@ -35,7 +35,7 @@ import org.nd4j.linalg.util.ArrayUtil;
* *
* A Policy responsability is to choose the next action given a state * A Policy responsability is to choose the next action given a state
*/ */
public abstract class Policy<O extends Encodable, A> { public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
public abstract NeuralNet getNeuralNet(); public abstract NeuralNet getNeuralNet();
@ -49,6 +49,7 @@ public abstract class Policy<O extends Encodable, A> {
return play(mdp, new HistoryProcessor(conf)); return play(mdp, new HistoryProcessor(conf));
} }
@Override
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) { public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
getNeuralNet().reset(); getNeuralNet().reset();
Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp); Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);

View File

@ -22,6 +22,7 @@ import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.Value; import lombok.Value;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
@ -72,13 +73,13 @@ public class DataManager implements IDataManager {
} }
} }
public static void save(String path, Learning learning) throws IOException { public static void save(String path, ILearning learning) throws IOException {
try (BufferedOutputStream os = new BufferedOutputStream(new FileOutputStream(path))) { try (BufferedOutputStream os = new BufferedOutputStream(new FileOutputStream(path))) {
save(os, learning); save(os, learning);
} }
} }
public static void save(OutputStream os, Learning learning) throws IOException { public static void save(OutputStream os, ILearning learning) throws IOException {
try (ZipOutputStream zipfile = new ZipOutputStream(os)) { try (ZipOutputStream zipfile = new ZipOutputStream(os)) {
@ -91,7 +92,9 @@ public class DataManager implements IDataManager {
zipfile.putNextEntry(dqn); zipfile.putNextEntry(dqn);
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
learning.getNeuralNet().save(bos); if(learning instanceof NeuralNetFetchable) {
((NeuralNetFetchable)learning).getNeuralNet().save(bos);
}
bos.flush(); bos.flush();
bos.close(); bos.close();
@ -104,7 +107,9 @@ public class DataManager implements IDataManager {
zipfile.putNextEntry(hpconf); zipfile.putNextEntry(hpconf);
ByteArrayOutputStream bos2 = new ByteArrayOutputStream(); ByteArrayOutputStream bos2 = new ByteArrayOutputStream();
learning.getNeuralNet().save(bos2); if(learning instanceof NeuralNetFetchable) {
((NeuralNetFetchable)learning).getNeuralNet().save(bos2);
}
bos2.flush(); bos2.flush();
bos2.close(); bos2.close();
@ -256,13 +261,15 @@ public class DataManager implements IDataManager {
return exists; return exists;
} }
public void save(Learning learning) throws IOException { public void save(ILearning learning) throws IOException {
if (!saveData) if (!saveData)
return; return;
save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning); save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning);
learning.getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model"); if(learning instanceof NeuralNetFetchable) {
((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
}
} }

View File

@ -1,126 +0,0 @@
package org.deeplearning4j.rl4j.util;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
/**
* DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the
* training process can be fed to the DataManager
*/
@Slf4j
public class DataManagerSyncTrainingListener implements SyncTrainingListener {
private final IDataManager dataManager;
private final int saveFrequency;
private final int monitorFrequency;
private int lastSave;
private int lastMonitor;
private DataManagerSyncTrainingListener(Builder builder) {
this.dataManager = builder.dataManager;
this.saveFrequency = builder.saveFrequency;
this.lastSave = -builder.saveFrequency;
this.monitorFrequency = builder.monitorFrequency;
this.lastMonitor = -builder.monitorFrequency;
}
@Override
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
try {
dataManager.writeInfo(event.getLearning());
} catch (Exception e) {
log.error("Training failed.", e);
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
@Override
public void onTrainingEnd() {
// Do nothing
}
@Override
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
int stepCounter = event.getLearning().getStepCounter();
if (stepCounter - lastMonitor >= monitorFrequency
&& event.getLearning().getHistoryProcessor() != null
&& dataManager.isSaveData()) {
lastMonitor = stepCounter;
int[] shape = event.getLearning().getMdp().getObservationSpace().getShape();
event.getLearning().getHistoryProcessor().startMonitor(dataManager.getVideoDir() + "/video-" + event.getLearning().getEpochCounter() + "-"
+ stepCounter + ".mp4", shape);
}
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
try {
int stepCounter = event.getLearning().getStepCounter();
if (stepCounter - lastSave >= saveFrequency) {
dataManager.save(event.getLearning());
lastSave = stepCounter;
}
dataManager.appendStat(event.getStatEntry());
dataManager.writeInfo(event.getLearning());
} catch (Exception e) {
log.error("Training failed.", e);
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
public static Builder builder(IDataManager dataManager) {
return new Builder(dataManager);
}
public static class Builder {
private final IDataManager dataManager;
private int saveFrequency = Constants.MODEL_SAVE_FREQ;
private int monitorFrequency = Constants.MONITOR_FREQ;
/**
* Create a Builder with the given DataManager
* @param dataManager
*/
public Builder(IDataManager dataManager) {
this.dataManager = dataManager;
}
/**
* A number that represent the number of steps since the last call to DataManager.save() before can it be called again.
* @param saveFrequency (Default: 100000)
*/
public Builder saveFrequency(int saveFrequency) {
this.saveFrequency = saveFrequency;
return this;
}
/**
* A number that represent the number of steps since the last call to HistoryProcessor.startMonitor() before can it be called again.
* @param monitorFrequency (Default: 10000)
*/
public Builder monitorFrequency(int monitorFrequency) {
this.monitorFrequency = monitorFrequency;
return this;
}
/**
* Creates a DataManagerSyncTrainingListener with the configured parameters
* @return An instance of DataManagerSyncTrainingListener
*/
public DataManagerSyncTrainingListener build() {
return new DataManagerSyncTrainingListener(this);
}
}
}

View File

@ -0,0 +1,83 @@
package org.deeplearning4j.rl4j.util;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
/**
* DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the
* training process can be fed to the DataManager
*/
@Slf4j
public class DataManagerTrainingListener implements TrainingListener {
private final IDataManager dataManager;
private int lastSave = -Constants.MODEL_SAVE_FREQ;
public DataManagerTrainingListener(IDataManager dataManager) {
this.dataManager = dataManager;
}
@Override
public ListenerResponse onTrainingStart() {
return ListenerResponse.CONTINUE;
}
@Override
public void onTrainingEnd() {
}
@Override
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
IHistoryProcessor hp = trainer.getHistoryProcessor();
if(hp != null) {
int[] shape = trainer.getMdp().getObservationSpace().getShape();
String filename = dataManager.getVideoDir() + "/video-";
if (trainer instanceof AsyncThread) {
filename += ((AsyncThread) trainer).getThreadNumber() + "-";
}
filename += trainer.getEpochCounter() + "-" + trainer.getStepCounter() + ".mp4";
hp.startMonitor(filename, shape);
}
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
IHistoryProcessor hp = trainer.getHistoryProcessor();
if(hp != null) {
hp.stopMonitor();
}
try {
dataManager.appendStat(statEntry);
} catch (Exception e) {
log.error("Training failed.", e);
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onTrainingProgress(ILearning learning) {
try {
int stepCounter = learning.getStepCounter();
if (stepCounter - lastSave >= Constants.MODEL_SAVE_FREQ) {
dataManager.save(learning);
lastSave = stepCounter;
}
dataManager.writeInfo(learning);
} catch (Exception e) {
log.error("Training failed.", e);
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
}

View File

@ -27,7 +27,7 @@ public interface IDataManager {
String getVideoDir(); String getVideoDir();
void appendStat(StatEntry statEntry) throws IOException; void appendStat(StatEntry statEntry) throws IOException;
void writeInfo(ILearning iLearning) throws IOException; void writeInfo(ILearning iLearning) throws IOException;
void save(Learning learning) throws IOException; void save(ILearning learning) throws IOException;
//In order for jackson to serialize StatEntry //In order for jackson to serialize StatEntry
//please use Lombok @Value (see QLStatEntry) //please use Lombok @Value (see QLStatEntry)

View File

@ -0,0 +1,127 @@
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class AsyncLearningTest {
@Test
public void when_training_expect_AsyncGlobalStarted() {
// Arrange
TestContext context = new TestContext();
context.asyncGlobal.setMaxLoops(1);
// Act
context.sut.train();
// Assert
assertTrue(context.asyncGlobal.hasBeenStarted);
assertTrue(context.asyncGlobal.hasBeenTerminated);
}
@Test
public void when_trainStartReturnsStop_expect_noTraining() {
// Arrange
TestContext context = new TestContext();
context.listener.setRemainingTrainingStartCallCount(0);
// Act
context.sut.train();
// Assert
assertEquals(1, context.listener.onTrainingStartCallCount);
assertEquals(1, context.listener.onTrainingEndCallCount);
assertEquals(0, context.policy.playCallCount);
assertTrue(context.asyncGlobal.hasBeenTerminated);
}
@Test
public void when_trainingIsComplete_expect_trainingStop() {
// Arrange
TestContext context = new TestContext();
// Act
context.sut.train();
// Assert
assertEquals(1, context.listener.onTrainingStartCallCount);
assertEquals(1, context.listener.onTrainingEndCallCount);
assertTrue(context.asyncGlobal.hasBeenTerminated);
}
@Test
public void when_training_expect_onTrainingProgressCalled() {
// Arrange
TestContext context = new TestContext();
// Act
context.sut.train();
// Assert
assertEquals(1, context.listener.onTrainingProgressCallCount);
}
public static class TestContext {
public final MockAsyncConfiguration conf = new MockAsyncConfiguration(1, 1);
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockPolicy policy = new MockPolicy();
public final TestAsyncLearning sut = new TestAsyncLearning(conf, asyncGlobal, policy);
public final MockTrainingListener listener = new MockTrainingListener();
public TestContext() {
sut.addListener(listener);
asyncGlobal.setMaxLoops(1);
sut.setProgressMonitorFrequency(1);
}
}
public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
private final AsyncConfiguration conf;
private final IAsyncGlobal asyncGlobal;
private final IPolicy<MockEncodable, Integer> policy;
public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
super(conf);
this.conf = conf;
this.asyncGlobal = asyncGlobal;
this.policy = policy;
}
@Override
public IPolicy getPolicy() {
return policy;
}
@Override
public AsyncConfiguration getConfiguration() {
return conf;
}
@Override
protected AsyncThread newThread(int i, int deviceAffinity) {
return null;
}
@Override
public MDP getMdp() {
return null;
}
@Override
protected IAsyncGlobal getAsyncGlobal() {
return asyncGlobal;
}
@Override
public MockNeuralNet getNeuralNet() {
return null;
}
}
}

View File

@ -1,206 +1,135 @@
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.support.MockDataManager; import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class AsyncThreadTest { public class AsyncThreadTest {
@Test @Test
public void refac_withoutHistoryProcessor_checkDataManagerCallsRemainTheSame() { public void when_newEpochStarted_expect_neuralNetworkReset() {
// Arrange // Arrange
MockDataManager dataManager = new MockDataManager(false); TestContext context = new TestContext();
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10); context.listener.setRemainingOnNewEpochCallCount(5);
MockNeuralNet neuralNet = new MockNeuralNet();
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockAsyncConfiguration config = new MockAsyncConfiguration(10, 2);
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, dataManager);
// Act // Act
sut.run(); context.sut.run();
// Assert // Assert
assertEquals(4, dataManager.statEntries.size()); assertEquals(6, context.neuralNet.resetCallCount);
IDataManager.StatEntry entry = dataManager.statEntries.get(0);
assertEquals(2, entry.getStepCounter());
assertEquals(0, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
entry = dataManager.statEntries.get(1);
assertEquals(4, entry.getStepCounter());
assertEquals(1, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
entry = dataManager.statEntries.get(2);
assertEquals(6, entry.getStepCounter());
assertEquals(2, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
entry = dataManager.statEntries.get(3);
assertEquals(8, entry.getStepCounter());
assertEquals(3, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
assertEquals(0, dataManager.isSaveDataCallCount);
assertEquals(0, dataManager.getVideoDirCallCount);
} }
@Test @Test
public void refac_withHistoryProcessor_isSaveFalse_checkDataManagerCallsRemainTheSame() { public void when_onNewEpochReturnsStop_expect_threadStopped() {
// Arrange // Arrange
MockDataManager dataManager = new MockDataManager(false); TestContext context = new TestContext();
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10); context.listener.setRemainingOnNewEpochCallCount(1);
MockNeuralNet neuralNet = new MockNeuralNet();
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
.build();
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
sut.setHistoryProcessor(hp);
// Act // Act
sut.run(); context.sut.run();
// Assert // Assert
assertEquals(9, dataManager.statEntries.size()); assertEquals(2, context.listener.onNewEpochCallCount);
assertEquals(1, context.listener.onEpochTrainingResultCallCount);
for(int i = 0; i < 9; ++i) {
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(i + 1, entry.getStepCounter());
assertEquals(i, entry.getEpochCounter());
assertEquals(79.0, entry.getReward(), 0.0);
}
assertEquals(10, dataManager.isSaveDataCallCount);
assertEquals(0, dataManager.getVideoDirCallCount);
} }
@Test @Test
public void refac_withHistoryProcessor_isSaveTrue_checkDataManagerCallsRemainTheSame() { public void when_epochTrainingResultReturnsStop_expect_threadStopped() {
// Arrange // Arrange
MockDataManager dataManager = new MockDataManager(true); TestContext context = new TestContext();
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10); context.listener.setRemainingOnEpochTrainingResult(1);
MockNeuralNet neuralNet = new MockNeuralNet();
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
.build();
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
sut.setHistoryProcessor(hp);
// Act // Act
sut.run(); context.sut.run();
// Assert // Assert
assertEquals(9, dataManager.statEntries.size()); assertEquals(2, context.listener.onNewEpochCallCount);
assertEquals(2, context.listener.onEpochTrainingResultCallCount);
for(int i = 0; i < 9; ++i) {
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(i + 1, entry.getStepCounter());
assertEquals(i, entry.getEpochCounter());
assertEquals(79.0, entry.getReward(), 0.0);
} }
assertEquals(1, dataManager.isSaveDataCallCount); @Test
assertEquals(1, dataManager.getVideoDirCallCount); public void when_run_expect_preAndPostEpochCalled() {
// Arrange
TestContext context = new TestContext();
// Act
context.sut.run();
// Assert
assertEquals(6, context.sut.preEpochCallCount);
assertEquals(6, context.sut.postEpochCallCount);
} }
public static class MockAsyncGlobal implements IAsyncGlobal { @Test
public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() {
// Arrange
TestContext context = new TestContext();
private final int maxLoops; // Act
private int currentLoop = 0; context.sut.run();
public MockAsyncGlobal(int maxLoops) { // Assert
assertEquals(5, context.listener.statEntries.size());
this.maxLoops = maxLoops; int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 };
for(int i = 0; i < 5; ++i) {
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
assertEquals(i, statEntry.getEpochCounter());
assertEquals(2.0, statEntry.getReward(), 0.0001);
}
} }
@Override private static class TestContext {
public boolean isRunning() { public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
return true; public final MockNeuralNet neuralNet = new MockNeuralNet();
} public final MockObservationSpace observationSpace = new MockObservationSpace();
public final MockMDP mdp = new MockMDP(observationSpace);
@Override public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 2);
public void setRunning(boolean value) { public final TrainingListenerList listeners = new TrainingListenerList();
public final MockTrainingListener listener = new MockTrainingListener();
} public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
@Override
public boolean isTrainingComplete() {
return ++currentLoop >= maxLoops;
}
@Override
public void start() {
}
@Override
public AtomicInteger getT() {
return null;
}
@Override
public NeuralNet getCurrent() {
return null;
}
@Override
public NeuralNet getTarget() {
return null;
}
@Override
public void enqueue(Gradient[] gradient, Integer nstep) {
public TestContext() {
asyncGlobal.setMaxLoops(10);
listeners.add(listener);
} }
} }
public static class MockAsyncThread extends AsyncThread { public static class MockAsyncThread extends AsyncThread {
IAsyncGlobal asyncGlobal; public int preEpochCallCount = 0;
private final MockNeuralNet neuralNet; public int postEpochCallCount = 0;
private final MDP mdp;
private final AsyncConfiguration conf;
private final IDataManager dataManager;
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) {
super(asyncGlobal, threadNumber, 0); private final IAsyncGlobal asyncGlobal;
private final MockNeuralNet neuralNet;
private final AsyncConfiguration conf;
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
super(asyncGlobal, mdp, listeners, threadNumber, 0);
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;
this.neuralNet = neuralNet; this.neuralNet = neuralNet;
this.mdp = mdp;
this.conf = conf; this.conf = conf;
this.dataManager = dataManager; }
@Override
protected void preEpoch() {
++preEpochCallCount;
super.preEpoch();
}
@Override
protected void postEpoch() {
++postEpochCallCount;
super.postEpoch();
} }
@Override @Override
@ -208,31 +137,16 @@ public class AsyncThreadTest {
return neuralNet; return neuralNet;
} }
@Override
protected int getThreadNumber() {
return 0;
}
@Override @Override
protected IAsyncGlobal getAsyncGlobal() { protected IAsyncGlobal getAsyncGlobal() {
return asyncGlobal; return asyncGlobal;
} }
@Override
protected MDP getMdp() {
return mdp;
}
@Override @Override
protected AsyncConfiguration getConf() { protected AsyncConfiguration getConf() {
return conf; return conf;
} }
@Override
protected IDataManager getDataManager() {
return dataManager;
}
@Override @Override
protected Policy getPolicy(NeuralNet net) { protected Policy getPolicy(NeuralNet net) {
return null; return null;
@ -244,129 +158,6 @@ public class AsyncThreadTest {
} }
} }
public static class MockNeuralNet implements NeuralNet {
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}
@Override
public boolean isRecurrent() {
return false;
}
@Override
public void reset() {
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];
}
@Override
public NeuralNet clone() {
return null;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] labels) {
return new Gradient[0];
}
@Override
public void fit(INDArray input, INDArray[] labels) {
}
@Override
public void applyGradient(Gradient[] gradients, int batchSize) {
}
@Override
public double getLatestScore() {
return 0;
}
@Override
public void save(OutputStream os) throws IOException {
}
@Override
public void save(String filename) throws IOException {
}
}
public static class MockAsyncConfiguration implements AsyncConfiguration {
private final int nStep;
private final int maxEpochStep;
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
this.nStep = nStep;
this.maxEpochStep = maxEpochStep;
}
@Override
public int getSeed() {
return 0;
}
@Override
public int getMaxEpochStep() {
return maxEpochStep;
}
@Override
public int getMaxStep() {
return 0;
}
@Override
public int getNumThread() {
return 0;
}
@Override
public int getNstep() {
return nStep;
}
@Override
public int getTargetDqnUpdateFreq() {
return 0;
}
@Override
public int getUpdateStart() {
return 0;
}
@Override
public double getRewardFactor() {
return 0;
}
@Override
public double getGamma() {
return 0;
}
@Override
public double getErrorClamp() {
return 0;
}
}
} }

View File

@ -0,0 +1,98 @@
package org.deeplearning4j.rl4j.learning.async.listener;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class AsyncTrainingListenerListTest {
@Test
public void when_listIsEmpty_expect_notifyTrainingStartedReturnTrue() {
// Arrange
TrainingListenerList sut = new TrainingListenerList();
// Act
boolean resultTrainingStarted = sut.notifyTrainingStarted();
boolean resultNewEpoch = sut.notifyNewEpoch(null);
boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null);
// Assert
assertTrue(resultTrainingStarted);
assertTrue(resultNewEpoch);
assertTrue(resultEpochTrainingResult);
}
@Test
public void when_firstListerStops_expect_othersListnersNotCalled() {
// Arrange
MockTrainingListener listener1 = new MockTrainingListener();
listener1.onTrainingResultResponse = TrainingListener.ListenerResponse.STOP;
MockTrainingListener listener2 = new MockTrainingListener();
TrainingListenerList sut = new TrainingListenerList();
sut.add(listener1);
sut.add(listener2);
// Act
sut.notifyEpochTrainingResult(null, null);
// Assert
assertEquals(1, listener1.onEpochTrainingResultCallCount);
assertEquals(0, listener2.onEpochTrainingResultCallCount);
}
@Test
public void when_allListenersContinue_expect_listReturnsTrue() {
// Arrange
MockTrainingListener listener1 = new MockTrainingListener();
MockTrainingListener listener2 = new MockTrainingListener();
TrainingListenerList sut = new TrainingListenerList();
sut.add(listener1);
sut.add(listener2);
// Act
boolean resultTrainingProgress = sut.notifyEpochTrainingResult(null, null);
// Assert
assertTrue(resultTrainingProgress);
}
private static class MockTrainingListener implements TrainingListener {
public int onEpochTrainingResultCallCount = 0;
public ListenerResponse onTrainingResultResponse = ListenerResponse.CONTINUE;
public int onTrainingProgressCallCount = 0;
public ListenerResponse onTrainingProgressResponse = ListenerResponse.CONTINUE;
@Override
public ListenerResponse onTrainingStart() {
return ListenerResponse.CONTINUE;
}
@Override
public void onTrainingEnd() {
}
@Override
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
++onEpochTrainingResultCallCount;
return onTrainingResultResponse;
}
@Override
public ListenerResponse onTrainingProgress(ILearning learning) {
++onTrainingProgressCallCount;
return onTrainingProgressResponse;
}
}
}

View File

@ -0,0 +1,83 @@
package org.deeplearning4j.rl4j.learning.listener;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.junit.Test;
import static org.junit.Assert.*;
public class TrainingListenerListTest {
@Test
public void when_listIsEmpty_expect_notifyReturnTrue() {
// Arrange
TrainingListenerList sut = new TrainingListenerList();
// Act
boolean resultTrainingStarted = sut.notifyTrainingStarted();
boolean resultNewEpoch = sut.notifyNewEpoch(null);
boolean resultEpochFinished = sut.notifyEpochTrainingResult(null, null);
// Assert
assertTrue(resultTrainingStarted);
assertTrue(resultNewEpoch);
assertTrue(resultEpochFinished);
}
@Test
public void when_firstListerStops_expect_othersListnersNotCalled() {
// Arrange
MockTrainingListener listener1 = new MockTrainingListener();
listener1.setRemainingTrainingStartCallCount(0);
listener1.setRemainingOnNewEpochCallCount(0);
listener1.setRemainingonTrainingProgressCallCount(0);
listener1.setRemainingOnEpochTrainingResult(0);
MockTrainingListener listener2 = new MockTrainingListener();
TrainingListenerList sut = new TrainingListenerList();
sut.add(listener1);
sut.add(listener2);
// Act
sut.notifyTrainingStarted();
sut.notifyNewEpoch(null);
sut.notifyEpochTrainingResult(null, null);
sut.notifyTrainingProgress(null);
sut.notifyTrainingFinished();
// Assert
assertEquals(1, listener1.onTrainingStartCallCount);
assertEquals(0, listener2.onTrainingStartCallCount);
assertEquals(1, listener1.onNewEpochCallCount);
assertEquals(0, listener2.onNewEpochCallCount);
assertEquals(1, listener1.onEpochTrainingResultCallCount);
assertEquals(0, listener2.onEpochTrainingResultCallCount);
assertEquals(1, listener1.onTrainingProgressCallCount);
assertEquals(0, listener2.onTrainingProgressCallCount);
assertEquals(1, listener1.onTrainingEndCallCount);
assertEquals(1, listener2.onTrainingEndCallCount);
}
@Test
public void when_allListenersContinue_expect_listReturnsTrue() {
// Arrange
MockTrainingListener listener1 = new MockTrainingListener();
MockTrainingListener listener2 = new MockTrainingListener();
TrainingListenerList sut = new TrainingListenerList();
sut.add(listener1);
sut.add(listener2);
// Act
boolean resultTrainingStarted = sut.notifyTrainingStarted();
boolean resultNewEpoch = sut.notifyNewEpoch(null);
boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null);
boolean resultProgress = sut.notifyTrainingProgress(null);
// Assert
assertTrue(resultTrainingStarted);
assertTrue(resultNewEpoch);
assertTrue(resultEpochTrainingResult);
assertTrue(resultProgress);
}
}

View File

@ -2,12 +2,10 @@ package org.deeplearning4j.rl4j.learning.sync;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.learning.sync.support.MockSyncTrainingListener;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.support.MockDataManager; import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test; import org.junit.Test;
@ -19,7 +17,7 @@ public class SyncLearningTest {
public void when_training_expect_listenersToBeCalled() { public void when_training_expect_listenersToBeCalled() {
// Arrange // Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener(); MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig); MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener); sut.addListener(listener);
@ -27,8 +25,8 @@ public class SyncLearningTest {
sut.train(); sut.train();
assertEquals(1, listener.onTrainingStartCallCount); assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(10, listener.onEpochStartCallCount); assertEquals(10, listener.onNewEpochCallCount);
assertEquals(10, listener.onEpochEndStartCallCount); assertEquals(10, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount); assertEquals(1, listener.onTrainingEndCallCount);
} }
@ -36,65 +34,59 @@ public class SyncLearningTest {
public void when_trainingStartCanContinueFalse_expect_trainingStopped() { public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
// Arrange // Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener(); MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig); MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener); sut.addListener(listener);
listener.trainingStartCanContinue = false; listener.setRemainingTrainingStartCallCount(0);
// Act // Act
sut.train(); sut.train();
assertEquals(1, listener.onTrainingStartCallCount); assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(0, listener.onEpochStartCallCount); assertEquals(0, listener.onNewEpochCallCount);
assertEquals(0, listener.onEpochEndStartCallCount); assertEquals(0, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount); assertEquals(1, listener.onTrainingEndCallCount);
} }
@Test @Test
public void when_epochStartCanContinueFalse_expect_trainingStopped() { public void when_newEpochCanContinueFalse_expect_trainingStopped() {
// Arrange // Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener(); MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig); MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener); sut.addListener(listener);
listener.nbStepsEpochStartCanContinue = 3; listener.setRemainingOnNewEpochCallCount(2);
// Act // Act
sut.train(); sut.train();
assertEquals(1, listener.onTrainingStartCallCount); assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(3, listener.onEpochStartCallCount); assertEquals(3, listener.onNewEpochCallCount);
assertEquals(2, listener.onEpochEndStartCallCount); assertEquals(2, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount); assertEquals(1, listener.onTrainingEndCallCount);
} }
@Test @Test
public void when_epochEndCanContinueFalse_expect_trainingStopped() { public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
// Arrange // Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
MockSyncTrainingListener listener = new MockSyncTrainingListener(); MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig); MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener); sut.addListener(listener);
listener.nbStepsEpochEndCanContinue = 3; listener.setRemainingOnEpochTrainingResult(2);
// Act // Act
sut.train(); sut.train();
assertEquals(1, listener.onTrainingStartCallCount); assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(3, listener.onEpochStartCallCount); assertEquals(3, listener.onNewEpochCallCount);
assertEquals(3, listener.onEpochEndStartCallCount); assertEquals(3, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount); assertEquals(1, listener.onTrainingEndCallCount);
} }
public static class MockSyncLearning extends SyncLearning { public static class MockSyncLearning extends SyncLearning {
private LConfiguration conf; private final LConfiguration conf;
public MockSyncLearning(LConfiguration conf, IDataManager dataManager) {
super(conf);
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
this.conf = conf;
}
public MockSyncLearning(LConfiguration conf) { public MockSyncLearning(LConfiguration conf) {
super(conf); super(conf);
@ -119,7 +111,7 @@ public class SyncLearningTest {
} }
@Override @Override
public Policy getPolicy() { public IPolicy getPolicy() {
return null; return null;
} }

View File

@ -1,65 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscrete;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
import org.deeplearning4j.rl4j.learning.sync.support.MockMDP;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.support.MockDataManager;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
public class QLearningDiscreteTest {
@Test
public void refac_checkDataManagerCallsRemainTheSame() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder()
.maxStep(10)
.expRepMaxSize(1)
.build();
MockDataManager dataManager = new MockDataManager(true);
MockQLearningDiscrete sut = new MockQLearningDiscrete(10, lconfig, dataManager, 2, 3);
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
.build();
sut.setHistoryProcessor(new MockHistoryProcessor(hpConfig));
// Act
sut.train();
assertEquals(10, dataManager.statEntries.size());
for(int i = 0; i < 10; ++i) {
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(i, entry.getEpochCounter());
assertEquals(i+1, entry.getStepCounter());
assertEquals(1.0, entry.getReward(), 0.0);
}
assertEquals(4, dataManager.isSaveDataCallCount);
assertEquals(4, dataManager.getVideoDirCallCount);
assertEquals(11, dataManager.writeInfoCallCount);
assertEquals(5, dataManager.saveCallCount);
}
public static class MockQLearningDiscrete extends QLearningDiscrete {
public MockQLearningDiscrete(int maxSteps, QLConfiguration conf,
IDataManager dataManager, int saveFrequency, int monitorFrequency) {
super(new MockMDP(maxSteps), new MockDQN(), conf, 2);
addListener(DataManagerSyncTrainingListener.builder(dataManager)
.saveFrequency(saveFrequency)
.monitorFrequency(monitorFrequency)
.build());
}
@Override
protected IDataManager.StatEntry trainEpoch() {
setStepCounter(getStepCounter() + 1);
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
}
}
}

View File

@ -8,6 +8,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -29,12 +30,11 @@ public class QLearningDiscreteTest {
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
0, 1.0, 0, 0, 0, 0, true); 0, 1.0, 0, 0, 0, 0, true);
MockDataManager dataManager = new MockDataManager(false); MockDataManager dataManager = new MockDataManager(false);
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, 10); MockExpReplay expReplay = new MockExpReplay();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10);
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hp); sut.setHistoryProcessor(hp);
MockExpReplay expReplay = new MockExpReplay();
sut.setExpReplay(expReplay);
MockEncodable obs = new MockEncodable(1); MockEncodable obs = new MockEncodable(1);
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>(); List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
@ -131,8 +131,11 @@ public class QLearningDiscreteTest {
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> { public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn, public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn,
QLConfiguration conf, IDataManager dataManager, int epsilonNbStep) { QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
super(mdp, dqn, conf, dataManager, epsilonNbStep); int epsilonNbStep) {
super(mdp, dqn, conf, epsilonNbStep);
addListener(new DataManagerTrainingListener(dataManager));
setExpReplay(expReplay);
} }
@Override @Override

View File

@ -1,46 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
public class MockSyncTrainingListener implements SyncTrainingListener {
public int onTrainingStartCallCount = 0;
public int onTrainingEndCallCount = 0;
public int onEpochStartCallCount = 0;
public int onEpochEndStartCallCount = 0;
public boolean trainingStartCanContinue = true;
public int nbStepsEpochStartCanContinue = Integer.MAX_VALUE;
public int nbStepsEpochEndCanContinue = Integer.MAX_VALUE;
@Override
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
++onTrainingStartCallCount;
return trainingStartCanContinue ? ListenerResponse.CONTINUE : ListenerResponse.STOP;
}
@Override
public void onTrainingEnd() {
++onTrainingEndCallCount;
}
@Override
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
++onEpochStartCallCount;
if(onEpochStartCallCount >= nbStepsEpochStartCanContinue) {
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
++onEpochEndStartCallCount;
if(onEpochEndStartCallCount >= nbStepsEpochEndCanContinue) {
return ListenerResponse.STOP;
}
return ListenerResponse.CONTINUE;
}
}

View File

@ -0,0 +1,65 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
public class MockAsyncConfiguration implements AsyncConfiguration {
private final int nStep;
private final int maxEpochStep;
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
this.nStep = nStep;
this.maxEpochStep = maxEpochStep;
}
@Override
public int getSeed() {
return 0;
}
@Override
public int getMaxEpochStep() {
return maxEpochStep;
}
@Override
public int getMaxStep() {
return 0;
}
@Override
public int getNumThread() {
return 0;
}
@Override
public int getNstep() {
return nStep;
}
@Override
public int getTargetDqnUpdateFreq() {
return 0;
}
@Override
public int getUpdateStart() {
return 0;
}
@Override
public double getRewardFactor() {
return 0;
}
@Override
public double getGamma() {
return 0;
}
@Override
public double getErrorClamp() {
return 0;
}
}

View File

@ -0,0 +1,65 @@
package org.deeplearning4j.rl4j.support;
import lombok.Setter;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.network.NeuralNet;
import java.util.concurrent.atomic.AtomicInteger;
public class MockAsyncGlobal implements IAsyncGlobal {
public boolean hasBeenStarted = false;
public boolean hasBeenTerminated = false;
@Setter
private int maxLoops;
@Setter
private int numLoopsStopRunning;
private int currentLoop = 0;
public MockAsyncGlobal() {
maxLoops = Integer.MAX_VALUE;
numLoopsStopRunning = Integer.MAX_VALUE;
}
@Override
public boolean isRunning() {
return currentLoop < numLoopsStopRunning;
}
@Override
public void terminate() {
hasBeenTerminated = true;
}
@Override
public boolean isTrainingComplete() {
return ++currentLoop > maxLoops;
}
@Override
public void start() {
hasBeenStarted = true;
}
@Override
public AtomicInteger getT() {
return null;
}
@Override
public NeuralNet getCurrent() {
return null;
}
@Override
public NeuralNet getTarget() {
return null;
}
@Override
public void enqueue(Gradient[] gradient, Integer nstep) {
}
}

View File

@ -44,7 +44,7 @@ public class MockDataManager implements IDataManager {
} }
@Override @Override
public void save(Learning learning) throws IOException { public void save(ILearning learning) throws IOException {
++saveCallCount; ++saveCallCount;
} }
} }

View File

@ -0,0 +1,74 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.IOException;
import java.io.OutputStream;
public class MockNeuralNet implements NeuralNet {
public int resetCallCount = 0;
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}
@Override
public boolean isRecurrent() {
return false;
}
@Override
public void reset() {
++resetCallCount;
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];
}
@Override
public NeuralNet clone() {
return null;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] labels) {
return new Gradient[0];
}
@Override
public void fit(INDArray input, INDArray[] labels) {
}
@Override
public void applyGradient(Gradient[] gradients, int batchSize) {
}
@Override
public double getLatestScore() {
return 0;
}
@Override
public void save(OutputStream os) throws IOException {
}
@Override
public void save(String filename) throws IOException {
}
}

View File

@ -0,0 +1,17 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
public class MockPolicy implements IPolicy<MockEncodable, Integer> {
public int playCallCount = 0;
@Override
public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) {
++playCallCount;
return 0;
}
}

View File

@ -0,0 +1,65 @@
package org.deeplearning4j.rl4j.support;
import lombok.Setter;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.util.IDataManager;
import java.util.ArrayList;
import java.util.List;
public class MockTrainingListener implements TrainingListener {
public int onTrainingStartCallCount = 0;
public int onTrainingEndCallCount = 0;
public int onNewEpochCallCount = 0;
public int onEpochTrainingResultCallCount = 0;
public int onTrainingProgressCallCount = 0;
@Setter
private int remainingTrainingStartCallCount = Integer.MAX_VALUE;
@Setter
private int remainingOnNewEpochCallCount = Integer.MAX_VALUE;
@Setter
private int remainingOnEpochTrainingResult = Integer.MAX_VALUE;
@Setter
private int remainingonTrainingProgressCallCount = Integer.MAX_VALUE;
public final List<IDataManager.StatEntry> statEntries = new ArrayList<>();
@Override
public ListenerResponse onTrainingStart() {
++onTrainingStartCallCount;
--remainingTrainingStartCallCount;
return remainingTrainingStartCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
++onNewEpochCallCount;
--remainingOnNewEpochCallCount;
return remainingOnNewEpochCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
++onEpochTrainingResultCallCount;
--remainingOnEpochTrainingResult;
statEntries.add(statEntry);
return remainingOnEpochTrainingResult < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onTrainingProgress(ILearning learning) {
++onTrainingProgressCallCount;
--remainingonTrainingProgressCallCount;
return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public void onTrainingEnd() {
++onTrainingEndCallCount;
}
}

View File

@ -0,0 +1,169 @@
package org.deeplearning4j.rl4j.util;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.support.MockDataManager;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
public class DataManagerTrainingListenerTest {
@Test
public void when_callingOnNewEpochWithoutHistoryProcessor_expect_noException() {
// Arrange
TestTrainer trainer = new TestTrainer();
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
// Act
TrainingListener.ListenerResponse response = sut.onNewEpoch(trainer);
// Assert
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
}
@Test
public void when_callingOnNewEpochWithHistoryProcessor_expect_startMonitorNotCalled() {
// Arrange
TestTrainer trainer = new TestTrainer();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
trainer.setHistoryProcessor(hp);
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
// Act
TrainingListener.ListenerResponse response = sut.onNewEpoch(trainer);
// Assert
assertEquals(1, hp.startMonitorCallCount);
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
}
@Test
public void when_callingOnEpochTrainingResultWithoutHistoryProcessor_expect_noException() {
// Arrange
TestTrainer trainer = new TestTrainer();
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
// Act
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, null);
// Assert
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
}
@Test
public void when_callingOnNewEpochWithHistoryProcessor_expect_stopMonitorNotCalled() {
// Arrange
TestTrainer trainer = new TestTrainer();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
trainer.setHistoryProcessor(hp);
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
// Act
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, null);
// Assert
assertEquals(1, hp.stopMonitorCallCount);
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
}
@Test
public void when_callingOnEpochTrainingResult_expect_callToDataManagerAppendStat() {
// Arrange
TestTrainer trainer = new TestTrainer();
MockDataManager dm = new MockDataManager(false);
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
MockStatEntry statEntry = new MockStatEntry(0, 0, 0.0);
// Act
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, statEntry);
// Assert
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
assertEquals(1, dm.statEntries.size());
assertSame(statEntry, dm.statEntries.get(0));
}
@Test
public void when_callingOnTrainingProgress_expect_callToDataManagerSaveAndWriteInfo() {
// Arrange
TestTrainer learning = new TestTrainer();
MockDataManager dm = new MockDataManager(false);
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
// Act
TrainingListener.ListenerResponse response = sut.onTrainingProgress(learning);
// Assert
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
assertEquals(1, dm.writeInfoCallCount);
assertEquals(1, dm.saveCallCount);
}
@Test
public void when_stepCounterCloseToLastSave_expect_dataManagerSaveNotCalled() {
// Arrange
TestTrainer learning = new TestTrainer();
MockDataManager dm = new MockDataManager(false);
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
// Act
TrainingListener.ListenerResponse response = sut.onTrainingProgress(learning);
TrainingListener.ListenerResponse response2 = sut.onTrainingProgress(learning);
// Assert
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response2);
assertEquals(1, dm.saveCallCount);
}
private static class TestTrainer implements IEpochTrainer, ILearning
{
@Override
public int getStepCounter() {
return 0;
}
@Override
public int getEpochCounter() {
return 0;
}
@Getter
@Setter
private IHistoryProcessor historyProcessor;
@Override
public IPolicy getPolicy() {
return null;
}
@Override
public void train() {
}
@Override
public LConfiguration getConfiguration() {
return null;
}
@Override
public MDP getMdp() {
return new MockMDP(new MockObservationSpace());
}
}
}