RL4J - AsyncTrainingListener (#8072)
* Code clarity: Extracted parts of run() into private methods Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Added listener pattern to async learning Signed-off-by: unknown <aboulang2002@yahoo.com> * Merged all listeners logic Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Added interface and common data to training events Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Fixed missing info log file Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Fixed bad merge; removed useless TrainingEvent Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Removed param from training start/end event Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Removed 'event' classes from the training listener Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Reverted changes to QLearningDiscrete.setTarget()master
parent
d58a4b45b1
commit
59f1cbf0c6
|
@ -0,0 +1,31 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The common API between Learning and AsyncThread.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface IEpochTrainer {
|
||||||
|
int getStepCounter();
|
||||||
|
int getEpochCounter();
|
||||||
|
IHistoryProcessor getHistoryProcessor();
|
||||||
|
MDP getMdp();
|
||||||
|
}
|
|
@ -17,7 +17,7 @@
|
||||||
package org.deeplearning4j.rl4j.learning;
|
package org.deeplearning4j.rl4j.learning;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
*/
|
*/
|
||||||
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> extends StepCountable {
|
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> extends StepCountable {
|
||||||
|
|
||||||
Policy<O, A> getPolicy();
|
IPolicy<O, A> getPolicy();
|
||||||
|
|
||||||
void train();
|
void train();
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> ex
|
||||||
|
|
||||||
MDP<O, A, AS> getMdp();
|
MDP<O, A, AS> getMdp();
|
||||||
|
|
||||||
|
IHistoryProcessor getHistoryProcessor();
|
||||||
|
|
||||||
interface LConfiguration {
|
interface LConfiguration {
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,6 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
@ -63,7 +62,6 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
||||||
@Getter
|
@Getter
|
||||||
private NN target;
|
private NN target;
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
|
||||||
private boolean running = true;
|
private boolean running = true;
|
||||||
|
|
||||||
public AsyncGlobal(NN initial, AsyncConfiguration a3cc) {
|
public AsyncGlobal(NN initial, AsyncConfiguration a3cc) {
|
||||||
|
@ -78,8 +76,10 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
||||||
}
|
}
|
||||||
|
|
||||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
public void enqueue(Gradient[] gradient, Integer nstep) {
|
||||||
|
if(running && !isTrainingComplete()) {
|
||||||
queue.add(new Pair<>(gradient, nstep));
|
queue.add(new Pair<>(gradient, nstep));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
@ -105,4 +105,12 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
|
||||||
|
*/
|
||||||
|
public void terminate() {
|
||||||
|
running = false;
|
||||||
|
queue.clear();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,33 +16,49 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
|
import lombok.AccessLevel;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* The entry point for async training. This class will start a number ({@link AsyncConfiguration#getNumThread()
|
||||||
|
* configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals
|
||||||
|
* (see setProgressEventInterval(int))
|
||||||
|
*
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16.
|
||||||
*
|
* @author Alexandre Boulanger
|
||||||
* Async learning always follow the same pattern in RL4J
|
|
||||||
* -launch the Global thread
|
|
||||||
* -launch the "save threads"
|
|
||||||
* -periodically evaluate the model of the global thread for monitoring purposes
|
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
extends Learning<O, A, AS, NN> {
|
extends Learning<O, A, AS, NN> {
|
||||||
|
|
||||||
protected abstract IDataManager getDataManager();
|
@Getter(AccessLevel.PROTECTED)
|
||||||
|
private final TrainingListenerList listeners = new TrainingListenerList();
|
||||||
|
|
||||||
public AsyncLearning(AsyncConfiguration conf) {
|
public AsyncLearning(AsyncConfiguration conf) {
|
||||||
super(conf);
|
super(conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a {@link TrainingListener} listener at the end of the listener list.
|
||||||
|
*
|
||||||
|
* @param listener the listener to be added
|
||||||
|
*/
|
||||||
|
public void addListener(TrainingListener listener) {
|
||||||
|
listeners.add(listener);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the configuration
|
||||||
|
* @return the configuration (see {@link AsyncConfiguration})
|
||||||
|
*/
|
||||||
public abstract AsyncConfiguration getConfiguration();
|
public abstract AsyncConfiguration getConfiguration();
|
||||||
|
|
||||||
protected abstract AsyncThread newThread(int i, int deviceAffinity);
|
protected abstract AsyncThread newThread(int i, int deviceAffinity);
|
||||||
|
@ -57,41 +73,80 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
return getAsyncGlobal().isTrainingComplete();
|
return getAsyncGlobal().isTrainingComplete();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void launchThreads() {
|
private boolean canContinue = true;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of milliseconds between calls to onTrainingProgress
|
||||||
|
*/
|
||||||
|
@Getter @Setter
|
||||||
|
private int progressMonitorFrequency = 20000;
|
||||||
|
|
||||||
|
private void launchThreads() {
|
||||||
startGlobalThread();
|
startGlobalThread();
|
||||||
for (int i = 0; i < getConfiguration().getNumThread(); i++) {
|
for (int i = 0; i < getConfiguration().getNumThread(); i++) {
|
||||||
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
|
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
|
||||||
t.start();
|
t.start();
|
||||||
|
|
||||||
}
|
}
|
||||||
log.info("Threads launched.");
|
log.info("Threads launched.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The current step
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
public int getStepCounter() {
|
public int getStepCounter() {
|
||||||
return getAsyncGlobal().getT().get();
|
return getAsyncGlobal().getT().get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will train the model<p>
|
||||||
|
* The training stop when:<br>
|
||||||
|
* - A worker thread terminate the AsyncGlobal thread (see {@link AsyncGlobal})<br>
|
||||||
|
* OR<br>
|
||||||
|
* - a listener explicitly stops it<br>
|
||||||
|
* <p>
|
||||||
|
* Listeners<br>
|
||||||
|
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
|
||||||
|
* returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse TrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
|
||||||
|
* Events:
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link TrainingListener#onTrainingStart() onTrainingStart()} is called once when the training starts.</li>
|
||||||
|
* <li>{@link TrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
public void train() {
|
public void train() {
|
||||||
|
|
||||||
try {
|
|
||||||
log.info("AsyncLearning training starting.");
|
log.info("AsyncLearning training starting.");
|
||||||
|
|
||||||
|
canContinue = listeners.notifyTrainingStarted();
|
||||||
|
if (canContinue) {
|
||||||
launchThreads();
|
launchThreads();
|
||||||
|
monitorTraining();
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanupPostTraining();
|
||||||
|
listeners.notifyTrainingFinished();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void monitorTraining() {
|
||||||
|
try {
|
||||||
|
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
||||||
|
canContinue = listeners.notifyTrainingProgress(this);
|
||||||
|
if(!canContinue) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//this is simply for stat purposes
|
|
||||||
getDataManager().writeInfo(this);
|
|
||||||
synchronized (this) {
|
synchronized (this) {
|
||||||
while (!isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
wait(progressMonitorFrequency);
|
||||||
getPolicy().play(getMdp(), getHistoryProcessor());
|
|
||||||
getDataManager().writeInfo(this);
|
|
||||||
wait(20000);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (InterruptedException e) {
|
||||||
log.error("Training failed.", e);
|
log.error("Training interrupted.", e);
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void cleanupPostTraining() {
|
||||||
|
// Worker threads stops automatically when the global thread stops
|
||||||
|
getAsyncGlobal().terminate();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,33 +21,31 @@ import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.*;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
|
||||||
import org.deeplearning4j.rl4j.learning.StepCountable;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.Constants;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
|
||||||
*
|
|
||||||
* This represent a local thread that explore the environment
|
* This represent a local thread that explore the environment
|
||||||
* and calculate a gradient to enqueue to the global thread/model
|
* and calculate a gradient to enqueue to the global thread/model
|
||||||
*
|
*
|
||||||
* It has its own version of a model that it syncs at the start of every
|
* It has its own version of a model that it syncs at the start of every
|
||||||
* sub epoch
|
* sub epoch
|
||||||
*
|
*
|
||||||
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
extends Thread implements StepCountable {
|
extends Thread implements StepCountable, IEpochTrainer {
|
||||||
|
|
||||||
|
@Getter
|
||||||
private int threadNumber;
|
private int threadNumber;
|
||||||
@Getter
|
@Getter
|
||||||
protected final int deviceNum;
|
protected final int deviceNum;
|
||||||
|
@ -55,12 +53,16 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
private int stepCounter = 0;
|
private int stepCounter = 0;
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private int epochCounter = 0;
|
private int epochCounter = 0;
|
||||||
|
@Getter
|
||||||
|
private MDP<O, A, AS> mdp;
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private IHistoryProcessor historyProcessor;
|
private IHistoryProcessor historyProcessor;
|
||||||
@Getter
|
|
||||||
private int lastMonitor = -Constants.MONITOR_FREQ;
|
|
||||||
|
|
||||||
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
|
private final TrainingListenerList listeners;
|
||||||
|
|
||||||
|
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, MDP<O, A, AS> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
||||||
|
this.mdp = mdp;
|
||||||
|
this.listeners = listeners;
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
this.deviceNum = deviceNum;
|
this.deviceNum = deviceNum;
|
||||||
}
|
}
|
||||||
|
@ -80,75 +82,106 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void preEpoch() {
|
protected void preEpoch() {
|
||||||
if (getStepCounter() - lastMonitor >= Constants.MONITOR_FREQ && getHistoryProcessor() != null
|
// Do nothing
|
||||||
&& getDataManager().isSaveData()) {
|
|
||||||
lastMonitor = getStepCounter();
|
|
||||||
int[] shape = getMdp().getObservationSpace().getShape();
|
|
||||||
getHistoryProcessor().startMonitor(getDataManager().getVideoDir() + "/video-" + threadNumber + "-"
|
|
||||||
+ getEpochCounter() + "-" + getStepCounter() + ".mp4", shape);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will start the worker thread<p>
|
||||||
|
* The thread will stop when:<br>
|
||||||
|
* - The AsyncGlobal thread terminates or reports that the training is complete
|
||||||
|
* (see {@link AsyncGlobal#isTrainingComplete()}). In such case, the currently running epoch will still be handled normally and
|
||||||
|
* events will also be fired normally.<br>
|
||||||
|
* OR<br>
|
||||||
|
* - a listener explicitly stops it, in which case, the AsyncGlobal thread will be terminated along with
|
||||||
|
* all other worker threads <br>
|
||||||
|
* <p>
|
||||||
|
* Listeners<br>
|
||||||
|
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
|
||||||
|
* returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse
|
||||||
|
* TrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
|
||||||
|
* Events:
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link TrainingListener#onNewEpoch(IEpochTrainer) onNewEpoch()} is called when a new epoch is started.</li>
|
||||||
|
* <li>{@link TrainingListener#onEpochTrainingResult(IEpochTrainer, IDataManager.StatEntry) onEpochTrainingResult()} is called at the end of every
|
||||||
|
* epoch. It will not be called if onNewEpoch() stops the training.</li>
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
RunContext<O> context = new RunContext<>();
|
||||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
|
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
|
||||||
|
|
||||||
|
|
||||||
try {
|
|
||||||
log.info("ThreadNum-" + threadNumber + " Started!");
|
log.info("ThreadNum-" + threadNumber + " Started!");
|
||||||
|
|
||||||
|
boolean canContinue = initWork(context);
|
||||||
|
if (canContinue) {
|
||||||
|
|
||||||
|
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
||||||
|
handleTraining(context);
|
||||||
|
if (context.length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||||
|
canContinue = finishEpoch(context) && startNewEpoch(context);
|
||||||
|
if (!canContinue) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
terminateWork();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initNewEpoch(RunContext context) {
|
||||||
getCurrent().reset();
|
getCurrent().reset();
|
||||||
Learning.InitMdp<O> initMdp = Learning.initMdp(getMdp(), historyProcessor);
|
Learning.InitMdp<O> initMdp = Learning.initMdp(getMdp(), historyProcessor);
|
||||||
O obs = initMdp.getLastObs();
|
|
||||||
double rewards = initMdp.getReward();
|
|
||||||
int length = initMdp.getSteps();
|
|
||||||
|
|
||||||
preEpoch();
|
context.obs = initMdp.getLastObs();
|
||||||
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
context.rewards = initMdp.getReward();
|
||||||
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - length);
|
context.length = initMdp.getSteps();
|
||||||
SubEpochReturn<O> subEpochReturn = trainSubEpoch(obs, maxSteps);
|
}
|
||||||
obs = subEpochReturn.getLastObs();
|
|
||||||
|
private void handleTraining(RunContext<O> context) {
|
||||||
|
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.length);
|
||||||
|
SubEpochReturn<O> subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
||||||
|
|
||||||
|
context.obs = subEpochReturn.getLastObs();
|
||||||
stepCounter += subEpochReturn.getSteps();
|
stepCounter += subEpochReturn.getSteps();
|
||||||
length += subEpochReturn.getSteps();
|
context.length += subEpochReturn.getSteps();
|
||||||
rewards += subEpochReturn.getReward();
|
context.rewards += subEpochReturn.getReward();
|
||||||
double score = subEpochReturn.getScore();
|
context.score = subEpochReturn.getScore();
|
||||||
if (length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
}
|
||||||
postEpoch();
|
|
||||||
|
|
||||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score);
|
|
||||||
getDataManager().appendStat(statEntry);
|
|
||||||
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
|
||||||
|
|
||||||
getCurrent().reset();
|
|
||||||
initMdp = Learning.initMdp(getMdp(), historyProcessor);
|
|
||||||
obs = initMdp.getLastObs();
|
|
||||||
rewards = initMdp.getReward();
|
|
||||||
length = initMdp.getSteps();
|
|
||||||
epochCounter++;
|
|
||||||
|
|
||||||
|
private boolean initWork(RunContext context) {
|
||||||
|
initNewEpoch(context);
|
||||||
preEpoch();
|
preEpoch();
|
||||||
|
return listeners.notifyNewEpoch(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean startNewEpoch(RunContext context) {
|
||||||
|
initNewEpoch(context);
|
||||||
|
epochCounter++;
|
||||||
|
preEpoch();
|
||||||
|
return listeners.notifyNewEpoch(this);
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Thread crashed: " + e.getCause());
|
private boolean finishEpoch(RunContext context) {
|
||||||
getAsyncGlobal().setRunning(false);
|
|
||||||
e.printStackTrace();
|
|
||||||
} finally {
|
|
||||||
postEpoch();
|
postEpoch();
|
||||||
|
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.length, context.score);
|
||||||
|
|
||||||
|
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
|
||||||
|
|
||||||
|
return listeners.notifyEpochTrainingResult(this, statEntry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void terminateWork() {
|
||||||
|
postEpoch();
|
||||||
|
getAsyncGlobal().terminate();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract NN getCurrent();
|
protected abstract NN getCurrent();
|
||||||
|
|
||||||
protected abstract int getThreadNumber();
|
|
||||||
|
|
||||||
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
||||||
|
|
||||||
protected abstract MDP<O, A, AS> getMdp();
|
|
||||||
|
|
||||||
protected abstract AsyncConfiguration getConf();
|
protected abstract AsyncConfiguration getConf();
|
||||||
|
|
||||||
protected abstract IDataManager getDataManager();
|
|
||||||
|
|
||||||
protected abstract Policy<O, A> getPolicy(NN net);
|
protected abstract Policy<O, A> getPolicy(NN net);
|
||||||
|
|
||||||
protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep);
|
protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep);
|
||||||
|
@ -172,4 +205,11 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
double score;
|
double score;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class RunContext<O extends Encodable> {
|
||||||
|
private O obs;
|
||||||
|
private double rewards;
|
||||||
|
private int length;
|
||||||
|
private double score;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,9 @@ import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
|
@ -44,8 +46,8 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
||||||
@Getter
|
@Getter
|
||||||
private NN current;
|
private NN current;
|
||||||
|
|
||||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
|
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, MDP<O, Integer, DiscreteSpace> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
||||||
super(asyncGlobal, threadNumber, deviceNum);
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
synchronized (asyncGlobal) {
|
synchronized (asyncGlobal) {
|
||||||
current = (NN)asyncGlobal.getCurrent().clone();
|
current = (NN)asyncGlobal.getCurrent().clone();
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,9 +23,14 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
public interface IAsyncGlobal<NN extends NeuralNet> {
|
public interface IAsyncGlobal<NN extends NeuralNet> {
|
||||||
boolean isRunning();
|
boolean isRunning();
|
||||||
void setRunning(boolean value);
|
|
||||||
boolean isTrainingComplete();
|
boolean isTrainingComplete();
|
||||||
void start();
|
void start();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
|
||||||
|
*/
|
||||||
|
void terminate();
|
||||||
|
|
||||||
AtomicInteger getT();
|
AtomicInteger getT();
|
||||||
NN getCurrent();
|
NN getCurrent();
|
||||||
NN getTarget();
|
NN getTarget();
|
||||||
|
|
|
@ -26,7 +26,6 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
||||||
|
@ -47,24 +46,19 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
||||||
final private AsyncGlobal asyncGlobal;
|
final private AsyncGlobal asyncGlobal;
|
||||||
@Getter
|
@Getter
|
||||||
final private ACPolicy<O> policy;
|
final private ACPolicy<O> policy;
|
||||||
@Getter
|
|
||||||
final private IDataManager dataManager;
|
|
||||||
|
|
||||||
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf,
|
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) {
|
||||||
IDataManager dataManager) {
|
|
||||||
super(conf);
|
super(conf);
|
||||||
this.iActorCritic = iActorCritic;
|
this.iActorCritic = iActorCritic;
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.dataManager = dataManager;
|
|
||||||
policy = new ACPolicy<>(iActorCritic, getRandom());
|
policy = new ACPolicy<>(iActorCritic, getRandom());
|
||||||
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
|
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
protected AsyncThread newThread(int i, int deviceNum) {
|
protected AsyncThread newThread(int i, int deviceNum) {
|
||||||
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), i, dataManager, deviceNum);
|
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), deviceNum, getListeners(), i);
|
||||||
}
|
}
|
||||||
|
|
||||||
public IActorCritic getNeuralNet() {
|
public IActorCritic getNeuralNet() {
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -43,24 +44,38 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
||||||
|
|
||||||
final private HistoryProcessor.Configuration hpconf;
|
final private HistoryProcessor.Configuration hpconf;
|
||||||
|
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
@Deprecated
|
||||||
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
super(mdp, IActorCritic, conf, dataManager);
|
this(mdp, actorCritic, hpconf, conf);
|
||||||
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
|
}
|
||||||
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
||||||
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||||
|
super(mdp, IActorCritic, conf);
|
||||||
this.hpconf = hpconf;
|
this.hpconf = hpconf;
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf,
|
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||||
dataManager);
|
}
|
||||||
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||||
|
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
|
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
||||||
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||||
|
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AsyncThread newThread(int i, int deviceNum) {
|
public AsyncThread newThread(int i, int deviceNum) {
|
||||||
|
|
|
@ -20,6 +20,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.ac.*;
|
import org.deeplearning4j.rl4j.network.ac.*;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -33,33 +34,58 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
*/
|
*/
|
||||||
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
|
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
|
||||||
IDataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
super(mdp, IActorCritic, conf, dataManager);
|
this(mdp, IActorCritic, conf);
|
||||||
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
|
}
|
||||||
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
|
||||||
|
super(mdp, actorCritic, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||||
A3CConfiguration conf, IDataManager dataManager) {
|
A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||||
|
A3CConfiguration conf) {
|
||||||
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
|
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
|
||||||
IDataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
|
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
|
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
|
||||||
|
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
A3CConfiguration conf, IDataManager dataManager) {
|
A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
|
A3CConfiguration conf) {
|
||||||
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
|
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
|
||||||
IDataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
|
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
|
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf) {
|
||||||
|
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,13 +22,13 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
@ -46,24 +46,19 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
@Getter
|
@Getter
|
||||||
final protected A3CDiscrete.A3CConfiguration conf;
|
final protected A3CDiscrete.A3CConfiguration conf;
|
||||||
@Getter
|
@Getter
|
||||||
final protected MDP<O, Integer, DiscreteSpace> mdp;
|
|
||||||
@Getter
|
|
||||||
final protected AsyncGlobal<IActorCritic> asyncGlobal;
|
final protected AsyncGlobal<IActorCritic> asyncGlobal;
|
||||||
@Getter
|
@Getter
|
||||||
final protected int threadNumber;
|
final protected int threadNumber;
|
||||||
@Getter
|
|
||||||
final protected IDataManager dataManager;
|
|
||||||
|
|
||||||
final private Random random;
|
final private Random random;
|
||||||
|
|
||||||
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
|
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
|
||||||
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager, int deviceNum) {
|
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
|
||||||
super(asyncGlobal, threadNumber, deviceNum);
|
int threadNumber) {
|
||||||
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
this.conf = a3cc;
|
this.conf = a3cc;
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
this.mdp = mdp;
|
|
||||||
this.dataManager = dataManager;
|
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
|
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
|
||||||
random = new Random(conf.getSeed() + threadNumber);
|
random = new Random(conf.getSeed() + threadNumber);
|
||||||
}
|
}
|
||||||
|
@ -85,15 +80,15 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
//if recurrent then train as a time serie with a batch size of 1
|
//if recurrent then train as a time serie with a batch size of 1
|
||||||
boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent();
|
boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent();
|
||||||
|
|
||||||
int[] shape = getHistoryProcessor() == null ? mdp.getObservationSpace().getShape()
|
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
|
||||||
: getHistoryProcessor().getConf().getShape();
|
: getHistoryProcessor().getConf().getShape();
|
||||||
int[] nshape = recurrent ? Learning.makeShape(1, shape, size)
|
int[] nshape = recurrent ? Learning.makeShape(1, shape, size)
|
||||||
: Learning.makeShape(size, shape);
|
: Learning.makeShape(size, shape);
|
||||||
|
|
||||||
INDArray input = Nd4j.create(nshape);
|
INDArray input = Nd4j.create(nshape);
|
||||||
INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1);
|
INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1);
|
||||||
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, mdp.getActionSpace().getSize(), size)
|
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, getMdp().getActionSpace().getSize(), size)
|
||||||
: Nd4j.zeros(size, mdp.getActionSpace().getSize());
|
: Nd4j.zeros(size, getMdp().getActionSpace().getSize());
|
||||||
|
|
||||||
double r = minTrans.getReward();
|
double r = minTrans.getReward();
|
||||||
for (int i = size - 1; i >= 0; i--) {
|
for (int i = size - 1; i >= 0; i--) {
|
||||||
|
|
|
@ -24,10 +24,9 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
|
@ -40,16 +39,12 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
@Getter
|
@Getter
|
||||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
final private MDP<O, Integer, DiscreteSpace> mdp;
|
||||||
@Getter
|
@Getter
|
||||||
final private IDataManager dataManager;
|
|
||||||
@Getter
|
|
||||||
final private AsyncGlobal<IDQN> asyncGlobal;
|
final private AsyncGlobal<IDQN> asyncGlobal;
|
||||||
|
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf,
|
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
|
||||||
IDataManager dataManager) {
|
|
||||||
super(conf);
|
super(conf);
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.dataManager = dataManager;
|
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||||
|
@ -57,14 +52,14 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AsyncThread newThread(int i, int deviceNum) {
|
public AsyncThread newThread(int i, int deviceNum) {
|
||||||
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager, deviceNum);
|
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, getListeners(), i, deviceNum);
|
||||||
}
|
}
|
||||||
|
|
||||||
public IDQN getNeuralNet() {
|
public IDQN getNeuralNet() {
|
||||||
return asyncGlobal.getCurrent();
|
return asyncGlobal.getCurrent();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Policy<O, Integer> getPolicy() {
|
public IPolicy<O, Integer> getPolicy() {
|
||||||
return new DQNPolicy<O>(getNeuralNet());
|
return new DQNPolicy<O>(getNeuralNet());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -35,22 +36,38 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
|
||||||
|
|
||||||
final private HistoryProcessor.Configuration hpconf;
|
final private HistoryProcessor.Configuration hpconf;
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
super(mdp, dqn, conf, dataManager);
|
this(mdp, dqn, hpconf, conf);
|
||||||
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
|
}
|
||||||
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
|
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
|
||||||
|
super(mdp, dqn, conf);
|
||||||
this.hpconf = hpconf;
|
this.hpconf = hpconf;
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
|
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
|
||||||
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||||
|
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
|
||||||
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AsyncThread newThread(int i, int deviceNum) {
|
public AsyncThread newThread(int i, int deviceNum) {
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -29,19 +30,37 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
*/
|
*/
|
||||||
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
|
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
super(mdp, dqn, conf, dataManager);
|
super(mdp, dqn, conf);
|
||||||
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
|
AsyncNStepQLConfiguration conf) {
|
||||||
|
super(mdp, dqn, conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
|
AsyncNStepQLConfiguration conf) {
|
||||||
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
|
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
|
||||||
|
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,9 +19,10 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
|
@ -29,7 +30,6 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -44,31 +44,25 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
@Getter
|
@Getter
|
||||||
final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf;
|
final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf;
|
||||||
@Getter
|
@Getter
|
||||||
final protected MDP<O, Integer, DiscreteSpace> mdp;
|
|
||||||
@Getter
|
|
||||||
final protected IAsyncGlobal<IDQN> asyncGlobal;
|
final protected IAsyncGlobal<IDQN> asyncGlobal;
|
||||||
@Getter
|
@Getter
|
||||||
final protected int threadNumber;
|
final protected int threadNumber;
|
||||||
@Getter
|
|
||||||
final protected IDataManager dataManager;
|
|
||||||
|
|
||||||
final private Random random;
|
final private Random random;
|
||||||
|
|
||||||
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
||||||
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber,
|
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf,
|
||||||
IDataManager dataManager, int deviceNum) {
|
TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
||||||
super(asyncGlobal, threadNumber, deviceNum);
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
this.mdp = mdp;
|
|
||||||
this.dataManager = dataManager;
|
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
|
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
|
||||||
random = new Random(conf.getSeed() + threadNumber);
|
random = new Random(conf.getSeed() + threadNumber);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Policy<O, Integer> getPolicy(IDQN nn) {
|
public Policy<O, Integer> getPolicy(IDQN nn) {
|
||||||
return new EpsGreedy(new DQNPolicy(nn), mdp, conf.getUpdateStart(), conf.getEpsilonNbStep(),
|
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
|
||||||
random, conf.getMinEpsilon(), this);
|
random, conf.getMinEpsilon(), this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,11 +75,11 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
|
|
||||||
int size = rewards.size();
|
int size = rewards.size();
|
||||||
|
|
||||||
int[] shape = getHistoryProcessor() == null ? mdp.getObservationSpace().getShape()
|
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
|
||||||
: getHistoryProcessor().getConf().getShape();
|
: getHistoryProcessor().getConf().getShape();
|
||||||
int[] nshape = Learning.makeShape(size, shape);
|
int[] nshape = Learning.makeShape(size, shape);
|
||||||
INDArray input = Nd4j.create(nshape);
|
INDArray input = Nd4j.create(nshape);
|
||||||
INDArray targets = Nd4j.create(size, mdp.getActionSpace().getSize());
|
INDArray targets = Nd4j.create(size, getMdp().getActionSpace().getSize());
|
||||||
|
|
||||||
double r = minTrans.getReward();
|
double r = minTrans.getReward();
|
||||||
for (int i = size - 1; i >= 0; i--) {
|
for (int i = size - 1; i >= 0; i--) {
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.learning.listener;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The base definition of all training event listeners
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface TrainingListener {
|
||||||
|
enum ListenerResponse {
|
||||||
|
/**
|
||||||
|
* Tell the learning process to continue calling the listeners and the training.
|
||||||
|
*/
|
||||||
|
CONTINUE,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tell the learning process to stop calling the listeners and terminate the training.
|
||||||
|
*/
|
||||||
|
STOP,
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called once when the training starts.
|
||||||
|
* @return A ListenerResponse telling the source of the event if it should go on or cancel the training.
|
||||||
|
*/
|
||||||
|
ListenerResponse onTrainingStart();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called once when the training has finished. This method is called even when the training has been aborted.
|
||||||
|
*/
|
||||||
|
void onTrainingEnd();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called before the start of every epoch.
|
||||||
|
* @param trainer A {@link IEpochTrainer}
|
||||||
|
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||||
|
*/
|
||||||
|
ListenerResponse onNewEpoch(IEpochTrainer trainer);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when an epoch has been completed
|
||||||
|
* @param trainer A {@link IEpochTrainer}
|
||||||
|
* @param statEntry A {@link org.deeplearning4j.rl4j.util.IDataManager.StatEntry}
|
||||||
|
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||||
|
*/
|
||||||
|
ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called regularly to monitor the training progress.
|
||||||
|
* @param learning A {@link ILearning}
|
||||||
|
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
||||||
|
*/
|
||||||
|
ListenerResponse onTrainingProgress(ILearning learning);
|
||||||
|
}
|
|
@ -0,0 +1,105 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.listener;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The base logic to notify training listeners with the different training events.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class TrainingListenerList {
|
||||||
|
protected final List<TrainingListener> listeners = new ArrayList<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a listener at the end of the list
|
||||||
|
* @param listener The listener to be added
|
||||||
|
*/
|
||||||
|
public void add(TrainingListener listener) {
|
||||||
|
listeners.add(listener);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Notify the listeners that the training has started. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
|
||||||
|
* @return whether or not the source training should be stopped
|
||||||
|
*/
|
||||||
|
public boolean notifyTrainingStarted() {
|
||||||
|
for (TrainingListener listener : listeners) {
|
||||||
|
if (listener.onTrainingStart() == TrainingListener.ListenerResponse.STOP) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Notify the listeners that the training has finished.
|
||||||
|
*/
|
||||||
|
public void notifyTrainingFinished() {
|
||||||
|
for (TrainingListener listener : listeners) {
|
||||||
|
listener.onTrainingEnd();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Notify the listeners that a new epoch has started. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
|
||||||
|
* @return whether or not the source training should be stopped
|
||||||
|
*/
|
||||||
|
public boolean notifyNewEpoch(IEpochTrainer trainer) {
|
||||||
|
for (TrainingListener listener : listeners) {
|
||||||
|
if (listener.onNewEpoch(trainer) == TrainingListener.ListenerResponse.STOP) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Notify the listeners that an epoch has been completed and the training results are available. Will stop early if a listener returns {@link org.deeplearning4j.rl4j.learning.listener.TrainingListener.ListenerResponse#STOP}
|
||||||
|
* @return whether or not the source training should be stopped
|
||||||
|
*/
|
||||||
|
public boolean notifyEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
|
||||||
|
for (TrainingListener listener : listeners) {
|
||||||
|
if (listener.onEpochTrainingResult(trainer, statEntry) == TrainingListener.ListenerResponse.STOP) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Notify the listeners that they update the progress ot the trainning.
|
||||||
|
*/
|
||||||
|
public boolean notifyTrainingProgress(ILearning learning) {
|
||||||
|
for (TrainingListener listener : listeners) {
|
||||||
|
if (listener.onTrainingProgress(learning) == TrainingListener.ListenerResponse.STOP) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,19 +16,18 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync;
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mother class and useful factorisations for all training methods that
|
* Mother class and useful factorisations for all training methods that
|
||||||
* are not asynchronous.
|
* are not asynchronous.
|
||||||
|
@ -38,9 +37,9 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
extends Learning<O, A, AS, NN> {
|
extends Learning<O, A, AS, NN> implements IEpochTrainer {
|
||||||
|
|
||||||
private List<SyncTrainingListener> listeners = new ArrayList<>();
|
private final TrainingListenerList listeners = new TrainingListenerList();
|
||||||
|
|
||||||
public SyncLearning(LConfiguration conf) {
|
public SyncLearning(LConfiguration conf) {
|
||||||
super(conf);
|
super(conf);
|
||||||
|
@ -49,12 +48,24 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
|
||||||
/**
|
/**
|
||||||
* Add a listener at the end of the listener list.
|
* Add a listener at the end of the listener list.
|
||||||
*
|
*
|
||||||
* @param listener
|
* @param listener The listener to add
|
||||||
*/
|
*/
|
||||||
public void addListener(SyncTrainingListener listener) {
|
public void addListener(TrainingListener listener) {
|
||||||
listeners.add(listener);
|
listeners.add(listener);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of epochs between calls to onTrainingProgress. Default is 5
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
private int progressMonitorFrequency = 5;
|
||||||
|
|
||||||
|
public void setProgressMonitorFrequency(int value) {
|
||||||
|
if(value == 0) throw new IllegalArgumentException("The progressMonitorFrequency cannot be 0");
|
||||||
|
|
||||||
|
progressMonitorFrequency = value;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method will train the model<p>
|
* This method will train the model<p>
|
||||||
* The training stop when:<br>
|
* The training stop when:<br>
|
||||||
|
@ -64,81 +75,49 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
|
||||||
* <p>
|
* <p>
|
||||||
* Listeners<br>
|
* Listeners<br>
|
||||||
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
|
* For a given event, the listeners are called sequentially in same the order as they were added. If one listener
|
||||||
* returns {@link SyncTrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
|
* returns {@link TrainingListener.ListenerResponse SyncTrainingListener.ListenerResponse.STOP}, the remaining listeners in the list won't be called.<br>
|
||||||
* Events:
|
* Events:
|
||||||
* <ul>
|
* <ul>
|
||||||
* <li>{@link SyncTrainingListener#onTrainingStart(SyncTrainingEvent) onTrainingStart()} is called once when the training starts.</li>
|
* <li>{@link TrainingListener#onTrainingStart() onTrainingStart()} is called once when the training starts.</li>
|
||||||
* <li>{@link SyncTrainingListener#onEpochStart(SyncTrainingEvent) onEpochStart()} and {@link SyncTrainingListener#onEpochEnd(SyncTrainingEpochEndEvent) onEpochEnd()} are called for every epoch. onEpochEnd will not be called if onEpochStart stops the training</li>
|
* <li>{@link TrainingListener#onNewEpoch(IEpochTrainer) onNewEpoch()} and {@link TrainingListener#onEpochTrainingResult(IEpochTrainer, IDataManager.StatEntry) onEpochTrainingResult()} are called for every epoch. onEpochTrainingResult will not be called if onNewEpoch stops the training</li>
|
||||||
* <li>{@link SyncTrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
|
* <li>{@link TrainingListener#onTrainingProgress(ILearning) onTrainingProgress()} is called after onEpochTrainingResult()</li>
|
||||||
|
* <li>{@link TrainingListener#onTrainingEnd() onTrainingEnd()} is always called at the end of the training, even if the training was cancelled by a listener.</li>
|
||||||
* </ul>
|
* </ul>
|
||||||
*/
|
*/
|
||||||
public void train() {
|
public void train() {
|
||||||
|
|
||||||
log.info("training starting.");
|
log.info("training starting.");
|
||||||
|
|
||||||
boolean canContinue = notifyTrainingStarted();
|
boolean canContinue = listeners.notifyTrainingStarted();
|
||||||
if (canContinue) {
|
if (canContinue) {
|
||||||
while (getStepCounter() < getConfiguration().getMaxStep()) {
|
while (getStepCounter() < getConfiguration().getMaxStep()) {
|
||||||
preEpoch();
|
preEpoch();
|
||||||
canContinue = notifyEpochStarted();
|
canContinue = listeners.notifyNewEpoch(this);
|
||||||
if (!canContinue) {
|
if (!canContinue) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
IDataManager.StatEntry statEntry = trainEpoch();
|
IDataManager.StatEntry statEntry = trainEpoch();
|
||||||
|
canContinue = listeners.notifyEpochTrainingResult(this, statEntry);
|
||||||
postEpoch();
|
|
||||||
canContinue = notifyEpochFinished(statEntry);
|
|
||||||
if (!canContinue) {
|
if (!canContinue) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
postEpoch();
|
||||||
|
|
||||||
|
if(getEpochCounter() % progressMonitorFrequency == 0) {
|
||||||
|
canContinue = listeners.notifyTrainingProgress(this);
|
||||||
|
if (!canContinue) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
||||||
incrementEpoch();
|
incrementEpoch();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
notifyTrainingFinished();
|
listeners.notifyTrainingFinished();
|
||||||
}
|
|
||||||
|
|
||||||
private boolean notifyTrainingStarted() {
|
|
||||||
SyncTrainingEvent event = new SyncTrainingEvent(this);
|
|
||||||
for (SyncTrainingListener listener : listeners) {
|
|
||||||
if (listener.onTrainingStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void notifyTrainingFinished() {
|
|
||||||
for (SyncTrainingListener listener : listeners) {
|
|
||||||
listener.onTrainingEnd();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean notifyEpochStarted() {
|
|
||||||
SyncTrainingEvent event = new SyncTrainingEvent(this);
|
|
||||||
for (SyncTrainingListener listener : listeners) {
|
|
||||||
if (listener.onEpochStart(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean notifyEpochFinished(IDataManager.StatEntry statEntry) {
|
|
||||||
SyncTrainingEpochEndEvent event = new SyncTrainingEpochEndEvent(this, statEntry);
|
|
||||||
for (SyncTrainingListener listener : listeners) {
|
|
||||||
if (listener.onEpochEnd(event) == SyncTrainingListener.ListenerResponse.STOP) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract void preEpoch();
|
protected abstract void preEpoch();
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.listener;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A subclass of SyncTrainingEvent that is passed to SyncTrainingListener.onEpochEnd()
|
|
||||||
*/
|
|
||||||
public class SyncTrainingEpochEndEvent extends SyncTrainingEvent {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The stats of the epoch training
|
|
||||||
*/
|
|
||||||
@Getter
|
|
||||||
private final IDataManager.StatEntry statEntry;
|
|
||||||
|
|
||||||
public SyncTrainingEpochEndEvent(Learning learning, IDataManager.StatEntry statEntry) {
|
|
||||||
super(learning);
|
|
||||||
this.statEntry = statEntry;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,21 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.listener;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* SyncTrainingEvent are passed as parameters to the events of SyncTrainingListener
|
|
||||||
*/
|
|
||||||
public class SyncTrainingEvent {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The source of the event
|
|
||||||
*/
|
|
||||||
@Getter
|
|
||||||
private final Learning learning;
|
|
||||||
|
|
||||||
public SyncTrainingEvent(Learning learning) {
|
|
||||||
this.learning = learning;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,45 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.listener;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A listener interface to use with a descendant of {@link org.deeplearning4j.rl4j.learning.sync.SyncLearning}
|
|
||||||
*/
|
|
||||||
public interface SyncTrainingListener {
|
|
||||||
|
|
||||||
public enum ListenerResponse {
|
|
||||||
/**
|
|
||||||
* Tell SyncLearning to continue calling the listeners and the training.
|
|
||||||
*/
|
|
||||||
CONTINUE,
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Tell SyncLearning to stop calling the listeners and terminate the training.
|
|
||||||
*/
|
|
||||||
STOP,
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Called once when the training starts.
|
|
||||||
* @param event
|
|
||||||
* @return A ListenerResponse telling the source of the event if it should go on or cancel the training.
|
|
||||||
*/
|
|
||||||
ListenerResponse onTrainingStart(SyncTrainingEvent event);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Called once when the training has finished. This method is called even when the training has been aborted.
|
|
||||||
*/
|
|
||||||
void onTrainingEnd();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Called before the start of every epoch.
|
|
||||||
* @param event
|
|
||||||
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
|
||||||
*/
|
|
||||||
ListenerResponse onEpochStart(SyncTrainingEvent event);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Called after the end of every epoch.
|
|
||||||
* @param event
|
|
||||||
* @return A ListenerResponse telling the source of the event if it should continue or stop the training.
|
|
||||||
*/
|
|
||||||
ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event);
|
|
||||||
}
|
|
|
@ -49,7 +49,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
// @Getter
|
// @Getter
|
||||||
// final private IExpReplay<A> expReplay;
|
// final private IExpReplay<A> expReplay;
|
||||||
@Getter
|
@Getter
|
||||||
@Setter(AccessLevel.PACKAGE)
|
@Setter(AccessLevel.PROTECTED)
|
||||||
protected IExpReplay<A> expReplay;
|
protected IExpReplay<A> expReplay;
|
||||||
|
|
||||||
public QLearning(QLConfiguration conf) {
|
public QLearning(QLConfiguration conf) {
|
||||||
|
|
|
@ -28,8 +28,6 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
|
@ -64,20 +62,9 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
@Setter
|
@Setter
|
||||||
private IDQN targetDQN;
|
private IDQN targetDQN;
|
||||||
private int lastAction;
|
private int lastAction;
|
||||||
private INDArray history[] = null;
|
private INDArray[] history = null;
|
||||||
private double accuReward = 0;
|
private double accuReward = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated
|
|
||||||
* Use QLearningDiscrete(MDP, IDQN, QLConfiguration, int) and add the required listeners with addListener() instead.
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
|
||||||
IDataManager dataManager, int epsilonNbStep) {
|
|
||||||
this(mdp, dqn, conf, epsilonNbStep);
|
|
||||||
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
|
|
||||||
}
|
|
||||||
|
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||||
int epsilonNbStep) {
|
int epsilonNbStep) {
|
||||||
super(conf);
|
super(conf);
|
||||||
|
@ -186,7 +173,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
|
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
|
||||||
if (transitions.size() == 0)
|
if (transitions.size() == 0)
|
||||||
throw new IllegalArgumentException("too few transitions");
|
throw new IllegalArgumentException("too few transitions");
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -33,20 +34,36 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscrete<O> {
|
public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscrete<O> {
|
||||||
|
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||||
QLConfiguration conf, IDataManager dataManager) {
|
QLConfiguration conf, IDataManager dataManager) {
|
||||||
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
this(mdp, dqn, hpconf, conf);
|
||||||
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
|
}
|
||||||
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||||
|
QLConfiguration conf) {
|
||||||
|
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
||||||
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||||
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
||||||
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -31,21 +32,35 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscrete<O> {
|
public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscrete<O> {
|
||||||
|
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
|
||||||
IDataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep());
|
this(mdp, dqn, conf);
|
||||||
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
|
}
|
||||||
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
|
||||||
|
super(mdp, dqn, conf, conf.getEpsilonNbStep());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
|
QLearning.QLConfiguration conf) {
|
||||||
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
||||||
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
||||||
|
QLearning.QLConfiguration conf) {
|
||||||
|
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
package org.deeplearning4j.rl4j.policy;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
|
||||||
|
public interface IPolicy<O extends Encodable, A> {
|
||||||
|
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
||||||
|
}
|
|
@ -35,7 +35,7 @@ import org.nd4j.linalg.util.ArrayUtil;
|
||||||
*
|
*
|
||||||
* A Policy responsability is to choose the next action given a state
|
* A Policy responsability is to choose the next action given a state
|
||||||
*/
|
*/
|
||||||
public abstract class Policy<O extends Encodable, A> {
|
public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
||||||
|
|
||||||
public abstract NeuralNet getNeuralNet();
|
public abstract NeuralNet getNeuralNet();
|
||||||
|
|
||||||
|
@ -49,6 +49,7 @@ public abstract class Policy<O extends Encodable, A> {
|
||||||
return play(mdp, new HistoryProcessor(conf));
|
return play(mdp, new HistoryProcessor(conf));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
||||||
getNeuralNet().reset();
|
getNeuralNet().reset();
|
||||||
Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);
|
Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);
|
||||||
|
|
|
@ -22,6 +22,7 @@ import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
@ -72,13 +73,13 @@ public class DataManager implements IDataManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void save(String path, Learning learning) throws IOException {
|
public static void save(String path, ILearning learning) throws IOException {
|
||||||
try (BufferedOutputStream os = new BufferedOutputStream(new FileOutputStream(path))) {
|
try (BufferedOutputStream os = new BufferedOutputStream(new FileOutputStream(path))) {
|
||||||
save(os, learning);
|
save(os, learning);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void save(OutputStream os, Learning learning) throws IOException {
|
public static void save(OutputStream os, ILearning learning) throws IOException {
|
||||||
|
|
||||||
try (ZipOutputStream zipfile = new ZipOutputStream(os)) {
|
try (ZipOutputStream zipfile = new ZipOutputStream(os)) {
|
||||||
|
|
||||||
|
@ -91,7 +92,9 @@ public class DataManager implements IDataManager {
|
||||||
zipfile.putNextEntry(dqn);
|
zipfile.putNextEntry(dqn);
|
||||||
|
|
||||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||||
learning.getNeuralNet().save(bos);
|
if(learning instanceof NeuralNetFetchable) {
|
||||||
|
((NeuralNetFetchable)learning).getNeuralNet().save(bos);
|
||||||
|
}
|
||||||
bos.flush();
|
bos.flush();
|
||||||
bos.close();
|
bos.close();
|
||||||
|
|
||||||
|
@ -104,7 +107,9 @@ public class DataManager implements IDataManager {
|
||||||
zipfile.putNextEntry(hpconf);
|
zipfile.putNextEntry(hpconf);
|
||||||
|
|
||||||
ByteArrayOutputStream bos2 = new ByteArrayOutputStream();
|
ByteArrayOutputStream bos2 = new ByteArrayOutputStream();
|
||||||
learning.getNeuralNet().save(bos2);
|
if(learning instanceof NeuralNetFetchable) {
|
||||||
|
((NeuralNetFetchable)learning).getNeuralNet().save(bos2);
|
||||||
|
}
|
||||||
bos2.flush();
|
bos2.flush();
|
||||||
bos2.close();
|
bos2.close();
|
||||||
|
|
||||||
|
@ -256,13 +261,15 @@ public class DataManager implements IDataManager {
|
||||||
return exists;
|
return exists;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void save(Learning learning) throws IOException {
|
public void save(ILearning learning) throws IOException {
|
||||||
|
|
||||||
if (!saveData)
|
if (!saveData)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning);
|
save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning);
|
||||||
learning.getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
|
if(learning instanceof NeuralNetFetchable) {
|
||||||
|
((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,126 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.util;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the
|
|
||||||
* training process can be fed to the DataManager
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class DataManagerSyncTrainingListener implements SyncTrainingListener {
|
|
||||||
private final IDataManager dataManager;
|
|
||||||
private final int saveFrequency;
|
|
||||||
private final int monitorFrequency;
|
|
||||||
|
|
||||||
private int lastSave;
|
|
||||||
private int lastMonitor;
|
|
||||||
|
|
||||||
private DataManagerSyncTrainingListener(Builder builder) {
|
|
||||||
this.dataManager = builder.dataManager;
|
|
||||||
|
|
||||||
this.saveFrequency = builder.saveFrequency;
|
|
||||||
this.lastSave = -builder.saveFrequency;
|
|
||||||
|
|
||||||
this.monitorFrequency = builder.monitorFrequency;
|
|
||||||
this.lastMonitor = -builder.monitorFrequency;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
|
|
||||||
try {
|
|
||||||
dataManager.writeInfo(event.getLearning());
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Training failed.", e);
|
|
||||||
return ListenerResponse.STOP;
|
|
||||||
}
|
|
||||||
return ListenerResponse.CONTINUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onTrainingEnd() {
|
|
||||||
// Do nothing
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
|
|
||||||
int stepCounter = event.getLearning().getStepCounter();
|
|
||||||
|
|
||||||
if (stepCounter - lastMonitor >= monitorFrequency
|
|
||||||
&& event.getLearning().getHistoryProcessor() != null
|
|
||||||
&& dataManager.isSaveData()) {
|
|
||||||
lastMonitor = stepCounter;
|
|
||||||
int[] shape = event.getLearning().getMdp().getObservationSpace().getShape();
|
|
||||||
event.getLearning().getHistoryProcessor().startMonitor(dataManager.getVideoDir() + "/video-" + event.getLearning().getEpochCounter() + "-"
|
|
||||||
+ stepCounter + ".mp4", shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ListenerResponse.CONTINUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
|
|
||||||
try {
|
|
||||||
int stepCounter = event.getLearning().getStepCounter();
|
|
||||||
if (stepCounter - lastSave >= saveFrequency) {
|
|
||||||
dataManager.save(event.getLearning());
|
|
||||||
lastSave = stepCounter;
|
|
||||||
}
|
|
||||||
|
|
||||||
dataManager.appendStat(event.getStatEntry());
|
|
||||||
dataManager.writeInfo(event.getLearning());
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Training failed.", e);
|
|
||||||
return ListenerResponse.STOP;
|
|
||||||
}
|
|
||||||
|
|
||||||
return ListenerResponse.CONTINUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Builder builder(IDataManager dataManager) {
|
|
||||||
return new Builder(dataManager);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
private final IDataManager dataManager;
|
|
||||||
private int saveFrequency = Constants.MODEL_SAVE_FREQ;
|
|
||||||
private int monitorFrequency = Constants.MONITOR_FREQ;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a Builder with the given DataManager
|
|
||||||
* @param dataManager
|
|
||||||
*/
|
|
||||||
public Builder(IDataManager dataManager) {
|
|
||||||
this.dataManager = dataManager;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A number that represent the number of steps since the last call to DataManager.save() before can it be called again.
|
|
||||||
* @param saveFrequency (Default: 100000)
|
|
||||||
*/
|
|
||||||
public Builder saveFrequency(int saveFrequency) {
|
|
||||||
this.saveFrequency = saveFrequency;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A number that represent the number of steps since the last call to HistoryProcessor.startMonitor() before can it be called again.
|
|
||||||
* @param monitorFrequency (Default: 10000)
|
|
||||||
*/
|
|
||||||
public Builder monitorFrequency(int monitorFrequency) {
|
|
||||||
this.monitorFrequency = monitorFrequency;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a DataManagerSyncTrainingListener with the configured parameters
|
|
||||||
* @return An instance of DataManagerSyncTrainingListener
|
|
||||||
*/
|
|
||||||
public DataManagerSyncTrainingListener build() {
|
|
||||||
return new DataManagerSyncTrainingListener(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
package org.deeplearning4j.rl4j.util;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DataManagerSyncTrainingListener can be added to the listeners of SyncLearning so that the
|
||||||
|
* training process can be fed to the DataManager
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class DataManagerTrainingListener implements TrainingListener {
|
||||||
|
private final IDataManager dataManager;
|
||||||
|
|
||||||
|
private int lastSave = -Constants.MODEL_SAVE_FREQ;
|
||||||
|
|
||||||
|
public DataManagerTrainingListener(IDataManager dataManager) {
|
||||||
|
this.dataManager = dataManager;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onTrainingStart() {
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onTrainingEnd() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
|
||||||
|
IHistoryProcessor hp = trainer.getHistoryProcessor();
|
||||||
|
if(hp != null) {
|
||||||
|
int[] shape = trainer.getMdp().getObservationSpace().getShape();
|
||||||
|
String filename = dataManager.getVideoDir() + "/video-";
|
||||||
|
if (trainer instanceof AsyncThread) {
|
||||||
|
filename += ((AsyncThread) trainer).getThreadNumber() + "-";
|
||||||
|
}
|
||||||
|
filename += trainer.getEpochCounter() + "-" + trainer.getStepCounter() + ".mp4";
|
||||||
|
hp.startMonitor(filename, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
|
||||||
|
IHistoryProcessor hp = trainer.getHistoryProcessor();
|
||||||
|
if(hp != null) {
|
||||||
|
hp.stopMonitor();
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
dataManager.appendStat(statEntry);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Training failed.", e);
|
||||||
|
return ListenerResponse.STOP;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onTrainingProgress(ILearning learning) {
|
||||||
|
try {
|
||||||
|
int stepCounter = learning.getStepCounter();
|
||||||
|
if (stepCounter - lastSave >= Constants.MODEL_SAVE_FREQ) {
|
||||||
|
dataManager.save(learning);
|
||||||
|
lastSave = stepCounter;
|
||||||
|
}
|
||||||
|
|
||||||
|
dataManager.writeInfo(learning);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Training failed.", e);
|
||||||
|
return ListenerResponse.STOP;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
}
|
|
@ -27,7 +27,7 @@ public interface IDataManager {
|
||||||
String getVideoDir();
|
String getVideoDir();
|
||||||
void appendStat(StatEntry statEntry) throws IOException;
|
void appendStat(StatEntry statEntry) throws IOException;
|
||||||
void writeInfo(ILearning iLearning) throws IOException;
|
void writeInfo(ILearning iLearning) throws IOException;
|
||||||
void save(Learning learning) throws IOException;
|
void save(ILearning learning) throws IOException;
|
||||||
|
|
||||||
//In order for jackson to serialize StatEntry
|
//In order for jackson to serialize StatEntry
|
||||||
//please use Lombok @Value (see QLStatEntry)
|
//please use Lombok @Value (see QLStatEntry)
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
public class AsyncLearningTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_training_expect_AsyncGlobalStarted() {
|
||||||
|
// Arrange
|
||||||
|
TestContext context = new TestContext();
|
||||||
|
context.asyncGlobal.setMaxLoops(1);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
context.sut.train();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(context.asyncGlobal.hasBeenStarted);
|
||||||
|
assertTrue(context.asyncGlobal.hasBeenTerminated);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_trainStartReturnsStop_expect_noTraining() {
|
||||||
|
// Arrange
|
||||||
|
TestContext context = new TestContext();
|
||||||
|
context.listener.setRemainingTrainingStartCallCount(0);
|
||||||
|
// Act
|
||||||
|
context.sut.train();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, context.listener.onTrainingStartCallCount);
|
||||||
|
assertEquals(1, context.listener.onTrainingEndCallCount);
|
||||||
|
assertEquals(0, context.policy.playCallCount);
|
||||||
|
assertTrue(context.asyncGlobal.hasBeenTerminated);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_trainingIsComplete_expect_trainingStop() {
|
||||||
|
// Arrange
|
||||||
|
TestContext context = new TestContext();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
context.sut.train();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, context.listener.onTrainingStartCallCount);
|
||||||
|
assertEquals(1, context.listener.onTrainingEndCallCount);
|
||||||
|
assertTrue(context.asyncGlobal.hasBeenTerminated);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_training_expect_onTrainingProgressCalled() {
|
||||||
|
// Arrange
|
||||||
|
TestContext context = new TestContext();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
context.sut.train();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, context.listener.onTrainingProgressCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static class TestContext {
|
||||||
|
public final MockAsyncConfiguration conf = new MockAsyncConfiguration(1, 1);
|
||||||
|
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
||||||
|
public final MockPolicy policy = new MockPolicy();
|
||||||
|
public final TestAsyncLearning sut = new TestAsyncLearning(conf, asyncGlobal, policy);
|
||||||
|
public final MockTrainingListener listener = new MockTrainingListener();
|
||||||
|
|
||||||
|
public TestContext() {
|
||||||
|
sut.addListener(listener);
|
||||||
|
asyncGlobal.setMaxLoops(1);
|
||||||
|
sut.setProgressMonitorFrequency(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
|
||||||
|
private final AsyncConfiguration conf;
|
||||||
|
private final IAsyncGlobal asyncGlobal;
|
||||||
|
private final IPolicy<MockEncodable, Integer> policy;
|
||||||
|
|
||||||
|
public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
|
||||||
|
super(conf);
|
||||||
|
this.conf = conf;
|
||||||
|
this.asyncGlobal = asyncGlobal;
|
||||||
|
this.policy = policy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IPolicy getPolicy() {
|
||||||
|
return policy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AsyncConfiguration getConfiguration() {
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected AsyncThread newThread(int i, int deviceAffinity) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MDP getMdp() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected IAsyncGlobal getAsyncGlobal() {
|
||||||
|
return asyncGlobal;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MockNeuralNet getNeuralNet() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,206 +1,135 @@
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockMDP;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
public class AsyncThreadTest {
|
public class AsyncThreadTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void refac_withoutHistoryProcessor_checkDataManagerCallsRemainTheSame() {
|
public void when_newEpochStarted_expect_neuralNetworkReset() {
|
||||||
// Arrange
|
// Arrange
|
||||||
MockDataManager dataManager = new MockDataManager(false);
|
TestContext context = new TestContext();
|
||||||
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
context.listener.setRemainingOnNewEpochCallCount(5);
|
||||||
MockNeuralNet neuralNet = new MockNeuralNet();
|
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
|
||||||
MockMDP mdp = new MockMDP(observationSpace);
|
|
||||||
MockAsyncConfiguration config = new MockAsyncConfiguration(10, 2);
|
|
||||||
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, dataManager);
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.run();
|
context.sut.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(4, dataManager.statEntries.size());
|
assertEquals(6, context.neuralNet.resetCallCount);
|
||||||
|
|
||||||
IDataManager.StatEntry entry = dataManager.statEntries.get(0);
|
|
||||||
assertEquals(2, entry.getStepCounter());
|
|
||||||
assertEquals(0, entry.getEpochCounter());
|
|
||||||
assertEquals(2.0, entry.getReward(), 0.0);
|
|
||||||
|
|
||||||
entry = dataManager.statEntries.get(1);
|
|
||||||
assertEquals(4, entry.getStepCounter());
|
|
||||||
assertEquals(1, entry.getEpochCounter());
|
|
||||||
assertEquals(2.0, entry.getReward(), 0.0);
|
|
||||||
|
|
||||||
entry = dataManager.statEntries.get(2);
|
|
||||||
assertEquals(6, entry.getStepCounter());
|
|
||||||
assertEquals(2, entry.getEpochCounter());
|
|
||||||
assertEquals(2.0, entry.getReward(), 0.0);
|
|
||||||
|
|
||||||
entry = dataManager.statEntries.get(3);
|
|
||||||
assertEquals(8, entry.getStepCounter());
|
|
||||||
assertEquals(3, entry.getEpochCounter());
|
|
||||||
assertEquals(2.0, entry.getReward(), 0.0);
|
|
||||||
|
|
||||||
assertEquals(0, dataManager.isSaveDataCallCount);
|
|
||||||
assertEquals(0, dataManager.getVideoDirCallCount);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void refac_withHistoryProcessor_isSaveFalse_checkDataManagerCallsRemainTheSame() {
|
public void when_onNewEpochReturnsStop_expect_threadStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
MockDataManager dataManager = new MockDataManager(false);
|
TestContext context = new TestContext();
|
||||||
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
context.listener.setRemainingOnNewEpochCallCount(1);
|
||||||
MockNeuralNet neuralNet = new MockNeuralNet();
|
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
|
||||||
MockMDP mdp = new MockMDP(observationSpace);
|
|
||||||
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
|
|
||||||
|
|
||||||
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
|
||||||
.build();
|
|
||||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
|
|
||||||
|
|
||||||
|
|
||||||
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
|
|
||||||
sut.setHistoryProcessor(hp);
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.run();
|
context.sut.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(9, dataManager.statEntries.size());
|
assertEquals(2, context.listener.onNewEpochCallCount);
|
||||||
|
assertEquals(1, context.listener.onEpochTrainingResultCallCount);
|
||||||
for(int i = 0; i < 9; ++i) {
|
|
||||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
|
||||||
assertEquals(i + 1, entry.getStepCounter());
|
|
||||||
assertEquals(i, entry.getEpochCounter());
|
|
||||||
assertEquals(79.0, entry.getReward(), 0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
assertEquals(10, dataManager.isSaveDataCallCount);
|
|
||||||
assertEquals(0, dataManager.getVideoDirCallCount);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void refac_withHistoryProcessor_isSaveTrue_checkDataManagerCallsRemainTheSame() {
|
public void when_epochTrainingResultReturnsStop_expect_threadStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
MockDataManager dataManager = new MockDataManager(true);
|
TestContext context = new TestContext();
|
||||||
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
context.listener.setRemainingOnEpochTrainingResult(1);
|
||||||
MockNeuralNet neuralNet = new MockNeuralNet();
|
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
|
||||||
MockMDP mdp = new MockMDP(observationSpace);
|
|
||||||
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
|
|
||||||
|
|
||||||
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
|
||||||
.build();
|
|
||||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
|
|
||||||
|
|
||||||
|
|
||||||
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
|
|
||||||
sut.setHistoryProcessor(hp);
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.run();
|
context.sut.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(9, dataManager.statEntries.size());
|
assertEquals(2, context.listener.onNewEpochCallCount);
|
||||||
|
assertEquals(2, context.listener.onEpochTrainingResultCallCount);
|
||||||
for(int i = 0; i < 9; ++i) {
|
|
||||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
|
||||||
assertEquals(i + 1, entry.getStepCounter());
|
|
||||||
assertEquals(i, entry.getEpochCounter());
|
|
||||||
assertEquals(79.0, entry.getReward(), 0.0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(1, dataManager.isSaveDataCallCount);
|
@Test
|
||||||
assertEquals(1, dataManager.getVideoDirCallCount);
|
public void when_run_expect_preAndPostEpochCalled() {
|
||||||
|
// Arrange
|
||||||
|
TestContext context = new TestContext();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
context.sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(6, context.sut.preEpochCallCount);
|
||||||
|
assertEquals(6, context.sut.postEpochCallCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class MockAsyncGlobal implements IAsyncGlobal {
|
@Test
|
||||||
|
public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() {
|
||||||
|
// Arrange
|
||||||
|
TestContext context = new TestContext();
|
||||||
|
|
||||||
private final int maxLoops;
|
// Act
|
||||||
private int currentLoop = 0;
|
context.sut.run();
|
||||||
|
|
||||||
public MockAsyncGlobal(int maxLoops) {
|
// Assert
|
||||||
|
assertEquals(5, context.listener.statEntries.size());
|
||||||
this.maxLoops = maxLoops;
|
int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 };
|
||||||
|
for(int i = 0; i < 5; ++i) {
|
||||||
|
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
|
||||||
|
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
|
||||||
|
assertEquals(i, statEntry.getEpochCounter());
|
||||||
|
assertEquals(2.0, statEntry.getReward(), 0.0001);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private static class TestContext {
|
||||||
public boolean isRunning() {
|
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
||||||
return true;
|
public final MockNeuralNet neuralNet = new MockNeuralNet();
|
||||||
}
|
public final MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
|
public final MockMDP mdp = new MockMDP(observationSpace);
|
||||||
@Override
|
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 2);
|
||||||
public void setRunning(boolean value) {
|
public final TrainingListenerList listeners = new TrainingListenerList();
|
||||||
|
public final MockTrainingListener listener = new MockTrainingListener();
|
||||||
}
|
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isTrainingComplete() {
|
|
||||||
return ++currentLoop >= maxLoops;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void start() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public AtomicInteger getT() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NeuralNet getCurrent() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NeuralNet getTarget() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
|
||||||
|
|
||||||
|
public TestContext() {
|
||||||
|
asyncGlobal.setMaxLoops(10);
|
||||||
|
listeners.add(listener);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class MockAsyncThread extends AsyncThread {
|
public static class MockAsyncThread extends AsyncThread {
|
||||||
|
|
||||||
IAsyncGlobal asyncGlobal;
|
public int preEpochCallCount = 0;
|
||||||
private final MockNeuralNet neuralNet;
|
public int postEpochCallCount = 0;
|
||||||
private final MDP mdp;
|
|
||||||
private final AsyncConfiguration conf;
|
|
||||||
private final IDataManager dataManager;
|
|
||||||
|
|
||||||
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) {
|
|
||||||
super(asyncGlobal, threadNumber, 0);
|
private final IAsyncGlobal asyncGlobal;
|
||||||
|
private final MockNeuralNet neuralNet;
|
||||||
|
private final AsyncConfiguration conf;
|
||||||
|
|
||||||
|
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
|
||||||
|
super(asyncGlobal, mdp, listeners, threadNumber, 0);
|
||||||
|
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
this.neuralNet = neuralNet;
|
this.neuralNet = neuralNet;
|
||||||
this.mdp = mdp;
|
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
this.dataManager = dataManager;
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void preEpoch() {
|
||||||
|
++preEpochCallCount;
|
||||||
|
super.preEpoch();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void postEpoch() {
|
||||||
|
++postEpochCallCount;
|
||||||
|
super.postEpoch();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -208,31 +137,16 @@ public class AsyncThreadTest {
|
||||||
return neuralNet;
|
return neuralNet;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
protected int getThreadNumber() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected IAsyncGlobal getAsyncGlobal() {
|
protected IAsyncGlobal getAsyncGlobal() {
|
||||||
return asyncGlobal;
|
return asyncGlobal;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
protected MDP getMdp() {
|
|
||||||
return mdp;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected AsyncConfiguration getConf() {
|
protected AsyncConfiguration getConf() {
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
protected IDataManager getDataManager() {
|
|
||||||
return dataManager;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Policy getPolicy(NeuralNet net) {
|
protected Policy getPolicy(NeuralNet net) {
|
||||||
return null;
|
return null;
|
||||||
|
@ -244,129 +158,6 @@ public class AsyncThreadTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class MockNeuralNet implements NeuralNet {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NeuralNetwork[] getNeuralNetworks() {
|
|
||||||
return new NeuralNetwork[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isRecurrent() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray[] outputAll(INDArray batch) {
|
|
||||||
return new INDArray[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NeuralNet clone() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void copy(NeuralNet from) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
|
||||||
return new Gradient[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void fit(INDArray input, INDArray[] labels) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void applyGradient(Gradient[] gradients, int batchSize) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getLatestScore() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void save(OutputStream os) throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void save(String filename) throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class MockAsyncConfiguration implements AsyncConfiguration {
|
|
||||||
|
|
||||||
private final int nStep;
|
|
||||||
private final int maxEpochStep;
|
|
||||||
|
|
||||||
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
|
|
||||||
this.nStep = nStep;
|
|
||||||
|
|
||||||
this.maxEpochStep = maxEpochStep;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getSeed() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getMaxEpochStep() {
|
|
||||||
return maxEpochStep;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getMaxStep() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getNumThread() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getNstep() {
|
|
||||||
return nStep;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getTargetDqnUpdateFreq() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getUpdateStart() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getRewardFactor() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getGamma() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getErrorClamp() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.async.listener;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
public class AsyncTrainingListenerListTest {
|
||||||
|
@Test
|
||||||
|
public void when_listIsEmpty_expect_notifyTrainingStartedReturnTrue() {
|
||||||
|
// Arrange
|
||||||
|
TrainingListenerList sut = new TrainingListenerList();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
boolean resultTrainingStarted = sut.notifyTrainingStarted();
|
||||||
|
boolean resultNewEpoch = sut.notifyNewEpoch(null);
|
||||||
|
boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(resultTrainingStarted);
|
||||||
|
assertTrue(resultNewEpoch);
|
||||||
|
assertTrue(resultEpochTrainingResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_firstListerStops_expect_othersListnersNotCalled() {
|
||||||
|
// Arrange
|
||||||
|
MockTrainingListener listener1 = new MockTrainingListener();
|
||||||
|
listener1.onTrainingResultResponse = TrainingListener.ListenerResponse.STOP;
|
||||||
|
MockTrainingListener listener2 = new MockTrainingListener();
|
||||||
|
TrainingListenerList sut = new TrainingListenerList();
|
||||||
|
sut.add(listener1);
|
||||||
|
sut.add(listener2);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.notifyEpochTrainingResult(null, null);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, listener1.onEpochTrainingResultCallCount);
|
||||||
|
assertEquals(0, listener2.onEpochTrainingResultCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_allListenersContinue_expect_listReturnsTrue() {
|
||||||
|
// Arrange
|
||||||
|
MockTrainingListener listener1 = new MockTrainingListener();
|
||||||
|
MockTrainingListener listener2 = new MockTrainingListener();
|
||||||
|
TrainingListenerList sut = new TrainingListenerList();
|
||||||
|
sut.add(listener1);
|
||||||
|
sut.add(listener2);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
boolean resultTrainingProgress = sut.notifyEpochTrainingResult(null, null);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(resultTrainingProgress);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class MockTrainingListener implements TrainingListener {
|
||||||
|
|
||||||
|
public int onEpochTrainingResultCallCount = 0;
|
||||||
|
public ListenerResponse onTrainingResultResponse = ListenerResponse.CONTINUE;
|
||||||
|
public int onTrainingProgressCallCount = 0;
|
||||||
|
public ListenerResponse onTrainingProgressResponse = ListenerResponse.CONTINUE;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onTrainingStart() {
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onTrainingEnd() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
|
||||||
|
++onEpochTrainingResultCallCount;
|
||||||
|
return onTrainingResultResponse;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onTrainingProgress(ILearning learning) {
|
||||||
|
++onTrainingProgressCallCount;
|
||||||
|
return onTrainingProgressResponse;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,83 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.listener;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.support.MockTrainingListener;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
public class TrainingListenerListTest {
|
||||||
|
@Test
|
||||||
|
public void when_listIsEmpty_expect_notifyReturnTrue() {
|
||||||
|
// Arrange
|
||||||
|
TrainingListenerList sut = new TrainingListenerList();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
boolean resultTrainingStarted = sut.notifyTrainingStarted();
|
||||||
|
boolean resultNewEpoch = sut.notifyNewEpoch(null);
|
||||||
|
boolean resultEpochFinished = sut.notifyEpochTrainingResult(null, null);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(resultTrainingStarted);
|
||||||
|
assertTrue(resultNewEpoch);
|
||||||
|
assertTrue(resultEpochFinished);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_firstListerStops_expect_othersListnersNotCalled() {
|
||||||
|
// Arrange
|
||||||
|
MockTrainingListener listener1 = new MockTrainingListener();
|
||||||
|
listener1.setRemainingTrainingStartCallCount(0);
|
||||||
|
listener1.setRemainingOnNewEpochCallCount(0);
|
||||||
|
listener1.setRemainingonTrainingProgressCallCount(0);
|
||||||
|
listener1.setRemainingOnEpochTrainingResult(0);
|
||||||
|
MockTrainingListener listener2 = new MockTrainingListener();
|
||||||
|
TrainingListenerList sut = new TrainingListenerList();
|
||||||
|
sut.add(listener1);
|
||||||
|
sut.add(listener2);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.notifyTrainingStarted();
|
||||||
|
sut.notifyNewEpoch(null);
|
||||||
|
sut.notifyEpochTrainingResult(null, null);
|
||||||
|
sut.notifyTrainingProgress(null);
|
||||||
|
sut.notifyTrainingFinished();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, listener1.onTrainingStartCallCount);
|
||||||
|
assertEquals(0, listener2.onTrainingStartCallCount);
|
||||||
|
|
||||||
|
assertEquals(1, listener1.onNewEpochCallCount);
|
||||||
|
assertEquals(0, listener2.onNewEpochCallCount);
|
||||||
|
|
||||||
|
assertEquals(1, listener1.onEpochTrainingResultCallCount);
|
||||||
|
assertEquals(0, listener2.onEpochTrainingResultCallCount);
|
||||||
|
|
||||||
|
assertEquals(1, listener1.onTrainingProgressCallCount);
|
||||||
|
assertEquals(0, listener2.onTrainingProgressCallCount);
|
||||||
|
|
||||||
|
assertEquals(1, listener1.onTrainingEndCallCount);
|
||||||
|
assertEquals(1, listener2.onTrainingEndCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_allListenersContinue_expect_listReturnsTrue() {
|
||||||
|
// Arrange
|
||||||
|
MockTrainingListener listener1 = new MockTrainingListener();
|
||||||
|
MockTrainingListener listener2 = new MockTrainingListener();
|
||||||
|
TrainingListenerList sut = new TrainingListenerList();
|
||||||
|
sut.add(listener1);
|
||||||
|
sut.add(listener2);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
boolean resultTrainingStarted = sut.notifyTrainingStarted();
|
||||||
|
boolean resultNewEpoch = sut.notifyNewEpoch(null);
|
||||||
|
boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null);
|
||||||
|
boolean resultProgress = sut.notifyTrainingProgress(null);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(resultTrainingStarted);
|
||||||
|
assertTrue(resultNewEpoch);
|
||||||
|
assertTrue(resultEpochTrainingResult);
|
||||||
|
assertTrue(resultProgress);
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,12 +2,10 @@ package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockSyncTrainingListener;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
import org.deeplearning4j.rl4j.support.MockTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
@ -19,7 +17,7 @@ public class SyncLearningTest {
|
||||||
public void when_training_expect_listenersToBeCalled() {
|
public void when_training_expect_listenersToBeCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
MockTrainingListener listener = new MockTrainingListener();
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
sut.addListener(listener);
|
sut.addListener(listener);
|
||||||
|
|
||||||
|
@ -27,8 +25,8 @@ public class SyncLearningTest {
|
||||||
sut.train();
|
sut.train();
|
||||||
|
|
||||||
assertEquals(1, listener.onTrainingStartCallCount);
|
assertEquals(1, listener.onTrainingStartCallCount);
|
||||||
assertEquals(10, listener.onEpochStartCallCount);
|
assertEquals(10, listener.onNewEpochCallCount);
|
||||||
assertEquals(10, listener.onEpochEndStartCallCount);
|
assertEquals(10, listener.onEpochTrainingResultCallCount);
|
||||||
assertEquals(1, listener.onTrainingEndCallCount);
|
assertEquals(1, listener.onTrainingEndCallCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,65 +34,59 @@ public class SyncLearningTest {
|
||||||
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
|
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
MockTrainingListener listener = new MockTrainingListener();
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
sut.addListener(listener);
|
sut.addListener(listener);
|
||||||
listener.trainingStartCanContinue = false;
|
listener.setRemainingTrainingStartCallCount(0);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.train();
|
sut.train();
|
||||||
|
|
||||||
assertEquals(1, listener.onTrainingStartCallCount);
|
assertEquals(1, listener.onTrainingStartCallCount);
|
||||||
assertEquals(0, listener.onEpochStartCallCount);
|
assertEquals(0, listener.onNewEpochCallCount);
|
||||||
assertEquals(0, listener.onEpochEndStartCallCount);
|
assertEquals(0, listener.onEpochTrainingResultCallCount);
|
||||||
assertEquals(1, listener.onTrainingEndCallCount);
|
assertEquals(1, listener.onTrainingEndCallCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_epochStartCanContinueFalse_expect_trainingStopped() {
|
public void when_newEpochCanContinueFalse_expect_trainingStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
MockTrainingListener listener = new MockTrainingListener();
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
sut.addListener(listener);
|
sut.addListener(listener);
|
||||||
listener.nbStepsEpochStartCanContinue = 3;
|
listener.setRemainingOnNewEpochCallCount(2);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.train();
|
sut.train();
|
||||||
|
|
||||||
assertEquals(1, listener.onTrainingStartCallCount);
|
assertEquals(1, listener.onTrainingStartCallCount);
|
||||||
assertEquals(3, listener.onEpochStartCallCount);
|
assertEquals(3, listener.onNewEpochCallCount);
|
||||||
assertEquals(2, listener.onEpochEndStartCallCount);
|
assertEquals(2, listener.onEpochTrainingResultCallCount);
|
||||||
assertEquals(1, listener.onTrainingEndCallCount);
|
assertEquals(1, listener.onTrainingEndCallCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_epochEndCanContinueFalse_expect_trainingStopped() {
|
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
||||||
MockSyncTrainingListener listener = new MockSyncTrainingListener();
|
MockTrainingListener listener = new MockTrainingListener();
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
sut.addListener(listener);
|
sut.addListener(listener);
|
||||||
listener.nbStepsEpochEndCanContinue = 3;
|
listener.setRemainingOnEpochTrainingResult(2);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.train();
|
sut.train();
|
||||||
|
|
||||||
assertEquals(1, listener.onTrainingStartCallCount);
|
assertEquals(1, listener.onTrainingStartCallCount);
|
||||||
assertEquals(3, listener.onEpochStartCallCount);
|
assertEquals(3, listener.onNewEpochCallCount);
|
||||||
assertEquals(3, listener.onEpochEndStartCallCount);
|
assertEquals(3, listener.onEpochTrainingResultCallCount);
|
||||||
assertEquals(1, listener.onTrainingEndCallCount);
|
assertEquals(1, listener.onTrainingEndCallCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class MockSyncLearning extends SyncLearning {
|
public static class MockSyncLearning extends SyncLearning {
|
||||||
|
|
||||||
private LConfiguration conf;
|
private final LConfiguration conf;
|
||||||
|
|
||||||
public MockSyncLearning(LConfiguration conf, IDataManager dataManager) {
|
|
||||||
super(conf);
|
|
||||||
addListener(DataManagerSyncTrainingListener.builder(dataManager).build());
|
|
||||||
this.conf = conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
public MockSyncLearning(LConfiguration conf) {
|
public MockSyncLearning(LConfiguration conf) {
|
||||||
super(conf);
|
super(conf);
|
||||||
|
@ -119,7 +111,7 @@ public class SyncLearningTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Policy getPolicy() {
|
public IPolicy getPolicy() {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscrete;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockMDP;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerSyncTrainingListener;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class QLearningDiscreteTest {
|
|
||||||
@Test
|
|
||||||
public void refac_checkDataManagerCallsRemainTheSame() {
|
|
||||||
// Arrange
|
|
||||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder()
|
|
||||||
.maxStep(10)
|
|
||||||
.expRepMaxSize(1)
|
|
||||||
.build();
|
|
||||||
MockDataManager dataManager = new MockDataManager(true);
|
|
||||||
MockQLearningDiscrete sut = new MockQLearningDiscrete(10, lconfig, dataManager, 2, 3);
|
|
||||||
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
|
||||||
.build();
|
|
||||||
sut.setHistoryProcessor(new MockHistoryProcessor(hpConfig));
|
|
||||||
|
|
||||||
// Act
|
|
||||||
sut.train();
|
|
||||||
|
|
||||||
assertEquals(10, dataManager.statEntries.size());
|
|
||||||
for(int i = 0; i < 10; ++i) {
|
|
||||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
|
||||||
assertEquals(i, entry.getEpochCounter());
|
|
||||||
assertEquals(i+1, entry.getStepCounter());
|
|
||||||
assertEquals(1.0, entry.getReward(), 0.0);
|
|
||||||
|
|
||||||
}
|
|
||||||
assertEquals(4, dataManager.isSaveDataCallCount);
|
|
||||||
assertEquals(4, dataManager.getVideoDirCallCount);
|
|
||||||
assertEquals(11, dataManager.writeInfoCallCount);
|
|
||||||
assertEquals(5, dataManager.saveCallCount);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class MockQLearningDiscrete extends QLearningDiscrete {
|
|
||||||
|
|
||||||
public MockQLearningDiscrete(int maxSteps, QLConfiguration conf,
|
|
||||||
IDataManager dataManager, int saveFrequency, int monitorFrequency) {
|
|
||||||
super(new MockMDP(maxSteps), new MockDQN(), conf, 2);
|
|
||||||
addListener(DataManagerSyncTrainingListener.builder(dataManager)
|
|
||||||
.saveFrequency(saveFrequency)
|
|
||||||
.monitorFrequency(monitorFrequency)
|
|
||||||
.build());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected IDataManager.StatEntry trainEpoch() {
|
|
||||||
setStepCounter(getStepCounter() + 1);
|
|
||||||
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -8,6 +8,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -29,12 +30,11 @@ public class QLearningDiscreteTest {
|
||||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
||||||
0, 1.0, 0, 0, 0, 0, true);
|
0, 1.0, 0, 0, 0, 0, true);
|
||||||
MockDataManager dataManager = new MockDataManager(false);
|
MockDataManager dataManager = new MockDataManager(false);
|
||||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, 10);
|
MockExpReplay expReplay = new MockExpReplay();
|
||||||
|
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10);
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||||
sut.setHistoryProcessor(hp);
|
sut.setHistoryProcessor(hp);
|
||||||
MockExpReplay expReplay = new MockExpReplay();
|
|
||||||
sut.setExpReplay(expReplay);
|
|
||||||
MockEncodable obs = new MockEncodable(1);
|
MockEncodable obs = new MockEncodable(1);
|
||||||
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -131,8 +131,11 @@ public class QLearningDiscreteTest {
|
||||||
|
|
||||||
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
||||||
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn,
|
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn,
|
||||||
QLConfiguration conf, IDataManager dataManager, int epsilonNbStep) {
|
QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
|
||||||
super(mdp, dqn, conf, dataManager, epsilonNbStep);
|
int epsilonNbStep) {
|
||||||
|
super(mdp, dqn, conf, epsilonNbStep);
|
||||||
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
|
setExpReplay(expReplay);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,46 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.support;
|
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEpochEndEvent;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingEvent;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.listener.SyncTrainingListener;
|
|
||||||
|
|
||||||
public class MockSyncTrainingListener implements SyncTrainingListener {
|
|
||||||
|
|
||||||
public int onTrainingStartCallCount = 0;
|
|
||||||
public int onTrainingEndCallCount = 0;
|
|
||||||
public int onEpochStartCallCount = 0;
|
|
||||||
public int onEpochEndStartCallCount = 0;
|
|
||||||
|
|
||||||
public boolean trainingStartCanContinue = true;
|
|
||||||
public int nbStepsEpochStartCanContinue = Integer.MAX_VALUE;
|
|
||||||
public int nbStepsEpochEndCanContinue = Integer.MAX_VALUE;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListenerResponse onTrainingStart(SyncTrainingEvent event) {
|
|
||||||
++onTrainingStartCallCount;
|
|
||||||
return trainingStartCanContinue ? ListenerResponse.CONTINUE : ListenerResponse.STOP;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onTrainingEnd() {
|
|
||||||
++onTrainingEndCallCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListenerResponse onEpochStart(SyncTrainingEvent event) {
|
|
||||||
++onEpochStartCallCount;
|
|
||||||
if(onEpochStartCallCount >= nbStepsEpochStartCanContinue) {
|
|
||||||
return ListenerResponse.STOP;
|
|
||||||
}
|
|
||||||
return ListenerResponse.CONTINUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListenerResponse onEpochEnd(SyncTrainingEpochEndEvent event) {
|
|
||||||
++onEpochEndStartCallCount;
|
|
||||||
if(onEpochEndStartCallCount >= nbStepsEpochEndCanContinue) {
|
|
||||||
return ListenerResponse.STOP;
|
|
||||||
}
|
|
||||||
return ListenerResponse.CONTINUE;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
|
||||||
|
|
||||||
|
public class MockAsyncConfiguration implements AsyncConfiguration {
|
||||||
|
|
||||||
|
private final int nStep;
|
||||||
|
private final int maxEpochStep;
|
||||||
|
|
||||||
|
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
|
||||||
|
this.nStep = nStep;
|
||||||
|
|
||||||
|
this.maxEpochStep = maxEpochStep;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getSeed() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getMaxEpochStep() {
|
||||||
|
return maxEpochStep;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getMaxStep() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumThread() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNstep() {
|
||||||
|
return nStep;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getTargetDqnUpdateFreq() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getUpdateStart() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getRewardFactor() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getGamma() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getErrorClamp() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import lombok.Setter;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
public class MockAsyncGlobal implements IAsyncGlobal {
|
||||||
|
|
||||||
|
public boolean hasBeenStarted = false;
|
||||||
|
public boolean hasBeenTerminated = false;
|
||||||
|
|
||||||
|
@Setter
|
||||||
|
private int maxLoops;
|
||||||
|
@Setter
|
||||||
|
private int numLoopsStopRunning;
|
||||||
|
private int currentLoop = 0;
|
||||||
|
|
||||||
|
public MockAsyncGlobal() {
|
||||||
|
maxLoops = Integer.MAX_VALUE;
|
||||||
|
numLoopsStopRunning = Integer.MAX_VALUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isRunning() {
|
||||||
|
return currentLoop < numLoopsStopRunning;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void terminate() {
|
||||||
|
hasBeenTerminated = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isTrainingComplete() {
|
||||||
|
return ++currentLoop > maxLoops;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void start() {
|
||||||
|
hasBeenStarted = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AtomicInteger getT() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNet getCurrent() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNet getTarget() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void enqueue(Gradient[] gradient, Integer nstep) {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -44,7 +44,7 @@ public class MockDataManager implements IDataManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void save(Learning learning) throws IOException {
|
public void save(ILearning learning) throws IOException {
|
||||||
++saveCallCount;
|
++saveCallCount;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
|
||||||
|
public class MockNeuralNet implements NeuralNet {
|
||||||
|
|
||||||
|
public int resetCallCount = 0;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNetwork[] getNeuralNetworks() {
|
||||||
|
return new NeuralNetwork[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isRecurrent() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
++resetCallCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
|
return new INDArray[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNet clone() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(NeuralNet from) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
||||||
|
return new Gradient[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(INDArray input, INDArray[] labels) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void applyGradient(Gradient[] gradients, int batchSize) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getLatestScore() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(OutputStream os) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(String filename) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
|
||||||
|
public class MockPolicy implements IPolicy<MockEncodable, Integer> {
|
||||||
|
|
||||||
|
public int playCallCount = 0;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) {
|
||||||
|
++playCallCount;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import lombok.Setter;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class MockTrainingListener implements TrainingListener {
|
||||||
|
|
||||||
|
public int onTrainingStartCallCount = 0;
|
||||||
|
public int onTrainingEndCallCount = 0;
|
||||||
|
public int onNewEpochCallCount = 0;
|
||||||
|
public int onEpochTrainingResultCallCount = 0;
|
||||||
|
public int onTrainingProgressCallCount = 0;
|
||||||
|
|
||||||
|
@Setter
|
||||||
|
private int remainingTrainingStartCallCount = Integer.MAX_VALUE;
|
||||||
|
@Setter
|
||||||
|
private int remainingOnNewEpochCallCount = Integer.MAX_VALUE;
|
||||||
|
@Setter
|
||||||
|
private int remainingOnEpochTrainingResult = Integer.MAX_VALUE;
|
||||||
|
@Setter
|
||||||
|
private int remainingonTrainingProgressCallCount = Integer.MAX_VALUE;
|
||||||
|
|
||||||
|
public final List<IDataManager.StatEntry> statEntries = new ArrayList<>();
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onTrainingStart() {
|
||||||
|
++onTrainingStartCallCount;
|
||||||
|
--remainingTrainingStartCallCount;
|
||||||
|
return remainingTrainingStartCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
|
||||||
|
++onNewEpochCallCount;
|
||||||
|
--remainingOnNewEpochCallCount;
|
||||||
|
return remainingOnNewEpochCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
|
||||||
|
++onEpochTrainingResultCallCount;
|
||||||
|
--remainingOnEpochTrainingResult;
|
||||||
|
statEntries.add(statEntry);
|
||||||
|
return remainingOnEpochTrainingResult < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse onTrainingProgress(ILearning learning) {
|
||||||
|
++onTrainingProgressCallCount;
|
||||||
|
--remainingonTrainingProgressCallCount;
|
||||||
|
return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onTrainingEnd() {
|
||||||
|
++onTrainingEndCallCount;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,169 @@
|
||||||
|
package org.deeplearning4j.rl4j.util;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockMDP;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertSame;
|
||||||
|
|
||||||
|
public class DataManagerTrainingListenerTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingOnNewEpochWithoutHistoryProcessor_expect_noException() {
|
||||||
|
// Arrange
|
||||||
|
TestTrainer trainer = new TestTrainer();
|
||||||
|
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
TrainingListener.ListenerResponse response = sut.onNewEpoch(trainer);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingOnNewEpochWithHistoryProcessor_expect_startMonitorNotCalled() {
|
||||||
|
// Arrange
|
||||||
|
TestTrainer trainer = new TestTrainer();
|
||||||
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
|
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||||
|
trainer.setHistoryProcessor(hp);
|
||||||
|
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
TrainingListener.ListenerResponse response = sut.onNewEpoch(trainer);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, hp.startMonitorCallCount);
|
||||||
|
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingOnEpochTrainingResultWithoutHistoryProcessor_expect_noException() {
|
||||||
|
// Arrange
|
||||||
|
TestTrainer trainer = new TestTrainer();
|
||||||
|
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, null);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingOnNewEpochWithHistoryProcessor_expect_stopMonitorNotCalled() {
|
||||||
|
// Arrange
|
||||||
|
TestTrainer trainer = new TestTrainer();
|
||||||
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
|
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||||
|
trainer.setHistoryProcessor(hp);
|
||||||
|
DataManagerTrainingListener sut = new DataManagerTrainingListener(new MockDataManager(false));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, null);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, hp.stopMonitorCallCount);
|
||||||
|
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingOnEpochTrainingResult_expect_callToDataManagerAppendStat() {
|
||||||
|
// Arrange
|
||||||
|
TestTrainer trainer = new TestTrainer();
|
||||||
|
MockDataManager dm = new MockDataManager(false);
|
||||||
|
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
|
||||||
|
MockStatEntry statEntry = new MockStatEntry(0, 0, 0.0);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
TrainingListener.ListenerResponse response = sut.onEpochTrainingResult(trainer, statEntry);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||||
|
assertEquals(1, dm.statEntries.size());
|
||||||
|
assertSame(statEntry, dm.statEntries.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingOnTrainingProgress_expect_callToDataManagerSaveAndWriteInfo() {
|
||||||
|
// Arrange
|
||||||
|
TestTrainer learning = new TestTrainer();
|
||||||
|
MockDataManager dm = new MockDataManager(false);
|
||||||
|
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
TrainingListener.ListenerResponse response = sut.onTrainingProgress(learning);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||||
|
assertEquals(1, dm.writeInfoCallCount);
|
||||||
|
assertEquals(1, dm.saveCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_stepCounterCloseToLastSave_expect_dataManagerSaveNotCalled() {
|
||||||
|
// Arrange
|
||||||
|
TestTrainer learning = new TestTrainer();
|
||||||
|
MockDataManager dm = new MockDataManager(false);
|
||||||
|
DataManagerTrainingListener sut = new DataManagerTrainingListener(dm);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
TrainingListener.ListenerResponse response = sut.onTrainingProgress(learning);
|
||||||
|
TrainingListener.ListenerResponse response2 = sut.onTrainingProgress(learning);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response);
|
||||||
|
assertEquals(TrainingListener.ListenerResponse.CONTINUE, response2);
|
||||||
|
assertEquals(1, dm.saveCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class TestTrainer implements IEpochTrainer, ILearning
|
||||||
|
{
|
||||||
|
@Override
|
||||||
|
public int getStepCounter() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getEpochCounter() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
private IHistoryProcessor historyProcessor;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IPolicy getPolicy() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void train() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LConfiguration getConfiguration() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MDP getMdp() {
|
||||||
|
return new MockMDP(new MockObservationSpace());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue