RL4J - AsyncTrainingListener (#8072)
* Code clarity: Extracted parts of run() into private methods Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Added listener pattern to async learning Signed-off-by: unknown <aboulang2002@yahoo.com> * Merged all listeners logic Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Added interface and common data to training events Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Fixed missing info log file Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Fixed bad merge; removed useless TrainingEvent Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Removed param from training start/end event Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Removed 'event' classes from the training listener Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Reverted changes to QLearningDiscrete.setTarget()master
parent
d58a4b45b1
commit
59f1cbf0c6
|
@ -0,0 +1,31 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.learning;
|
||||
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
|
||||
/**
|
||||
* The common API between Learning and AsyncThread.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface IEpochTrainer {
|
||||
int getStepCounter();
|
||||
int getEpochCounter();
|
||||
IHistoryProcessor getHistoryProcessor();
|
||||
MDP getMdp();
|
||||
}
|
|
@ -17,7 +17,7 @@
|
|||
package org.deeplearning4j.rl4j.learning;
|
||||
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
|
||||
|
@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
|||
*/
|
||||
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> extends StepCountable {
|
||||
|
||||
Policy<O, A> getPolicy();
|
||||
IPolicy<O, A> getPolicy();
|
||||
|
||||
void train();
|
||||
|
||||
|
@ -38,6 +38,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> ex
|
|||
|
||||
MDP<O, A, AS> getMdp();
|
||||
|
||||
IHistoryProcessor getHistoryProcessor();
|
||||
|
||||
interface LConfiguration {
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
|||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
|
@ -63,7 +62,6 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
|||
@Getter
|
||||
private NN target;
|
||||
@Getter
|
||||
@Setter
|
||||
private boolean running = true;
|
||||
|
||||
public AsyncGlobal(NN initial, AsyncConfiguration a3cc) {
|
||||
|
@ -78,8 +76,10 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
|||
}
|
||||
|
||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
||||
if(running && !isTrainingComplete()) {
|
||||
queue.add(new Pair<>(gradient, nstep));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
|
@ -105,4 +105,12 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
|
||||
*/
|
||||
public void terminate() {
|
||||
running = false;
|
||||
queue.clear();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -16,33 +16,49 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* The entry point for async training. This class will start a number ({@link AsyncConfiguration#getNumThread()
|
||||
* configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals
|
||||
* (see setProgressEventInterval(int))
|
||||
*
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16.
|
||||
*
|
||||
* Async learning always follow the same pattern in RL4J
|
||||
* -launch the Global thread
|
||||
* -launch the "save threads"
|
||||
* -periodically evaluate the model of the global thread for monitoring purposes
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||
extends Learning<O, A, AS, NN> {
|
||||
|
||||
protected abstract IDataManager getDataManager();
|
||||
@Getter(AccessLevel.PROTECTED)
|
||||
private final TrainingListenerList listeners = new TrainingListenerList();
|
||||
|
||||
public AsyncLearning(AsyncConfiguration conf) {
|
||||
super(conf);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a {@link TrainingListener} listener at the end of the listener list.
|
||||
*
|
||||
* @param listener the listener to be added
|
||||
*/
|
||||
public void addListener(TrainingListener listener) {
|
||||
listeners.add(listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the configuration
|
||||
* @return the configuration (see {@link AsyncConfiguration})
|
||||
*/
|
||||
public abstract AsyncConfiguration getConfiguration();
|
||||
|
||||
protected abstract AsyncThread newThread(int i, int deviceAffinity);
|
||||
|
@ -57,41 +73,80 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
|||
return getAsyncGlobal().isTrainingComplete();
|
||||
}
|
||||
|
||||
public void launchThreads() {
|
||||
private boolean canContinue = true;
|
||||
|
||||
/**
|
||||
* Number of milliseconds between calls to onTrainingProgress
|
||||
*/
|
||||
@Getter @Setter
|
||||
private int progressMonitorFrequency = 20000;
|
||||
|
||||
private void launchThreads() {
|
||||
startGlobalThread();
|
||||
for (int i = 0; i < getConfiguration().getNumThread(); i++) {
|
||||
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
|
||||
t.start();
|
||||
|
||||
}
|
||||
log.info("Threads launched.");
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The current step
|
||||
*/
|
||||
@Override
|
||||
public int getStepCounter() {
|
||||
return getAsyncGlobal().getT().get();
|
||||
}
|
||||
|
||||
/**
|
||||
* This method will train the model<p>
|
||||
* The training stop when:<br>
|
||||
* - A worker thread terminate the AsyncGlobal thread (see {@link AsyncGlobal})<br>
|
||||
* OR<br>
|
||||
* - a listener explicitly stops it<br>
|
||||
* <p>
|
||||
* Listeners<br>
|
||||
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
|
||||
* returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse TrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
|
||||
* Events:
|
||||
* <ul>
|
||||
* <li>{@link TrainingListener#onTrainingStart() onTrainingStart()} is called once when the training starts.</li>
|
||||
* <li>{@link TrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
|
||||
* </ul>
|
||||
*/
|
||||
public void train() {
|
||||
|
||||
try {
|
||||
log.info("AsyncLearning training starting.");
|
||||
|
||||
canContinue = listeners.notifyTrainingStarted();
|
||||
if (canContinue) {
|
||||
launchThreads();
|
||||
monitorTraining();
|
||||
}
|
||||
|
||||
cleanupPostTraining();
|
||||
listeners.notifyTrainingFinished();
|
||||
}
|
||||
|
||||
protected void monitorTraining() {
|
||||
try {
|
||||
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
||||
canContinue = listeners.notifyTrainingProgress(this);
|
||||
if(!canContinue) {
|
||||
return;
|
||||
}
|
||||
|
||||
//this is simply for stat purposes
|
||||
getDataManager().writeInfo(this);
|
||||
synchronized (this) {
|
||||
while (!isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
||||
getPolicy().play(getMdp(), getHistoryProcessor());
|
||||
getDataManager().writeInfo(this);
|
||||
wait(20000);
|
||||
wait(progressMonitorFrequency);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("Training failed.", e);
|
||||
e.printStackTrace();
|
||||
} catch (InterruptedException e) {
|
||||
log.error("Training interrupted.", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
protected void cleanupPostTraining() {
|
||||
// Worker threads stops automatically when the global thread stops
|
||||
getAsyncGlobal().terminate();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,33 +21,31 @@ import lombok.Getter;
|
|||
import lombok.Setter;
|
||||
import lombok.Value;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.StepCountable;
|
||||
import org.deeplearning4j.rl4j.learning.*;
|
||||
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.Constants;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
*
|
||||
* This represent a local thread that explore the environment
|
||||
* and calculate a gradient to enqueue to the global thread/model
|
||||
*
|
||||
* It has its own version of a model that it syncs at the start of every
|
||||
* sub epoch
|
||||
*
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||
extends Thread implements StepCountable {
|
||||
extends Thread implements StepCountable, IEpochTrainer {
|
||||
|
||||
@Getter
|
||||
private int threadNumber;
|
||||
@Getter
|
||||
protected final int deviceNum;
|
||||
|
@ -55,12 +53,16 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
private int stepCounter = 0;
|
||||
@Getter @Setter
|
||||
private int epochCounter = 0;
|
||||
@Getter
|
||||
private MDP<O, A, AS> mdp;
|
||||
@Getter @Setter
|
||||
private IHistoryProcessor historyProcessor;
|
||||
@Getter
|
||||
private int lastMonitor = -Constants.MONITOR_FREQ;
|
||||
|
||||
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
|
||||
private final TrainingListenerList listeners;
|
||||
|
||||
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, MDP<O, A, AS> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
||||
this.mdp = mdp;
|
||||
this.listeners = listeners;
|
||||
this.threadNumber = threadNumber;
|
||||
this.deviceNum = deviceNum;
|
||||
}
|
||||
|
@ -80,75 +82,106 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
}
|
||||
|
||||
protected void preEpoch() {
|
||||
if (getStepCounter() - lastMonitor >= Constants.MONITOR_FREQ && getHistoryProcessor() != null
|
||||
&& getDataManager().isSaveData()) {
|
||||
lastMonitor = getStepCounter();
|
||||
int[] shape = getMdp().getObservationSpace().getShape();
|
||||
getHistoryProcessor().startMonitor(getDataManager().getVideoDir() + "/video-" + threadNumber + "-"
|
||||
+ getEpochCounter() + "-" + getStepCounter() + ".mp4", shape);
|
||||
}
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
/**
|
||||
* This method will start the worker thread<p>
|
||||
* The thread will stop when:<br>
|
||||
* - The AsyncGlobal thread terminates or reports that the training is complete
|
||||
* (see {@link AsyncGlobal#isTrainingComplete()}). In such case, the currently running epoch will still be handled normally and
|
||||
* events will also be fired normally.<br>
|
||||
* OR<br>
|
||||
* - a listener explicitly stops it, in which case, the AsyncGlobal thread will be terminated along with
|
||||
* all other worker threads <br>
|
||||
* <p>
|
||||
* Listeners<br>
|
||||
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
|
||||
* returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse
|
||||
* TrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
|
||||
* Events:
|
||||
* <ul>
|
||||
* <li>{@link TrainingListener#onNewEpoch(IEpochTrainer) onNewEpoch()} is called when a new epoch is started.</li>
|
||||
* <li>{@link TrainingListener#onEpochTrainingResult(IEpochTrainer, IDataManager.StatEntry) onEpochTrainingResult()} is called at the end of every
|
||||
* epoch. It will not be called if onNewEpoch() stops the training.</li>
|
||||
* </ul>
|
||||
*/
|
||||
@Override
|
||||
public void run() {
|
||||
RunContext<O> context = new RunContext<>();
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
|
||||
|
||||
|
||||
try {
|
||||
log.info("ThreadNum-" + threadNumber + " Started!");
|
||||
|
||||
boolean canContinue = initWork(context);
|
||||
if (canContinue) {
|
||||
|
||||
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
||||
handleTraining(context);
|
||||
if (context.length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||
canContinue = finishEpoch(context) && startNewEpoch(context);
|
||||
if (!canContinue) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
terminateWork();
|
||||
}
|
||||
|
||||
private void initNewEpoch(RunContext context) {
|
||||
getCurrent().reset();
|
||||
Learning.InitMdp<O> initMdp = Learning.initMdp(getMdp(), historyProcessor);
|
||||
O obs = initMdp.getLastObs();
|
||||
double rewards = initMdp.getReward();
|
||||
int length = initMdp.getSteps();
|
||||
|
||||
preEpoch();
|
||||
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
||||
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - length);
|
||||
SubEpochReturn<O> subEpochReturn = trainSubEpoch(obs, maxSteps);
|
||||
obs = subEpochReturn.getLastObs();
|
||||
context.obs = initMdp.getLastObs();
|
||||
context.rewards = initMdp.getReward();
|
||||
context.length = initMdp.getSteps();
|
||||
}
|
||||
|
||||
private void handleTraining(RunContext<O> context) {
|
||||
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.length);
|
||||
SubEpochReturn<O> subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
||||
|
||||
context.obs = subEpochReturn.getLastObs();
|
||||
stepCounter += subEpochReturn.getSteps();
|
||||
length += subEpochReturn.getSteps();
|
||||
rewards += subEpochReturn.getReward();
|
||||
double score = subEpochReturn.getScore();
|
||||
if (length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||
postEpoch();
|
||||
|
||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score);
|
||||
getDataManager().appendStat(statEntry);
|
||||
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
||||
|
||||
getCurrent().reset();
|
||||
initMdp = Learning.initMdp(getMdp(), historyProcessor);
|
||||
obs = initMdp.getLastObs();
|
||||
rewards = initMdp.getReward();
|
||||
length = initMdp.getSteps();
|
||||
epochCounter++;
|
||||
context.length += subEpochReturn.getSteps();
|
||||
context.rewards += subEpochReturn.getReward();
|
||||
context.score = subEpochReturn.getScore();
|
||||
}
|
||||
|
||||
private boolean initWork(RunContext context) {
|
||||
initNewEpoch(context);
|
||||
preEpoch();
|
||||
return listeners.notifyNewEpoch(this);
|
||||
}
|
||||
|
||||
private boolean startNewEpoch(RunContext context) {
|
||||
initNewEpoch(context);
|
||||
epochCounter++;
|
||||
preEpoch();
|
||||
return listeners.notifyNewEpoch(this);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("Thread crashed: " + e.getCause());
|
||||
getAsyncGlobal().setRunning(false);
|
||||
e.printStackTrace();
|
||||
} finally {
|
||||
|
||||
private boolean finishEpoch(RunContext context) {
|
||||
postEpoch();
|
||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.length, context.score);
|
||||
|
||||
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
|
||||
|
||||
return listeners.notifyEpochTrainingResult(this, statEntry);
|
||||
}
|
||||
|
||||
private void terminateWork() {
|
||||
postEpoch();
|
||||
getAsyncGlobal().terminate();
|
||||
}
|
||||
|
||||
protected abstract NN getCurrent();
|
||||
|
||||
protected abstract int getThreadNumber();
|
||||
|
||||
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
||||
|
||||
protected abstract MDP<O, A, AS> getMdp();
|
||||
|
||||
protected abstract AsyncConfiguration getConf();
|
||||
|
||||
protected abstract IDataManager getDataManager();
|
||||
|
||||
protected abstract Policy<O, A> getPolicy(NN net);
|
||||
|
||||
protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep);
|
||||
|
@ -172,4 +205,11 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
double score;
|
||||
}
|
||||
|
||||
private static class RunContext<O extends Encodable> {
|
||||
private O obs;
|
||||
private double rewards;
|
||||
private int length;
|
||||
private double score;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -21,7 +21,9 @@ import org.deeplearning4j.gym.StepReply;
|
|||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
|
@ -44,8 +46,8 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
|||
@Getter
|
||||
private NN current;
|
||||
|
||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
|
||||
super(asyncGlobal, threadNumber, deviceNum);
|
||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, MDP<O, Integer, DiscreteSpace> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||
synchronized (asyncGlobal) {
|
||||
current = (NN)asyncGlobal.getCurrent().clone();
|
||||
}
|
||||
|
|
|
@ -23,9 +23,14 @@ import java.util.concurrent.atomic.AtomicInteger;
|
|||
|
||||
public interface IAsyncGlobal<NN extends NeuralNet> {
|
||||
boolean isRunning();
|
||||
void setRunning(boolean value);
|
||||
boolean isTrainingComplete();
|
||||
void start();
|
||||
|
||||
/**
|
||||
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
|
||||
*/
|
||||
void terminate();
|
||||
|
||||
AtomicInteger getT();
|
||||
NN getCurrent();
|
||||
NN getTarget();
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
|||
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
||||
|
@ -47,24 +46,19 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
|||
final private AsyncGlobal asyncGlobal;
|
||||
@Getter
|
||||
final private ACPolicy<O> policy;
|
||||
@Getter
|
||||
final private IDataManager dataManager;
|
||||
|
||||
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf,
|
||||
IDataManager dataManager) {
|
||||
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) {
|
||||
super(conf);
|
||||
this.iActorCritic = iActorCritic;
|
||||
this.mdp = mdp;
|
||||
this.configuration = conf;
|
||||
this.dataManager = dataManager;
|
||||
policy = new ACPolicy<>(iActorCritic, getRandom());
|
||||
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
|
||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AsyncThread newThread(int i, int deviceNum) {
|
||||
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), i, dataManager, deviceNum);
|
||||
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), deviceNum, getListeners(), i);
|
||||
}
|
||||
|
||||
public IActorCritic getNeuralNet() {
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
|
|||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
|
@ -43,24 +44,38 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
|||
|
||||
final private HistoryProcessor.Configuration hpconf;
|
||||
|
||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
||||
@Deprecated
|
||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||
super(mdp, IActorCritic, conf, dataManager);
|
||||
this(mdp, actorCritic, hpconf, conf);
|
||||
addListener(new DataManagerTrainingListener(dataManager));
|
||||
}
|
||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||
super(mdp, IActorCritic, conf);
|
||||
this.hpconf = hpconf;
|
||||
setHistoryProcessor(hpconf);
|
||||
}
|
||||
|
||||
|
||||
@Deprecated
|
||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf,
|
||||
dataManager);
|
||||
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||
}
|
||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
|
||||
}
|
||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AsyncThread newThread(int i, int deviceNum) {
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
|||
import org.deeplearning4j.rl4j.network.ac.*;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
|
@ -33,33 +34,58 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
|||
*/
|
||||
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
|
||||
|
||||
@Deprecated
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
|
||||
IDataManager dataManager) {
|
||||
super(mdp, IActorCritic, conf, dataManager);
|
||||
this(mdp, IActorCritic, conf);
|
||||
addListener(new DataManagerTrainingListener(dataManager));
|
||||
}
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
|
||||
super(mdp, actorCritic, conf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||
A3CConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||
dataManager);
|
||||
}
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||
A3CConfiguration conf) {
|
||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
|
||||
IDataManager dataManager) {
|
||||
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
|
||||
}
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
|
||||
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||
A3CConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||
dataManager);
|
||||
}
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||
A3CConfiguration conf) {
|
||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
|
||||
IDataManager dataManager) {
|
||||
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
|
||||
}
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf) {
|
||||
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -22,13 +22,13 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
|||
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
@ -46,24 +46,19 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
|||
@Getter
|
||||
final protected A3CDiscrete.A3CConfiguration conf;
|
||||
@Getter
|
||||
final protected MDP<O, Integer, DiscreteSpace> mdp;
|
||||
@Getter
|
||||
final protected AsyncGlobal<IActorCritic> asyncGlobal;
|
||||
@Getter
|
||||
final protected int threadNumber;
|
||||
@Getter
|
||||
final protected IDataManager dataManager;
|
||||
|
||||
final private Random random;
|
||||
|
||||
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
|
||||
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager, int deviceNum) {
|
||||
super(asyncGlobal, threadNumber, deviceNum);
|
||||
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
|
||||
int threadNumber) {
|
||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||
this.conf = a3cc;
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.threadNumber = threadNumber;
|
||||
this.mdp = mdp;
|
||||
this.dataManager = dataManager;
|
||||
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
|
||||
random = new Random(conf.getSeed() + threadNumber);
|
||||
}
|
||||
|
@ -85,15 +80,15 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
|||
//if recurrent then train as a time serie with a batch size of 1
|
||||
boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent();
|
||||
|
||||
int[] shape = getHistoryProcessor() == null ? mdp.getObservationSpace().getShape()
|
||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
|
||||
: getHistoryProcessor().getConf().getShape();
|
||||
int[] nshape = recurrent ? Learning.makeShape(1, shape, size)
|
||||
: Learning.makeShape(size, shape);
|
||||
|
||||
INDArray input = Nd4j.create(nshape);
|
||||
INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1);
|
||||
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, mdp.getActionSpace().getSize(), size)
|
||||
: Nd4j.zeros(size, mdp.getActionSpace().getSize());
|
||||
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, getMdp().getActionSpace().getSize(), size)
|
||||
: Nd4j.zeros(size, getMdp().getActionSpace().getSize());
|
||||
|
||||
double r = minTrans.getReward();
|
||||
for (int i = size - 1; i >= 0; i--) {
|
||||
|
|
|
@ -24,10 +24,9 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
|||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
|
@ -40,16 +39,12 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
|||
@Getter
|
||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
||||
@Getter
|
||||
final private IDataManager dataManager;
|
||||
@Getter
|
||||
final private AsyncGlobal<IDQN> asyncGlobal;
|
||||
|
||||
|
||||
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf,
|
||||
IDataManager dataManager) {
|
||||
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
|
||||
super(conf);
|
||||
this.mdp = mdp;
|
||||
this.dataManager = dataManager;
|
||||
this.configuration = conf;
|
||||
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||
|
@ -57,14 +52,14 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
|||
|
||||
@Override
|
||||
public AsyncThread newThread(int i, int deviceNum) {
|
||||
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager, deviceNum);
|
||||
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, getListeners(), i, deviceNum);
|
||||
}
|
||||
|
||||
public IDQN getNeuralNet() {
|
||||
return asyncGlobal.getCurrent();
|
||||
}
|
||||
|
||||
public Policy<O, Integer> getPolicy() {
|
||||
public IPolicy<O, Integer> getPolicy() {
|
||||
return new DQNPolicy<O>(getNeuralNet());
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
|||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
|
@ -35,22 +36,38 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
|
|||
|
||||
final private HistoryProcessor.Configuration hpconf;
|
||||
|
||||
@Deprecated
|
||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
super(mdp, dqn, conf, dataManager);
|
||||
this(mdp, dqn, hpconf, conf);
|
||||
addListener(new DataManagerTrainingListener(dataManager));
|
||||
}
|
||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
|
||||
super(mdp, dqn, conf);
|
||||
this.hpconf = hpconf;
|
||||
setHistoryProcessor(hpconf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||
}
|
||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
|
||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
||||
}
|
||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
|
||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AsyncThread newThread(int i, int deviceNum) {
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
|||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
|
@ -29,19 +30,37 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
|||
*/
|
||||
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
|
||||
|
||||
@Deprecated
|
||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
super(mdp, dqn, conf, dataManager);
|
||||
super(mdp, dqn, conf);
|
||||
addListener(new DataManagerTrainingListener(dataManager));
|
||||
}
|
||||
|
||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||
AsyncNStepQLConfiguration conf) {
|
||||
super(mdp, dqn, conf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||
dataManager);
|
||||
}
|
||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
AsyncNStepQLConfiguration conf) {
|
||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
||||
}
|
||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
|
||||
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -19,9 +19,10 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
|||
import lombok.Getter;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||
|
@ -29,7 +30,6 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
|||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -44,31 +44,25 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
|||
@Getter
|
||||
final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf;
|
||||
@Getter
|
||||
final protected MDP<O, Integer, DiscreteSpace> mdp;
|
||||
@Getter
|
||||
final protected IAsyncGlobal<IDQN> asyncGlobal;
|
||||
@Getter
|
||||
final protected int threadNumber;
|
||||
@Getter
|
||||
final protected IDataManager dataManager;
|
||||
|
||||
final private Random random;
|
||||
|
||||
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
||||
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber,
|
||||
IDataManager dataManager, int deviceNum) {
|
||||
super(asyncGlobal, threadNumber, deviceNum);
|
||||
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf,
|
||||
TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||
this.conf = conf;
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.threadNumber = threadNumber;
|
||||
this.mdp = mdp;
|
||||
this.dataManager = dataManager;
|
||||
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
|
||||
random = new Random(conf.getSeed() + threadNumber);
|
||||
}
|
||||
|
||||
public Policy<O, Integer> getPolicy(IDQN nn) {
|
||||
return new EpsGreedy(new DQNPolicy(nn), mdp, conf.getUpdateStart(), conf.getEpsilonNbStep(),
|
||||
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
|
||||
random, conf.getMinEpsilon(), this);
|
||||
}
|
||||
|
||||
|
@ -81,11 +75,11 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
|||
|
||||
int size = rewards.size();
|
||||
|
||||
int[] shape = getHistoryProcessor() == null ? mdp.getObservationSpace().getShape()
|
||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
|
||||
: getHistoryProcessor().getConf().getShape();
|
||||
int[] nshape = Learning.makeShape(size, shape);
|
||||
INDArray input = Nd4j.create(nshape);
|
||||
INDArray targets = Nd4j.create(size, mdp.getActionSpace().getSize());
|
||||
INDArray targets = Nd4j.create(size, getMdp().getActionSpace().getSize());
|
||||
|
||||
double r = minTrans.getReward();
|
||||
for (int i = size - 1; i >= 0; i--) {
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.learning.listener;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* The base definition of all training event listeners
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface TrainingListener {
|
||||
enum ListenerResponse {
|
||||
/**
|
||||
* Tell the learning process to continue calling the listeners and the training.
|
||||
*/
|
||||
CONTINUE,
|
||||
|
||||
/**
|
||||
* Tell the learning process to stop calling the listeners and terminate the training.
|
||||
*/
|
||||
STOP,
|
||||
}
|
||||
|
||||
/**
|
||||
* Called once when the training starts.
|
||||
* @return A ListenerResponse telling the source of the event if it should go on or cancel the training.
|
||||
*/
|
||||
ListenerResponse onTrainingStart();
|
||||
|
||||
/**
|
||||
* Called once when the training has finished. This method is called even when the training has been aborted.
|
||||
*/
|
||||
void onTrainingEnd();
|
||||
|
||||
/**
|
||||
* Called before the start of every epoch.
|
||||
* @param trainer A {@link IEpochTrainer}
|
||||
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||
*/
|
||||
ListenerResponse onNewEpoch(IEpochTrainer trainer);
|
||||
|
||||
/**
|
||||
* Called when an epoch has been completed
|
||||
* @param trainer A {@link IEpochTrainer}
|
||||
* @param statEntry A {@link org.deeplearning4j.rl4j.util.IDataManager.StatEntry}
|
||||
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||
*/
|
||||
ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry);
|
||||
|
||||
/**
|
||||
* Called regularly to monitor the training progress.
|
||||
* @param learning A {@link ILearning}
|
||||
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||
*/
|
||||
ListenerResponse onTrainingProgress(ILearning learning);
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.learning.listener;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* The base logic to notify training listeners with the different training events.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class TrainingListenerList {
|
||||
protected final List<TrainingListener> listeners = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* Add a listener at the end of the list
|
||||
* @param listener The listener to be added
|
||||
*/
|
||||
public void add(TrainingListener listener) {
|
||||
listeners.add(listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* Notify the listeners that the training has started. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
|
||||
* @return whether or not the source training should be stopped
|
||||
*/
|
||||
public boolean notifyTrainingStarted() {
|
||||
for (TrainingListener listener : listeners) {
|
||||
if (listener.onTrainingStart() == TrainingListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Notify the listeners that the training has finished.
|
||||
*/
|
||||
public void notifyTrainingFinished() {
|
||||
for (TrainingListener listener : listeners) {
|
||||
listener.onTrainingEnd();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Notify the listeners that a new epoch has started. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
|
||||
* @return whether or not the source training should be stopped
|
||||
*/
|
||||
public boolean notifyNewEpoch(IEpochTrainer trainer) {
|
||||
for (TrainingListener listener : listeners) {
|
||||
if (listener.onNewEpoch(trainer) == TrainingListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Notify the listeners that an epoch has been completed and the training results are available. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
|
||||
* @return whether or not the source training should be stopped
|
||||
*/
|
||||
public boolean notifyEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
|
||||
for (TrainingListener listener : listeners) {
|
||||
if (listener.onEpochTrainingResult(trainer, statEntry) == TrainingListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Notify the listeners that they update the progress ot the trainning.
|
||||
*/
|
||||
public boolean notifyTrainingProgress(ILearning learning) {
|
||||
for (TrainingListener listener : listeners) {
|
||||
if (listener.onTrainingProgress(learning) == TrainingListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -16,19 +16,18 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.sync;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
||||
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Mother class and useful factorisations for all training methods that
|
||||
* are not asynchronous.
|
||||
|
@ -38,9 +37,9 @@ import java.util.List;
|
|||
*/
|
||||
@Slf4j
|
||||
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||
extends Learning<O, A, AS, NN> {
|
||||
extends Learning<O, A, AS, NN> implements IEpochTrainer {
|
||||
|
||||
private List<SyncTrainingListener> listeners = new ArrayList<>();
|
||||
private final TrainingListenerList listeners = new TrainingListenerList();
|
||||
|
||||
public SyncLearning(LConfiguration conf) {
|
||||
super(conf);
|
||||
|
@ -49,12 +48,24 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
|
|||
/**
|
||||
* Add a listener at the end of the listener list.
|
||||
*
|
||||
* @param listener
|
||||
* @param listener The listener to add
|
||||
*/
|
||||
public void addListener(SyncTrainingListener listener) {
|
||||
public void addListener(TrainingListener listener) {
|
||||
listeners.add(listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of epochs between calls to onTrainingProgress. Default is 5
|
||||
*/
|
||||
@Getter
|
||||
private int progressMonitorFrequency = 5;
|
||||
|
||||
public void setProgressMonitorFrequency(int value) {
|
||||
if(value == 0) throw new IllegalArgumentException("The progressMonitorFrequency cannot be 0");
|
||||
|
||||
progressMonitorFrequency = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method will train the model<p>
|
||||
* The training stop when:<br>
|
||||
|
@ -64,81 +75,49 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
|
|||
* <p>
|
||||
* Listeners<br>
|
||||
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
|
||||
* returns {@link SyncTrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
|
||||
* returns {@link TrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
|
||||
* Events:
|
||||
* <ul>
|
||||
* <li>{@link SyncTrainingListener#onTrainingStart(SyncTrainingEvent) onTrainingStart()} is called once when the training starts.</li>
|
||||
* <li>{@link SyncTrainingListener#onEpochStart(SyncTrainingEvent) onEpochStart()} and {@link SyncTrainingListener#onEpochEnd(SyncTrainingEpochEndEvent) onEpochEnd()} are called for every epoch. onEpochEnd will not be called if onEpochStart stops the training</li>
|
||||
* <li>{@link SyncTrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
|
||||
* <li>{@link TrainingListener#onTrainingStart() onTrainingStart()} is called once when the training starts.</li>
|
||||
* <li>{@link TrainingListener#onNewEpoch(IEpochTrainer) onNewEpoch()} and {@link TrainingListener#onEpochTrainingResult(IEpochTrainer, IDataManager.StatEntry) onEpochTrainingResult()} are called for every epoch. onEpochTrainingResult will not be called if onNewEpoch stops the training</li>
|
||||
* <li>{@link TrainingListener#onTrainingProgress(ILearning) onTrainingProgress()} is called after onEpochTrainingResult()</li>
|
||||
* <li>{@link TrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
|
||||
* </ul>
|
||||
*/
|
||||
public void train() {
|
||||
|
||||
log.info("training starting.");
|
||||
|
||||
boolean canContinue = notifyTrainingStarted();
|
||||
boolean canContinue = listeners.notifyTrainingStarted();
|
||||
if (canContinue) {
|
||||
while (getStepCounter() < getConfiguration().getMaxStep()) {
|
||||
preEpoch();
|
||||
canContinue = notifyEpochStarted();
|
||||
canContinue = listeners.notifyNewEpoch(this);
|
||||
if (!canContinue) {
|
||||
break;
|
||||
}
|
||||
|
||||
IDataManager.StatEntry statEntry = trainEpoch();
|
||||
|
||||
postEpoch();
|
||||
canContinue = notifyEpochFinished(statEntry);
|
||||
canContinue = listeners.notifyEpochTrainingResult(this, statEntry);
|
||||
if (!canContinue) {
|
||||
break;
|
||||
}
|
||||
|
||||
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
||||
postEpoch();
|
||||
|
||||
if(getEpochCounter() % progressMonitorFrequency == 0) {
|
||||
canContinue = listeners.notifyTrainingProgress(this);
|
||||
if (!canContinue) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
||||
incrementEpoch();
|
||||
}
|
||||
}
|
||||
|
||||
notifyTrainingFinished();
|
||||
}
|
||||
|
||||
private boolean notifyTrainingStarted() {
|
||||
SyncTrainingEvent event = new SyncTrainingEvent(this);
|
||||
for (SyncTrainingListener listener : listeners) {
|
||||
if (listener.onTrainingStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private void notifyTrainingFinished() {
|
||||
for (SyncTrainingListener listener : listeners) {
|
||||
listener.onTrainingEnd();
|
||||
}
|
||||
}
|
||||
|
||||
private boolean notifyEpochStarted() {
|
||||
SyncTrainingEvent event = new SyncTrainingEvent(this);
|
||||
for (SyncTrainingListener listener : listeners) {
|
||||
if (listener.onEpochStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean notifyEpochFinished(IDataManager.StatEntry statEntry) {
|
||||
SyncTrainingEpochEndEvent event = new SyncTrainingEpochEndEvent(this, statEntry);
|
||||
for (SyncTrainingListener listener : listeners) {
|
||||
if (listener.onEpochEnd(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
listeners.notifyTrainingFinished();
|
||||
}
|
||||
|
||||
protected abstract void preEpoch();
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.listener;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* A subclass of SyncTrainingEvent that is passed to SyncTrainingListener.onEpochEnd()
|
||||
*/
|
||||
public class SyncTrainingEpochEndEvent extends SyncTrainingEvent {
|
||||
|
||||
/**
|
||||
* The stats of the epoch training
|
||||
*/
|
||||
@Getter
|
||||
private final IDataManager.StatEntry statEntry;
|
||||
|
||||
public SyncTrainingEpochEndEvent(Learning learning, IDataManager.StatEntry statEntry) {
|
||||
super(learning);
|
||||
this.statEntry = statEntry;
|
||||
}
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.listener;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
|
||||
/**
|
||||
* SyncTrainingEvent are passed as parameters to the events of SyncTrainingListener
|
||||
*/
|
||||
public class SyncTrainingEvent {
|
||||
|
||||
/**
|
||||
* The source of the event
|
||||
*/
|
||||
@Getter
|
||||
private final Learning learning;
|
||||
|
||||
public SyncTrainingEvent(Learning learning) {
|
||||
this.learning = learning;
|
||||
}
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.listener;
|
||||
|
||||
/**
|
||||
* A listener interface to use with a descendant of {@link org.deeplearning4j.rl4j.learning.sync.SyncLearning}
|
||||
*/
|
||||
public interface SyncTrainingListener {
|
||||
|
||||
public enum ListenerResponse {
|
||||
/**
|
||||
* Tell SyncLearning to continue calling the listeners and the training.
|
||||
*/
|
||||
CONTINUE,
|
||||
|
||||
/**
|
||||
* Tell SyncLearning to stop calling the listeners and terminate the training.
|
||||
*/
|
||||
STOP,
|
||||
}
|
||||
|
||||
/**
|
||||
* Called once when the training starts.
|
||||
* @param event
|
||||
* @return A ListenerResponse telling the source of the event if it should go on or cancel the training.
|
||||
*/
|
||||
ListenerResponse onTrainingStart(SyncTrainingEvent event);
|
||||
|
||||
/**
|
||||
* Called once when the training has finished. This method is called even when the training has been aborted.
|
||||
*/
|
||||
void onTrainingEnd();
|
||||
|
||||
/**
|
||||
* Called before the start of every epoch.
|
||||
* @param event
|
||||
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||
*/
|
||||
ListenerResponse onEpochStart(SyncTrainingEvent event);
|
||||
|
||||
/**
|
||||
* Called after the end of every epoch.
|
||||
* @param event
|
||||
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||
*/
|
||||
ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event);
|
||||
}
|
|
@ -49,7 +49,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
|||
// @Getter
|
||||
// final private IExpReplay<A> expReplay;
|
||||
@Getter
|
||||
@Setter(AccessLevel.PACKAGE)
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
protected IExpReplay<A> expReplay;
|
||||
|
||||
public QLearning(QLConfiguration conf) {
|
||||
|
|
|
@ -28,8 +28,6 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
|||
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||
|
@ -64,20 +62,9 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
@Setter
|
||||
private IDQN targetDQN;
|
||||
private int lastAction;
|
||||
private INDArray history[] = null;
|
||||
private INDArray[] history = null;
|
||||
private double accuReward = 0;
|
||||
|
||||
/**
|
||||
* @deprecated
|
||||
* Use QLearningDiscrete(MDP, IDQN, QLConfiguration, int) and add the required listeners with addListener() instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||
IDataManager dataManager, int epsilonNbStep) {
|
||||
this(mdp, dqn, conf, epsilonNbStep);
|
||||
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
|
||||
}
|
||||
|
||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||
int epsilonNbStep) {
|
||||
super(conf);
|
||||
|
@ -186,7 +173,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
|
||||
}
|
||||
|
||||
|
||||
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
|
||||
if (transitions.size() == 0)
|
||||
throw new IllegalArgumentException("too few transitions");
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
|||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
|
@ -33,20 +34,36 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
|||
public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscrete<O> {
|
||||
|
||||
|
||||
|
||||
@Deprecated
|
||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||
QLConfiguration conf, IDataManager dataManager) {
|
||||
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
||||
this(mdp, dqn, hpconf, conf);
|
||||
addListener(new DataManagerTrainingListener(dataManager));
|
||||
}
|
||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||
QLConfiguration conf) {
|
||||
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
||||
setHistoryProcessor(hpconf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||
}
|
||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
||||
}
|
||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
|||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
|
@ -31,21 +32,35 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
|||
public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscrete<O> {
|
||||
|
||||
|
||||
|
||||
@Deprecated
|
||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
|
||||
IDataManager dataManager) {
|
||||
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep());
|
||||
this(mdp, dqn, conf);
|
||||
addListener(new DataManagerTrainingListener(dataManager));
|
||||
}
|
||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
|
||||
super(mdp, dqn, conf, conf.getEpsilonNbStep());
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||
dataManager);
|
||||
}
|
||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
QLearning.QLConfiguration conf) {
|
||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
||||
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
||||
}
|
||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
||||
QLearning.QLConfiguration conf) {
|
||||
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
package org.deeplearning4j.rl4j.policy;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
|
||||
public interface IPolicy<O extends Encodable, A> {
|
||||
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
||||
}
|
|
@ -35,7 +35,7 @@ import org.nd4j.linalg.util.ArrayUtil;
|
|||
*
|
||||
* A Policy responsability is to choose the next action given a state
|
||||
*/
|
||||
public abstract class Policy<O extends Encodable, A> {
|
||||
public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
||||
|
||||
public abstract NeuralNet getNeuralNet();
|
||||
|
||||
|
@ -49,6 +49,7 @@ public abstract class Policy<O extends Encodable, A> {
|
|||
return play(mdp, new HistoryProcessor(conf));
|
||||
}
|
||||
|
||||
@Override
|
||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
||||
getNeuralNet().reset();
|
||||
Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);
|
||||
|
|
|
@ -22,6 +22,7 @@ import lombok.Builder;
|
|||
import lombok.Getter;
|
||||
import lombok.Value;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
|
@ -72,13 +73,13 @@ public class DataManager implements IDataManager {
|
|||
}
|
||||
}
|
||||
|
||||
public static void save(String path, Learning learning) throws IOException {
|
||||
public static void save(String path, ILearning learning) throws IOException {
|
||||
try (BufferedOutputStream os = new BufferedOutputStream(new FileOutputStream(path))) {
|
||||
save(os, learning);
|
||||
}
|
||||
}
|
||||
|
||||
public static void save(OutputStream os, Learning learning) throws IOException {
|
||||
public static void save(OutputStream os, ILearning learning) throws IOException {
|
||||
|
||||
try (ZipOutputStream zipfile = new ZipOutputStream(os)) {
|
||||
|
||||
|
@ -91,7 +92,9 @@ public class DataManager implements IDataManager {
|
|||
zipfile.putNextEntry(dqn);
|
||||
|
||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
learning.getNeuralNet().save(bos);
|
||||
if(learning instanceof NeuralNetFetchable) {
|
||||
((NeuralNetFetchable)learning).getNeuralNet().save(bos);
|
||||
}
|
||||
bos.flush();
|
||||
bos.close();
|
||||
|
||||
|
@ -104,7 +107,9 @@ public class DataManager implements IDataManager {
|
|||
zipfile.putNextEntry(hpconf);
|
||||
|
||||
ByteArrayOutputStream bos2 = new ByteArrayOutputStream();
|
||||
learning.getNeuralNet().save(bos2);
|
||||
if(learning instanceof NeuralNetFetchable) {
|
||||
((NeuralNetFetchable)learning).getNeuralNet().save(bos2);
|
||||
}
|
||||
bos2.flush();
|
||||
bos2.close();
|
||||
|
||||
|
@ -256,13 +261,15 @@ public class DataManager implements IDataManager {
|
|||
return exists;
|
||||
}
|
||||
|
||||
public void save(Learning learning) throws IOException {
|
||||
public void save(ILearning learning) throws IOException {
|
||||
|
||||
if (!saveData)
|
||||
return;
|
||||
|
||||
save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning);
|
||||
learning.getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
|
||||
if(learning instanceof NeuralNetFetchable) {
|
||||
((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -1,126 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.util;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
||||
|
||||
/**
|
||||
* DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the
|
||||
* training process can be fed to the DataManager
|
||||
*/
|
||||
@Slf4j
|
||||
public class DataManagerSyncTrainingListener implements SyncTrainingListener {
|
||||
private final IDataManager dataManager;
|
||||
private final int saveFrequency;
|
||||
private final int monitorFrequency;
|
||||
|
||||
private int lastSave;
|
||||
private int lastMonitor;
|
||||
|
||||
private DataManagerSyncTrainingListener(Builder builder) {
|
||||
this.dataManager = builder.dataManager;
|
||||
|
||||
this.saveFrequency = builder.saveFrequency;
|
||||
this.lastSave = -builder.saveFrequency;
|
||||
|
||||
this.monitorFrequency = builder.monitorFrequency;
|
||||
this.lastMonitor = -builder.monitorFrequency;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
|
||||
try {
|
||||
dataManager.writeInfo(event.getLearning());
|
||||
} catch (Exception e) {
|
||||
log.error("Training failed.", e);
|
||||
return ListenerResponse.STOP;
|
||||
}
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrainingEnd() {
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
|
||||
int stepCounter = event.getLearning().getStepCounter();
|
||||
|
||||
if (stepCounter - lastMonitor >= monitorFrequency
|
||||
&& event.getLearning().getHistoryProcessor() != null
|
||||
&& dataManager.isSaveData()) {
|
||||
lastMonitor = stepCounter;
|
||||
int[] shape = event.getLearning().getMdp().getObservationSpace().getShape();
|
||||
event.getLearning().getHistoryProcessor().startMonitor(dataManager.getVideoDir() + "/video-" + event.getLearning().getEpochCounter() + "-"
|
||||
+ stepCounter + ".mp4", shape);
|
||||
}
|
||||
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
|
||||
try {
|
||||
int stepCounter = event.getLearning().getStepCounter();
|
||||
if (stepCounter - lastSave >= saveFrequency) {
|
||||
dataManager.save(event.getLearning());
|
||||
lastSave = stepCounter;
|
||||
}
|
||||
|
||||
dataManager.appendStat(event.getStatEntry());
|
||||
dataManager.writeInfo(event.getLearning());
|
||||
} catch (Exception e) {
|
||||
log.error("Training failed.", e);
|
||||
return ListenerResponse.STOP;
|
||||
}
|
||||
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
public static Builder builder(IDataManager dataManager) {
|
||||
return new Builder(dataManager);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private final IDataManager dataManager;
|
||||
private int saveFrequency = Constants.MODEL_SAVE_FREQ;
|
||||
private int monitorFrequency = Constants.MONITOR_FREQ;
|
||||
|
||||
/**
|
||||
* Create a Builder with the given DataManager
|
||||
* @param dataManager
|
||||
*/
|
||||
public Builder(IDataManager dataManager) {
|
||||
this.dataManager = dataManager;
|
||||
}
|
||||
|
||||
/**
|
||||
* A number that represent the number of steps since the last call to DataManager.save() before can it be called again.
|
||||
* @param saveFrequency (Default: 100000)
|
||||
*/
|
||||
public Builder saveFrequency(int saveFrequency) {
|
||||
this.saveFrequency = saveFrequency;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* A number that represent the number of steps since the last call to HistoryProcessor.startMonitor() before can it be called again.
|
||||
* @param monitorFrequency (Default: 10000)
|
||||
*/
|
||||
public Builder monitorFrequency(int monitorFrequency) {
|
||||
this.monitorFrequency = monitorFrequency;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a DataManagerSyncTrainingListener with the configured parameters
|
||||
* @return An instance of DataManagerSyncTrainingListener
|
||||
*/
|
||||
public DataManagerSyncTrainingListener build() {
|
||||
return new DataManagerSyncTrainingListener(this);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
package org.deeplearning4j.rl4j.util;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||
|
||||
/**
|
||||
* DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the
|
||||
* training process can be fed to the DataManager
|
||||
*/
|
||||
@Slf4j
|
||||
public class DataManagerTrainingListener implements TrainingListener {
|
||||
private final IDataManager dataManager;
|
||||
|
||||
private int lastSave = -Constants.MODEL_SAVE_FREQ;
|
||||
|
||||
public DataManagerTrainingListener(IDataManager dataManager) {
|
||||
this.dataManager = dataManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onTrainingStart() {
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrainingEnd() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
|
||||
IHistoryProcessor hp = trainer.getHistoryProcessor();
|
||||
if(hp != null) {
|
||||
int[] shape = trainer.getMdp().getObservationSpace().getShape();
|
||||
String filename = dataManager.getVideoDir() + "/video-";
|
||||
if (trainer instanceof AsyncThread) {
|
||||
filename += ((AsyncThread) trainer).getThreadNumber() + "-";
|
||||
}
|
||||
filename += trainer.getEpochCounter() + "-" + trainer.getStepCounter() + ".mp4";
|
||||
hp.startMonitor(filename, shape);
|
||||
}
|
||||
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
|
||||
IHistoryProcessor hp = trainer.getHistoryProcessor();
|
||||
if(hp != null) {
|
||||
hp.stopMonitor();
|
||||
}
|
||||
try {
|
||||
dataManager.appendStat(statEntry);
|
||||
} catch (Exception e) {
|
||||
log.error("Training failed.", e);
|
||||
return ListenerResponse.STOP;
|
||||
}
|
||||
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onTrainingProgress(ILearning learning) {
|
||||
try {
|
||||
int stepCounter = learning.getStepCounter();
|
||||
if (stepCounter - lastSave >= Constants.MODEL_SAVE_FREQ) {
|
||||
dataManager.save(learning);
|
||||
lastSave = stepCounter;
|
||||
}
|
||||
|
||||
dataManager.writeInfo(learning);
|
||||
} catch (Exception e) {
|
||||
log.error("Training failed.", e);
|
||||
return ListenerResponse.STOP;
|
||||
}
|
||||
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
}
|
|
@ -27,7 +27,7 @@ public interface IDataManager {
|
|||
String getVideoDir();
|
||||
void appendStat(StatEntry statEntry) throws IOException;
|
||||
void writeInfo(ILearning iLearning) throws IOException;
|
||||
void save(Learning learning) throws IOException;
|
||||
void save(ILearning learning) throws IOException;
|
||||
|
||||
//In order for jackson to serialize StatEntry
|
||||
//please use Lombok @Value (see QLStatEntry)
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.support.*;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class AsyncLearningTest {
|
||||
|
||||
@Test
|
||||
public void when_training_expect_AsyncGlobalStarted() {
|
||||
// Arrange
|
||||
TestContext context = new TestContext();
|
||||
context.asyncGlobal.setMaxLoops(1);
|
||||
|
||||
// Act
|
||||
context.sut.train();
|
||||
|
||||
// Assert
|
||||
assertTrue(context.asyncGlobal.hasBeenStarted);
|
||||
assertTrue(context.asyncGlobal.hasBeenTerminated);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_trainStartReturnsStop_expect_noTraining() {
|
||||
// Arrange
|
||||
TestContext context = new TestContext();
|
||||
context.listener.setRemainingTrainingStartCallCount(0);
|
||||
// Act
|
||||
context.sut.train();
|
||||
|
||||
// Assert
|
||||
assertEquals(1, context.listener.onTrainingStartCallCount);
|
||||
assertEquals(1, context.listener.onTrainingEndCallCount);
|
||||
assertEquals(0, context.policy.playCallCount);
|
||||
assertTrue(context.asyncGlobal.hasBeenTerminated);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_trainingIsComplete_expect_trainingStop() {
|
||||
// Arrange
|
||||
TestContext context = new TestContext();
|
||||
|
||||
// Act
|
||||
context.sut.train();
|
||||
|
||||
// Assert
|
||||
assertEquals(1, context.listener.onTrainingStartCallCount);
|
||||
assertEquals(1, context.listener.onTrainingEndCallCount);
|
||||
assertTrue(context.asyncGlobal.hasBeenTerminated);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_training_expect_onTrainingProgressCalled() {
|
||||
// Arrange
|
||||
TestContext context = new TestContext();
|
||||
|
||||
// Act
|
||||
context.sut.train();
|
||||
|
||||
// Assert
|
||||
assertEquals(1, context.listener.onTrainingProgressCallCount);
|
||||
}
|
||||
|
||||
|
||||
public static class TestContext {
|
||||
public final MockAsyncConfiguration conf = new MockAsyncConfiguration(1, 1);
|
||||
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
||||
public final MockPolicy policy = new MockPolicy();
|
||||
public final TestAsyncLearning sut = new TestAsyncLearning(conf, asyncGlobal, policy);
|
||||
public final MockTrainingListener listener = new MockTrainingListener();
|
||||
|
||||
public TestContext() {
|
||||
sut.addListener(listener);
|
||||
asyncGlobal.setMaxLoops(1);
|
||||
sut.setProgressMonitorFrequency(1);
|
||||
}
|
||||
}
|
||||
|
||||
public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
|
||||
private final AsyncConfiguration conf;
|
||||
private final IAsyncGlobal asyncGlobal;
|
||||
private final IPolicy<MockEncodable, Integer> policy;
|
||||
|
||||
public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
|
||||
super(conf);
|
||||
this.conf = conf;
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.policy = policy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public IPolicy getPolicy() {
|
||||
return policy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AsyncConfiguration getConfiguration() {
|
||||
return conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AsyncThread newThread(int i, int deviceAffinity) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MDP getMdp() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IAsyncGlobal getAsyncGlobal() {
|
||||
return asyncGlobal;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MockNeuralNet getNeuralNet() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,206 +1,135 @@
|
|||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.support.MockMDP;
|
||||
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
||||
import org.deeplearning4j.rl4j.support.*;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class AsyncThreadTest {
|
||||
|
||||
@Test
|
||||
public void refac_withoutHistoryProcessor_checkDataManagerCallsRemainTheSame() {
|
||||
public void when_newEpochStarted_expect_neuralNetworkReset() {
|
||||
// Arrange
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
||||
MockNeuralNet neuralNet = new MockNeuralNet();
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdp = new MockMDP(observationSpace);
|
||||
MockAsyncConfiguration config = new MockAsyncConfiguration(10, 2);
|
||||
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, dataManager);
|
||||
TestContext context = new TestContext();
|
||||
context.listener.setRemainingOnNewEpochCallCount(5);
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
context.sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(4, dataManager.statEntries.size());
|
||||
|
||||
IDataManager.StatEntry entry = dataManager.statEntries.get(0);
|
||||
assertEquals(2, entry.getStepCounter());
|
||||
assertEquals(0, entry.getEpochCounter());
|
||||
assertEquals(2.0, entry.getReward(), 0.0);
|
||||
|
||||
entry = dataManager.statEntries.get(1);
|
||||
assertEquals(4, entry.getStepCounter());
|
||||
assertEquals(1, entry.getEpochCounter());
|
||||
assertEquals(2.0, entry.getReward(), 0.0);
|
||||
|
||||
entry = dataManager.statEntries.get(2);
|
||||
assertEquals(6, entry.getStepCounter());
|
||||
assertEquals(2, entry.getEpochCounter());
|
||||
assertEquals(2.0, entry.getReward(), 0.0);
|
||||
|
||||
entry = dataManager.statEntries.get(3);
|
||||
assertEquals(8, entry.getStepCounter());
|
||||
assertEquals(3, entry.getEpochCounter());
|
||||
assertEquals(2.0, entry.getReward(), 0.0);
|
||||
|
||||
assertEquals(0, dataManager.isSaveDataCallCount);
|
||||
assertEquals(0, dataManager.getVideoDirCallCount);
|
||||
assertEquals(6, context.neuralNet.resetCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void refac_withHistoryProcessor_isSaveFalse_checkDataManagerCallsRemainTheSame() {
|
||||
public void when_onNewEpochReturnsStop_expect_threadStopped() {
|
||||
// Arrange
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
||||
MockNeuralNet neuralNet = new MockNeuralNet();
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdp = new MockMDP(observationSpace);
|
||||
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
|
||||
|
||||
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
||||
.build();
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
|
||||
|
||||
|
||||
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
|
||||
sut.setHistoryProcessor(hp);
|
||||
TestContext context = new TestContext();
|
||||
context.listener.setRemainingOnNewEpochCallCount(1);
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
context.sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(9, dataManager.statEntries.size());
|
||||
|
||||
for(int i = 0; i < 9; ++i) {
|
||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||
assertEquals(i + 1, entry.getStepCounter());
|
||||
assertEquals(i, entry.getEpochCounter());
|
||||
assertEquals(79.0, entry.getReward(), 0.0);
|
||||
}
|
||||
|
||||
assertEquals(10, dataManager.isSaveDataCallCount);
|
||||
assertEquals(0, dataManager.getVideoDirCallCount);
|
||||
assertEquals(2, context.listener.onNewEpochCallCount);
|
||||
assertEquals(1, context.listener.onEpochTrainingResultCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void refac_withHistoryProcessor_isSaveTrue_checkDataManagerCallsRemainTheSame() {
|
||||
public void when_epochTrainingResultReturnsStop_expect_threadStopped() {
|
||||
// Arrange
|
||||
MockDataManager dataManager = new MockDataManager(true);
|
||||
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
||||
MockNeuralNet neuralNet = new MockNeuralNet();
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdp = new MockMDP(observationSpace);
|
||||
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
|
||||
|
||||
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
||||
.build();
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
|
||||
|
||||
|
||||
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
|
||||
sut.setHistoryProcessor(hp);
|
||||
TestContext context = new TestContext();
|
||||
context.listener.setRemainingOnEpochTrainingResult(1);
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
context.sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(9, dataManager.statEntries.size());
|
||||
|
||||
for(int i = 0; i < 9; ++i) {
|
||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||
assertEquals(i + 1, entry.getStepCounter());
|
||||
assertEquals(i, entry.getEpochCounter());
|
||||
assertEquals(79.0, entry.getReward(), 0.0);
|
||||
assertEquals(2, context.listener.onNewEpochCallCount);
|
||||
assertEquals(2, context.listener.onEpochTrainingResultCallCount);
|
||||
}
|
||||
|
||||
assertEquals(1, dataManager.isSaveDataCallCount);
|
||||
assertEquals(1, dataManager.getVideoDirCallCount);
|
||||
@Test
|
||||
public void when_run_expect_preAndPostEpochCalled() {
|
||||
// Arrange
|
||||
TestContext context = new TestContext();
|
||||
|
||||
// Act
|
||||
context.sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(6, context.sut.preEpochCallCount);
|
||||
assertEquals(6, context.sut.postEpochCallCount);
|
||||
}
|
||||
|
||||
public static class MockAsyncGlobal implements IAsyncGlobal {
|
||||
@Test
|
||||
public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() {
|
||||
// Arrange
|
||||
TestContext context = new TestContext();
|
||||
|
||||
private final int maxLoops;
|
||||
private int currentLoop = 0;
|
||||
// Act
|
||||
context.sut.run();
|
||||
|
||||
public MockAsyncGlobal(int maxLoops) {
|
||||
|
||||
this.maxLoops = maxLoops;
|
||||
// Assert
|
||||
assertEquals(5, context.listener.statEntries.size());
|
||||
int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 };
|
||||
for(int i = 0; i < 5; ++i) {
|
||||
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
|
||||
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
|
||||
assertEquals(i, statEntry.getEpochCounter());
|
||||
assertEquals(2.0, statEntry.getReward(), 0.0001);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRunning() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setRunning(boolean value) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTrainingComplete() {
|
||||
return ++currentLoop >= maxLoops;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public AtomicInteger getT() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet getCurrent() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet getTarget() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
||||
private static class TestContext {
|
||||
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
||||
public final MockNeuralNet neuralNet = new MockNeuralNet();
|
||||
public final MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
public final MockMDP mdp = new MockMDP(observationSpace);
|
||||
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 2);
|
||||
public final TrainingListenerList listeners = new TrainingListenerList();
|
||||
public final MockTrainingListener listener = new MockTrainingListener();
|
||||
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
|
||||
|
||||
public TestContext() {
|
||||
asyncGlobal.setMaxLoops(10);
|
||||
listeners.add(listener);
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockAsyncThread extends AsyncThread {
|
||||
|
||||
IAsyncGlobal asyncGlobal;
|
||||
private final MockNeuralNet neuralNet;
|
||||
private final MDP mdp;
|
||||
private final AsyncConfiguration conf;
|
||||
private final IDataManager dataManager;
|
||||
public int preEpochCallCount = 0;
|
||||
public int postEpochCallCount = 0;
|
||||
|
||||
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) {
|
||||
super(asyncGlobal, threadNumber, 0);
|
||||
|
||||
private final IAsyncGlobal asyncGlobal;
|
||||
private final MockNeuralNet neuralNet;
|
||||
private final AsyncConfiguration conf;
|
||||
|
||||
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
|
||||
super(asyncGlobal, mdp, listeners, threadNumber, 0);
|
||||
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.neuralNet = neuralNet;
|
||||
this.mdp = mdp;
|
||||
this.conf = conf;
|
||||
this.dataManager = dataManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void preEpoch() {
|
||||
++preEpochCallCount;
|
||||
super.preEpoch();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void postEpoch() {
|
||||
++postEpochCallCount;
|
||||
super.postEpoch();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -208,31 +137,16 @@ public class AsyncThreadTest {
|
|||
return neuralNet;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int getThreadNumber() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IAsyncGlobal getAsyncGlobal() {
|
||||
return asyncGlobal;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MDP getMdp() {
|
||||
return mdp;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AsyncConfiguration getConf() {
|
||||
return conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IDataManager getDataManager() {
|
||||
return dataManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Policy getPolicy(NeuralNet net) {
|
||||
return null;
|
||||
|
@ -244,129 +158,6 @@ public class AsyncThreadTest {
|
|||
}
|
||||
}
|
||||
|
||||
public static class MockNeuralNet implements NeuralNet {
|
||||
|
||||
@Override
|
||||
public NeuralNetwork[] getNeuralNetworks() {
|
||||
return new NeuralNetwork[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRecurrent() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] outputAll(INDArray batch) {
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet clone() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NeuralNet from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
||||
return new Gradient[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(INDArray input, INDArray[] labels) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void applyGradient(Gradient[] gradients, int batchSize) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getLatestScore() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(OutputStream os) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(String filename) throws IOException {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockAsyncConfiguration implements AsyncConfiguration {
|
||||
|
||||
private final int nStep;
|
||||
private final int maxEpochStep;
|
||||
|
||||
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
|
||||
this.nStep = nStep;
|
||||
|
||||
this.maxEpochStep = maxEpochStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSeed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxEpochStep() {
|
||||
return maxEpochStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxStep() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumThread() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNstep() {
|
||||
return nStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getTargetDqnUpdateFreq() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getUpdateStart() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getRewardFactor() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getGamma() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getErrorClamp() {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
package org.deeplearning4j.rl4j.learning.async.listener;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class AsyncTrainingListenerListTest {
|
||||
@Test
|
||||
public void when_listIsEmpty_expect_notifyTrainingStartedReturnTrue() {
|
||||
// Arrange
|
||||
TrainingListenerList sut = new TrainingListenerList();
|
||||
|
||||
// Act
|
||||
boolean resultTrainingStarted = sut.notifyTrainingStarted();
|
||||
boolean resultNewEpoch = sut.notifyNewEpoch(null);
|
||||
boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null);
|
||||
|
||||
// Assert
|
||||
assertTrue(resultTrainingStarted);
|
||||
assertTrue(resultNewEpoch);
|
||||
assertTrue(resultEpochTrainingResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_firstListerStops_expect_othersListnersNotCalled() {
|
||||
// Arrange
|
||||
MockTrainingListener listener1 = new MockTrainingListener();
|
||||
listener1.onTrainingResultResponse = TrainingListener.ListenerResponse.STOP;
|
||||
MockTrainingListener listener2 = new MockTrainingListener();
|
||||
TrainingListenerList sut = new TrainingListenerList();
|
||||
sut.add(listener1);
|
||||
sut.add(listener2);
|
||||
|
||||
// Act
|
||||
sut.notifyEpochTrainingResult(null, null);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, listener1.onEpochTrainingResultCallCount);
|
||||
assertEquals(0, listener2.onEpochTrainingResultCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_allListenersContinue_expect_listReturnsTrue() {
|
||||
// Arrange
|
||||
MockTrainingListener listener1 = new MockTrainingListener();
|
||||
MockTrainingListener listener2 = new MockTrainingListener();
|
||||
TrainingListenerList sut = new TrainingListenerList();
|
||||
sut.add(listener1);
|
||||
sut.add(listener2);
|
||||
|
||||
// Act
|
||||
boolean resultTrainingProgress = sut.notifyEpochTrainingResult(null, null);
|
||||
|
||||
// Assert
|
||||
assertTrue(resultTrainingProgress);
|
||||
}
|
||||
|
||||
private static class MockTrainingListener implements TrainingListener {
|
||||
|
||||
public int onEpochTrainingResultCallCount = 0;
|
||||
public ListenerResponse onTrainingResultResponse = ListenerResponse.CONTINUE;
|
||||
public int onTrainingProgressCallCount = 0;
|
||||
public ListenerResponse onTrainingProgressResponse = ListenerResponse.CONTINUE;
|
||||
|
||||
@Override
|
||||
public ListenerResponse onTrainingStart() {
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrainingEnd() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
|
||||
++onEpochTrainingResultCallCount;
|
||||
return onTrainingResultResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onTrainingProgress(ILearning learning) {
|
||||
++onTrainingProgressCallCount;
|
||||
return onTrainingProgressResponse;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
package org.deeplearning4j.rl4j.learning.listener;
|
||||
|
||||
import org.deeplearning4j.rl4j.support.MockTrainingListener;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class TrainingListenerListTest {
|
||||
@Test
|
||||
public void when_listIsEmpty_expect_notifyReturnTrue() {
|
||||
// Arrange
|
||||
TrainingListenerList sut = new TrainingListenerList();
|
||||
|
||||
// Act
|
||||
boolean resultTrainingStarted = sut.notifyTrainingStarted();
|
||||
boolean resultNewEpoch = sut.notifyNewEpoch(null);
|
||||
boolean resultEpochFinished = sut.notifyEpochTrainingResult(null, null);
|
||||
|
||||
// Assert
|
||||
assertTrue(resultTrainingStarted);
|
||||
assertTrue(resultNewEpoch);
|
||||
assertTrue(resultEpochFinished);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_firstListerStops_expect_othersListnersNotCalled() {
|
||||
// Arrange
|
||||
MockTrainingListener listener1 = new MockTrainingListener();
|
||||
listener1.setRemainingTrainingStartCallCount(0);
|
||||
listener1.setRemainingOnNewEpochCallCount(0);
|
||||
listener1.setRemainingonTrainingProgressCallCount(0);
|
||||
listener1.setRemainingOnEpochTrainingResult(0);
|
||||
MockTrainingListener listener2 = new MockTrainingListener();
|
||||
TrainingListenerList sut = new TrainingListenerList();
|
||||
sut.add(listener1);
|
||||
sut.add(listener2);
|
||||
|
||||
// Act
|
||||
sut.notifyTrainingStarted();
|
||||
sut.notifyNewEpoch(null);
|
||||
sut.notifyEpochTrainingResult(null, null);
|
||||
sut.notifyTrainingProgress(null);
|
||||
sut.notifyTrainingFinished();
|
||||
|
||||
// Assert
|
||||
assertEquals(1, listener1.onTrainingStartCallCount);
|
||||
assertEquals(0, listener2.onTrainingStartCallCount);
|
||||
|
||||
assertEquals(1, listener1.onNewEpochCallCount);
|
||||
assertEquals(0, listener2.onNewEpochCallCount);
|
||||
|
||||
assertEquals(1, listener1.onEpochTrainingResultCallCount);
|
||||
assertEquals(0, listener2.onEpochTrainingResultCallCount);
|
||||
|
||||
assertEquals(1, listener1.onTrainingProgressCallCount);
|
||||
assertEquals(0, listener2.onTrainingProgressCallCount);
|
||||
|
||||
assertEquals(1, listener1.onTrainingEndCallCount);
|
||||
assertEquals(1, listener2.onTrainingEndCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_allListenersContinue_expect_listReturnsTrue() {
|
||||
// Arrange
|
||||
MockTrainingListener listener1 = new MockTrainingListener();
|
||||
MockTrainingListener listener2 = new MockTrainingListener();
|
||||
TrainingListenerList sut = new TrainingListenerList();
|
||||
sut.add(listener1);
|
||||
sut.add(listener2);
|
||||
|
||||
// Act
|
||||
boolean resultTrainingStarted = sut.notifyTrainingStarted();
|
||||
boolean resultNewEpoch = sut.notifyNewEpoch(null);
|
||||
boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null);
|
||||
boolean resultProgress = sut.notifyTrainingProgress(null);
|
||||
|
||||
// Assert
|
||||
assertTrue(resultTrainingStarted);
|
||||
assertTrue(resultNewEpoch);
|
||||
assertTrue(resultEpochTrainingResult);
|
||||
assertTrue(resultProgress);
|
||||
}
|
||||
}
|
|
@ -2,12 +2,10 @@ package org.deeplearning4j.rl4j.learning.sync;
|
|||
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockSyncTrainingListener;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.support.MockTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -19,7 +17,7 @@ public class SyncLearningTest {
|
|||
public void when_training_expect_listenersToBeCalled() {
|
||||
// Arrange
|
||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||
MockTrainingListener listener = new MockTrainingListener();
|
||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||
sut.addListener(listener);
|
||||
|
||||
|
@ -27,8 +25,8 @@ public class SyncLearningTest {
|
|||
sut.train();
|
||||
|
||||
assertEquals(1, listener.onTrainingStartCallCount);
|
||||
assertEquals(10, listener.onEpochStartCallCount);
|
||||
assertEquals(10, listener.onEpochEndStartCallCount);
|
||||
assertEquals(10, listener.onNewEpochCallCount);
|
||||
assertEquals(10, listener.onEpochTrainingResultCallCount);
|
||||
assertEquals(1, listener.onTrainingEndCallCount);
|
||||
}
|
||||
|
||||
|
@ -36,65 +34,59 @@ public class SyncLearningTest {
|
|||
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
|
||||
// Arrange
|
||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||
MockTrainingListener listener = new MockTrainingListener();
|
||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||
sut.addListener(listener);
|
||||
listener.trainingStartCanContinue = false;
|
||||
listener.setRemainingTrainingStartCallCount(0);
|
||||
|
||||
// Act
|
||||
sut.train();
|
||||
|
||||
assertEquals(1, listener.onTrainingStartCallCount);
|
||||
assertEquals(0, listener.onEpochStartCallCount);
|
||||
assertEquals(0, listener.onEpochEndStartCallCount);
|
||||
assertEquals(0, listener.onNewEpochCallCount);
|
||||
assertEquals(0, listener.onEpochTrainingResultCallCount);
|
||||
assertEquals(1, listener.onTrainingEndCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_epochStartCanContinueFalse_expect_trainingStopped() {
|
||||
public void when_newEpochCanContinueFalse_expect_trainingStopped() {
|
||||
// Arrange
|
||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||
MockTrainingListener listener = new MockTrainingListener();
|
||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||
sut.addListener(listener);
|
||||
listener.nbStepsEpochStartCanContinue = 3;
|
||||
listener.setRemainingOnNewEpochCallCount(2);
|
||||
|
||||
// Act
|
||||
sut.train();
|
||||
|
||||
assertEquals(1, listener.onTrainingStartCallCount);
|
||||
assertEquals(3, listener.onEpochStartCallCount);
|
||||
assertEquals(2, listener.onEpochEndStartCallCount);
|
||||
assertEquals(3, listener.onNewEpochCallCount);
|
||||
assertEquals(2, listener.onEpochTrainingResultCallCount);
|
||||
assertEquals(1, listener.onTrainingEndCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_epochEndCanContinueFalse_expect_trainingStopped() {
|
||||
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
|
||||
// Arrange
|
||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
||||
MockTrainingListener listener = new MockTrainingListener();
|
||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||
sut.addListener(listener);
|
||||
listener.nbStepsEpochEndCanContinue = 3;
|
||||
listener.setRemainingOnEpochTrainingResult(2);
|
||||
|
||||
// Act
|
||||
sut.train();
|
||||
|
||||
assertEquals(1, listener.onTrainingStartCallCount);
|
||||
assertEquals(3, listener.onEpochStartCallCount);
|
||||
assertEquals(3, listener.onEpochEndStartCallCount);
|
||||
assertEquals(3, listener.onNewEpochCallCount);
|
||||
assertEquals(3, listener.onEpochTrainingResultCallCount);
|
||||
assertEquals(1, listener.onTrainingEndCallCount);
|
||||
}
|
||||
|
||||
public static class MockSyncLearning extends SyncLearning {
|
||||
|
||||
private LConfiguration conf;
|
||||
|
||||
public MockSyncLearning(LConfiguration conf, IDataManager dataManager) {
|
||||
super(conf);
|
||||
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
|
||||
this.conf = conf;
|
||||
}
|
||||
private final LConfiguration conf;
|
||||
|
||||
public MockSyncLearning(LConfiguration conf) {
|
||||
super(conf);
|
||||
|
@ -119,7 +111,7 @@ public class SyncLearningTest {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Policy getPolicy() {
|
||||
public IPolicy getPolicy() {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscrete;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockMDP;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class QLearningDiscreteTest {
|
||||
@Test
|
||||
public void refac_checkDataManagerCallsRemainTheSame() {
|
||||
// Arrange
|
||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder()
|
||||
.maxStep(10)
|
||||
.expRepMaxSize(1)
|
||||
.build();
|
||||
MockDataManager dataManager = new MockDataManager(true);
|
||||
MockQLearningDiscrete sut = new MockQLearningDiscrete(10, lconfig, dataManager, 2, 3);
|
||||
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
||||
.build();
|
||||
sut.setHistoryProcessor(new MockHistoryProcessor(hpConfig));
|
||||
|
||||
// Act
|
||||
sut.train();
|
||||
|
||||
assertEquals(10, dataManager.statEntries.size());
|
||||
for(int i = 0; i < 10; ++i) {
|
||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||
assertEquals(i, entry.getEpochCounter());
|
||||
assertEquals(i+1, entry.getStepCounter());
|
||||
assertEquals(1.0, entry.getReward(), 0.0);
|
||||
|
||||
}
|
||||
assertEquals(4, dataManager.isSaveDataCallCount);
|
||||
assertEquals(4, dataManager.getVideoDirCallCount);
|
||||
assertEquals(11, dataManager.writeInfoCallCount);
|
||||
assertEquals(5, dataManager.saveCallCount);
|
||||
}
|
||||
|
||||
public static class MockQLearningDiscrete extends QLearningDiscrete {
|
||||
|
||||
public MockQLearningDiscrete(int maxSteps, QLConfiguration conf,
|
||||
IDataManager dataManager, int saveFrequency, int monitorFrequency) {
|
||||
super(new MockMDP(maxSteps), new MockDQN(), conf, 2);
|
||||
addListener(DataManagerSyncTrainingListener.builder(dataManager)
|
||||
.saveFrequency(saveFrequency)
|
||||
.monitorFrequency(monitorFrequency)
|
||||
.build());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IDataManager.StatEntry trainEpoch() {
|
||||
setStepCounter(getStepCounter() + 1);
|
||||
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,6 +8,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
|||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.support.*;
|
||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -29,12 +30,11 @@ public class QLearningDiscreteTest {
|
|||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
||||
0, 1.0, 0, 0, 0, 0, true);
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, 10);
|
||||
MockExpReplay expReplay = new MockExpReplay();
|
||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10);
|
||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||
sut.setHistoryProcessor(hp);
|
||||
MockExpReplay expReplay = new MockExpReplay();
|
||||
sut.setExpReplay(expReplay);
|
||||
MockEncodable obs = new MockEncodable(1);
|
||||
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
||||
|
||||
|
@ -131,8 +131,11 @@ public class QLearningDiscreteTest {
|
|||
|
||||
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
||||
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn,
|
||||
QLConfiguration conf, IDataManager dataManager, int epsilonNbStep) {
|
||||
super(mdp, dqn, conf, dataManager, epsilonNbStep);
|
||||
QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
|
||||
int epsilonNbStep) {
|
||||
super(mdp, dqn, conf, epsilonNbStep);
|
||||
addListener(new DataManagerTrainingListener(dataManager));
|
||||
setExpReplay(expReplay);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
||||
|
||||
public class MockSyncTrainingListener implements SyncTrainingListener {
|
||||
|
||||
public int onTrainingStartCallCount = 0;
|
||||
public int onTrainingEndCallCount = 0;
|
||||
public int onEpochStartCallCount = 0;
|
||||
public int onEpochEndStartCallCount = 0;
|
||||
|
||||
public boolean trainingStartCanContinue = true;
|
||||
public int nbStepsEpochStartCanContinue = Integer.MAX_VALUE;
|
||||
public int nbStepsEpochEndCanContinue = Integer.MAX_VALUE;
|
||||
|
||||
@Override
|
||||
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
|
||||
++onTrainingStartCallCount;
|
||||
return trainingStartCanContinue ? ListenerResponse.CONTINUE : ListenerResponse.STOP;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrainingEnd() {
|
||||
++onTrainingEndCallCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
|
||||
++onEpochStartCallCount;
|
||||
if(onEpochStartCallCount >= nbStepsEpochStartCanContinue) {
|
||||
return ListenerResponse.STOP;
|
||||
}
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
|
||||
++onEpochEndStartCallCount;
|
||||
if(onEpochEndStartCallCount >= nbStepsEpochEndCanContinue) {
|
||||
return ListenerResponse.STOP;
|
||||
}
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
|
||||
|
||||
public class MockAsyncConfiguration implements AsyncConfiguration {
|
||||
|
||||
private final int nStep;
|
||||
private final int maxEpochStep;
|
||||
|
||||
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
|
||||
this.nStep = nStep;
|
||||
|
||||
this.maxEpochStep = maxEpochStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSeed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxEpochStep() {
|
||||
return maxEpochStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxStep() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumThread() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNstep() {
|
||||
return nStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getTargetDqnUpdateFreq() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getUpdateStart() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getRewardFactor() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getGamma() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getErrorClamp() {
|
||||
return 0;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class MockAsyncGlobal implements IAsyncGlobal {
|
||||
|
||||
public boolean hasBeenStarted = false;
|
||||
public boolean hasBeenTerminated = false;
|
||||
|
||||
@Setter
|
||||
private int maxLoops;
|
||||
@Setter
|
||||
private int numLoopsStopRunning;
|
||||
private int currentLoop = 0;
|
||||
|
||||
public MockAsyncGlobal() {
|
||||
maxLoops = Integer.MAX_VALUE;
|
||||
numLoopsStopRunning = Integer.MAX_VALUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRunning() {
|
||||
return currentLoop < numLoopsStopRunning;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void terminate() {
|
||||
hasBeenTerminated = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTrainingComplete() {
|
||||
return ++currentLoop > maxLoops;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
hasBeenStarted = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AtomicInteger getT() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet getCurrent() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet getTarget() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
||||
|
||||
}
|
||||
}
|
|
@ -44,7 +44,7 @@ public class MockDataManager implements IDataManager {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void save(Learning learning) throws IOException {
|
||||
public void save(ILearning learning) throws IOException {
|
||||
++saveCallCount;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
|
||||
public class MockNeuralNet implements NeuralNet {
|
||||
|
||||
public int resetCallCount = 0;
|
||||
|
||||
@Override
|
||||
public NeuralNetwork[] getNeuralNetworks() {
|
||||
return new NeuralNetwork[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRecurrent() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
++resetCallCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] outputAll(INDArray batch) {
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet clone() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NeuralNet from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
||||
return new Gradient[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(INDArray input, INDArray[] labels) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void applyGradient(Gradient[] gradients, int batchSize) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getLatestScore() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(OutputStream os) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(String filename) throws IOException {
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
|
||||
public class MockPolicy implements IPolicy<MockEncodable, Integer> {
|
||||
|
||||
public int playCallCount = 0;
|
||||
|
||||
@Override
|
||||
public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) {
|
||||
++playCallCount;
|
||||
return 0;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MockTrainingListener implements TrainingListener {
|
||||
|
||||
public int onTrainingStartCallCount = 0;
|
||||
public int onTrainingEndCallCount = 0;
|
||||
public int onNewEpochCallCount = 0;
|
||||
public int onEpochTrainingResultCallCount = 0;
|
||||
public int onTrainingProgressCallCount = 0;
|
||||
|
||||
@Setter
|
||||
private int remainingTrainingStartCallCount = Integer.MAX_VALUE;
|
||||
@Setter
|
||||
private int remainingOnNewEpochCallCount = Integer.MAX_VALUE;
|
||||
@Setter
|
||||
private int remainingOnEpochTrainingResult = Integer.MAX_VALUE;
|
||||
@Setter
|
||||
private int remainingonTrainingProgressCallCount = Integer.MAX_VALUE;
|
||||
|
||||
public final List<IDataManager.StatEntry> statEntries = new ArrayList<>();
|
||||
|
||||
|
||||
@Override
|
||||
public ListenerResponse onTrainingStart() {
|
||||
++onTrainingStartCallCount;
|
||||
--remainingTrainingStartCallCount;
|
||||
return remainingTrainingStartCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
|
||||
++onNewEpochCallCount;
|
||||
--remainingOnNewEpochCallCount;
|
||||
return remainingOnNewEpochCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
|
||||
++onEpochTrainingResultCallCount;
|
||||
--remainingOnEpochTrainingResult;
|
||||
statEntries.add(statEntry);
|
||||
return remainingOnEpochTrainingResult < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse onTrainingProgress(ILearning learning) {
|
||||
++onTrainingProgressCallCount;
|
||||
--remainingonTrainingProgressCallCount;
|
||||
return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrainingEnd() {
|
||||
++onTrainingEndCallCount;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,169 @@
|
|||
package org.deeplearning4j.rl4j.util;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.support.MockMDP;
|
||||
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertSame;
|
||||
|
||||
public class DataManagerTrainingListenerTest {
|
||||
|
||||
@Test
|
||||
public void when_callingOnNewEpochWithoutHistoryProcessor_expect_noException() {
|
||||
// Arrange
|
||||
TestTrainer trainer = new TestTrainer();
|
||||
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
|
||||
|
||||
// Act
|
||||
TrainingListener.ListenerResponse response = sut.onNewEpoch(trainer);
|
||||
|
||||
// Assert
|
||||
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingOnNewEpochWithHistoryProcessor_expect_startMonitorNotCalled() {
|
||||
// Arrange
|
||||
TestTrainer trainer = new TestTrainer();
|
||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||
trainer.setHistoryProcessor(hp);
|
||||
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
|
||||
|
||||
// Act
|
||||
TrainingListener.ListenerResponse response = sut.onNewEpoch(trainer);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, hp.startMonitorCallCount);
|
||||
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingOnEpochTrainingResultWithoutHistoryProcessor_expect_noException() {
|
||||
// Arrange
|
||||
TestTrainer trainer = new TestTrainer();
|
||||
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
|
||||
|
||||
// Act
|
||||
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, null);
|
||||
|
||||
// Assert
|
||||
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingOnNewEpochWithHistoryProcessor_expect_stopMonitorNotCalled() {
|
||||
// Arrange
|
||||
TestTrainer trainer = new TestTrainer();
|
||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||
trainer.setHistoryProcessor(hp);
|
||||
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
|
||||
|
||||
// Act
|
||||
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, null);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, hp.stopMonitorCallCount);
|
||||
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingOnEpochTrainingResult_expect_callToDataManagerAppendStat() {
|
||||
// Arrange
|
||||
TestTrainer trainer = new TestTrainer();
|
||||
MockDataManager dm = new MockDataManager(false);
|
||||
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
|
||||
MockStatEntry statEntry = new MockStatEntry(0, 0, 0.0);
|
||||
|
||||
// Act
|
||||
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, statEntry);
|
||||
|
||||
// Assert
|
||||
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||
assertEquals(1, dm.statEntries.size());
|
||||
assertSame(statEntry, dm.statEntries.get(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingOnTrainingProgress_expect_callToDataManagerSaveAndWriteInfo() {
|
||||
// Arrange
|
||||
TestTrainer learning = new TestTrainer();
|
||||
MockDataManager dm = new MockDataManager(false);
|
||||
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
|
||||
|
||||
// Act
|
||||
TrainingListener.ListenerResponse response = sut.onTrainingProgress(learning);
|
||||
|
||||
// Assert
|
||||
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||
assertEquals(1, dm.writeInfoCallCount);
|
||||
assertEquals(1, dm.saveCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_stepCounterCloseToLastSave_expect_dataManagerSaveNotCalled() {
|
||||
// Arrange
|
||||
TestTrainer learning = new TestTrainer();
|
||||
MockDataManager dm = new MockDataManager(false);
|
||||
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
|
||||
|
||||
// Act
|
||||
TrainingListener.ListenerResponse response = sut.onTrainingProgress(learning);
|
||||
TrainingListener.ListenerResponse response2 = sut.onTrainingProgress(learning);
|
||||
|
||||
// Assert
|
||||
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response2);
|
||||
assertEquals(1, dm.saveCallCount);
|
||||
}
|
||||
|
||||
private static class TestTrainer implements IEpochTrainer, ILearning
|
||||
{
|
||||
@Override
|
||||
public int getStepCounter() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getEpochCounter() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
private IHistoryProcessor historyProcessor;
|
||||
|
||||
@Override
|
||||
public IPolicy getPolicy() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public LConfiguration getConfiguration() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MDP getMdp() {
|
||||
return new MockMDP(new MockObservationSpace());
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue