RL4J: Sanitize async learner (#327)

* refactoring global async to use a much simpler update procedure with a single global lock

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* simplification of async learning algorithms, stabilization + better hyperparameters

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* started to play with using mockito for tests

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* Working on refactoring tests for async classes and trying to make async simpler

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* more work on mockito tests and making some tests much less complex and more explicit in what they are testing

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* some fixes from merging

* do not allow copying of the current network to worker threads, fixing debug line

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* adding some more tests around PR review

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* Adding more tests after review comments

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* few more tests and fixes from PR review

* remove rename of maxEpochStep to maxStepsPerEpisode as we agreed to review this in a seperate PR

* 2019 instead of 2018 on copyright header

* adding konduit copyright to files

* some more copyright headers

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

Co-authored-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Chris Bamford 2020-04-20 03:21:01 +01:00 committed by GitHub
parent 455a5d112d
commit 74420bca31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 1423 additions and 1712 deletions

3
.gitignore vendored
View File

@ -70,3 +70,6 @@ venv2/
# Ignore the nd4j files that are created by javacpp at build to stop merge conflicts
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
# Ignore meld temp files
*.orig

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.rl4j.mdp;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
/**
@ -31,20 +32,20 @@ import org.deeplearning4j.rl4j.space.ObservationSpace;
* in a "functionnal manner" if step return a mdp
*
*/
public interface MDP<O, A, AS extends ActionSpace<A>> {
public interface MDP<OBSERVATION, ACTION, ACTION_SPACE extends ActionSpace<ACTION>> {
ObservationSpace<O> getObservationSpace();
ObservationSpace<OBSERVATION> getObservationSpace();
AS getActionSpace();
ACTION_SPACE getActionSpace();
O reset();
OBSERVATION reset();
void close();
StepReply<O> step(A action);
StepReply<OBSERVATION> step(ACTION action);
boolean isDone();
MDP<O, A, AS> newInstance();
MDP<OBSERVATION, ACTION, ACTION_SPACE> newInstance();
}

View File

@ -17,24 +17,24 @@
package org.deeplearning4j.rl4j.space;
/**
* @param <A> the type of Action
* @param <ACTION> the type of Action
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
* <p>
* Should contain contextual information about the Action space, which is the space of all the actions that could be available.
* Also must know how to return a randomly uniformly sampled action.
*/
public interface ActionSpace<A> {
public interface ActionSpace<ACTION> {
/**
* @return A random action,
*/
A randomAction();
ACTION randomAction();
Object encode(A action);
Object encode(ACTION action);
int getSize();
A noOp();
ACTION noOp();
}

View File

@ -121,6 +121,13 @@
<version>${datavec.version}</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.3.3</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -30,7 +30,7 @@ import java.util.List;
*/
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
private List<StateActionPair<A>> stateActionPairs;
private List<StateActionPair<A>> stateActionPairs = new ArrayList<>();
public void setFinalObservation(Observation observation) {
// Do nothing

View File

@ -1,21 +0,0 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning;
public interface EpochStepCounter {
int getCurrentEpochStep();
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -28,9 +29,11 @@ import org.deeplearning4j.rl4j.mdp.MDP;
* @author Alexandre Boulanger
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*/
public interface IEpochTrainer extends EpochStepCounter {
int getStepCounter();
int getEpochCounter();
public interface IEpochTrainer {
int getStepCount();
int getEpochCount();
int getEpisodeCount();
int getCurrentEpisodeStepCount();
IHistoryProcessor getHistoryProcessor();
MDP getMdp();
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.learning;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.Value;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -51,7 +52,7 @@ public interface IHistoryProcessor {
@AllArgsConstructor
@Builder
@Value
@Data
public static class Configuration {
@Builder.Default int historyLength = 4;
@Builder.Default int rescaledWidth = 84;

View File

@ -21,19 +21,20 @@ import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16.
*
* A common interface that any training method should implement
*/
public interface ILearning<O, A, AS extends ActionSpace<A>> {
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
IPolicy<O, A> getPolicy();
void train();
int getStepCounter();
int getStepCount();
ILearningConfiguration getConfiguration();

View File

@ -38,13 +38,13 @@ import org.nd4j.linalg.factory.Nd4j;
*
*/
@Slf4j
public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
@Getter @Setter
private int stepCounter = 0;
protected int stepCount = 0;
@Getter @Setter
private int epochCounter = 0;
private int epochCount = 0;
@Getter @Setter
private IHistoryProcessor historyProcessor = null;
@ -73,11 +73,11 @@ public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends Neura
public abstract NN getNeuralNet();
public void incrementStep() {
stepCounter++;
stepCount++;
}
public void incrementEpoch() {
epochCounter++;
epochCount++;
}
public void setHistoryProcessor(HistoryProcessor.Configuration conf) {

View File

@ -20,13 +20,11 @@ package org.deeplearning4j.rl4j.learning.async;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.nd4j.linalg.primitives.Pair;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -52,69 +50,75 @@ import java.util.concurrent.atomic.AtomicInteger;
* structure
*/
@Slf4j
public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> {
public class AsyncGlobal<NN extends NeuralNet> implements IAsyncGlobal<NN> {
@Getter
final private NN current;
final private ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue;
final private IAsyncLearningConfiguration configuration;
private final IAsyncLearning learning;
@Getter
private AtomicInteger T = new AtomicInteger(0);
@Getter
private NN target;
@Getter
private boolean running = true;
public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration, IAsyncLearning learning) {
private NN target;
final private IAsyncLearningConfiguration configuration;
@Getter
private final Lock updateLock;
/**
* The number of times the gradient has been updated by worker threads
*/
@Getter
private int workerUpdateCount;
@Getter
private int stepCount;
public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration) {
this.current = initial;
target = (NN) initial.clone();
this.configuration = configuration;
this.learning = learning;
queue = new ConcurrentLinkedQueue<>();
// This is used to sync between
updateLock = new ReentrantLock();
}
public boolean isTrainingComplete() {
return T.get() >= configuration.getMaxStep();
return stepCount >= configuration.getMaxStep();
}
public void enqueue(Gradient[] gradient, Integer nstep) {
if (running && !isTrainingComplete()) {
queue.add(new Pair<>(gradient, nstep));
public void applyGradient(Gradient[] gradient, int nstep) {
if (isTrainingComplete()) {
return;
}
try {
updateLock.lock();
current.applyGradient(gradient, nstep);
stepCount += nstep;
workerUpdateCount++;
int targetUpdateFrequency = configuration.getLearnerUpdateFrequency();
// If we have a target update frequency, this means we only want to update the workers after a certain number of async updates
// This can lead to more stable training
if (targetUpdateFrequency != -1 && workerUpdateCount % targetUpdateFrequency == 0) {
log.info("Updating target network at updates={} steps={}", workerUpdateCount, stepCount);
} else {
target.copy(current);
}
} finally {
updateLock.unlock();
}
}
@Override
public void run() {
while (!isTrainingComplete() && running) {
if (!queue.isEmpty()) {
Pair<Gradient[], Integer> pair = queue.poll();
T.addAndGet(pair.getSecond());
Gradient[] gradient = pair.getFirst();
synchronized (this) {
current.applyGradient(gradient, pair.getSecond());
}
if (configuration.getLearnerUpdateFrequency() != -1 && T.get() / configuration.getLearnerUpdateFrequency() > (T.get() - pair.getSecond())
/ configuration.getLearnerUpdateFrequency()) {
log.info("TARGET UPDATE at T = " + T.get());
synchronized (this) {
target.copy(current);
}
}
}
}
}
/**
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too.
*/
public void terminate() {
if (running) {
running = false;
queue.clear();
learning.terminate();
public NN getTarget() {
try {
updateLock.lock();
return target;
} finally {
updateLock.unlock();
}
}

View File

@ -40,8 +40,8 @@ import org.nd4j.linalg.factory.Nd4j;
* @author Alexandre Boulanger
*/
@Slf4j
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Learning<O, A, AS, NN>
public abstract class AsyncLearning<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>, NN extends NeuralNet>
extends Learning<OBSERVATION, ACTION, ACTION_SPACE, NN>
implements IAsyncLearning {
private Thread monitorThread = null;
@ -69,10 +69,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
protected void startGlobalThread() {
getAsyncGlobal().start();
}
protected boolean isTrainingComplete() {
return getAsyncGlobal().isTrainingComplete();
}
@ -87,7 +83,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
private int progressMonitorFrequency = 20000;
private void launchThreads() {
startGlobalThread();
for (int i = 0; i < getConfiguration().getNumThreads(); i++) {
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
t.start();
@ -99,8 +94,8 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
* @return The current step
*/
@Override
public int getStepCounter() {
return getAsyncGlobal().getT().get();
public int getStepCount() {
return getAsyncGlobal().getStepCount();
}
/**
@ -129,14 +124,13 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
monitorTraining();
}
cleanupPostTraining();
listeners.notifyTrainingFinished();
}
protected void monitorTraining() {
try {
monitorThread = Thread.currentThread();
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
while (canContinue && !isTrainingComplete()) {
canContinue = listeners.notifyTrainingProgress(this);
if (!canContinue) {
return;
@ -152,11 +146,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
monitorThread = null;
}
protected void cleanupPostTraining() {
// Worker threads stops automatically when the global thread stops
getAsyncGlobal().terminate();
}
/**
* Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated.
*/

View File

@ -32,6 +32,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.factory.Nd4j;
@ -47,39 +48,63 @@ import org.nd4j.linalg.factory.Nd4j;
* @author Alexandre Boulanger
*/
@Slf4j
public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>, NN extends NeuralNet>
extends Thread implements IEpochTrainer {
@Getter
private int threadNumber;
@Getter
protected final int deviceNum;
/**
* The number of steps that this async thread has produced
*/
@Getter @Setter
private int stepCounter = 0;
protected int stepCount = 0;
/**
* The number of epochs (updates) that this thread has sent to the global learner
*/
@Getter @Setter
private int epochCounter = 0;
protected int epochCount = 0;
/**
* The number of environment episodes that have been played out
*/
@Getter @Setter
protected int episodeCount = 0;
/**
* The number of steps in the current episode
*/
@Getter
protected int currentEpisodeStepCount = 0;
/**
* If the current episode needs to be reset
*/
boolean episodeComplete = true;
@Getter @Setter
private IHistoryProcessor historyProcessor;
@Getter
private int currentEpochStep = 0;
private boolean isEpochStarted = false;
private final LegacyMDPWrapper<O, A, AS> mdp;
private boolean isEpisodeStarted = false;
private final LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE> mdp;
private final TrainingListenerList listeners;
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, MDP<O, A, AS> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
this.mdp = new LegacyMDPWrapper<O, A, AS>(mdp, null, this);
public AsyncThread(MDP<OBSERVATION, ACTION, ACTION_SPACE> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
this.mdp = new LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE>(mdp, null);
this.listeners = listeners;
this.threadNumber = threadNumber;
this.deviceNum = deviceNum;
}
public MDP<O, A, AS> getMdp() {
public MDP<OBSERVATION, ACTION, ACTION_SPACE> getMdp() {
return mdp.getWrappedMDP();
}
protected LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper() {
protected LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE> getLegacyMDPWrapper() {
return mdp;
}
@ -92,13 +117,13 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
mdp.setHistoryProcessor(historyProcessor);
}
protected void postEpoch() {
protected void postEpisode() {
if (getHistoryProcessor() != null)
getHistoryProcessor().stopMonitor();
}
protected void preEpoch() {
protected void preEpisode() {
// Do nothing
}
@ -125,74 +150,69 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
*/
@Override
public void run() {
try {
RunContext context = new RunContext();
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
log.info("ThreadNum-" + threadNumber + " Started!");
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
if (!isEpochStarted) {
boolean canContinue = startNewEpoch(context);
if (!canContinue) {
break;
}
while (!getAsyncGlobal().isTrainingComplete()) {
if (episodeComplete) {
startEpisode(context);
}
handleTraining(context);
if (currentEpochStep >= getConf().getMaxEpochStep() || getMdp().isDone()) {
boolean canContinue = finishEpoch(context);
if (!canContinue) {
if(!startEpoch(context)) {
break;
}
++epochCounter;
episodeComplete = handleTraining(context);
if(!finishEpoch(context)) {
break;
}
if(episodeComplete) {
finishEpisode(context);
}
}
finally {
terminateWork();
}
}
private void handleTraining(RunContext context) {
int maxSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpochStep);
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
private boolean finishEpoch(RunContext context) {
epochCount++;
IDataManager.StatEntry statEntry = new AsyncStatEntry(stepCount, epochCount, context.rewards, currentEpisodeStepCount, context.score);
return listeners.notifyEpochTrainingResult(this, statEntry);
}
private boolean startEpoch(RunContext context) {
return listeners.notifyNewEpoch(this);
}
private boolean handleTraining(RunContext context) {
int maxTrainSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpisodeStepCount);
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxTrainSteps);
context.obs = subEpochReturn.getLastObs();
context.rewards += subEpochReturn.getReward();
context.score = subEpochReturn.getScore();
return subEpochReturn.isEpisodeComplete();
}
private boolean startNewEpoch(RunContext context) {
private void startEpisode(RunContext context) {
getCurrent().reset();
Learning.InitMdp<Observation> initMdp = refacInitMdp();
context.obs = initMdp.getLastObs();
context.rewards = initMdp.getReward();
isEpochStarted = true;
preEpoch();
return listeners.notifyNewEpoch(this);
preEpisode();
episodeCount++;
}
private boolean finishEpoch(RunContext context) {
isEpochStarted = false;
postEpoch();
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, currentEpochStep, context.score);
private void finishEpisode(RunContext context) {
postEpisode();
log.info("ThreadNum-" + threadNumber + " Epoch: " + getCurrentEpochStep() + ", reward: " + context.rewards);
return listeners.notifyEpochTrainingResult(this, statEntry);
}
private void terminateWork() {
getAsyncGlobal().terminate();
if(isEpochStarted) {
postEpoch();
}
log.info("ThreadNum-{} Episode step: {}, Episode: {}, Epoch: {}, reward: {}", threadNumber, currentEpisodeStepCount, episodeCount, epochCount, context.rewards);
}
protected abstract NN getCurrent();
@ -201,35 +221,35 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
protected abstract IAsyncLearningConfiguration getConf();
protected abstract IPolicy<O, A> getPolicy(NN net);
protected abstract IPolicy<OBSERVATION, ACTION> getPolicy(NN net);
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
private Learning.InitMdp<Observation> refacInitMdp() {
currentEpochStep = 0;
currentEpisodeStepCount = 0;
double reward = 0;
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE> mdp = getLegacyMDPWrapper();
Observation observation = mdp.reset();
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
ACTION action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
while (observation.isSkipped() && !mdp.isDone()) {
StepReply<Observation> stepReply = mdp.step(action);
reward += stepReply.getReward();
observation = stepReply.getObservation();
incrementStep();
incrementSteps();
}
return new Learning.InitMdp(0, observation, reward);
}
public void incrementStep() {
++stepCounter;
++currentEpochStep;
public void incrementSteps() {
stepCount++;
currentEpisodeStepCount++;
}
@AllArgsConstructor
@ -239,6 +259,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
Observation lastObs;
double reward;
double score;
boolean episodeComplete;
}
@AllArgsConstructor

View File

@ -24,6 +24,9 @@ import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
@ -31,14 +34,18 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Stack;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*
* <p>
* Async Learning specialized for the Discrete Domain
*
*/
public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet>
extends AsyncThread<O, Integer, DiscreteSpace, NN> {
@Getter
@ -48,7 +55,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
private UpdateAlgorithm<NN> updateAlgorithm;
// TODO: Make it configurable with a builder
@Setter(AccessLevel.PROTECTED)
@Setter(AccessLevel.PROTECTED) @Getter
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
@ -56,9 +63,9 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
TrainingListenerList listeners,
int threadNumber,
int deviceNum) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
super(mdp, listeners, threadNumber, deviceNum);
synchronized (asyncGlobal) {
current = (NN)asyncGlobal.getCurrent().clone();
current = (NN) asyncGlobal.getTarget().clone();
}
}
@ -72,7 +79,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
}
@Override
protected void preEpoch() {
protected void preEpisode() {
experienceHandler.reset();
}
@ -82,27 +89,22 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
* that stack rewards with t_max MiniTrans
*
* @param sObs the obs to start from
* @param nstep the number of max nstep (step until t_max or state is terminal)
* @param trainingSteps the number of training steps
* @return subepoch training informations
*/
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
public SubEpochReturn trainSubEpoch(Observation sObs, int trainingSteps) {
synchronized (getAsyncGlobal()) {
current.copy(getAsyncGlobal().getCurrent());
}
current.copy(getAsyncGlobal().getTarget());
Observation obs = sObs;
IPolicy<O, Integer> policy = getPolicy(current);
Integer action = getMdp().getActionSpace().noOp();
IHistoryProcessor hp = getHistoryProcessor();
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
double reward = 0;
double accuReward = 0;
int stepAtStart = getCurrentEpochStep();
int lastStep = nstep * skipFrame + stepAtStart;
while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) {
while (!getMdp().isDone() && experienceHandler.getTrainingBatchSize() != trainingSteps) {
//if step of training, just repeat lastAction
if (!obs.isSkipped()) {
@ -115,20 +117,26 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
if (!obs.isSkipped()) {
experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone());
accuReward = 0;
incrementSteps();
}
obs = stepReply.getObservation();
reward += stepReply.getReward();
incrementStep();
}
if (getMdp().isDone() && getCurrentEpochStep() < lastStep) {
boolean episodeComplete = getMdp().isDone() || getConf().getMaxEpochStep() == currentEpisodeStepCount;
if (episodeComplete && experienceHandler.getTrainingBatchSize() != trainingSteps) {
experienceHandler.setFinalObservation(obs);
}
getAsyncGlobal().enqueue(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), getCurrentEpochStep());
int experienceSize = experienceHandler.getTrainingBatchSize();
return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore());
getAsyncGlobal().applyGradient(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), experienceSize);
return new SubEpochReturn(experienceSize, obs, reward, current.getLatestScore(), episodeComplete);
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -22,17 +23,29 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import java.util.concurrent.atomic.AtomicInteger;
public interface IAsyncGlobal<NN extends NeuralNet> {
boolean isRunning();
boolean isTrainingComplete();
void start();
/**
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
* The number of updates that have been applied by worker threads.
*/
void terminate();
int getWorkerUpdateCount();
AtomicInteger getT();
NN getCurrent();
/**
* The total number of environment steps that have been processed.
*/
int getStepCount();
/**
* A copy of the global network that is updated after a certain number of worker episodes.
*/
NN getTarget();
void enqueue(Gradient[] gradient, Integer nstep);
/**
* Apply gradients to the global network
* @param gradient
* @param batchSize
*/
void applyGradient(Gradient[] gradient, int batchSize);
}

View File

@ -57,7 +57,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
this.iActorCritic = iActorCritic;
this.mdp = mdp;
this.configuration = conf;
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this);
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
Long seed = conf.getSeed();
Random rnd = Nd4j.getRandom();

View File

@ -27,7 +27,6 @@ import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random;
@ -73,6 +72,6 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
@Override
protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
return new A3CUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getLearnerUpdateFrequency(), conf.getGamma());
return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), conf.getGamma());
}
}

View File

@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -27,28 +26,25 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.List;
public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
/**
* The Advantage Actor-Critic update algorithm can be used by A2C and A3C algorithms alike
*/
public class AdvantageActorCriticUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
private final IAsyncGlobal asyncGlobal;
private final int[] shape;
private final int actionSpaceSize;
private final int targetDqnUpdateFreq;
private final double gamma;
private final boolean recurrent;
public A3CUpdateAlgorithm(IAsyncGlobal asyncGlobal,
public AdvantageActorCriticUpdateAlgorithm(boolean recurrent,
int[] shape,
int actionSpaceSize,
int targetDqnUpdateFreq,
double gamma) {
this.asyncGlobal = asyncGlobal;
//if recurrent then train as a time serie with a batch size of 1
recurrent = asyncGlobal.getCurrent().isRecurrent();
this.recurrent = recurrent;
this.shape = shape;
this.actionSpaceSize = actionSpaceSize;
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
this.gamma = gamma;
}
@ -65,18 +61,12 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
: Nd4j.zeros(size, actionSpaceSize);
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
double r;
double value;
if (stateActionPair.isTerminal()) {
r = 0;
}
else {
INDArray[] output = null;
if (targetDqnUpdateFreq == -1)
output = current.outputAll(stateActionPair.getObservation().getData());
else synchronized (asyncGlobal) {
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
}
r = output[0].getDouble(0);
value = 0;
} else {
INDArray[] output = current.outputAll(stateActionPair.getObservation().getData());
value = output[0].getDouble(0);
}
for (int i = size - 1; i >= 0; --i) {
@ -86,7 +76,7 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
INDArray[] output = current.outputAll(observationData);
r = stateActionPair.getReward() + gamma * r;
value = stateActionPair.getReward() + gamma * value;
if (recurrent) {
input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(observationData);
} else {
@ -94,11 +84,11 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
}
//the critic
targets.putScalar(i, r);
targets.putScalar(i, value);
//the actor
double expectedV = output[0].getDouble(0);
double advantage = r - expectedV;
double advantage = value - expectedV;
if (recurrent) {
logSoftmax.putScalar(0, stateActionPair.getAction(), i, advantage);
} else {

View File

@ -50,7 +50,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
this.mdp = mdp;
this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this);
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
}
@Override
@ -59,7 +59,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
}
public IDQN getNeuralNet() {
return asyncGlobal.getCurrent();
return asyncGlobal.getTarget();
}
public IPolicy<O, Integer> getPolicy() {

View File

@ -30,8 +30,8 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -72,6 +72,6 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
@Override
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
return new QLearningUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getTargetDqnUpdateFreq(), conf.getGamma());
return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), conf.getGamma());
}
}

View File

@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -28,22 +27,16 @@ import java.util.List;
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
private final IAsyncGlobal asyncGlobal;
private final int[] shape;
private final int actionSpaceSize;
private final int targetDqnUpdateFreq;
private final double gamma;
public QLearningUpdateAlgorithm(IAsyncGlobal asyncGlobal,
int[] shape,
public QLearningUpdateAlgorithm(int[] shape,
int actionSpaceSize,
int targetDqnUpdateFreq,
double gamma) {
this.asyncGlobal = asyncGlobal;
this.shape = shape;
this.actionSpaceSize = actionSpaceSize;
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
this.gamma = gamma;
}
@ -60,14 +53,9 @@ public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
double r;
if (stateActionPair.isTerminal()) {
r = 0;
}
else {
} else {
INDArray[] output = null;
if (targetDqnUpdateFreq == -1)
output = current.outputAll(stateActionPair.getObservation().getData());
else synchronized (asyncGlobal) {
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
}
r = Nd4j.max(output[0]).getDouble(0);
}

View File

@ -20,8 +20,14 @@ public interface IAsyncLearningConfiguration extends ILearningConfiguration {
int getNumThreads();
/**
* The number of steps to collect for each worker thread between each global update
*/
int getNStep();
/**
* The frequency of worker thread gradient updates to perform a copy of the current working network to the target network
*/
int getLearnerUpdateFrequency();
int getMaxStep();

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
@ -35,8 +36,8 @@ import org.deeplearning4j.rl4j.util.IDataManager;
* @author Alexandre Boulanger
*/
@Slf4j
public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Learning<O, A, AS, NN> implements IEpochTrainer {
public abstract class SyncLearning<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>, NN extends NeuralNet>
extends Learning<OBSERVATION, ACTION, ACTION_SPACE, NN> implements IEpochTrainer {
private final TrainingListenerList listeners = new TrainingListenerList();
@ -85,7 +86,7 @@ public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends N
boolean canContinue = listeners.notifyTrainingStarted();
if (canContinue) {
while (getStepCounter() < getConfiguration().getMaxStep()) {
while (this.getStepCount() < getConfiguration().getMaxStep()) {
preEpoch();
canContinue = listeners.notifyNewEpoch(this);
if (!canContinue) {
@ -100,14 +101,14 @@ public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends N
postEpoch();
if(getEpochCounter() % progressMonitorFrequency == 0) {
if(getEpochCount() % progressMonitorFrequency == 0) {
canContinue = listeners.notifyTrainingProgress(this);
if (!canContinue) {
break;
}
}
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
log.info("Epoch: " + getEpochCount() + ", reward: " + statEntry.getReward());
incrementEpoch();
}
}

View File

@ -19,21 +19,16 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -43,8 +38,6 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
@ -58,7 +51,7 @@ import java.util.List;
@Slf4j
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
extends SyncLearning<O, A, AS, IDQN>
implements TargetQNetworkSource, EpochStepCounter {
implements TargetQNetworkSource, IEpochTrainer {
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
@ -90,7 +83,10 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
protected abstract QLStepReturn<Observation> trainStep(Observation obs);
@Getter
private int currentEpochStep = 0;
private int episodeCount;
@Getter
private int currentEpisodeStepCount = 0;
protected StatEntry trainEpoch() {
resetNetworks();
@ -104,9 +100,9 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
double meanQ = 0;
int numQ = 0;
List<Double> scores = new ArrayList<>();
while (currentEpochStep < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
while (currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
if (getStepCounter() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
if (this.getStepCount() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
updateTargetNetwork();
}
@ -132,20 +128,20 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
meanQ /= (numQ + 0.001); //avoid div zero
StatEntry statEntry = new QLStatEntry(getStepCounter(), getEpochCounter(), reward, currentEpochStep, scores,
StatEntry statEntry = new QLStatEntry(this.getStepCount(), getEpochCount(), reward, currentEpisodeStepCount, scores,
getEgPolicy().getEpsilon(), startQ, meanQ);
return statEntry;
}
protected void finishEpoch(Observation observation) {
// Do Nothing
episodeCount++;
}
@Override
public void incrementStep() {
super.incrementStep();
++currentEpochStep;
++currentEpisodeStepCount;
}
protected void resetNetworks() {
@ -154,7 +150,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
}
private InitMdp<Observation> refacInitMdp() {
currentEpochStep = 0;
currentEpisodeStepCount = 0;
double reward = 0;

View File

@ -47,6 +47,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
* <p>
@ -90,7 +91,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
int epsilonNbStep, Random random) {
this.configuration = conf;
this.mdp = new LegacyMDPWrapper<>(mdp, null, this);
this.mdp = new LegacyMDPWrapper<>(mdp, null);
qNetwork = dqn;
targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork());
@ -164,13 +165,13 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
// Update NN
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
if (getStepCounter() > updateStart) {
if (this.getStepCount() > updateStart) {
DataSet targets = setTarget(experienceHandler.generateTrainingBatch());
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
}
}
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
return new QLStepReturn<>(maxQ, getQNetwork().getLatestScore(), stepReply);
}
protected DataSet setTarget(List<Transition<Integer>> transitions) {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.rl4j.observation;
import lombok.Getter;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
/**

View File

@ -40,7 +40,7 @@ public class EncodableToImageWritableTransform implements Operation<Encodable, I
@Override
public ImageWritable transform(Encodable encodable) {
INDArray indArray = Nd4j.create((encodable).toArray()).reshape(height, width, colorChannels);
INDArray indArray = Nd4j.create(encodable.toArray()).reshape(height, width, colorChannels);
Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer());
return new ImageWritable(converter.convert(mat));
}

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
@ -40,7 +41,7 @@ import org.nd4j.linalg.api.rng.Random;
*/
@AllArgsConstructor
@Slf4j
public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<O, A> {
final private Policy<O, A> policy;
final private MDP<O, A, AS> mdp;
@ -57,8 +58,8 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
public A nextAction(INDArray input) {
double ep = getEpsilon();
if (learning.getStepCounter() % 500 == 1)
log.info("EP: " + ep + " " + learning.getStepCounter());
if (learning.getStepCount() % 500 == 1)
log.info("EP: " + ep + " " + learning.getStepCount());
if (rnd.nextDouble() > ep)
return policy.nextAction(input);
else
@ -70,6 +71,6 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
}
public double getEpsilon() {
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCounter() - updateStart) * 1.0 / epsilonNbStep));
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCount() - updateStart) * 1.0 / epsilonNbStep));
}
}

View File

@ -7,7 +7,7 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
public interface IPolicy<O, A> {
public interface IPolicy<O extends Encodable, A> {
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
A nextAction(INDArray input);
A nextAction(Observation observation);

View File

@ -16,10 +16,7 @@
package org.deeplearning4j.rl4j.policy;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
@ -27,6 +24,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
/**
@ -36,7 +34,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
*
* A Policy responsability is to choose the next action given a state
*/
public abstract class Policy<O, A> implements IPolicy<O, A> {
public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
public abstract NeuralNet getNeuralNet();
@ -54,10 +52,9 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
resetNetworks();
RefacEpochStepCounter epochStepCounter = new RefacEpochStepCounter();
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, epochStepCounter);
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp);
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp, epochStepCounter);
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp);
Observation obs = initMdp.getLastObs();
double reward = initMdp.getReward();
@ -79,7 +76,6 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
reward += stepReply.getReward();
obs = stepReply.getObservation();
epochStepCounter.incrementEpochStep();
}
return reward;
@ -89,8 +85,7 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
getNeuralNet().reset();
}
protected <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
epochStepCounter.setCurrentEpochStep(0);
protected <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
double reward = 0;
@ -104,21 +99,9 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
reward += stepReply.getReward();
observation = stepReply.getObservation();
epochStepCounter.incrementEpochStep();
}
return new Learning.InitMdp(0, observation, reward);
}
public class RefacEpochStepCounter implements EpochStepCounter {
@Getter
@Setter
private int currentEpochStep = 0;
public void incrementEpochStep() {
++currentEpochStep;
}
}
}

View File

@ -278,7 +278,7 @@ public class DataManager implements IDataManager {
Path infoPath = Paths.get(getInfo());
Info info = new Info(iLearning.getClass().getSimpleName(), iLearning.getMdp().getClass().getSimpleName(),
iLearning.getConfiguration(), iLearning.getStepCounter(), System.currentTimeMillis());
iLearning.getConfiguration(), iLearning.getStepCount(), System.currentTimeMillis());
String toWrite = toJson(info);
Files.write(infoPath, toWrite.getBytes(), StandardOpenOption.TRUNCATE_EXISTING);
@ -300,12 +300,12 @@ public class DataManager implements IDataManager {
if (!saveData)
return;
save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning);
save(getModelDir() + "/" + learning.getStepCount() + ".training", learning);
if(learning instanceof NeuralNetFetchable) {
try {
((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCount() + ".model");
} catch (UnsupportedOperationException e) {
String path = getModelDir() + "/" + learning.getStepCounter();
String path = getModelDir() + "/" + learning.getStepCount();
((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(path + "_value.model", path + "_policy.model");
}
}

View File

@ -40,7 +40,7 @@ public class DataManagerTrainingListener implements TrainingListener {
if (trainer instanceof AsyncThread) {
filename += ((AsyncThread) trainer).getThreadNumber() + "-";
}
filename += trainer.getEpochCounter() + "-" + trainer.getStepCounter() + ".mp4";
filename += trainer.getEpochCount() + "-" + trainer.getStepCount() + ".mp4";
hp.startMonitor(filename, shape);
}
@ -66,7 +66,7 @@ public class DataManagerTrainingListener implements TrainingListener {
@Override
public ListenerResponse onTrainingProgress(ILearning learning) {
try {
int stepCounter = learning.getStepCounter();
int stepCounter = learning.getStepCount();
if (stepCounter - lastSave >= Constants.MODEL_SAVE_FREQ) {
dataManager.save(learning);
lastSave = stepCounter;

View File

@ -8,7 +8,6 @@ import org.datavec.image.transform.CropImageTransform;
import org.datavec.image.transform.MultiImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
@ -30,10 +29,10 @@ import java.util.Map;
import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2GRAY;
public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
@Getter
private final MDP<O, A, AS> wrappedMDP;
private final MDP<OBSERVATION, A, AS> wrappedMDP;
@Getter
private final WrapperObservationSpace observationSpace;
private final int[] shape;
@ -44,16 +43,14 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
@Getter(AccessLevel.PRIVATE)
private IHistoryProcessor historyProcessor;
private final EpochStepCounter epochStepCounter;
private int skipFrame = 1;
private int steps = 0;
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) {
public LegacyMDPWrapper(MDP<OBSERVATION, A, AS> wrappedMDP, IHistoryProcessor historyProcessor) {
this.wrappedMDP = wrappedMDP;
this.shape = wrappedMDP.getObservationSpace().getShape();
this.observationSpace = new WrapperObservationSpace(shape);
this.historyProcessor = historyProcessor;
this.epochStepCounter = epochStepCounter;
setHistoryProcessor(historyProcessor);
}
@ -63,6 +60,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
createTransformProcess();
}
//TODO: this transform process should be decoupled from history processor and configured seperately by the end-user
private void createTransformProcess() {
IHistoryProcessor historyProcessor = getHistoryProcessor();
@ -103,7 +101,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
public Observation reset() {
transformProcess.reset();
O rawResetResponse = wrappedMDP.reset();
OBSERVATION rawResetResponse = wrappedMDP.reset();
record(rawResetResponse);
if(historyProcessor != null) {
@ -118,21 +116,21 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
public StepReply<Observation> step(A a) {
IHistoryProcessor historyProcessor = getHistoryProcessor();
StepReply<O> rawStepReply = wrappedMDP.step(a);
StepReply<OBSERVATION> rawStepReply = wrappedMDP.step(a);
INDArray rawObservation = getInput(rawStepReply.getObservation());
if(historyProcessor != null) {
historyProcessor.record(rawObservation);
}
int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
int stepOfObservation = steps++;
Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation());
Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone());
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
}
private void record(O obs) {
private void record(OBSERVATION obs) {
INDArray rawObservation = getInput(obs);
IHistoryProcessor historyProcessor = getHistoryProcessor();
@ -141,7 +139,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
}
}
private Map<String, Object> buildChannelsData(final O obs) {
private Map<String, Object> buildChannelsData(final OBSERVATION obs) {
return new HashMap<String, Object>() {{
put("data", obs);
}};
@ -159,11 +157,11 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
@Override
public MDP<Observation, A, AS> newInstance() {
return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), historyProcessor, epochStepCounter);
return new LegacyMDPWrapper<>(wrappedMDP.newInstance(), historyProcessor);
}
private INDArray getInput(O obs) {
INDArray arr = Nd4j.create(((Encodable)obs).toArray());
private INDArray getInput(OBSERVATION obs) {
INDArray arr = Nd4j.create(obs.toArray());
int[] shape = observationSpace.getShape();
if (shape.length == 1)
return arr.reshape(new long[] {1, arr.length()});

View File

@ -18,132 +18,93 @@
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockEncodable;
import org.deeplearning4j.rl4j.support.MockNeuralNet;
import org.deeplearning4j.rl4j.support.MockPolicy;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class AsyncLearningTest {
@Test
public void when_training_expect_AsyncGlobalStarted() {
// Arrange
TestContext context = new TestContext();
context.asyncGlobal.setMaxLoops(1);
AsyncLearning<Box, INDArray, ActionSpace<INDArray>, NeuralNet> asyncLearning;
// Act
context.sut.train();
@Mock
TrainingListener mockTrainingListener;
// Assert
assertTrue(context.asyncGlobal.hasBeenStarted);
assertTrue(context.asyncGlobal.hasBeenTerminated);
@Mock
AsyncGlobal<NeuralNet> mockAsyncGlobal;
@Mock
IAsyncLearningConfiguration mockConfiguration;
@Before
public void setup() {
asyncLearning = mock(AsyncLearning.class, Mockito.withSettings()
.useConstructor()
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
asyncLearning.addListener(mockTrainingListener);
when(asyncLearning.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
when(asyncLearning.getConfiguration()).thenReturn(mockConfiguration);
// Don't actually start any threads in any of these tests
when(mockConfiguration.getNumThreads()).thenReturn(0);
}
@Test
public void when_trainStartReturnsStop_expect_noTraining() {
// Arrange
TestContext context = new TestContext();
context.listener.setRemainingTrainingStartCallCount(0);
when(mockTrainingListener.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP);
// Act
context.sut.train();
asyncLearning.train();
// Assert
assertEquals(1, context.listener.onTrainingStartCallCount);
assertEquals(1, context.listener.onTrainingEndCallCount);
assertEquals(0, context.policy.playCallCount);
assertTrue(context.asyncGlobal.hasBeenTerminated);
verify(mockTrainingListener, times(1)).onTrainingStart();
verify(mockTrainingListener, times(1)).onTrainingEnd();
}
@Test
public void when_trainingIsComplete_expect_trainingStop() {
// Arrange
TestContext context = new TestContext();
when(mockAsyncGlobal.isTrainingComplete()).thenReturn(true);
// Act
context.sut.train();
asyncLearning.train();
// Assert
assertEquals(1, context.listener.onTrainingStartCallCount);
assertEquals(1, context.listener.onTrainingEndCallCount);
assertTrue(context.asyncGlobal.hasBeenTerminated);
verify(mockTrainingListener, times(1)).onTrainingStart();
verify(mockTrainingListener, times(1)).onTrainingEnd();
}
@Test
public void when_training_expect_onTrainingProgressCalled() {
// Arrange
TestContext context = new TestContext();
asyncLearning.setProgressMonitorFrequency(100);
when(mockTrainingListener.onTrainingProgress(eq(asyncLearning))).thenReturn(TrainingListener.ListenerResponse.STOP);
// Act
context.sut.train();
asyncLearning.train();
// Assert
assertEquals(1, context.listener.onTrainingProgressCallCount);
}
public static class TestContext {
MockAsyncConfiguration config = new MockAsyncConfiguration(1L, 11, 0, 0, 0, 0,0, 0, 0, 0);
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockPolicy policy = new MockPolicy();
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
public final MockTrainingListener listener = new MockTrainingListener(asyncGlobal);
public TestContext() {
sut.addListener(listener);
asyncGlobal.setMaxLoops(1);
sut.setProgressMonitorFrequency(1);
verify(mockTrainingListener, times(1)).onTrainingStart();
verify(mockTrainingListener, times(1)).onTrainingEnd();
verify(mockTrainingListener, times(1)).onTrainingProgress(eq(asyncLearning));
}
}
public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
private final IAsyncLearningConfiguration conf;
private final IAsyncGlobal asyncGlobal;
private final IPolicy<MockEncodable, Integer> policy;
public TestAsyncLearning(IAsyncLearningConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
this.conf = conf;
this.asyncGlobal = asyncGlobal;
this.policy = policy;
}
@Override
public IPolicy getPolicy() {
return policy;
}
@Override
public IAsyncLearningConfiguration getConfiguration() {
return conf;
}
@Override
protected AsyncThread newThread(int i, int deviceAffinity) {
return null;
}
@Override
public MDP getMdp() {
return null;
}
@Override
protected IAsyncGlobal getAsyncGlobal() {
return asyncGlobal;
}
@Override
public MockNeuralNet getNeuralNet() {
return null;
}
}
}

View File

@ -1,5 +1,4 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K. K.
*
* This program and the accompanying materials are made available under the
@ -17,161 +16,230 @@
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class AsyncThreadDiscreteTest {
AsyncThreadDiscrete<Encodable, NeuralNet> asyncThreadDiscrete;
@Mock
IAsyncLearningConfiguration mockAsyncConfiguration;
@Mock
UpdateAlgorithm<NeuralNet> mockUpdateAlgorithm;
@Mock
IAsyncGlobal<NeuralNet> mockAsyncGlobal;
@Mock
Policy<Encodable, Integer> mockGlobalCurrentPolicy;
@Mock
NeuralNet mockGlobalTargetNetwork;
@Mock
MDP<Encodable, Integer, DiscreteSpace> mockMDP;
@Mock
LegacyMDPWrapper<Encodable, Integer, DiscreteSpace> mockLegacyMDPWrapper;
@Mock
DiscreteSpace mockActionSpace;
@Mock
ObservationSpace<Encodable> mockObservationSpace;
@Mock
TrainingListenerList mockTrainingListenerList;
@Mock
Observation mockObservation;
int[] observationShape = new int[]{3, 10, 10};
int actionSize = 4;
private void setupMDPMocks() {
when(mockActionSpace.noOp()).thenReturn(0);
when(mockMDP.getActionSpace()).thenReturn(mockActionSpace);
when(mockObservationSpace.getShape()).thenReturn(observationShape);
when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace);
}
private void setupNNMocks() {
when(mockAsyncGlobal.getTarget()).thenReturn(mockGlobalTargetNetwork);
when(mockGlobalTargetNetwork.clone()).thenReturn(mockGlobalTargetNetwork);
}
@Before
public void setup() {
setupMDPMocks();
setupNNMocks();
asyncThreadDiscrete = mock(AsyncThreadDiscrete.class, Mockito.withSettings()
.useConstructor(mockAsyncGlobal, mockMDP, mockTrainingListenerList, 0, 0)
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm);
when(asyncThreadDiscrete.getConf()).thenReturn(mockAsyncConfiguration);
when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0);
when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy);
when(mockGlobalCurrentPolicy.nextAction(any(Observation.class))).thenReturn(0);
when(asyncThreadDiscrete.getLegacyMDPWrapper()).thenReturn(mockLegacyMDPWrapper);
}
@Test
public void refac_AsyncThreadDiscrete_trainSubEpoch() {
public void when_episodeCompletes_expect_stepsToBeInLineWithEpisodeLenth() {
// Arrange
int numEpochs = 1;
MockNeuralNet nnMock = new MockNeuralNet();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock);
asyncGlobalMock.setMaxLoops(hpConf.getSkipFrame() * numEpochs);
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace);
TrainingListenerList listeners = new TrainingListenerList();
MockPolicy policyMock = new MockPolicy();
MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5);
MockExperienceHandler experienceHandlerMock = new MockExperienceHandler();
MockUpdateAlgorithm updateAlgorithmMock = new MockUpdateAlgorithm();
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock, experienceHandlerMock, updateAlgorithmMock);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
int episodeRemaining = 5;
int remainingTrainingSteps = 10;
// return done after 4 steps (the episode finishes before nsteps)
when(mockMDP.isDone()).thenAnswer(invocation ->
asyncThreadDiscrete.getStepCount() == episodeRemaining
);
when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null));
// Act
sut.run();
AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
// Assert
assertEquals(2, sut.trainSubEpochResults.size());
double[][] expectedLastObservations = new double[][] {
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
};
double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 };
for(int i = 0; i < 2; ++i) {
AsyncThread.SubEpochReturn result = sut.trainSubEpochResults.get(i);
assertEquals(4, result.getSteps());
assertEquals(expectedSubEpochReturnRewards[i], result.getReward(), 0.00001);
assertEquals(0.0, result.getScore(), 0.00001);
double[] expectedLastObservation = expectedLastObservations[i];
assertEquals(expectedLastObservation.length, result.getLastObs().getData().shape()[1]);
for(int j = 0; j < expectedLastObservation.length; ++j) {
assertEquals(expectedLastObservation[j], 255.0 * result.getLastObs().getData().getDouble(j), 0.00001);
}
}
assertEquals(2, asyncGlobalMock.enqueueCallCount);
// HistoryProcessor
double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, };
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
for(int i = 0; i < expectedRecordValues.length; ++i) {
assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001);
assertTrue(subEpochReturn.isEpisodeComplete());
assertEquals(5, subEpochReturn.getSteps());
}
// Policy
double[][] expectedPolicyInputs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
};
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
double[] expectedRow = expectedPolicyInputs[i];
INDArray input = policyMock.actionInputs.get(i);
assertEquals(expectedRow.length, input.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
}
@Test
public void when_episodeCompletesDueToMaxStepsReached_expect_isEpisodeComplete() {
// Arrange
int remainingTrainingSteps = 50;
// Episode does not complete due to MDP
when(mockMDP.isDone()).thenReturn(false);
when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null));
when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(50);
// Act
AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
// Assert
assertTrue(subEpochReturn.isEpisodeComplete());
assertEquals(50, subEpochReturn.getSteps());
}
// ExperienceHandler
double[][] expectedExperienceHandlerInputs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
};
assertEquals(expectedExperienceHandlerInputs.length, experienceHandlerMock.addExperienceArgs.size());
for(int i = 0; i < expectedExperienceHandlerInputs.length; ++i) {
double[] expectedRow = expectedExperienceHandlerInputs[i];
INDArray input = experienceHandlerMock.addExperienceArgs.get(i).getObservation().getData();
assertEquals(expectedRow.length, input.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
}
}
@Test
public void when_episodeLongerThanNsteps_expect_returnNStepLength() {
// Arrange
int episodeRemaining = 5;
int remainingTrainingSteps = 4;
// return done after 4 steps (the episode finishes before nsteps)
when(mockMDP.isDone()).thenAnswer(invocation ->
asyncThreadDiscrete.getStepCount() == episodeRemaining
);
when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null));
// Act
AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
// Assert
assertFalse(subEpochReturn.isEpisodeComplete());
assertEquals(remainingTrainingSteps, subEpochReturn.getSteps());
}
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
@Test
public void when_framesAreSkipped_expect_proportionateStepCounterUpdates() {
int skipFrames = 2;
int remainingTrainingSteps = 10;
private final MockAsyncGlobal asyncGlobal;
private final MockPolicy policy;
private final MockAsyncConfiguration config;
// Episode does not complete due to MDP
when(mockMDP.isDone()).thenReturn(false);
public final List<SubEpochReturn> trainSubEpochResults = new ArrayList<SubEpochReturn>();
AtomicInteger stepCount = new AtomicInteger();
public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
MockAsyncConfiguration config, IHistoryProcessor hp,
ExperienceHandler<Integer, Transition<Integer>> experienceHandler,
UpdateAlgorithm<MockNeuralNet> updateAlgorithm) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.asyncGlobal = asyncGlobal;
this.policy = policy;
this.config = config;
setHistoryProcessor(hp);
setExperienceHandler(experienceHandler);
setUpdateAlgorithm(updateAlgorithm);
// Use skipFrames to return if observations are skipped or not
when(mockLegacyMDPWrapper.step(anyInt())).thenAnswer(invocationOnMock -> {
boolean isSkipped = stepCount.incrementAndGet() % skipFrames != 0;
Observation mockObs = new Observation(isSkipped ? null : Nd4j.create(observationShape));
return new StepReply<>(mockObs, 0.0, false, null);
});
// Act
AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
// Assert
assertFalse(subEpochReturn.isEpisodeComplete());
assertEquals(remainingTrainingSteps, subEpochReturn.getSteps());
assertEquals((remainingTrainingSteps - 1) * skipFrames + 1, stepCount.get());
}
@Override
protected IAsyncGlobal<MockNeuralNet> getAsyncGlobal() {
return asyncGlobal;
@Test
public void when_preEpisodeCalled_expect_experienceHandlerReset() {
// Arrange
int trainingSteps = 100;
for (int i = 0; i < trainingSteps; i++) {
asyncThreadDiscrete.getExperienceHandler().addExperience(mockObservation, 0, 0.0, false);
}
@Override
protected IAsyncLearningConfiguration getConf() {
return config;
int experienceHandlerSizeBeforeReset = asyncThreadDiscrete.getExperienceHandler().getTrainingBatchSize();
// Act
asyncThreadDiscrete.preEpisode();
// Assert
assertEquals(100, experienceHandlerSizeBeforeReset);
assertEquals(0, asyncThreadDiscrete.getExperienceHandler().getTrainingBatchSize());
}
@Override
protected IPolicy<MockEncodable, Integer> getPolicy(MockNeuralNet net) {
return policy;
}
@Override
protected UpdateAlgorithm<MockNeuralNet> buildUpdateAlgorithm() {
return null;
}
@Override
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
asyncGlobal.increaseCurrentLoop();
SubEpochReturn result = super.trainSubEpoch(sObs, nstep);
trainSubEpochResults.add(result);
return result;
}
}
}

View File

@ -1,220 +1,277 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockEncodable;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockNeuralNet;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Before;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.base.Preconditions;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class AsyncThreadTest {
@Test
public void when_newEpochStarted_expect_neuralNetworkReset() {
// Arrange
int numberOfEpochs = 5;
TestContext context = new TestContext(numberOfEpochs);
@Mock
ActionSpace<INDArray> mockActionSpace;
// Act
context.sut.run();
@Mock
ObservationSpace<Box> mockObservationSpace;
// Assert
assertEquals(numberOfEpochs, context.neuralNet.resetCallCount);
@Mock
IAsyncLearningConfiguration mockAsyncConfiguration;
@Mock
NeuralNet mockNeuralNet;
@Mock
IAsyncGlobal<NeuralNet> mockAsyncGlobal;
@Mock
MDP<Box, INDArray, ActionSpace<INDArray>> mockMDP;
@Mock
TrainingListenerList mockTrainingListeners;
int[] observationShape = new int[]{3, 10, 10};
int actionSize = 4;
AsyncThread<Box, INDArray, ActionSpace<INDArray>, NeuralNet> thread;
@Before
public void setup() {
setupMDPMocks();
setupThreadMocks();
}
private void setupThreadMocks() {
thread = mock(AsyncThread.class, Mockito.withSettings()
.useConstructor(mockMDP, mockTrainingListeners, 0, 0)
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
when(thread.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
when(thread.getCurrent()).thenReturn(mockNeuralNet);
}
private void setupMDPMocks() {
when(mockObservationSpace.getShape()).thenReturn(observationShape);
when(mockActionSpace.noOp()).thenReturn(Nd4j.zeros(actionSize));
when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace);
when(mockMDP.getActionSpace()).thenReturn(mockActionSpace);
int dataLength = 1;
for (int d : observationShape) {
dataLength *= d;
}
when(mockMDP.reset()).thenReturn(new Box(new double[dataLength]));
}
private void mockTrainingListeners() {
mockTrainingListeners(false, false);
}
private void mockTrainingListeners(boolean stopOnNotifyNewEpoch, boolean stopOnNotifyEpochTrainingResult) {
when(mockTrainingListeners.notifyNewEpoch(eq(thread))).thenReturn(!stopOnNotifyNewEpoch);
when(mockTrainingListeners.notifyEpochTrainingResult(eq(thread), any(IDataManager.StatEntry.class))).thenReturn(!stopOnNotifyEpochTrainingResult);
}
private void mockTrainingContext() {
mockTrainingContext(1000, 100, 10);
}
private void mockTrainingContext(int maxSteps, int maxStepsPerEpisode, int nstep) {
// Some conditions of this test harness
Preconditions.checkArgument(maxStepsPerEpisode >= nstep, "episodeLength must be greater than or equal to nstep");
Preconditions.checkArgument(maxStepsPerEpisode % nstep == 0, "episodeLength must be a multiple of nstep");
Observation mockObs = new Observation(Nd4j.zeros(observationShape));
when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode);
when(mockAsyncConfiguration.getNStep()).thenReturn(nstep);
when(thread.getConf()).thenReturn(mockAsyncConfiguration);
// if we hit the max step count
when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps);
when(thread.trainSubEpoch(any(Observation.class), anyInt())).thenAnswer(invocationOnMock -> {
int steps = invocationOnMock.getArgument(1);
thread.stepCount += steps;
thread.currentEpisodeStepCount += steps;
boolean isEpisodeComplete = thread.getCurrentEpisodeStepCount() % maxStepsPerEpisode == 0;
return new AsyncThread.SubEpochReturn(steps, mockObs, 0.0, 0.0, isEpisodeComplete);
});
}
@Test
public void when_onNewEpochReturnsStop_expect_threadStopped() {
public void when_episodeComplete_expect_neuralNetworkReset() {
// Arrange
int stopAfterNumCalls = 1;
TestContext context = new TestContext(100000);
context.listener.setRemainingOnNewEpochCallCount(stopAfterNumCalls);
mockTrainingContext(100, 10, 10);
mockTrainingListeners();
// Act
context.sut.run();
thread.run();
// Assert
assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: The call that returns stop is counted
assertEquals(stopAfterNumCalls, context.listener.onEpochTrainingResultCallCount);
verify(mockNeuralNet, times(10)).reset(); // there are 10 episodes so the network should be reset between each
assertEquals(10, thread.getEpochCount()); // We are performing a training iteration every 10 steps, so there should be 10 epochs
assertEquals(10, thread.getEpisodeCount()); // There should be 10 completed episodes
assertEquals(100, thread.getStepCount()); // 100 steps overall
}
@Test
public void when_epochTrainingResultReturnsStop_expect_threadStopped() {
public void when_notifyNewEpochReturnsStop_expect_threadStopped() {
// Arrange
int stopAfterNumCalls = 1;
TestContext context = new TestContext(100000);
context.listener.setRemainingOnEpochTrainingResult(stopAfterNumCalls);
mockTrainingContext();
mockTrainingListeners(true, false);
// Act
context.sut.run();
thread.run();
// Assert
assertEquals(stopAfterNumCalls + 1, context.listener.onEpochTrainingResultCallCount); // +1: The call that returns stop is counted
assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: onNewEpoch is called on the epoch that onEpochTrainingResult() will stop
assertEquals(0, thread.getEpochCount());
assertEquals(1, thread.getEpisodeCount());
assertEquals(0, thread.getStepCount());
}
@Test
public void when_run_expect_preAndPostEpochCalled() {
public void when_notifyEpochTrainingResultReturnsStop_expect_threadStopped() {
// Arrange
int numberOfEpochs = 5;
TestContext context = new TestContext(numberOfEpochs);
mockTrainingContext();
mockTrainingListeners(false, true);
// Act
context.sut.run();
thread.run();
// Assert
assertEquals(numberOfEpochs, context.sut.preEpochCallCount);
assertEquals(numberOfEpochs, context.sut.postEpochCallCount);
assertEquals(1, thread.getEpochCount());
assertEquals(1, thread.getEpisodeCount());
assertEquals(10, thread.getStepCount()); // one epoch is by default 10 steps
}
@Test
public void when_run_expect_preAndPostEpisodeCalled() {
// Arrange
mockTrainingContext(100, 10, 5);
mockTrainingListeners(false, false);
// Act
thread.run();
// Assert
assertEquals(20, thread.getEpochCount());
assertEquals(10, thread.getEpisodeCount());
assertEquals(100, thread.getStepCount());
verify(thread, times(10)).preEpisode(); // over 100 steps there will be 10 episodes
verify(thread, times(10)).postEpisode();
}
@Test
public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() {
// Arrange
int numberOfEpochs = 5;
TestContext context = new TestContext(numberOfEpochs);
mockTrainingContext(100, 10, 5);
mockTrainingListeners(false, false);
// Act
context.sut.run();
thread.run();
// Assert
assertEquals(numberOfEpochs, context.listener.statEntries.size());
int[] expectedStepCounter = new int[] { 10, 20, 30, 40, 50 };
double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init
+ 1.0; // Reward from trainSubEpoch()
for(int i = 0; i < numberOfEpochs; ++i) {
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
assertEquals(i, statEntry.getEpochCounter());
assertEquals(expectedReward, statEntry.getReward(), 0.0001);
}
assertEquals(20, thread.getEpochCount());
assertEquals(10, thread.getEpisodeCount());
assertEquals(100, thread.getStepCount());
// Over 100 steps there will be 20 training iterations, so there will be 20 calls to notifyEpochTrainingResult
verify(mockTrainingListeners, times(20)).notifyEpochTrainingResult(eq(thread), any(IDataManager.StatEntry.class));
}
@Test
public void when_run_expect_trainSubEpochCalled() {
// Arrange
int numberOfEpochs = 5;
TestContext context = new TestContext(numberOfEpochs);
mockTrainingContext(100, 10, 5);
mockTrainingListeners(false, false);
// Act
context.sut.run();
thread.run();
// Assert
assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size());
double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 };
for(int i = 0; i < context.sut.trainSubEpochParams.size(); ++i) {
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
assertEquals(2, params.nstep);
assertEquals(expectedObservation.length, params.obs.getData().shape()[1]);
for(int j = 0; j < expectedObservation.length; ++j){
assertEquals(expectedObservation[j], 255.0 * params.obs.getData().getDouble(j), 0.00001);
}
}
assertEquals(20, thread.getEpochCount());
assertEquals(10, thread.getEpisodeCount());
assertEquals(100, thread.getStepCount());
// There should be 20 calls to trainsubepoch with 5 steps per epoch
verify(thread, times(20)).trainSubEpoch(any(Observation.class), eq(5));
}
private static class TestContext {
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockNeuralNet neuralNet = new MockNeuralNet();
public final MockObservationSpace observationSpace = new MockObservationSpace();
public final MockMDP mdp = new MockMDP(observationSpace);
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 10, 0, 0, 0, 0, 0, 0, 10, 0);
public final TrainingListenerList listeners = new TrainingListenerList();
public final MockTrainingListener listener = new MockTrainingListener();
public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf);
@Test
public void when_remainingEpisodeLengthSmallerThanNSteps_expect_trainSubEpochCalledWithMinimumValue() {
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
int currentEpisodeSteps = 95;
mockTrainingContext(1000, 100, 10);
mockTrainingListeners(false, true);
public TestContext(int numEpochs) {
asyncGlobal.setMaxLoops(numEpochs);
listeners.add(listener);
sut.setHistoryProcessor(historyProcessor);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
// want to mock that we are 95 steps into the episode
doAnswer(invocationOnMock -> {
for (int i = 0; i < currentEpisodeSteps; i++) {
thread.incrementSteps();
}
}
public static class MockAsyncThread extends AsyncThread<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
public int preEpochCallCount = 0;
public int postEpochCallCount = 0;
private final MockAsyncGlobal asyncGlobal;
private final MockNeuralNet neuralNet;
private final IAsyncLearningConfiguration conf;
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, IAsyncLearningConfiguration conf, TrainingListenerList listeners) {
super(asyncGlobal, mdp, listeners, threadNumber, 0);
this.asyncGlobal = asyncGlobal;
this.neuralNet = neuralNet;
this.conf = conf;
}
@Override
protected void preEpoch() {
++preEpochCallCount;
super.preEpoch();
}
@Override
protected void postEpoch() {
++postEpochCallCount;
super.postEpoch();
}
@Override
protected MockNeuralNet getCurrent() {
return neuralNet;
}
@Override
protected IAsyncGlobal getAsyncGlobal() {
return asyncGlobal;
}
@Override
protected IAsyncLearningConfiguration getConf() {
return conf;
}
@Override
protected Policy getPolicy(MockNeuralNet net) {
return null;
}).when(thread).preEpisode();
mockTrainingListeners(false, true);
// Act
thread.run();
// Assert
assertEquals(1, thread.getEpochCount());
assertEquals(1, thread.getEpisodeCount());
assertEquals(100, thread.getStepCount());
// There should be 1 call to trainsubepoch with 5 steps as this is the remaining episode steps
verify(thread, times(1)).trainSubEpoch(any(Observation.class), eq(5));
}
@Override
protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) {
asyncGlobal.increaseCurrentLoop();
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
for(int i = 0; i < nstep; ++i) {
incrementStep();
}
return new SubEpochReturn(nstep, null, 1.0, 1.0);
}
@AllArgsConstructor
@Getter
public static class TrainSubEpochParams {
Observation obs;
int nstep;
}
}
}

View File

@ -1,160 +0,0 @@
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
public class A3CUpdateAlgorithmTest {
@Test
public void refac_calcGradient_non_terminal() {
// Arrange
double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace(new int[] { 5 });
MockMDP mdpMock = new MockMDP(observationSpace);
MockActorCritic actorCriticMock = new MockActorCritic();
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
A3CUpdateAlgorithm sut = new A3CUpdateAlgorithm(asyncGlobalMock, observationSpace.getShape(), mdpMock.getActionSpace().getSize(), -1, gamma);
INDArray[] originalObservations = new INDArray[] {
Nd4j.create(new double[] { 0.0, 0.1, 0.2, 0.3, 0.4 }),
Nd4j.create(new double[] { 1.0, 1.1, 1.2, 1.3, 1.4 }),
Nd4j.create(new double[] { 2.0, 2.1, 2.2, 2.3, 2.4 }),
Nd4j.create(new double[] { 3.0, 3.1, 3.2, 3.3, 3.4 }),
};
int[] actions = new int[] { 0, 1, 2, 1 };
double[] rewards = new double[] { 0.1, 1.0, 10.0, 100.0 };
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>();
for(int i = 0; i < originalObservations.length; ++i) {
experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false));
}
// Act
sut.computeGradients(actorCriticMock, experience);
// Assert
assertEquals(1, actorCriticMock.gradientParams.size());
// Inputs
INDArray input = actorCriticMock.gradientParams.get(0).getLeft();
for(int i = 0; i < 4; ++i) {
for(int j = 0; j < 5; ++j) {
assertEquals(i + j / 10.0, input.getDouble(i, j), 0.00001);
}
}
INDArray targets = actorCriticMock.gradientParams.get(0).getRight()[0];
INDArray logSoftmax = actorCriticMock.gradientParams.get(0).getRight()[1];
assertEquals(4, targets.shape()[0]);
assertEquals(1, targets.shape()[1]);
// FIXME: check targets values once fixed
assertEquals(4, logSoftmax.shape()[0]);
assertEquals(5, logSoftmax.shape()[1]);
// FIXME: check logSoftmax values once fixed
}
public class MockActorCritic implements IActorCritic {
public final List<Pair<INDArray, INDArray[]>> gradientParams = new ArrayList<>();
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}
@Override
public boolean isRecurrent() {
return false;
}
@Override
public void reset() {
}
@Override
public void fit(INDArray input, INDArray[] labels) {
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[] { batch.mul(-1.0) };
}
@Override
public IActorCritic clone() {
return this;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public void copy(IActorCritic from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] labels) {
gradientParams.add(new Pair<INDArray, INDArray[]>(input, labels));
return new Gradient[0];
}
@Override
public void applyGradient(Gradient[] gradient, int batchSize) {
}
@Override
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
}
@Override
public void save(String pathValue, String pathPolicy) throws IOException {
}
@Override
public double getLatestScore() {
return 0;
}
@Override
public void save(OutputStream os) throws IOException {
}
@Override
public void save(String filename) throws IOException {
}
}
}

View File

@ -0,0 +1,93 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class AdvantageActorCriticUpdateAlgorithmTest {
@Mock
AsyncGlobal<NeuralNet> mockAsyncGlobal;
@Mock
IActorCritic mockActorCritic;
@Test
public void refac_calcGradient_non_terminal() {
// Arrange
int[] observationShape = new int[]{5};
double gamma = 0.9;
AdvantageActorCriticUpdateAlgorithm algorithm = new AdvantageActorCriticUpdateAlgorithm(false, observationShape, 1, gamma);
INDArray[] originalObservations = new INDArray[]{
Nd4j.create(new double[]{0.0, 0.1, 0.2, 0.3, 0.4}),
Nd4j.create(new double[]{1.0, 1.1, 1.2, 1.3, 1.4}),
Nd4j.create(new double[]{2.0, 2.1, 2.2, 2.3, 2.4}),
Nd4j.create(new double[]{3.0, 3.1, 3.2, 3.3, 3.4}),
};
int[] actions = new int[]{0, 1, 2, 1};
double[] rewards = new double[]{0.1, 1.0, 10.0, 100.0};
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>();
for (int i = 0; i < originalObservations.length; ++i) {
experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false));
}
when(mockActorCritic.outputAll(any(INDArray.class))).thenAnswer(invocation -> {
INDArray batch = invocation.getArgument(0);
return new INDArray[]{batch.mul(-1.0)};
});
ArgumentCaptor<INDArray> inputArgumentCaptor = ArgumentCaptor.forClass(INDArray.class);
ArgumentCaptor<INDArray[]> criticActorArgumentCaptor = ArgumentCaptor.forClass(INDArray[].class);
// Act
algorithm.computeGradients(mockActorCritic, experience);
verify(mockActorCritic, times(1)).gradient(inputArgumentCaptor.capture(), criticActorArgumentCaptor.capture());
assertEquals(Nd4j.stack(0, originalObservations), inputArgumentCaptor.getValue());
//TODO: the actual AdvantageActorCritic Algo is not implemented correctly, so needs to be fixed, then we can test these
// assertEquals(Nd4j.zeros(1), criticActorArgumentCaptor.getValue()[0]);
// assertEquals(Nd4j.zeros(1), criticActorArgumentCaptor.getValue()[1]);
}
}

View File

@ -1,3 +1,19 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.listener;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;

View File

@ -1,11 +1,30 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockDQN;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -14,17 +33,21 @@ import java.util.List;
import static org.junit.Assert.assertEquals;
@RunWith(MockitoJUnitRunner.class)
public class QLearningUpdateAlgorithmTest {
@Mock
AsyncGlobal mockAsyncGlobal;
@Test
public void when_isTerminal_expect_initRewardIs0() {
// Arrange
MockDQN dqnMock = new MockDQN();
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 1 }, 1, -1, 1.0);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 1 }, 1, 1.0);
final Observation observation = new Observation(Nd4j.zeros(1));
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(new Observation(Nd4j.zeros(1)), 0, 0.0, true));
add(new StateActionPair<Integer>(observation, 0, 0.0, true));
}
};
@ -38,12 +61,11 @@ public class QLearningUpdateAlgorithmTest {
@Test
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
// Arrange
MockDQN globalDQNMock = new MockDQN();
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, -1, 1.0);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, 1.0);
final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }));
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false));
add(new StateActionPair<Integer>(observation, 0, 0.0, false));
}
};
MockDQN dqnMock = new MockDQN();
@ -57,35 +79,11 @@ public class QLearningUpdateAlgorithmTest {
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
}
@Test
public void when_terminalWithTargetUpdate_expect_initRewardWithMaxQFromGlobal() {
// Arrange
MockDQN globalDQNMock = new MockDQN();
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, 1.0);
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false));
}
};
MockDQN dqnMock = new MockDQN();
// Act
sut.computeGradients(dqnMock, experience);
// Assert
assertEquals(1, globalDQNMock.outputAllParams.size());
assertEquals(-123.0, globalDQNMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
}
@Test
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
// Arrange
double gamma = 0.9;
MockDQN globalDQNMock = new MockDQN();
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, gamma);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, gamma);
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false));

View File

@ -1,20 +1,56 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.listener;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import org.mockito.Mock;
import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class TrainingListenerListTest {
@Mock
IEpochTrainer mockTrainer;
@Mock
ILearning mockLearning;
@Mock
IDataManager.StatEntry mockStatEntry;
@Test
public void when_listIsEmpty_expect_notifyReturnTrue() {
// Arrange
TrainingListenerList sut = new TrainingListenerList();
TrainingListenerList trainingListenerList = new TrainingListenerList();
// Act
boolean resultTrainingStarted = sut.notifyTrainingStarted();
boolean resultNewEpoch = sut.notifyNewEpoch(null);
boolean resultEpochFinished = sut.notifyEpochTrainingResult(null, null);
boolean resultTrainingStarted = trainingListenerList.notifyTrainingStarted();
boolean resultNewEpoch = trainingListenerList.notifyNewEpoch(null);
boolean resultEpochFinished = trainingListenerList.notifyEpochTrainingResult(null, null);
// Assert
assertTrue(resultTrainingStarted);
@ -25,54 +61,56 @@ public class TrainingListenerListTest {
@Test
public void when_firstListerStops_expect_othersListnersNotCalled() {
// Arrange
MockTrainingListener listener1 = new MockTrainingListener();
listener1.setRemainingTrainingStartCallCount(0);
listener1.setRemainingOnNewEpochCallCount(0);
listener1.setRemainingonTrainingProgressCallCount(0);
listener1.setRemainingOnEpochTrainingResult(0);
MockTrainingListener listener2 = new MockTrainingListener();
TrainingListenerList sut = new TrainingListenerList();
sut.add(listener1);
sut.add(listener2);
TrainingListener listener1 = mock(TrainingListener.class);
TrainingListener listener2 = mock(TrainingListener.class);
TrainingListenerList trainingListenerList = new TrainingListenerList();
trainingListenerList.add(listener1);
trainingListenerList.add(listener2);
when(listener1.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP);
when(listener1.onNewEpoch(eq(mockTrainer))).thenReturn(TrainingListener.ListenerResponse.STOP);
when(listener1.onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry))).thenReturn(TrainingListener.ListenerResponse.STOP);
when(listener1.onTrainingProgress(eq(mockLearning))).thenReturn(TrainingListener.ListenerResponse.STOP);
// Act
sut.notifyTrainingStarted();
sut.notifyNewEpoch(null);
sut.notifyEpochTrainingResult(null, null);
sut.notifyTrainingProgress(null);
sut.notifyTrainingFinished();
trainingListenerList.notifyTrainingStarted();
trainingListenerList.notifyNewEpoch(mockTrainer);
trainingListenerList.notifyEpochTrainingResult(mockTrainer, null);
trainingListenerList.notifyTrainingProgress(mockLearning);
trainingListenerList.notifyTrainingFinished();
// Assert
assertEquals(1, listener1.onTrainingStartCallCount);
assertEquals(0, listener2.onTrainingStartCallCount);
assertEquals(1, listener1.onNewEpochCallCount);
assertEquals(0, listener2.onNewEpochCallCount);
verify(listener1, times(1)).onTrainingStart();
verify(listener2, never()).onTrainingStart();
assertEquals(1, listener1.onEpochTrainingResultCallCount);
assertEquals(0, listener2.onEpochTrainingResultCallCount);
verify(listener1, times(1)).onNewEpoch(eq(mockTrainer));
verify(listener2, never()).onNewEpoch(eq(mockTrainer));
assertEquals(1, listener1.onTrainingProgressCallCount);
assertEquals(0, listener2.onTrainingProgressCallCount);
verify(listener1, times(1)).onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry));
verify(listener2, never()).onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry));
assertEquals(1, listener1.onTrainingEndCallCount);
assertEquals(1, listener2.onTrainingEndCallCount);
verify(listener1, times(1)).onTrainingProgress(eq(mockLearning));
verify(listener2, never()).onTrainingProgress(eq(mockLearning));
verify(listener1, times(1)).onTrainingEnd();
verify(listener2, times(1)).onTrainingEnd();
}
@Test
public void when_allListenersContinue_expect_listReturnsTrue() {
// Arrange
MockTrainingListener listener1 = new MockTrainingListener();
MockTrainingListener listener2 = new MockTrainingListener();
TrainingListenerList sut = new TrainingListenerList();
sut.add(listener1);
sut.add(listener2);
TrainingListener listener1 = mock(TrainingListener.class);
TrainingListener listener2 = mock(TrainingListener.class);
TrainingListenerList trainingListenerList = new TrainingListenerList();
trainingListenerList.add(listener1);
trainingListenerList.add(listener2);
// Act
boolean resultTrainingStarted = sut.notifyTrainingStarted();
boolean resultNewEpoch = sut.notifyNewEpoch(null);
boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null);
boolean resultProgress = sut.notifyTrainingProgress(null);
boolean resultTrainingStarted = trainingListenerList.notifyTrainingStarted();
boolean resultNewEpoch = trainingListenerList.notifyNewEpoch(null);
boolean resultEpochTrainingResult = trainingListenerList.notifyEpochTrainingResult(null, null);
boolean resultProgress = trainingListenerList.notifyTrainingProgress(null);
// Assert
assertTrue(resultTrainingStarted);

View File

@ -1,151 +1,117 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.LearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class SyncLearningTest {
@Mock
TrainingListener mockTrainingListener;
SyncLearning<Box, INDArray, ActionSpace<INDArray>, NeuralNet> syncLearning;
@Mock
ILearningConfiguration mockLearningConfiguration;
@Before
public void setup() {
syncLearning = mock(SyncLearning.class, Mockito.withSettings()
.useConstructor()
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
syncLearning.addListener(mockTrainingListener);
when(syncLearning.trainEpoch()).thenAnswer(invocation -> {
//syncLearning.incrementEpoch();
syncLearning.incrementStep();
return new MockStatEntry(syncLearning.getEpochCount(), syncLearning.getStepCount(), 1.0);
});
when(syncLearning.getConfiguration()).thenReturn(mockLearningConfiguration);
when(mockLearningConfiguration.getMaxStep()).thenReturn(100);
}
@Test
public void when_training_expect_listenersToBeCalled() {
// Arrange
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
// Act
sut.train();
syncLearning.train();
verify(mockTrainingListener, times(1)).onTrainingStart();
verify(mockTrainingListener, times(100)).onNewEpoch(eq(syncLearning));
verify(mockTrainingListener, times(100)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
verify(mockTrainingListener, times(1)).onTrainingEnd();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(10, listener.onNewEpochCallCount);
assertEquals(10, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
}
@Test
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
// Arrange
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
listener.setRemainingTrainingStartCallCount(0);
when(mockTrainingListener.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP);
// Act
sut.train();
syncLearning.train();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(0, listener.onNewEpochCallCount);
assertEquals(0, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
verify(mockTrainingListener, times(1)).onTrainingStart();
verify(mockTrainingListener, times(0)).onNewEpoch(eq(syncLearning));
verify(mockTrainingListener, times(0)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
verify(mockTrainingListener, times(1)).onTrainingEnd();
}
@Test
public void when_newEpochCanContinueFalse_expect_trainingStopped() {
// Arrange
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
listener.setRemainingOnNewEpochCallCount(2);
when(mockTrainingListener.onNewEpoch(eq(syncLearning)))
.thenReturn(TrainingListener.ListenerResponse.CONTINUE)
.thenReturn(TrainingListener.ListenerResponse.CONTINUE)
.thenReturn(TrainingListener.ListenerResponse.STOP);
// Act
sut.train();
syncLearning.train();
verify(mockTrainingListener, times(1)).onTrainingStart();
verify(mockTrainingListener, times(3)).onNewEpoch(eq(syncLearning));
verify(mockTrainingListener, times(2)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
verify(mockTrainingListener, times(1)).onTrainingEnd();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(3, listener.onNewEpochCallCount);
assertEquals(2, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
}
@Test
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
// Arrange
LearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
listener.setRemainingOnEpochTrainingResult(2);
when(mockTrainingListener.onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class)))
.thenReturn(TrainingListener.ListenerResponse.CONTINUE)
.thenReturn(TrainingListener.ListenerResponse.CONTINUE)
.thenReturn(TrainingListener.ListenerResponse.STOP);
// Act
sut.train();
syncLearning.train();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(3, listener.onNewEpochCallCount);
assertEquals(3, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
}
public static class MockSyncLearning extends SyncLearning {
private final ILearningConfiguration conf;
@Getter
private int currentEpochStep = 0;
public MockSyncLearning(ILearningConfiguration conf) {
this.conf = conf;
}
@Override
protected void preEpoch() { currentEpochStep = 0; }
@Override
protected void postEpoch() { }
@Override
protected IDataManager.StatEntry trainEpoch() {
setStepCounter(getStepCounter() + 1);
return new MockStatEntry(getCurrentEpochStep(), getStepCounter(), 1.0);
}
@Override
public NeuralNet getNeuralNet() {
return null;
}
@Override
public IPolicy getPolicy() {
return null;
}
@Override
public ILearningConfiguration getConfiguration() {
return conf;
}
@Override
public MDP getMdp() {
return null;
}
verify(mockTrainingListener, times(1)).onTrainingStart();
verify(mockTrainingListener, times(3)).onNewEpoch(eq(syncLearning));
verify(mockTrainingListener, times(3)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
verify(mockTrainingListener, times(1)).onTrainingEnd();
}
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
@ -26,11 +27,21 @@ import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.api.rng.Random;
@ -40,150 +51,146 @@ import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class QLearningDiscreteTest {
QLearningDiscrete<Encodable> qLearningDiscrete;
@Mock
IHistoryProcessor mockHistoryProcessor;
@Mock
IHistoryProcessor.Configuration mockHistoryConfiguration;
@Mock
MDP<Encodable, Integer, DiscreteSpace> mockMDP;
@Mock
DiscreteSpace mockActionSpace;
@Mock
ObservationSpace<Encodable> mockObservationSpace;
@Mock
IDQN mockDQN;
@Mock
QLearningConfiguration mockQlearningConfiguration;
int[] observationShape = new int[]{3, 10, 10};
int totalObservationSize = 1;
private void setupMDPMocks() {
when(mockObservationSpace.getShape()).thenReturn(observationShape);
when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace);
when(mockMDP.getActionSpace()).thenReturn(mockActionSpace);
int dataLength = 1;
for (int d : observationShape) {
dataLength *= d;
}
}
private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay) {
when(mockQlearningConfiguration.getBatchSize()).thenReturn(batchSize);
when(mockQlearningConfiguration.getRewardFactor()).thenReturn(rewardFactor);
when(mockQlearningConfiguration.getExpRepMaxSize()).thenReturn(maxExperienceReplay);
when(mockQlearningConfiguration.getSeed()).thenReturn(123L);
qLearningDiscrete = mock(
QLearningDiscrete.class,
Mockito.withSettings()
.useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0)
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
);
}
private void mockHistoryProcessor(int skipFrames) {
when(mockHistoryConfiguration.getRescaledHeight()).thenReturn(observationShape[1]);
when(mockHistoryConfiguration.getRescaledWidth()).thenReturn(observationShape[2]);
when(mockHistoryConfiguration.getOffsetX()).thenReturn(0);
when(mockHistoryConfiguration.getOffsetY()).thenReturn(0);
when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]);
when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]);
when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames);
when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration);
qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor);
}
@Before
public void setup() {
setupMDPMocks();
for (int i : observationShape) {
totalObservationSize *= i;
}
}
@Test
public void refac_QLearningDiscrete_trainStep() {
public void when_singleTrainStep_expect_correctValues() {
// Arrange
MockObservationSpace observationSpace = new MockObservationSpace();
MockDQN dqn = new MockDQN();
MockRandom random = new MockRandom(new double[]{
0.7309677600860596,
0.8314409852027893,
0.2405363917350769,
0.6063451766967773,
0.6374173760414124,
0.3090505599975586,
0.5504369735717773,
0.11700659990310669
},
new int[]{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4});
MockMDP mdp = new MockMDP(observationSpace, random);
mockTestContext(100,0,2,1.0, 10);
int initStepCount = 8;
// An example observation and 2 Q values output (2 actions)
Observation observation = new Observation(Nd4j.zeros(observationShape));
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
QLearningConfiguration conf = QLearningConfiguration.builder()
.seed(0L)
.maxEpochStep(24)
.maxStep(0)
.expRepMaxSize(5).batchSize(1).targetDqnUpdateFreq(1000)
.updateStart(initStepCount)
.rewardFactor(1.0)
.gamma(0)
.errorClamp(0)
.minEpsilon(0)
.epsilonNbStep(0)
.doubleDQN(true)
.build();
MockDataManager dataManager = new MockDataManager(false);
MockExperienceHandler experienceHandler = new MockExperienceHandler();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, experienceHandler, 10, random);
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hp);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
// Act
IDataManager.StatEntry result = sut.trainEpoch();
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
// Assert
// HistoryProcessor calls
double[] expectedRecords = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
assertEquals(expectedRecords.length, hp.recordCalls.size());
for (int i = 0; i < expectedRecords.length; ++i) {
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
}
assertEquals(0, hp.startMonitorCallCount);
assertEquals(0, hp.stopMonitorCallCount);
assertEquals(1.0, stepReturn.getMaxQ(), 1e-5);
// DQN calls
assertEquals(1, dqn.fitParams.size());
assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001);
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
assertEquals(14, dqn.outputParams.size());
double[][] expectedDQNOutput = new double[][]{
new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
};
for (int i = 0; i < expectedDQNOutput.length; ++i) {
INDArray outputParam = dqn.outputParams.get(i);
StepReply<Observation> stepReply = stepReturn.getStepReply();
assertEquals(5, outputParam.shape()[1]);
assertEquals(1, outputParam.shape()[2]);
assertEquals(0, stepReply.getReward(), 1e-5);
assertFalse(stepReply.isDone());
assertFalse(stepReply.getObservation().isSkipped());
assertEquals(observation.getData().reshape(observationShape), stepReply.getObservation().getData().reshape(observationShape));
double[] expectedRow = expectedDQNOutput[i];
for (int j = 0; j < expectedRow.length; ++j) {
assertEquals("row: " + i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
}
}
// MDP calls
assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray());
@Test
public void when_singleTrainStepSkippedFrames_expect_correctValues() {
// Arrange
mockTestContext(100,0,2,1.0, 10);
// ExperienceHandler calls
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
double[][] expectedTrObservations = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
};
mockHistoryProcessor(2);
assertEquals(expectedTrObservations.length, experienceHandler.addExperienceArgs.size());
for(int i = 0; i < expectedTrRewards.length; ++i) {
StateActionPair<Integer> stateActionPair = experienceHandler.addExperienceArgs.get(i);
assertEquals(expectedTrRewards[i], stateActionPair.getReward(), 0.0001);
assertEquals((int)expectedTrActions[i], (int)stateActionPair.getAction());
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * stateActionPair.getObservation().getData().getDouble(0, j, 0), 0.0001);
}
}
assertEquals(expectedTrNextObservation[expectedTrNextObservation.length - 1], 255.0 * experienceHandler.finalObservation.getData().getDouble(0), 0.0001);
// An example observation and 2 Q values output (2 actions)
Observation observation = new Observation(Nd4j.zeros(observationShape));
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
// trainEpoch result
assertEquals(initStepCount + 16, result.getStepCounter());
assertEquals(300.0, result.getReward(), 0.00001);
assertTrue(dqn.hasBeenReset);
assertTrue(((MockDQN) sut.getTargetQNetwork()).hasBeenReset);
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
// Act
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
// Assert
assertEquals(1.0, stepReturn.getMaxQ(), 1e-5);
StepReply<Observation> stepReply = stepReturn.getStepReply();
assertEquals(0, stepReply.getReward(), 1e-5);
assertFalse(stepReply.isDone());
assertTrue(stepReply.getObservation().isSkipped());
}
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
QLearningConfiguration conf, IDataManager dataManager, ExperienceHandler<Integer, Transition<Integer>> experienceHandler,
int epsilonNbStep, Random rnd) {
super(mdp, dqn, conf, epsilonNbStep, rnd);
addListener(new DataManagerTrainingListener(dataManager));
setExperienceHandler(experienceHandler);
}
//TODO: there are much more test cases here that can be improved upon
@Override
protected DataSet setTarget(List<Transition<Integer>> transitions) {
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
}
@Override
public IDataManager.StatEntry trainEpoch() {
return super.trainEpoch();
}
}
}

View File

@ -1,79 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray;
public class MockMDP implements MDP<Object, Integer, DiscreteSpace> {
private final int maxSteps;
private final DiscreteSpace actionSpace = new DiscreteSpace(1);
private final MockObservationSpace observationSpace = new MockObservationSpace();
private int currentStep = 0;
public MockMDP(int maxSteps) {
this.maxSteps = maxSteps;
}
@Override
public ObservationSpace<Object> getObservationSpace() {
return observationSpace;
}
@Override
public DiscreteSpace getActionSpace() {
return actionSpace;
}
@Override
public Object reset() {
return null;
}
@Override
public void close() {
}
@Override
public StepReply<Object> step(Integer integer) {
return new StepReply<Object>(null, 1.0, isDone(), null);
}
@Override
public boolean isDone() {
return currentStep >= maxSteps;
}
@Override
public MDP<Object, Integer, DiscreteSpace> newInstance() {
return null;
}
private static class MockObservationSpace implements ObservationSpace {
@Override
public String getName() {
return null;
}
@Override
public int[] getShape() {
return new int[0];
}
@Override
public INDArray getLow() {
return null;
}
@Override
public INDArray getHigh() {
return null;
}
}
}

View File

@ -257,9 +257,9 @@ public class PolicyTest {
}
@Override
protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
return super.refacInitMdp(mdpWrapper, hp, epochStepCounter);
return super.refacInitMdp(mdpWrapper, hp);
}
}
}

View File

@ -1,37 +0,0 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.support;
import lombok.AllArgsConstructor;
import lombok.Value;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
@Value
@AllArgsConstructor
public class MockAsyncConfiguration implements IAsyncLearningConfiguration {
private Long seed;
private int maxEpochStep;
private int maxStep;
private int updateStart;
private double rewardFactor;
private double gamma;
private double errorClamp;
private int numThreads;
private int nStep;
private int learnerUpdateFrequency;
}

View File

@ -1,75 +0,0 @@
package org.deeplearning4j.rl4j.support;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.network.NeuralNet;
import java.util.concurrent.atomic.AtomicInteger;
public class MockAsyncGlobal<NN extends NeuralNet> implements IAsyncGlobal<NN> {
@Getter
private final NN current;
public boolean hasBeenStarted = false;
public boolean hasBeenTerminated = false;
public int enqueueCallCount = 0;
@Setter
private int maxLoops;
@Setter
private int numLoopsStopRunning;
private int currentLoop = 0;
public MockAsyncGlobal() {
this(null);
}
public MockAsyncGlobal(NN current) {
maxLoops = Integer.MAX_VALUE;
numLoopsStopRunning = Integer.MAX_VALUE;
this.current = current;
}
@Override
public boolean isRunning() {
return currentLoop < numLoopsStopRunning;
}
@Override
public void terminate() {
hasBeenTerminated = true;
}
@Override
public boolean isTrainingComplete() {
return currentLoop >= maxLoops;
}
@Override
public void start() {
hasBeenStarted = true;
}
@Override
public AtomicInteger getT() {
return null;
}
@Override
public NN getTarget() {
return current;
}
@Override
public void enqueue(Gradient[] gradient, Integer nstep) {
++enqueueCallCount;
}
public void increaseCurrentLoop() {
++currentLoop;
}
}

View File

@ -1,46 +0,0 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.observation.Observation;
import java.util.ArrayList;
import java.util.List;
public class MockExperienceHandler implements ExperienceHandler<Integer, Transition<Integer>> {
public List<StateActionPair<Integer>> addExperienceArgs = new ArrayList<StateActionPair<Integer>>();
public Observation finalObservation;
public boolean isGenerateTrainingBatchCalled;
public boolean isResetCalled;
@Override
public void addExperience(Observation observation, Integer action, double reward, boolean isTerminal) {
addExperienceArgs.add(new StateActionPair<>(observation, action, reward, isTerminal));
}
@Override
public void setFinalObservation(Observation observation) {
finalObservation = observation;
}
@Override
public List<Transition<Integer>> generateTrainingBatch() {
isGenerateTrainingBatchCalled = true;
return new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(null, 0, 0.0, false));
}
};
}
@Override
public void reset() {
isResetCalled = true;
}
@Override
public int getTrainingBatchSize() {
return 1;
}
}

View File

@ -1,77 +0,0 @@
package org.deeplearning4j.rl4j.support;
import lombok.Setter;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.util.IDataManager;
import java.util.ArrayList;
import java.util.List;
public class MockTrainingListener implements TrainingListener {
private final MockAsyncGlobal asyncGlobal;
public int onTrainingStartCallCount = 0;
public int onTrainingEndCallCount = 0;
public int onNewEpochCallCount = 0;
public int onEpochTrainingResultCallCount = 0;
public int onTrainingProgressCallCount = 0;
@Setter
private int remainingTrainingStartCallCount = Integer.MAX_VALUE;
@Setter
private int remainingOnNewEpochCallCount = Integer.MAX_VALUE;
@Setter
private int remainingOnEpochTrainingResult = Integer.MAX_VALUE;
@Setter
private int remainingonTrainingProgressCallCount = Integer.MAX_VALUE;
public final List<IDataManager.StatEntry> statEntries = new ArrayList<>();
public MockTrainingListener() {
this(null);
}
public MockTrainingListener(MockAsyncGlobal asyncGlobal) {
this.asyncGlobal = asyncGlobal;
}
@Override
public ListenerResponse onTrainingStart() {
++onTrainingStartCallCount;
--remainingTrainingStartCallCount;
return remainingTrainingStartCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
++onNewEpochCallCount;
--remainingOnNewEpochCallCount;
return remainingOnNewEpochCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
++onEpochTrainingResultCallCount;
--remainingOnEpochTrainingResult;
statEntries.add(statEntry);
return remainingOnEpochTrainingResult < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onTrainingProgress(ILearning learning) {
++onTrainingProgressCallCount;
--remainingonTrainingProgressCallCount;
if(asyncGlobal != null) {
asyncGlobal.increaseCurrentLoop();
}
return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public void onTrainingEnd() {
++onTrainingEndCallCount;
}
}

View File

@ -1,19 +0,0 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import java.util.ArrayList;
import java.util.List;
public class MockUpdateAlgorithm implements UpdateAlgorithm<MockNeuralNet> {
public final List<List<StateActionPair<Integer>>> experiences = new ArrayList<List<StateActionPair<Integer>>>();
@Override
public Gradient[] computeGradients(MockNeuralNet current, List<StateActionPair<Integer>> experience) {
experiences.add(experience);
return new Gradient[0];
}
}

View File

@ -151,20 +151,26 @@ public class DataManagerTrainingListenerTest {
private static class TestTrainer implements IEpochTrainer, ILearning
{
@Override
public int getStepCounter() {
public int getStepCount() {
return 0;
}
@Override
public int getEpochCounter() {
public int getEpochCount() {
return 0;
}
@Override
public int getCurrentEpochStep() {
public int getEpisodeCount() {
return 0;
}
@Override
public int getCurrentEpisodeStepCount() {
return 0;
}
@Getter
@Setter
private IHistoryProcessor historyProcessor;