RL4J: Sanitize async learner (#327)

* refactoring global async to use a much simpler update procedure with a single global lock

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* simplification of async learning algorithms, stabilization + better hyperparameters

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* started to play with using mockito for tests

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* Working on refactoring tests for async classes and trying to make async simpler

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* more work on mockito tests and making some tests much less complex and more explicit in what they are testing

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* some fixes from merging

* do not allow copying of the current network to worker threads, fixing debug line

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* adding some more tests around PR review

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* Adding more tests after review comments

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* few more tests and fixes from PR review

* remove rename of maxEpochStep to maxStepsPerEpisode as we agreed to review this in a seperate PR

* 2019 instead of 2018 on copyright header

* adding konduit copyright to files

* some more copyright headers

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

Co-authored-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Chris Bamford 2020-04-20 03:21:01 +01:00 committed by GitHub
parent 455a5d112d
commit 74420bca31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 1423 additions and 1712 deletions

3
.gitignore vendored
View File

@ -70,3 +70,6 @@ venv2/
# Ignore the nd4j files that are created by javacpp at build to stop merge conflicts # Ignore the nd4j files that are created by javacpp at build to stop merge conflicts
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
# Ignore meld temp files
*.orig

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.rl4j.mdp;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
/** /**
@ -31,20 +32,20 @@ import org.deeplearning4j.rl4j.space.ObservationSpace;
* in a "functionnal manner" if step return a mdp * in a "functionnal manner" if step return a mdp
* *
*/ */
public interface MDP<O, A, AS extends ActionSpace<A>> { public interface MDP<OBSERVATION, ACTION, ACTION_SPACE extends ActionSpace<ACTION>> {
ObservationSpace<O> getObservationSpace(); ObservationSpace<OBSERVATION> getObservationSpace();
AS getActionSpace(); ACTION_SPACE getActionSpace();
O reset(); OBSERVATION reset();
void close(); void close();
StepReply<O> step(A action); StepReply<OBSERVATION> step(ACTION action);
boolean isDone(); boolean isDone();
MDP<O, A, AS> newInstance(); MDP<OBSERVATION, ACTION, ACTION_SPACE> newInstance();
} }

View File

@ -17,24 +17,24 @@
package org.deeplearning4j.rl4j.space; package org.deeplearning4j.rl4j.space;
/** /**
* @param <A> the type of Action * @param <ACTION> the type of Action
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
* <p> * <p>
* Should contain contextual information about the Action space, which is the space of all the actions that could be available. * Should contain contextual information about the Action space, which is the space of all the actions that could be available.
* Also must know how to return a randomly uniformly sampled action. * Also must know how to return a randomly uniformly sampled action.
*/ */
public interface ActionSpace<A> { public interface ActionSpace<ACTION> {
/** /**
* @return A random action, * @return A random action,
*/ */
A randomAction(); ACTION randomAction();
Object encode(A action); Object encode(ACTION action);
int getSize(); int getSize();
A noOp(); ACTION noOp();
} }

View File

@ -121,6 +121,13 @@
<version>${datavec.version}</version> <version>${datavec.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.3.3</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -1,54 +1,54 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2020 Konduit K.K. * Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0. * https://www.apache.org/licenses/LICENSE-2.0.
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations * License for the specific language governing permissions and limitations
* under the License. * under the License.
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.rl4j.experience; package org.deeplearning4j.rl4j.experience;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import java.util.List; import java.util.List;
/** /**
* A common interface to all classes capable of handling experience generated by the agents in a learning context. * A common interface to all classes capable of handling experience generated by the agents in a learning context.
* *
* @param <A> Action type * @param <A> Action type
* @param <E> Experience type * @param <E> Experience type
* *
* @author Alexandre Boulanger * @author Alexandre Boulanger
*/ */
public interface ExperienceHandler<A, E> { public interface ExperienceHandler<A, E> {
void addExperience(Observation observation, A action, double reward, boolean isTerminal); void addExperience(Observation observation, A action, double reward, boolean isTerminal);
/** /**
* Called when the episode is done with the last observation * Called when the episode is done with the last observation
* @param observation * @param observation
*/ */
void setFinalObservation(Observation observation); void setFinalObservation(Observation observation);
/** /**
* @return The size of the list that will be returned by generateTrainingBatch(). * @return The size of the list that will be returned by generateTrainingBatch().
*/ */
int getTrainingBatchSize(); int getTrainingBatchSize();
/** /**
* The elements are returned in the historical order (i.e. in the order they happened) * The elements are returned in the historical order (i.e. in the order they happened)
* @return The list of experience elements * @return The list of experience elements
*/ */
List<E> generateTrainingBatch(); List<E> generateTrainingBatch();
/** /**
* Signal the experience handler that a new episode is starting * Signal the experience handler that a new episode is starting
*/ */
void reset(); void reset();
} }

View File

@ -30,7 +30,7 @@ import java.util.List;
*/ */
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> { public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
private List<StateActionPair<A>> stateActionPairs; private List<StateActionPair<A>> stateActionPairs = new ArrayList<>();
public void setFinalObservation(Observation observation) { public void setFinalObservation(Observation observation) {
// Do nothing // Do nothing

View File

@ -1,21 +0,0 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning;
public interface EpochStepCounter {
int getCurrentEpochStep();
}

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -28,9 +29,11 @@ import org.deeplearning4j.rl4j.mdp.MDP;
* @author Alexandre Boulanger * @author Alexandre Boulanger
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*/ */
public interface IEpochTrainer extends EpochStepCounter { public interface IEpochTrainer {
int getStepCounter(); int getStepCount();
int getEpochCounter(); int getEpochCount();
int getEpisodeCount();
int getCurrentEpisodeStepCount();
IHistoryProcessor getHistoryProcessor(); IHistoryProcessor getHistoryProcessor();
MDP getMdp(); MDP getMdp();
} }

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.learning;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data;
import lombok.Value; import lombok.Value;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -51,7 +52,7 @@ public interface IHistoryProcessor {
@AllArgsConstructor @AllArgsConstructor
@Builder @Builder
@Value @Data
public static class Configuration { public static class Configuration {
@Builder.Default int historyLength = 4; @Builder.Default int historyLength = 4;
@Builder.Default int rescaledWidth = 84; @Builder.Default int rescaledWidth = 84;

View File

@ -21,19 +21,20 @@ import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16.
* *
* A common interface that any training method should implement * A common interface that any training method should implement
*/ */
public interface ILearning<O, A, AS extends ActionSpace<A>> { public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
IPolicy<O, A> getPolicy(); IPolicy<O, A> getPolicy();
void train(); void train();
int getStepCounter(); int getStepCount();
ILearningConfiguration getConfiguration(); ILearningConfiguration getConfiguration();

View File

@ -38,13 +38,13 @@ import org.nd4j.linalg.factory.Nd4j;
* *
*/ */
@Slf4j @Slf4j
public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends NeuralNet> public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> { implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
@Getter @Setter @Getter @Setter
private int stepCounter = 0; protected int stepCount = 0;
@Getter @Setter @Getter @Setter
private int epochCounter = 0; private int epochCount = 0;
@Getter @Setter @Getter @Setter
private IHistoryProcessor historyProcessor = null; private IHistoryProcessor historyProcessor = null;
@ -73,11 +73,11 @@ public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends Neura
public abstract NN getNeuralNet(); public abstract NN getNeuralNet();
public void incrementStep() { public void incrementStep() {
stepCounter++; stepCount++;
} }
public void incrementEpoch() { public void incrementEpoch() {
epochCounter++; epochCount++;
} }
public void setHistoryProcessor(HistoryProcessor.Configuration conf) { public void setHistoryProcessor(HistoryProcessor.Configuration conf) {

View File

@ -20,13 +20,11 @@ package org.deeplearning4j.rl4j.learning.async;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.nd4j.linalg.primitives.Pair;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.locks.Lock;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantLock;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -52,69 +50,75 @@ import java.util.concurrent.atomic.AtomicInteger;
* structure * structure
*/ */
@Slf4j @Slf4j
public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> { public class AsyncGlobal<NN extends NeuralNet> implements IAsyncGlobal<NN> {
@Getter
final private NN current; final private NN current;
final private ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue;
final private IAsyncLearningConfiguration configuration;
private final IAsyncLearning learning;
@Getter
private AtomicInteger T = new AtomicInteger(0);
@Getter
private NN target;
@Getter
private boolean running = true;
public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration, IAsyncLearning learning) { private NN target;
final private IAsyncLearningConfiguration configuration;
@Getter
private final Lock updateLock;
/**
* The number of times the gradient has been updated by worker threads
*/
@Getter
private int workerUpdateCount;
@Getter
private int stepCount;
public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration) {
this.current = initial; this.current = initial;
target = (NN) initial.clone(); target = (NN) initial.clone();
this.configuration = configuration; this.configuration = configuration;
this.learning = learning;
queue = new ConcurrentLinkedQueue<>(); // This is used to sync between
updateLock = new ReentrantLock();
} }
public boolean isTrainingComplete() { public boolean isTrainingComplete() {
return T.get() >= configuration.getMaxStep(); return stepCount >= configuration.getMaxStep();
} }
public void enqueue(Gradient[] gradient, Integer nstep) { public void applyGradient(Gradient[] gradient, int nstep) {
if (running && !isTrainingComplete()) {
queue.add(new Pair<>(gradient, nstep)); if (isTrainingComplete()) {
return;
} }
try {
updateLock.lock();
current.applyGradient(gradient, nstep);
stepCount += nstep;
workerUpdateCount++;
int targetUpdateFrequency = configuration.getLearnerUpdateFrequency();
// If we have a target update frequency, this means we only want to update the workers after a certain number of async updates
// This can lead to more stable training
if (targetUpdateFrequency != -1 && workerUpdateCount % targetUpdateFrequency == 0) {
log.info("Updating target network at updates={} steps={}", workerUpdateCount, stepCount);
} else {
target.copy(current);
}
} finally {
updateLock.unlock();
}
} }
@Override @Override
public void run() { public NN getTarget() {
try {
while (!isTrainingComplete() && running) { updateLock.lock();
if (!queue.isEmpty()) { return target;
Pair<Gradient[], Integer> pair = queue.poll(); } finally {
T.addAndGet(pair.getSecond()); updateLock.unlock();
Gradient[] gradient = pair.getFirst();
synchronized (this) {
current.applyGradient(gradient, pair.getSecond());
}
if (configuration.getLearnerUpdateFrequency() != -1 && T.get() / configuration.getLearnerUpdateFrequency() > (T.get() - pair.getSecond())
/ configuration.getLearnerUpdateFrequency()) {
log.info("TARGET UPDATE at T = " + T.get());
synchronized (this) {
target.copy(current);
}
}
}
}
}
/**
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too.
*/
public void terminate() {
if (running) {
running = false;
queue.clear();
learning.terminate();
} }
} }

View File

@ -40,9 +40,9 @@ import org.nd4j.linalg.factory.Nd4j;
* @author Alexandre Boulanger * @author Alexandre Boulanger
*/ */
@Slf4j @Slf4j
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> public abstract class AsyncLearning<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>, NN extends NeuralNet>
extends Learning<O, A, AS, NN> extends Learning<OBSERVATION, ACTION, ACTION_SPACE, NN>
implements IAsyncLearning { implements IAsyncLearning {
private Thread monitorThread = null; private Thread monitorThread = null;
@ -69,10 +69,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
protected abstract IAsyncGlobal<NN> getAsyncGlobal(); protected abstract IAsyncGlobal<NN> getAsyncGlobal();
protected void startGlobalThread() {
getAsyncGlobal().start();
}
protected boolean isTrainingComplete() { protected boolean isTrainingComplete() {
return getAsyncGlobal().isTrainingComplete(); return getAsyncGlobal().isTrainingComplete();
} }
@ -87,7 +83,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
private int progressMonitorFrequency = 20000; private int progressMonitorFrequency = 20000;
private void launchThreads() { private void launchThreads() {
startGlobalThread();
for (int i = 0; i < getConfiguration().getNumThreads(); i++) { for (int i = 0; i < getConfiguration().getNumThreads(); i++) {
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices()); Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
t.start(); t.start();
@ -99,8 +94,8 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
* @return The current step * @return The current step
*/ */
@Override @Override
public int getStepCounter() { public int getStepCount() {
return getAsyncGlobal().getT().get(); return getAsyncGlobal().getStepCount();
} }
/** /**
@ -129,14 +124,13 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
monitorTraining(); monitorTraining();
} }
cleanupPostTraining();
listeners.notifyTrainingFinished(); listeners.notifyTrainingFinished();
} }
protected void monitorTraining() { protected void monitorTraining() {
try { try {
monitorThread = Thread.currentThread(); monitorThread = Thread.currentThread();
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) { while (canContinue && !isTrainingComplete()) {
canContinue = listeners.notifyTrainingProgress(this); canContinue = listeners.notifyTrainingProgress(this);
if (!canContinue) { if (!canContinue) {
return; return;
@ -152,11 +146,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
monitorThread = null; monitorThread = null;
} }
protected void cleanupPostTraining() {
// Worker threads stops automatically when the global thread stops
getAsyncGlobal().terminate();
}
/** /**
* Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated. * Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated.
*/ */

View File

@ -32,6 +32,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -47,39 +48,63 @@ import org.nd4j.linalg.factory.Nd4j;
* @author Alexandre Boulanger * @author Alexandre Boulanger
*/ */
@Slf4j @Slf4j
public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends NeuralNet> public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>, NN extends NeuralNet>
extends Thread implements IEpochTrainer { extends Thread implements IEpochTrainer {
@Getter @Getter
private int threadNumber; private int threadNumber;
@Getter @Getter
protected final int deviceNum; protected final int deviceNum;
/**
* The number of steps that this async thread has produced
*/
@Getter @Setter @Getter @Setter
private int stepCounter = 0; protected int stepCount = 0;
/**
* The number of epochs (updates) that this thread has sent to the global learner
*/
@Getter @Setter @Getter @Setter
private int epochCounter = 0; protected int epochCount = 0;
/**
* The number of environment episodes that have been played out
*/
@Getter @Setter
protected int episodeCount = 0;
/**
* The number of steps in the current episode
*/
@Getter
protected int currentEpisodeStepCount = 0;
/**
* If the current episode needs to be reset
*/
boolean episodeComplete = true;
@Getter @Setter @Getter @Setter
private IHistoryProcessor historyProcessor; private IHistoryProcessor historyProcessor;
@Getter private boolean isEpisodeStarted = false;
private int currentEpochStep = 0; private final LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE> mdp;
private boolean isEpochStarted = false;
private final LegacyMDPWrapper<O, A, AS> mdp;
private final TrainingListenerList listeners; private final TrainingListenerList listeners;
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, MDP<O, A, AS> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) { public AsyncThread(MDP<OBSERVATION, ACTION, ACTION_SPACE> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
this.mdp = new LegacyMDPWrapper<O, A, AS>(mdp, null, this); this.mdp = new LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE>(mdp, null);
this.listeners = listeners; this.listeners = listeners;
this.threadNumber = threadNumber; this.threadNumber = threadNumber;
this.deviceNum = deviceNum; this.deviceNum = deviceNum;
} }
public MDP<O, A, AS> getMdp() { public MDP<OBSERVATION, ACTION, ACTION_SPACE> getMdp() {
return mdp.getWrappedMDP(); return mdp.getWrappedMDP();
} }
protected LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper() { protected LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE> getLegacyMDPWrapper() {
return mdp; return mdp;
} }
@ -92,13 +117,13 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
mdp.setHistoryProcessor(historyProcessor); mdp.setHistoryProcessor(historyProcessor);
} }
protected void postEpoch() { protected void postEpisode() {
if (getHistoryProcessor() != null) if (getHistoryProcessor() != null)
getHistoryProcessor().stopMonitor(); getHistoryProcessor().stopMonitor();
} }
protected void preEpoch() { protected void preEpisode() {
// Do nothing // Do nothing
} }
@ -125,74 +150,69 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
*/ */
@Override @Override
public void run() { public void run() {
try { RunContext context = new RunContext();
RunContext context = new RunContext(); Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
log.info("ThreadNum-" + threadNumber + " Started!"); log.info("ThreadNum-" + threadNumber + " Started!");
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) { while (!getAsyncGlobal().isTrainingComplete()) {
if (!isEpochStarted) {
boolean canContinue = startNewEpoch(context);
if (!canContinue) {
break;
}
}
handleTraining(context); if (episodeComplete) {
startEpisode(context);
if (currentEpochStep >= getConf().getMaxEpochStep() || getMdp().isDone()) { }
boolean canContinue = finishEpoch(context);
if (!canContinue) { if(!startEpoch(context)) {
break; break;
} }
++epochCounter; episodeComplete = handleTraining(context);
}
if(!finishEpoch(context)) {
break;
}
if(episodeComplete) {
finishEpisode(context);
} }
}
finally {
terminateWork();
} }
} }
private void handleTraining(RunContext context) { private boolean finishEpoch(RunContext context) {
int maxSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpochStep); epochCount++;
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); IDataManager.StatEntry statEntry = new AsyncStatEntry(stepCount, epochCount, context.rewards, currentEpisodeStepCount, context.score);
return listeners.notifyEpochTrainingResult(this, statEntry);
}
private boolean startEpoch(RunContext context) {
return listeners.notifyNewEpoch(this);
}
private boolean handleTraining(RunContext context) {
int maxTrainSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpisodeStepCount);
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxTrainSteps);
context.obs = subEpochReturn.getLastObs(); context.obs = subEpochReturn.getLastObs();
context.rewards += subEpochReturn.getReward(); context.rewards += subEpochReturn.getReward();
context.score = subEpochReturn.getScore(); context.score = subEpochReturn.getScore();
return subEpochReturn.isEpisodeComplete();
} }
private boolean startNewEpoch(RunContext context) { private void startEpisode(RunContext context) {
getCurrent().reset(); getCurrent().reset();
Learning.InitMdp<Observation> initMdp = refacInitMdp(); Learning.InitMdp<Observation> initMdp = refacInitMdp();
context.obs = initMdp.getLastObs(); context.obs = initMdp.getLastObs();
context.rewards = initMdp.getReward(); context.rewards = initMdp.getReward();
isEpochStarted = true; preEpisode();
preEpoch(); episodeCount++;
return listeners.notifyNewEpoch(this);
} }
private boolean finishEpoch(RunContext context) { private void finishEpisode(RunContext context) {
isEpochStarted = false; postEpisode();
postEpoch();
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, currentEpochStep, context.score);
log.info("ThreadNum-" + threadNumber + " Epoch: " + getCurrentEpochStep() + ", reward: " + context.rewards); log.info("ThreadNum-{} Episode step: {}, Episode: {}, Epoch: {}, reward: {}", threadNumber, currentEpisodeStepCount, episodeCount, epochCount, context.rewards);
return listeners.notifyEpochTrainingResult(this, statEntry);
}
private void terminateWork() {
getAsyncGlobal().terminate();
if(isEpochStarted) {
postEpoch();
}
} }
protected abstract NN getCurrent(); protected abstract NN getCurrent();
@ -201,35 +221,35 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
protected abstract IAsyncLearningConfiguration getConf(); protected abstract IAsyncLearningConfiguration getConf();
protected abstract IPolicy<O, A> getPolicy(NN net); protected abstract IPolicy<OBSERVATION, ACTION> getPolicy(NN net);
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep); protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
private Learning.InitMdp<Observation> refacInitMdp() { private Learning.InitMdp<Observation> refacInitMdp() {
currentEpochStep = 0; currentEpisodeStepCount = 0;
double reward = 0; double reward = 0;
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper(); LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE> mdp = getLegacyMDPWrapper();
Observation observation = mdp.reset(); Observation observation = mdp.reset();
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP ACTION action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
while (observation.isSkipped() && !mdp.isDone()) { while (observation.isSkipped() && !mdp.isDone()) {
StepReply<Observation> stepReply = mdp.step(action); StepReply<Observation> stepReply = mdp.step(action);
reward += stepReply.getReward(); reward += stepReply.getReward();
observation = stepReply.getObservation(); observation = stepReply.getObservation();
incrementStep(); incrementSteps();
} }
return new Learning.InitMdp(0, observation, reward); return new Learning.InitMdp(0, observation, reward);
} }
public void incrementStep() { public void incrementSteps() {
++stepCounter; stepCount++;
++currentEpochStep; currentEpisodeStepCount++;
} }
@AllArgsConstructor @AllArgsConstructor
@ -239,6 +259,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
Observation lastObs; Observation lastObs;
double reward; double reward;
double score; double score;
boolean episodeComplete;
} }
@AllArgsConstructor @AllArgsConstructor

View File

@ -24,6 +24,9 @@ import lombok.Setter;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
@ -31,15 +34,19 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Stack;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
* * <p>
* Async Learning specialized for the Discrete Domain * Async Learning specialized for the Discrete Domain
*
*/ */
public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet> public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet>
extends AsyncThread<O, Integer, DiscreteSpace, NN> { extends AsyncThread<O, Integer, DiscreteSpace, NN> {
@Getter @Getter
private NN current; private NN current;
@ -48,7 +55,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
private UpdateAlgorithm<NN> updateAlgorithm; private UpdateAlgorithm<NN> updateAlgorithm;
// TODO: Make it configurable with a builder // TODO: Make it configurable with a builder
@Setter(AccessLevel.PROTECTED) @Setter(AccessLevel.PROTECTED) @Getter
private ExperienceHandler experienceHandler = new StateActionExperienceHandler(); private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
@ -56,9 +63,9 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
TrainingListenerList listeners, TrainingListenerList listeners,
int threadNumber, int threadNumber,
int deviceNum) { int deviceNum) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); super(mdp, listeners, threadNumber, deviceNum);
synchronized (asyncGlobal) { synchronized (asyncGlobal) {
current = (NN)asyncGlobal.getCurrent().clone(); current = (NN) asyncGlobal.getTarget().clone();
} }
} }
@ -72,7 +79,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
} }
@Override @Override
protected void preEpoch() { protected void preEpisode() {
experienceHandler.reset(); experienceHandler.reset();
} }
@ -81,28 +88,23 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
* "Subepoch" correspond to the t_max-step iterations * "Subepoch" correspond to the t_max-step iterations
* that stack rewards with t_max MiniTrans * that stack rewards with t_max MiniTrans
* *
* @param sObs the obs to start from * @param sObs the obs to start from
* @param nstep the number of max nstep (step until t_max or state is terminal) * @param trainingSteps the number of training steps
* @return subepoch training informations * @return subepoch training informations
*/ */
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) { public SubEpochReturn trainSubEpoch(Observation sObs, int trainingSteps) {
synchronized (getAsyncGlobal()) { current.copy(getAsyncGlobal().getTarget());
current.copy(getAsyncGlobal().getCurrent());
}
Observation obs = sObs; Observation obs = sObs;
IPolicy<O, Integer> policy = getPolicy(current); IPolicy<O, Integer> policy = getPolicy(current);
Integer action = getMdp().getActionSpace().noOp(); Integer action = getMdp().getActionSpace().noOp();
IHistoryProcessor hp = getHistoryProcessor();
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
double reward = 0; double reward = 0;
double accuReward = 0; double accuReward = 0;
int stepAtStart = getCurrentEpochStep();
int lastStep = nstep * skipFrame + stepAtStart; while (!getMdp().isDone() && experienceHandler.getTrainingBatchSize() != trainingSteps) {
while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) {
//if step of training, just repeat lastAction //if step of training, just repeat lastAction
if (!obs.isSkipped()) { if (!obs.isSkipped()) {
@ -115,20 +117,26 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
if (!obs.isSkipped()) { if (!obs.isSkipped()) {
experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone()); experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone());
accuReward = 0; accuReward = 0;
incrementSteps();
} }
obs = stepReply.getObservation(); obs = stepReply.getObservation();
reward += stepReply.getReward(); reward += stepReply.getReward();
incrementStep();
} }
if (getMdp().isDone() && getCurrentEpochStep() < lastStep) { boolean episodeComplete = getMdp().isDone() || getConf().getMaxEpochStep() == currentEpisodeStepCount;
if (episodeComplete && experienceHandler.getTrainingBatchSize() != trainingSteps) {
experienceHandler.setFinalObservation(obs); experienceHandler.setFinalObservation(obs);
} }
getAsyncGlobal().enqueue(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), getCurrentEpochStep()); int experienceSize = experienceHandler.getTrainingBatchSize();
return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore()); getAsyncGlobal().applyGradient(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), experienceSize);
return new SubEpochReturn(experienceSize, obs, reward, current.getLatestScore(), episodeComplete);
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -22,17 +23,29 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
public interface IAsyncGlobal<NN extends NeuralNet> { public interface IAsyncGlobal<NN extends NeuralNet> {
boolean isRunning();
boolean isTrainingComplete(); boolean isTrainingComplete();
void start();
/** /**
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded. * The number of updates that have been applied by worker threads.
*/ */
void terminate(); int getWorkerUpdateCount();
AtomicInteger getT(); /**
NN getCurrent(); * The total number of environment steps that have been processed.
*/
int getStepCount();
/**
* A copy of the global network that is updated after a certain number of worker episodes.
*/
NN getTarget(); NN getTarget();
void enqueue(Gradient[] gradient, Integer nstep);
/**
* Apply gradients to the global network
* @param gradient
* @param batchSize
*/
void applyGradient(Gradient[] gradient, int batchSize);
} }

View File

@ -1,26 +1,26 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2020 Konduit K.K. * Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0. * https://www.apache.org/licenses/LICENSE-2.0.
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations * License for the specific language governing permissions and limitations
* under the License. * under the License.
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import java.util.List; import java.util.List;
public interface UpdateAlgorithm<NN extends NeuralNet> { public interface UpdateAlgorithm<NN extends NeuralNet> {
Gradient[] computeGradients(NN current, List<StateActionPair<Integer>> experience); Gradient[] computeGradients(NN current, List<StateActionPair<Integer>> experience);
} }

View File

@ -57,7 +57,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
this.iActorCritic = iActorCritic; this.iActorCritic = iActorCritic;
this.mdp = mdp; this.mdp = mdp;
this.configuration = conf; this.configuration = conf;
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this); asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
Long seed = conf.getSeed(); Long seed = conf.getSeed();
Random rnd = Nd4j.getRandom(); Random rnd = Nd4j.getRandom();

View File

@ -27,7 +27,6 @@ import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
@ -73,6 +72,6 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
@Override @Override
protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() { protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
return new A3CUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getLearnerUpdateFrequency(), conf.getGamma()); return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), conf.getGamma());
} }
} }

View File

@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -27,28 +26,25 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.List; import java.util.List;
public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> { /**
* The Advantage Actor-Critic update algorithm can be used by A2C and A3C algorithms alike
*/
public class AdvantageActorCriticUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
private final IAsyncGlobal asyncGlobal;
private final int[] shape; private final int[] shape;
private final int actionSpaceSize; private final int actionSpaceSize;
private final int targetDqnUpdateFreq;
private final double gamma; private final double gamma;
private final boolean recurrent; private final boolean recurrent;
public A3CUpdateAlgorithm(IAsyncGlobal asyncGlobal, public AdvantageActorCriticUpdateAlgorithm(boolean recurrent,
int[] shape, int[] shape,
int actionSpaceSize, int actionSpaceSize,
int targetDqnUpdateFreq, double gamma) {
double gamma) {
this.asyncGlobal = asyncGlobal;
//if recurrent then train as a time serie with a batch size of 1 //if recurrent then train as a time serie with a batch size of 1
recurrent = asyncGlobal.getCurrent().isRecurrent(); this.recurrent = recurrent;
this.shape = shape; this.shape = shape;
this.actionSpaceSize = actionSpaceSize; this.actionSpaceSize = actionSpaceSize;
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
this.gamma = gamma; this.gamma = gamma;
} }
@ -65,18 +61,12 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
: Nd4j.zeros(size, actionSpaceSize); : Nd4j.zeros(size, actionSpaceSize);
StateActionPair<Integer> stateActionPair = experience.get(size - 1); StateActionPair<Integer> stateActionPair = experience.get(size - 1);
double r; double value;
if(stateActionPair.isTerminal()) { if (stateActionPair.isTerminal()) {
r = 0; value = 0;
} } else {
else { INDArray[] output = current.outputAll(stateActionPair.getObservation().getData());
INDArray[] output = null; value = output[0].getDouble(0);
if (targetDqnUpdateFreq == -1)
output = current.outputAll(stateActionPair.getObservation().getData());
else synchronized (asyncGlobal) {
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
}
r = output[0].getDouble(0);
} }
for (int i = size - 1; i >= 0; --i) { for (int i = size - 1; i >= 0; --i) {
@ -86,7 +76,7 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
INDArray[] output = current.outputAll(observationData); INDArray[] output = current.outputAll(observationData);
r = stateActionPair.getReward() + gamma * r; value = stateActionPair.getReward() + gamma * value;
if (recurrent) { if (recurrent) {
input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(observationData); input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(observationData);
} else { } else {
@ -94,11 +84,11 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
} }
//the critic //the critic
targets.putScalar(i, r); targets.putScalar(i, value);
//the actor //the actor
double expectedV = output[0].getDouble(0); double expectedV = output[0].getDouble(0);
double advantage = r - expectedV; double advantage = value - expectedV;
if (recurrent) { if (recurrent) {
logSoftmax.putScalar(0, stateActionPair.getAction(), i, advantage); logSoftmax.putScalar(0, stateActionPair.getAction(), i, advantage);
} else { } else {
@ -108,6 +98,6 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
// targets -> value, critic // targets -> value, critic
// logSoftmax -> policy, actor // logSoftmax -> policy, actor
return current.gradient(input, new INDArray[] {targets, logSoftmax}); return current.gradient(input, new INDArray[]{targets, logSoftmax});
} }
} }

View File

@ -50,7 +50,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) { public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
this.mdp = mdp; this.mdp = mdp;
this.configuration = conf; this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this); this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
} }
@Override @Override
@ -59,7 +59,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
} }
public IDQN getNeuralNet() { public IDQN getNeuralNet() {
return asyncGlobal.getCurrent(); return asyncGlobal.getTarget();
} }
public IPolicy<O, Integer> getPolicy() { public IPolicy<O, Integer> getPolicy() {

View File

@ -30,8 +30,8 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -57,7 +57,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
rnd = Nd4j.getRandom(); rnd = Nd4j.getRandom();
Long seed = conf.getSeed(); Long seed = conf.getSeed();
if(seed != null) { if (seed != null) {
rnd.setSeed(seed + threadNumber); rnd.setSeed(seed + threadNumber);
} }
@ -72,6 +72,6 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
@Override @Override
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() { protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
return new QLearningUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getTargetDqnUpdateFreq(), conf.getGamma()); return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), conf.getGamma());
} }
} }

View File

@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -28,22 +27,16 @@ import java.util.List;
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> { public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
private final IAsyncGlobal asyncGlobal;
private final int[] shape; private final int[] shape;
private final int actionSpaceSize; private final int actionSpaceSize;
private final int targetDqnUpdateFreq;
private final double gamma; private final double gamma;
public QLearningUpdateAlgorithm(IAsyncGlobal asyncGlobal, public QLearningUpdateAlgorithm(int[] shape,
int[] shape,
int actionSpaceSize, int actionSpaceSize,
int targetDqnUpdateFreq,
double gamma) { double gamma) {
this.asyncGlobal = asyncGlobal;
this.shape = shape; this.shape = shape;
this.actionSpaceSize = actionSpaceSize; this.actionSpaceSize = actionSpaceSize;
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
this.gamma = gamma; this.gamma = gamma;
} }
@ -58,16 +51,11 @@ public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
StateActionPair<Integer> stateActionPair = experience.get(size - 1); StateActionPair<Integer> stateActionPair = experience.get(size - 1);
double r; double r;
if(stateActionPair.isTerminal()) { if (stateActionPair.isTerminal()) {
r = 0; r = 0;
} } else {
else {
INDArray[] output = null; INDArray[] output = null;
if (targetDqnUpdateFreq == -1) output = current.outputAll(stateActionPair.getObservation().getData());
output = current.outputAll(stateActionPair.getObservation().getData());
else synchronized (asyncGlobal) {
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
}
r = Nd4j.max(output[0]).getDouble(0); r = Nd4j.max(output[0]).getDouble(0);
} }

View File

@ -20,8 +20,14 @@ public interface IAsyncLearningConfiguration extends ILearningConfiguration {
int getNumThreads(); int getNumThreads();
/**
* The number of steps to collect for each worker thread between each global update
*/
int getNStep(); int getNStep();
/**
* The frequency of worker thread gradient updates to perform a copy of the current working network to the target network
*/
int getLearnerUpdateFrequency(); int getLearnerUpdateFrequency();
int getMaxStep(); int getMaxStep();

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.*; import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
@ -35,8 +36,8 @@ import org.deeplearning4j.rl4j.util.IDataManager;
* @author Alexandre Boulanger * @author Alexandre Boulanger
*/ */
@Slf4j @Slf4j
public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends NeuralNet> public abstract class SyncLearning<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>, NN extends NeuralNet>
extends Learning<O, A, AS, NN> implements IEpochTrainer { extends Learning<OBSERVATION, ACTION, ACTION_SPACE, NN> implements IEpochTrainer {
private final TrainingListenerList listeners = new TrainingListenerList(); private final TrainingListenerList listeners = new TrainingListenerList();
@ -85,7 +86,7 @@ public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends N
boolean canContinue = listeners.notifyTrainingStarted(); boolean canContinue = listeners.notifyTrainingStarted();
if (canContinue) { if (canContinue) {
while (getStepCounter() < getConfiguration().getMaxStep()) { while (this.getStepCount() < getConfiguration().getMaxStep()) {
preEpoch(); preEpoch();
canContinue = listeners.notifyNewEpoch(this); canContinue = listeners.notifyNewEpoch(this);
if (!canContinue) { if (!canContinue) {
@ -100,14 +101,14 @@ public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends N
postEpoch(); postEpoch();
if(getEpochCounter() % progressMonitorFrequency == 0) { if(getEpochCount() % progressMonitorFrequency == 0) {
canContinue = listeners.notifyTrainingProgress(this); canContinue = listeners.notifyTrainingProgress(this);
if (!canContinue) { if (!canContinue) {
break; break;
} }
} }
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward()); log.info("Epoch: " + getEpochCount() + ", reward: " + statEntry.getReward());
incrementEpoch(); incrementEpoch();
} }
} }

View File

@ -19,21 +19,16 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
import lombok.AccessLevel;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import lombok.Setter;
import lombok.Value; import lombok.Value;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter; import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning; import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -43,8 +38,6 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry; import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -58,7 +51,7 @@ import java.util.List;
@Slf4j @Slf4j
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>> public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
extends SyncLearning<O, A, AS, IDQN> extends SyncLearning<O, A, AS, IDQN>
implements TargetQNetworkSource, EpochStepCounter { implements TargetQNetworkSource, IEpochTrainer {
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper(); protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
@ -90,7 +83,10 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
protected abstract QLStepReturn<Observation> trainStep(Observation obs); protected abstract QLStepReturn<Observation> trainStep(Observation obs);
@Getter @Getter
private int currentEpochStep = 0; private int episodeCount;
@Getter
private int currentEpisodeStepCount = 0;
protected StatEntry trainEpoch() { protected StatEntry trainEpoch() {
resetNetworks(); resetNetworks();
@ -104,9 +100,9 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
double meanQ = 0; double meanQ = 0;
int numQ = 0; int numQ = 0;
List<Double> scores = new ArrayList<>(); List<Double> scores = new ArrayList<>();
while (currentEpochStep < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) { while (currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
if (getStepCounter() % getConfiguration().getTargetDqnUpdateFreq() == 0) { if (this.getStepCount() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
updateTargetNetwork(); updateTargetNetwork();
} }
@ -132,20 +128,20 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
meanQ /= (numQ + 0.001); //avoid div zero meanQ /= (numQ + 0.001); //avoid div zero
StatEntry statEntry = new QLStatEntry(getStepCounter(), getEpochCounter(), reward, currentEpochStep, scores, StatEntry statEntry = new QLStatEntry(this.getStepCount(), getEpochCount(), reward, currentEpisodeStepCount, scores,
getEgPolicy().getEpsilon(), startQ, meanQ); getEgPolicy().getEpsilon(), startQ, meanQ);
return statEntry; return statEntry;
} }
protected void finishEpoch(Observation observation) { protected void finishEpoch(Observation observation) {
// Do Nothing episodeCount++;
} }
@Override @Override
public void incrementStep() { public void incrementStep() {
super.incrementStep(); super.incrementStep();
++currentEpochStep; ++currentEpisodeStepCount;
} }
protected void resetNetworks() { protected void resetNetworks() {
@ -154,7 +150,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
} }
private InitMdp<Observation> refacInitMdp() { private InitMdp<Observation> refacInitMdp() {
currentEpochStep = 0; currentEpisodeStepCount = 0;
double reward = 0; double reward = 0;

View File

@ -47,6 +47,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.List; import java.util.List;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
* <p> * <p>
@ -90,7 +91,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf, public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
int epsilonNbStep, Random random) { int epsilonNbStep, Random random) {
this.configuration = conf; this.configuration = conf;
this.mdp = new LegacyMDPWrapper<>(mdp, null, this); this.mdp = new LegacyMDPWrapper<>(mdp, null);
qNetwork = dqn; qNetwork = dqn;
targetQNetwork = dqn.clone(); targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork()); policy = new DQNPolicy(getQNetwork());
@ -164,13 +165,13 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
// Update NN // Update NN
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"? // FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
if (getStepCounter() > updateStart) { if (this.getStepCount() > updateStart) {
DataSet targets = setTarget(experienceHandler.generateTrainingBatch()); DataSet targets = setTarget(experienceHandler.generateTrainingBatch());
getQNetwork().fit(targets.getFeatures(), targets.getLabels()); getQNetwork().fit(targets.getFeatures(), targets.getLabels());
} }
} }
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply); return new QLStepReturn<>(maxQ, getQNetwork().getLatestScore(), stepReply);
} }
protected DataSet setTarget(List<Transition<Integer>> transitions) { protected DataSet setTarget(List<Transition<Integer>> transitions) {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.rl4j.observation; package org.deeplearning4j.rl4j.observation;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
/** /**

View File

@ -40,7 +40,7 @@ public class EncodableToImageWritableTransform implements Operation<Encodable, I
@Override @Override
public ImageWritable transform(Encodable encodable) { public ImageWritable transform(Encodable encodable) {
INDArray indArray = Nd4j.create((encodable).toArray()).reshape(height, width, colorChannels); INDArray indArray = Nd4j.create(encodable.toArray()).reshape(height, width, colorChannels);
Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer()); Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer());
return new ImageWritable(converter.convert(mat)); return new ImageWritable(converter.convert(mat));
} }

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
@ -40,7 +41,7 @@ import org.nd4j.linalg.api.rng.Random;
*/ */
@AllArgsConstructor @AllArgsConstructor
@Slf4j @Slf4j
public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> { public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<O, A> {
final private Policy<O, A> policy; final private Policy<O, A> policy;
final private MDP<O, A, AS> mdp; final private MDP<O, A, AS> mdp;
@ -57,8 +58,8 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
public A nextAction(INDArray input) { public A nextAction(INDArray input) {
double ep = getEpsilon(); double ep = getEpsilon();
if (learning.getStepCounter() % 500 == 1) if (learning.getStepCount() % 500 == 1)
log.info("EP: " + ep + " " + learning.getStepCounter()); log.info("EP: " + ep + " " + learning.getStepCount());
if (rnd.nextDouble() > ep) if (rnd.nextDouble() > ep)
return policy.nextAction(input); return policy.nextAction(input);
else else
@ -70,6 +71,6 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
} }
public double getEpsilon() { public double getEpsilon() {
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCounter() - updateStart) * 1.0 / epsilonNbStep)); return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCount() - updateStart) * 1.0 / epsilonNbStep));
} }
} }

View File

@ -7,7 +7,7 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
public interface IPolicy<O, A> { public interface IPolicy<O extends Encodable, A> {
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp); <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
A nextAction(INDArray input); A nextAction(INDArray input);
A nextAction(Observation observation); A nextAction(Observation observation);

View File

@ -16,10 +16,7 @@
package org.deeplearning4j.rl4j.policy; package org.deeplearning4j.rl4j.policy;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.HistoryProcessor; import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
@ -27,6 +24,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
/** /**
@ -36,7 +34,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
* *
* A Policy responsability is to choose the next action given a state * A Policy responsability is to choose the next action given a state
*/ */
public abstract class Policy<O, A> implements IPolicy<O, A> { public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
public abstract NeuralNet getNeuralNet(); public abstract NeuralNet getNeuralNet();
@ -54,10 +52,9 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) { public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
resetNetworks(); resetNetworks();
RefacEpochStepCounter epochStepCounter = new RefacEpochStepCounter(); LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp);
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, epochStepCounter);
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp, epochStepCounter); Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp);
Observation obs = initMdp.getLastObs(); Observation obs = initMdp.getLastObs();
double reward = initMdp.getReward(); double reward = initMdp.getReward();
@ -79,7 +76,6 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
reward += stepReply.getReward(); reward += stepReply.getReward();
obs = stepReply.getObservation(); obs = stepReply.getObservation();
epochStepCounter.incrementEpochStep();
} }
return reward; return reward;
@ -89,8 +85,7 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
getNeuralNet().reset(); getNeuralNet().reset();
} }
protected <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { protected <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
epochStepCounter.setCurrentEpochStep(0);
double reward = 0; double reward = 0;
@ -104,21 +99,9 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
reward += stepReply.getReward(); reward += stepReply.getReward();
observation = stepReply.getObservation(); observation = stepReply.getObservation();
epochStepCounter.incrementEpochStep();
} }
return new Learning.InitMdp(0, observation, reward); return new Learning.InitMdp(0, observation, reward);
} }
public class RefacEpochStepCounter implements EpochStepCounter {
@Getter
@Setter
private int currentEpochStep = 0;
public void incrementEpochStep() {
++currentEpochStep;
}
}
} }

View File

@ -278,7 +278,7 @@ public class DataManager implements IDataManager {
Path infoPath = Paths.get(getInfo()); Path infoPath = Paths.get(getInfo());
Info info = new Info(iLearning.getClass().getSimpleName(), iLearning.getMdp().getClass().getSimpleName(), Info info = new Info(iLearning.getClass().getSimpleName(), iLearning.getMdp().getClass().getSimpleName(),
iLearning.getConfiguration(), iLearning.getStepCounter(), System.currentTimeMillis()); iLearning.getConfiguration(), iLearning.getStepCount(), System.currentTimeMillis());
String toWrite = toJson(info); String toWrite = toJson(info);
Files.write(infoPath, toWrite.getBytes(), StandardOpenOption.TRUNCATE_EXISTING); Files.write(infoPath, toWrite.getBytes(), StandardOpenOption.TRUNCATE_EXISTING);
@ -300,12 +300,12 @@ public class DataManager implements IDataManager {
if (!saveData) if (!saveData)
return; return;
save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning); save(getModelDir() + "/" + learning.getStepCount() + ".training", learning);
if(learning instanceof NeuralNetFetchable) { if(learning instanceof NeuralNetFetchable) {
try { try {
((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model"); ((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCount() + ".model");
} catch (UnsupportedOperationException e) { } catch (UnsupportedOperationException e) {
String path = getModelDir() + "/" + learning.getStepCounter(); String path = getModelDir() + "/" + learning.getStepCount();
((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(path + "_value.model", path + "_policy.model"); ((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(path + "_value.model", path + "_policy.model");
} }
} }

View File

@ -40,7 +40,7 @@ public class DataManagerTrainingListener implements TrainingListener {
if (trainer instanceof AsyncThread) { if (trainer instanceof AsyncThread) {
filename += ((AsyncThread) trainer).getThreadNumber() + "-"; filename += ((AsyncThread) trainer).getThreadNumber() + "-";
} }
filename += trainer.getEpochCounter() + "-" + trainer.getStepCounter() + ".mp4"; filename += trainer.getEpochCount() + "-" + trainer.getStepCount() + ".mp4";
hp.startMonitor(filename, shape); hp.startMonitor(filename, shape);
} }
@ -66,7 +66,7 @@ public class DataManagerTrainingListener implements TrainingListener {
@Override @Override
public ListenerResponse onTrainingProgress(ILearning learning) { public ListenerResponse onTrainingProgress(ILearning learning) {
try { try {
int stepCounter = learning.getStepCounter(); int stepCounter = learning.getStepCount();
if (stepCounter - lastSave >= Constants.MODEL_SAVE_FREQ) { if (stepCounter - lastSave >= Constants.MODEL_SAVE_FREQ) {
dataManager.save(learning); dataManager.save(learning);
lastSave = stepCounter; lastSave = stepCounter;

View File

@ -8,7 +8,6 @@ import org.datavec.image.transform.CropImageTransform;
import org.datavec.image.transform.MultiImageTransform; import org.datavec.image.transform.MultiImageTransform;
import org.datavec.image.transform.ResizeImageTransform; import org.datavec.image.transform.ResizeImageTransform;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
@ -30,10 +29,10 @@ import java.util.Map;
import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2GRAY; import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2GRAY;
public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> { public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
@Getter @Getter
private final MDP<O, A, AS> wrappedMDP; private final MDP<OBSERVATION, A, AS> wrappedMDP;
@Getter @Getter
private final WrapperObservationSpace observationSpace; private final WrapperObservationSpace observationSpace;
private final int[] shape; private final int[] shape;
@ -44,16 +43,14 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
@Getter(AccessLevel.PRIVATE) @Getter(AccessLevel.PRIVATE)
private IHistoryProcessor historyProcessor; private IHistoryProcessor historyProcessor;
private final EpochStepCounter epochStepCounter;
private int skipFrame = 1; private int skipFrame = 1;
private int steps = 0;
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) { public LegacyMDPWrapper(MDP<OBSERVATION, A, AS> wrappedMDP, IHistoryProcessor historyProcessor) {
this.wrappedMDP = wrappedMDP; this.wrappedMDP = wrappedMDP;
this.shape = wrappedMDP.getObservationSpace().getShape(); this.shape = wrappedMDP.getObservationSpace().getShape();
this.observationSpace = new WrapperObservationSpace(shape); this.observationSpace = new WrapperObservationSpace(shape);
this.historyProcessor = historyProcessor; this.historyProcessor = historyProcessor;
this.epochStepCounter = epochStepCounter;
setHistoryProcessor(historyProcessor); setHistoryProcessor(historyProcessor);
} }
@ -63,6 +60,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
createTransformProcess(); createTransformProcess();
} }
//TODO: this transform process should be decoupled from history processor and configured seperately by the end-user
private void createTransformProcess() { private void createTransformProcess() {
IHistoryProcessor historyProcessor = getHistoryProcessor(); IHistoryProcessor historyProcessor = getHistoryProcessor();
@ -103,7 +101,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
public Observation reset() { public Observation reset() {
transformProcess.reset(); transformProcess.reset();
O rawResetResponse = wrappedMDP.reset(); OBSERVATION rawResetResponse = wrappedMDP.reset();
record(rawResetResponse); record(rawResetResponse);
if(historyProcessor != null) { if(historyProcessor != null) {
@ -118,21 +116,21 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
public StepReply<Observation> step(A a) { public StepReply<Observation> step(A a) {
IHistoryProcessor historyProcessor = getHistoryProcessor(); IHistoryProcessor historyProcessor = getHistoryProcessor();
StepReply<O> rawStepReply = wrappedMDP.step(a); StepReply<OBSERVATION> rawStepReply = wrappedMDP.step(a);
INDArray rawObservation = getInput(rawStepReply.getObservation()); INDArray rawObservation = getInput(rawStepReply.getObservation());
if(historyProcessor != null) { if(historyProcessor != null) {
historyProcessor.record(rawObservation); historyProcessor.record(rawObservation);
} }
int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1; int stepOfObservation = steps++;
Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation()); Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation());
Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone()); Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone());
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
} }
private void record(O obs) { private void record(OBSERVATION obs) {
INDArray rawObservation = getInput(obs); INDArray rawObservation = getInput(obs);
IHistoryProcessor historyProcessor = getHistoryProcessor(); IHistoryProcessor historyProcessor = getHistoryProcessor();
@ -141,7 +139,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
} }
} }
private Map<String, Object> buildChannelsData(final O obs) { private Map<String, Object> buildChannelsData(final OBSERVATION obs) {
return new HashMap<String, Object>() {{ return new HashMap<String, Object>() {{
put("data", obs); put("data", obs);
}}; }};
@ -159,11 +157,11 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
@Override @Override
public MDP<Observation, A, AS> newInstance() { public MDP<Observation, A, AS> newInstance() {
return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), historyProcessor, epochStepCounter); return new LegacyMDPWrapper<>(wrappedMDP.newInstance(), historyProcessor);
} }
private INDArray getInput(O obs) { private INDArray getInput(OBSERVATION obs) {
INDArray arr = Nd4j.create(((Encodable)obs).toArray()); INDArray arr = Nd4j.create(obs.toArray());
int[] shape = observationSpace.getShape(); int[] shape = observationSpace.getShape();
if (shape.length == 1) if (shape.length == 1)
return arr.reshape(new long[] {1, arr.length()}); return arr.reshape(new long[] {1, arr.length()});

View File

@ -1,107 +1,107 @@
package org.deeplearning4j.rl4j.experience; package org.deeplearning4j.rl4j.experience;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class ReplayMemoryExperienceHandlerTest { public class ReplayMemoryExperienceHandlerTest {
@Test @Test
public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() { public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() {
// Arrange // Arrange
TestExpReplay expReplayMock = new TestExpReplay(); TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act // Act
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
int numStoredTransitions = expReplayMock.addedTransitions.size(); int numStoredTransitions = expReplayMock.addedTransitions.size();
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
// Assert // Assert
assertEquals(0, numStoredTransitions); assertEquals(0, numStoredTransitions);
assertEquals(1, expReplayMock.addedTransitions.size()); assertEquals(1, expReplayMock.addedTransitions.size());
} }
@Test @Test
public void when_addingExperience_expect_transitionsAreCorrect() { public void when_addingExperience_expect_transitionsAreCorrect() {
// Arrange // Arrange
TestExpReplay expReplayMock = new TestExpReplay(); TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act // Act
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 }))); sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
// Assert // Assert
assertEquals(2, expReplayMock.addedTransitions.size()); assertEquals(2, expReplayMock.addedTransitions.size());
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001); assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001);
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction()); assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001); assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001);
assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001); assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001);
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001); assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001);
assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction()); assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction());
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001); assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001);
assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001); assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001);
} }
@Test @Test
public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() { public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() {
// Arrange // Arrange
TestExpReplay expReplayMock = new TestExpReplay(); TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act // Act
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 2.0 }))); sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 2.0 })));
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
// Assert // Assert
assertEquals(1, expReplayMock.addedTransitions.size()); assertEquals(1, expReplayMock.addedTransitions.size());
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction()); assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
} }
@Test @Test
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() { public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
// Arrange // Arrange
TestExpReplay expReplayMock = new TestExpReplay(); TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 }))); sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
// Act // Act
int size = sut.getTrainingBatchSize(); int size = sut.getTrainingBatchSize();
// Assert // Assert
assertEquals(2, size); assertEquals(2, size);
} }
private static class TestExpReplay implements IExpReplay<Integer> { private static class TestExpReplay implements IExpReplay<Integer> {
public final List<Transition<Integer>> addedTransitions = new ArrayList<>(); public final List<Transition<Integer>> addedTransitions = new ArrayList<>();
@Override @Override
public ArrayList<Transition<Integer>> getBatch() { public ArrayList<Transition<Integer>> getBatch() {
return null; return null;
} }
@Override @Override
public void store(Transition<Integer> transition) { public void store(Transition<Integer> transition) {
addedTransitions.add(transition); addedTransitions.add(transition);
} }
@Override @Override
public int getBatchSize() { public int getBatchSize() {
return addedTransitions.size(); return addedTransitions.size();
} }
} }
} }

View File

@ -18,132 +18,93 @@
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.support.MockAsyncConfiguration; import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal; import org.junit.Before;
import org.deeplearning4j.rl4j.support.MockEncodable;
import org.deeplearning4j.rl4j.support.MockNeuralNet;
import org.deeplearning4j.rl4j.support.MockPolicy;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.eq;
import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class AsyncLearningTest { public class AsyncLearningTest {
@Test AsyncLearning<Box, INDArray, ActionSpace<INDArray>, NeuralNet> asyncLearning;
public void when_training_expect_AsyncGlobalStarted() {
// Arrange
TestContext context = new TestContext();
context.asyncGlobal.setMaxLoops(1);
// Act @Mock
context.sut.train(); TrainingListener mockTrainingListener;
// Assert @Mock
assertTrue(context.asyncGlobal.hasBeenStarted); AsyncGlobal<NeuralNet> mockAsyncGlobal;
assertTrue(context.asyncGlobal.hasBeenTerminated);
@Mock
IAsyncLearningConfiguration mockConfiguration;
@Before
public void setup() {
asyncLearning = mock(AsyncLearning.class, Mockito.withSettings()
.useConstructor()
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
asyncLearning.addListener(mockTrainingListener);
when(asyncLearning.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
when(asyncLearning.getConfiguration()).thenReturn(mockConfiguration);
// Don't actually start any threads in any of these tests
when(mockConfiguration.getNumThreads()).thenReturn(0);
} }
@Test @Test
public void when_trainStartReturnsStop_expect_noTraining() { public void when_trainStartReturnsStop_expect_noTraining() {
// Arrange // Arrange
TestContext context = new TestContext(); when(mockTrainingListener.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP);
context.listener.setRemainingTrainingStartCallCount(0);
// Act // Act
context.sut.train(); asyncLearning.train();
// Assert // Assert
assertEquals(1, context.listener.onTrainingStartCallCount); verify(mockTrainingListener, times(1)).onTrainingStart();
assertEquals(1, context.listener.onTrainingEndCallCount); verify(mockTrainingListener, times(1)).onTrainingEnd();
assertEquals(0, context.policy.playCallCount);
assertTrue(context.asyncGlobal.hasBeenTerminated);
} }
@Test @Test
public void when_trainingIsComplete_expect_trainingStop() { public void when_trainingIsComplete_expect_trainingStop() {
// Arrange // Arrange
TestContext context = new TestContext(); when(mockAsyncGlobal.isTrainingComplete()).thenReturn(true);
// Act // Act
context.sut.train(); asyncLearning.train();
// Assert // Assert
assertEquals(1, context.listener.onTrainingStartCallCount); verify(mockTrainingListener, times(1)).onTrainingStart();
assertEquals(1, context.listener.onTrainingEndCallCount); verify(mockTrainingListener, times(1)).onTrainingEnd();
assertTrue(context.asyncGlobal.hasBeenTerminated);
} }
@Test @Test
public void when_training_expect_onTrainingProgressCalled() { public void when_training_expect_onTrainingProgressCalled() {
// Arrange // Arrange
TestContext context = new TestContext(); asyncLearning.setProgressMonitorFrequency(100);
when(mockTrainingListener.onTrainingProgress(eq(asyncLearning))).thenReturn(TrainingListener.ListenerResponse.STOP);
// Act // Act
context.sut.train(); asyncLearning.train();
// Assert // Assert
assertEquals(1, context.listener.onTrainingProgressCallCount); verify(mockTrainingListener, times(1)).onTrainingStart();
verify(mockTrainingListener, times(1)).onTrainingEnd();
verify(mockTrainingListener, times(1)).onTrainingProgress(eq(asyncLearning));
} }
public static class TestContext {
MockAsyncConfiguration config = new MockAsyncConfiguration(1L, 11, 0, 0, 0, 0,0, 0, 0, 0);
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockPolicy policy = new MockPolicy();
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
public final MockTrainingListener listener = new MockTrainingListener(asyncGlobal);
public TestContext() {
sut.addListener(listener);
asyncGlobal.setMaxLoops(1);
sut.setProgressMonitorFrequency(1);
}
}
public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
private final IAsyncLearningConfiguration conf;
private final IAsyncGlobal asyncGlobal;
private final IPolicy<MockEncodable, Integer> policy;
public TestAsyncLearning(IAsyncLearningConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
this.conf = conf;
this.asyncGlobal = asyncGlobal;
this.policy = policy;
}
@Override
public IPolicy getPolicy() {
return policy;
}
@Override
public IAsyncLearningConfiguration getConfiguration() {
return conf;
}
@Override
protected AsyncThread newThread(int i, int deviceAffinity) {
return null;
}
@Override
public MDP getMdp() {
return null;
}
@Override
protected IAsyncGlobal getAsyncGlobal() {
return asyncGlobal;
}
@Override
public MockNeuralNet getNeuralNet() {
return null;
}
}
} }

View File

@ -1,6 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2020 Konduit K. K.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -17,161 +16,230 @@
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.concurrent.atomic.AtomicInteger;
import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class AsyncThreadDiscreteTest { public class AsyncThreadDiscreteTest {
AsyncThreadDiscrete<Encodable, NeuralNet> asyncThreadDiscrete;
@Mock
IAsyncLearningConfiguration mockAsyncConfiguration;
@Mock
UpdateAlgorithm<NeuralNet> mockUpdateAlgorithm;
@Mock
IAsyncGlobal<NeuralNet> mockAsyncGlobal;
@Mock
Policy<Encodable, Integer> mockGlobalCurrentPolicy;
@Mock
NeuralNet mockGlobalTargetNetwork;
@Mock
MDP<Encodable, Integer, DiscreteSpace> mockMDP;
@Mock
LegacyMDPWrapper<Encodable, Integer, DiscreteSpace> mockLegacyMDPWrapper;
@Mock
DiscreteSpace mockActionSpace;
@Mock
ObservationSpace<Encodable> mockObservationSpace;
@Mock
TrainingListenerList mockTrainingListenerList;
@Mock
Observation mockObservation;
int[] observationShape = new int[]{3, 10, 10};
int actionSize = 4;
private void setupMDPMocks() {
when(mockActionSpace.noOp()).thenReturn(0);
when(mockMDP.getActionSpace()).thenReturn(mockActionSpace);
when(mockObservationSpace.getShape()).thenReturn(observationShape);
when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace);
}
private void setupNNMocks() {
when(mockAsyncGlobal.getTarget()).thenReturn(mockGlobalTargetNetwork);
when(mockGlobalTargetNetwork.clone()).thenReturn(mockGlobalTargetNetwork);
}
@Before
public void setup() {
setupMDPMocks();
setupNNMocks();
asyncThreadDiscrete = mock(AsyncThreadDiscrete.class, Mockito.withSettings()
.useConstructor(mockAsyncGlobal, mockMDP, mockTrainingListenerList, 0, 0)
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm);
when(asyncThreadDiscrete.getConf()).thenReturn(mockAsyncConfiguration);
when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0);
when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy);
when(mockGlobalCurrentPolicy.nextAction(any(Observation.class))).thenReturn(0);
when(asyncThreadDiscrete.getLegacyMDPWrapper()).thenReturn(mockLegacyMDPWrapper);
}
@Test @Test
public void refac_AsyncThreadDiscrete_trainSubEpoch() { public void when_episodeCompletes_expect_stepsToBeInLineWithEpisodeLenth() {
// Arrange // Arrange
int numEpochs = 1; int episodeRemaining = 5;
MockNeuralNet nnMock = new MockNeuralNet(); int remainingTrainingSteps = 10;
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); // return done after 4 steps (the episode finishes before nsteps)
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock); when(mockMDP.isDone()).thenAnswer(invocation ->
asyncGlobalMock.setMaxLoops(hpConf.getSkipFrame() * numEpochs); asyncThreadDiscrete.getStepCount() == episodeRemaining
MockObservationSpace observationSpace = new MockObservationSpace(); );
MockMDP mdpMock = new MockMDP(observationSpace);
TrainingListenerList listeners = new TrainingListenerList(); when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null));
MockPolicy policyMock = new MockPolicy();
MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5);
MockExperienceHandler experienceHandlerMock = new MockExperienceHandler();
MockUpdateAlgorithm updateAlgorithmMock = new MockUpdateAlgorithm();
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock, experienceHandlerMock, updateAlgorithmMock);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
// Act // Act
sut.run(); AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
// Assert // Assert
assertEquals(2, sut.trainSubEpochResults.size()); assertTrue(subEpochReturn.isEpisodeComplete());
double[][] expectedLastObservations = new double[][] { assertEquals(5, subEpochReturn.getSteps());
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
};
double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 };
for(int i = 0; i < 2; ++i) {
AsyncThread.SubEpochReturn result = sut.trainSubEpochResults.get(i);
assertEquals(4, result.getSteps());
assertEquals(expectedSubEpochReturnRewards[i], result.getReward(), 0.00001);
assertEquals(0.0, result.getScore(), 0.00001);
double[] expectedLastObservation = expectedLastObservations[i];
assertEquals(expectedLastObservation.length, result.getLastObs().getData().shape()[1]);
for(int j = 0; j < expectedLastObservation.length; ++j) {
assertEquals(expectedLastObservation[j], 255.0 * result.getLastObs().getData().getDouble(j), 0.00001);
}
}
assertEquals(2, asyncGlobalMock.enqueueCallCount);
// HistoryProcessor
double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, };
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
for(int i = 0; i < expectedRecordValues.length; ++i) {
assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001);
}
// Policy
double[][] expectedPolicyInputs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
};
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
double[] expectedRow = expectedPolicyInputs[i];
INDArray input = policyMock.actionInputs.get(i);
assertEquals(expectedRow.length, input.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
}
}
// ExperienceHandler
double[][] expectedExperienceHandlerInputs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
};
assertEquals(expectedExperienceHandlerInputs.length, experienceHandlerMock.addExperienceArgs.size());
for(int i = 0; i < expectedExperienceHandlerInputs.length; ++i) {
double[] expectedRow = expectedExperienceHandlerInputs[i];
INDArray input = experienceHandlerMock.addExperienceArgs.get(i).getObservation().getData();
assertEquals(expectedRow.length, input.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
}
}
} }
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> { @Test
public void when_episodeCompletesDueToMaxStepsReached_expect_isEpisodeComplete() {
private final MockAsyncGlobal asyncGlobal; // Arrange
private final MockPolicy policy; int remainingTrainingSteps = 50;
private final MockAsyncConfiguration config;
public final List<SubEpochReturn> trainSubEpochResults = new ArrayList<SubEpochReturn>(); // Episode does not complete due to MDP
when(mockMDP.isDone()).thenReturn(false);
public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp, when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null));
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
MockAsyncConfiguration config, IHistoryProcessor hp,
ExperienceHandler<Integer, Transition<Integer>> experienceHandler,
UpdateAlgorithm<MockNeuralNet> updateAlgorithm) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.asyncGlobal = asyncGlobal;
this.policy = policy;
this.config = config;
setHistoryProcessor(hp);
setExperienceHandler(experienceHandler);
setUpdateAlgorithm(updateAlgorithm);
}
@Override when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(50);
protected IAsyncGlobal<MockNeuralNet> getAsyncGlobal() {
return asyncGlobal;
}
@Override // Act
protected IAsyncLearningConfiguration getConf() { AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
return config;
}
@Override // Assert
protected IPolicy<MockEncodable, Integer> getPolicy(MockNeuralNet net) { assertTrue(subEpochReturn.isEpisodeComplete());
return policy; assertEquals(50, subEpochReturn.getSteps());
}
@Override
protected UpdateAlgorithm<MockNeuralNet> buildUpdateAlgorithm() {
return null;
}
@Override
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
asyncGlobal.increaseCurrentLoop();
SubEpochReturn result = super.trainSubEpoch(sObs, nstep);
trainSubEpochResults.add(result);
return result;
}
} }
@Test
public void when_episodeLongerThanNsteps_expect_returnNStepLength() {
// Arrange
int episodeRemaining = 5;
int remainingTrainingSteps = 4;
// return done after 4 steps (the episode finishes before nsteps)
when(mockMDP.isDone()).thenAnswer(invocation ->
asyncThreadDiscrete.getStepCount() == episodeRemaining
);
when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null));
// Act
AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
// Assert
assertFalse(subEpochReturn.isEpisodeComplete());
assertEquals(remainingTrainingSteps, subEpochReturn.getSteps());
}
@Test
public void when_framesAreSkipped_expect_proportionateStepCounterUpdates() {
int skipFrames = 2;
int remainingTrainingSteps = 10;
// Episode does not complete due to MDP
when(mockMDP.isDone()).thenReturn(false);
AtomicInteger stepCount = new AtomicInteger();
// Use skipFrames to return if observations are skipped or not
when(mockLegacyMDPWrapper.step(anyInt())).thenAnswer(invocationOnMock -> {
boolean isSkipped = stepCount.incrementAndGet() % skipFrames != 0;
Observation mockObs = new Observation(isSkipped ? null : Nd4j.create(observationShape));
return new StepReply<>(mockObs, 0.0, false, null);
});
// Act
AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
// Assert
assertFalse(subEpochReturn.isEpisodeComplete());
assertEquals(remainingTrainingSteps, subEpochReturn.getSteps());
assertEquals((remainingTrainingSteps - 1) * skipFrames + 1, stepCount.get());
}
@Test
public void when_preEpisodeCalled_expect_experienceHandlerReset() {
// Arrange
int trainingSteps = 100;
for (int i = 0; i < trainingSteps; i++) {
asyncThreadDiscrete.getExperienceHandler().addExperience(mockObservation, 0, 0.0, false);
}
int experienceHandlerSizeBeforeReset = asyncThreadDiscrete.getExperienceHandler().getTrainingBatchSize();
// Act
asyncThreadDiscrete.preEpisode();
// Assert
assertEquals(100, experienceHandlerSizeBeforeReset);
assertEquals(0, asyncThreadDiscrete.getExperienceHandler().getTrainingBatchSize());
}
} }

View File

@ -1,220 +1,277 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.async;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.support.MockAsyncConfiguration; import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockEncodable;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockNeuralNet;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import java.util.ArrayList; import org.mockito.Mock;
import java.util.List; import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.base.Preconditions;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class AsyncThreadTest { public class AsyncThreadTest {
@Test @Mock
public void when_newEpochStarted_expect_neuralNetworkReset() { ActionSpace<INDArray> mockActionSpace;
// Arrange
int numberOfEpochs = 5;
TestContext context = new TestContext(numberOfEpochs);
// Act @Mock
context.sut.run(); ObservationSpace<Box> mockObservationSpace;
// Assert @Mock
assertEquals(numberOfEpochs, context.neuralNet.resetCallCount); IAsyncLearningConfiguration mockAsyncConfiguration;
@Mock
NeuralNet mockNeuralNet;
@Mock
IAsyncGlobal<NeuralNet> mockAsyncGlobal;
@Mock
MDP<Box, INDArray, ActionSpace<INDArray>> mockMDP;
@Mock
TrainingListenerList mockTrainingListeners;
int[] observationShape = new int[]{3, 10, 10};
int actionSize = 4;
AsyncThread<Box, INDArray, ActionSpace<INDArray>, NeuralNet> thread;
@Before
public void setup() {
setupMDPMocks();
setupThreadMocks();
}
private void setupThreadMocks() {
thread = mock(AsyncThread.class, Mockito.withSettings()
.useConstructor(mockMDP, mockTrainingListeners, 0, 0)
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
when(thread.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
when(thread.getCurrent()).thenReturn(mockNeuralNet);
}
private void setupMDPMocks() {
when(mockObservationSpace.getShape()).thenReturn(observationShape);
when(mockActionSpace.noOp()).thenReturn(Nd4j.zeros(actionSize));
when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace);
when(mockMDP.getActionSpace()).thenReturn(mockActionSpace);
int dataLength = 1;
for (int d : observationShape) {
dataLength *= d;
}
when(mockMDP.reset()).thenReturn(new Box(new double[dataLength]));
}
private void mockTrainingListeners() {
mockTrainingListeners(false, false);
}
private void mockTrainingListeners(boolean stopOnNotifyNewEpoch, boolean stopOnNotifyEpochTrainingResult) {
when(mockTrainingListeners.notifyNewEpoch(eq(thread))).thenReturn(!stopOnNotifyNewEpoch);
when(mockTrainingListeners.notifyEpochTrainingResult(eq(thread), any(IDataManager.StatEntry.class))).thenReturn(!stopOnNotifyEpochTrainingResult);
}
private void mockTrainingContext() {
mockTrainingContext(1000, 100, 10);
}
private void mockTrainingContext(int maxSteps, int maxStepsPerEpisode, int nstep) {
// Some conditions of this test harness
Preconditions.checkArgument(maxStepsPerEpisode >= nstep, "episodeLength must be greater than or equal to nstep");
Preconditions.checkArgument(maxStepsPerEpisode % nstep == 0, "episodeLength must be a multiple of nstep");
Observation mockObs = new Observation(Nd4j.zeros(observationShape));
when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode);
when(mockAsyncConfiguration.getNStep()).thenReturn(nstep);
when(thread.getConf()).thenReturn(mockAsyncConfiguration);
// if we hit the max step count
when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps);
when(thread.trainSubEpoch(any(Observation.class), anyInt())).thenAnswer(invocationOnMock -> {
int steps = invocationOnMock.getArgument(1);
thread.stepCount += steps;
thread.currentEpisodeStepCount += steps;
boolean isEpisodeComplete = thread.getCurrentEpisodeStepCount() % maxStepsPerEpisode == 0;
return new AsyncThread.SubEpochReturn(steps, mockObs, 0.0, 0.0, isEpisodeComplete);
});
} }
@Test @Test
public void when_onNewEpochReturnsStop_expect_threadStopped() { public void when_episodeComplete_expect_neuralNetworkReset() {
// Arrange // Arrange
int stopAfterNumCalls = 1; mockTrainingContext(100, 10, 10);
TestContext context = new TestContext(100000); mockTrainingListeners();
context.listener.setRemainingOnNewEpochCallCount(stopAfterNumCalls);
// Act // Act
context.sut.run(); thread.run();
// Assert // Assert
assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: The call that returns stop is counted verify(mockNeuralNet, times(10)).reset(); // there are 10 episodes so the network should be reset between each
assertEquals(stopAfterNumCalls, context.listener.onEpochTrainingResultCallCount); assertEquals(10, thread.getEpochCount()); // We are performing a training iteration every 10 steps, so there should be 10 epochs
assertEquals(10, thread.getEpisodeCount()); // There should be 10 completed episodes
assertEquals(100, thread.getStepCount()); // 100 steps overall
} }
@Test @Test
public void when_epochTrainingResultReturnsStop_expect_threadStopped() { public void when_notifyNewEpochReturnsStop_expect_threadStopped() {
// Arrange // Arrange
int stopAfterNumCalls = 1; mockTrainingContext();
TestContext context = new TestContext(100000); mockTrainingListeners(true, false);
context.listener.setRemainingOnEpochTrainingResult(stopAfterNumCalls);
// Act // Act
context.sut.run(); thread.run();
// Assert // Assert
assertEquals(stopAfterNumCalls + 1, context.listener.onEpochTrainingResultCallCount); // +1: The call that returns stop is counted assertEquals(0, thread.getEpochCount());
assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: onNewEpoch is called on the epoch that onEpochTrainingResult() will stop assertEquals(1, thread.getEpisodeCount());
assertEquals(0, thread.getStepCount());
} }
@Test @Test
public void when_run_expect_preAndPostEpochCalled() { public void when_notifyEpochTrainingResultReturnsStop_expect_threadStopped() {
// Arrange // Arrange
int numberOfEpochs = 5; mockTrainingContext();
TestContext context = new TestContext(numberOfEpochs); mockTrainingListeners(false, true);
// Act // Act
context.sut.run(); thread.run();
// Assert // Assert
assertEquals(numberOfEpochs, context.sut.preEpochCallCount); assertEquals(1, thread.getEpochCount());
assertEquals(numberOfEpochs, context.sut.postEpochCallCount); assertEquals(1, thread.getEpisodeCount());
assertEquals(10, thread.getStepCount()); // one epoch is by default 10 steps
}
@Test
public void when_run_expect_preAndPostEpisodeCalled() {
// Arrange
mockTrainingContext(100, 10, 5);
mockTrainingListeners(false, false);
// Act
thread.run();
// Assert
assertEquals(20, thread.getEpochCount());
assertEquals(10, thread.getEpisodeCount());
assertEquals(100, thread.getStepCount());
verify(thread, times(10)).preEpisode(); // over 100 steps there will be 10 episodes
verify(thread, times(10)).postEpisode();
} }
@Test @Test
public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() { public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() {
// Arrange // Arrange
int numberOfEpochs = 5; mockTrainingContext(100, 10, 5);
TestContext context = new TestContext(numberOfEpochs); mockTrainingListeners(false, false);
// Act // Act
context.sut.run(); thread.run();
// Assert // Assert
assertEquals(numberOfEpochs, context.listener.statEntries.size()); assertEquals(20, thread.getEpochCount());
int[] expectedStepCounter = new int[] { 10, 20, 30, 40, 50 }; assertEquals(10, thread.getEpisodeCount());
double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init assertEquals(100, thread.getStepCount());
+ 1.0; // Reward from trainSubEpoch()
for(int i = 0; i < numberOfEpochs; ++i) { // Over 100 steps there will be 20 training iterations, so there will be 20 calls to notifyEpochTrainingResult
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i); verify(mockTrainingListeners, times(20)).notifyEpochTrainingResult(eq(thread), any(IDataManager.StatEntry.class));
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
assertEquals(i, statEntry.getEpochCounter());
assertEquals(expectedReward, statEntry.getReward(), 0.0001);
}
} }
@Test @Test
public void when_run_expect_trainSubEpochCalled() { public void when_run_expect_trainSubEpochCalled() {
// Arrange // Arrange
int numberOfEpochs = 5; mockTrainingContext(100, 10, 5);
TestContext context = new TestContext(numberOfEpochs); mockTrainingListeners(false, false);
// Act // Act
context.sut.run(); thread.run();
// Assert // Assert
assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size()); assertEquals(20, thread.getEpochCount());
double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }; assertEquals(10, thread.getEpisodeCount());
for(int i = 0; i < context.sut.trainSubEpochParams.size(); ++i) { assertEquals(100, thread.getStepCount());
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
assertEquals(2, params.nstep); // There should be 20 calls to trainsubepoch with 5 steps per epoch
assertEquals(expectedObservation.length, params.obs.getData().shape()[1]); verify(thread, times(20)).trainSubEpoch(any(Observation.class), eq(5));
for(int j = 0; j < expectedObservation.length; ++j){ }
assertEquals(expectedObservation[j], 255.0 * params.obs.getData().getDouble(j), 0.00001);
@Test
public void when_remainingEpisodeLengthSmallerThanNSteps_expect_trainSubEpochCalledWithMinimumValue() {
int currentEpisodeSteps = 95;
mockTrainingContext(1000, 100, 10);
mockTrainingListeners(false, true);
// want to mock that we are 95 steps into the episode
doAnswer(invocationOnMock -> {
for (int i = 0; i < currentEpisodeSteps; i++) {
thread.incrementSteps();
} }
}
}
private static class TestContext {
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockNeuralNet neuralNet = new MockNeuralNet();
public final MockObservationSpace observationSpace = new MockObservationSpace();
public final MockMDP mdp = new MockMDP(observationSpace);
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 10, 0, 0, 0, 0, 0, 0, 10, 0);
public final TrainingListenerList listeners = new TrainingListenerList();
public final MockTrainingListener listener = new MockTrainingListener();
public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf);
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
public TestContext(int numEpochs) {
asyncGlobal.setMaxLoops(numEpochs);
listeners.add(listener);
sut.setHistoryProcessor(historyProcessor);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
}
}
public static class MockAsyncThread extends AsyncThread<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
public int preEpochCallCount = 0;
public int postEpochCallCount = 0;
private final MockAsyncGlobal asyncGlobal;
private final MockNeuralNet neuralNet;
private final IAsyncLearningConfiguration conf;
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, IAsyncLearningConfiguration conf, TrainingListenerList listeners) {
super(asyncGlobal, mdp, listeners, threadNumber, 0);
this.asyncGlobal = asyncGlobal;
this.neuralNet = neuralNet;
this.conf = conf;
}
@Override
protected void preEpoch() {
++preEpochCallCount;
super.preEpoch();
}
@Override
protected void postEpoch() {
++postEpochCallCount;
super.postEpoch();
}
@Override
protected MockNeuralNet getCurrent() {
return neuralNet;
}
@Override
protected IAsyncGlobal getAsyncGlobal() {
return asyncGlobal;
}
@Override
protected IAsyncLearningConfiguration getConf() {
return conf;
}
@Override
protected Policy getPolicy(MockNeuralNet net) {
return null; return null;
} }).when(thread).preEpisode();
@Override mockTrainingListeners(false, true);
protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) {
asyncGlobal.increaseCurrentLoop();
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
for(int i = 0; i < nstep; ++i) {
incrementStep();
}
return new SubEpochReturn(nstep, null, 1.0, 1.0);
}
@AllArgsConstructor // Act
@Getter thread.run();
public static class TrainSubEpochParams {
Observation obs; // Assert
int nstep; assertEquals(1, thread.getEpochCount());
} assertEquals(1, thread.getEpisodeCount());
assertEquals(100, thread.getStepCount());
// There should be 1 call to trainsubepoch with 5 steps as this is the remaining episode steps
verify(thread, times(1)).trainSubEpoch(any(Observation.class), eq(5));
} }
} }

View File

@ -1,160 +0,0 @@
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
public class A3CUpdateAlgorithmTest {
@Test
public void refac_calcGradient_non_terminal() {
// Arrange
double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace(new int[] { 5 });
MockMDP mdpMock = new MockMDP(observationSpace);
MockActorCritic actorCriticMock = new MockActorCritic();
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
A3CUpdateAlgorithm sut = new A3CUpdateAlgorithm(asyncGlobalMock, observationSpace.getShape(), mdpMock.getActionSpace().getSize(), -1, gamma);
INDArray[] originalObservations = new INDArray[] {
Nd4j.create(new double[] { 0.0, 0.1, 0.2, 0.3, 0.4 }),
Nd4j.create(new double[] { 1.0, 1.1, 1.2, 1.3, 1.4 }),
Nd4j.create(new double[] { 2.0, 2.1, 2.2, 2.3, 2.4 }),
Nd4j.create(new double[] { 3.0, 3.1, 3.2, 3.3, 3.4 }),
};
int[] actions = new int[] { 0, 1, 2, 1 };
double[] rewards = new double[] { 0.1, 1.0, 10.0, 100.0 };
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>();
for(int i = 0; i < originalObservations.length; ++i) {
experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false));
}
// Act
sut.computeGradients(actorCriticMock, experience);
// Assert
assertEquals(1, actorCriticMock.gradientParams.size());
// Inputs
INDArray input = actorCriticMock.gradientParams.get(0).getLeft();
for(int i = 0; i < 4; ++i) {
for(int j = 0; j < 5; ++j) {
assertEquals(i + j / 10.0, input.getDouble(i, j), 0.00001);
}
}
INDArray targets = actorCriticMock.gradientParams.get(0).getRight()[0];
INDArray logSoftmax = actorCriticMock.gradientParams.get(0).getRight()[1];
assertEquals(4, targets.shape()[0]);
assertEquals(1, targets.shape()[1]);
// FIXME: check targets values once fixed
assertEquals(4, logSoftmax.shape()[0]);
assertEquals(5, logSoftmax.shape()[1]);
// FIXME: check logSoftmax values once fixed
}
public class MockActorCritic implements IActorCritic {
public final List<Pair<INDArray, INDArray[]>> gradientParams = new ArrayList<>();
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}
@Override
public boolean isRecurrent() {
return false;
}
@Override
public void reset() {
}
@Override
public void fit(INDArray input, INDArray[] labels) {
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[] { batch.mul(-1.0) };
}
@Override
public IActorCritic clone() {
return this;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public void copy(IActorCritic from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] labels) {
gradientParams.add(new Pair<INDArray, INDArray[]>(input, labels));
return new Gradient[0];
}
@Override
public void applyGradient(Gradient[] gradient, int batchSize) {
}
@Override
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
}
@Override
public void save(String pathValue, String pathPolicy) throws IOException {
}
@Override
public double getLatestScore() {
return 0;
}
@Override
public void save(OutputStream os) throws IOException {
}
@Override
public void save(String filename) throws IOException {
}
}
}

View File

@ -0,0 +1,93 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class AdvantageActorCriticUpdateAlgorithmTest {
@Mock
AsyncGlobal<NeuralNet> mockAsyncGlobal;
@Mock
IActorCritic mockActorCritic;
@Test
public void refac_calcGradient_non_terminal() {
// Arrange
int[] observationShape = new int[]{5};
double gamma = 0.9;
AdvantageActorCriticUpdateAlgorithm algorithm = new AdvantageActorCriticUpdateAlgorithm(false, observationShape, 1, gamma);
INDArray[] originalObservations = new INDArray[]{
Nd4j.create(new double[]{0.0, 0.1, 0.2, 0.3, 0.4}),
Nd4j.create(new double[]{1.0, 1.1, 1.2, 1.3, 1.4}),
Nd4j.create(new double[]{2.0, 2.1, 2.2, 2.3, 2.4}),
Nd4j.create(new double[]{3.0, 3.1, 3.2, 3.3, 3.4}),
};
int[] actions = new int[]{0, 1, 2, 1};
double[] rewards = new double[]{0.1, 1.0, 10.0, 100.0};
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>();
for (int i = 0; i < originalObservations.length; ++i) {
experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false));
}
when(mockActorCritic.outputAll(any(INDArray.class))).thenAnswer(invocation -> {
INDArray batch = invocation.getArgument(0);
return new INDArray[]{batch.mul(-1.0)};
});
ArgumentCaptor<INDArray> inputArgumentCaptor = ArgumentCaptor.forClass(INDArray.class);
ArgumentCaptor<INDArray[]> criticActorArgumentCaptor = ArgumentCaptor.forClass(INDArray[].class);
// Act
algorithm.computeGradients(mockActorCritic, experience);
verify(mockActorCritic, times(1)).gradient(inputArgumentCaptor.capture(), criticActorArgumentCaptor.capture());
assertEquals(Nd4j.stack(0, originalObservations), inputArgumentCaptor.getValue());
//TODO: the actual AdvantageActorCritic Algo is not implemented correctly, so needs to be fixed, then we can test these
// assertEquals(Nd4j.zeros(1), criticActorArgumentCaptor.getValue()[0]);
// assertEquals(Nd4j.zeros(1), criticActorArgumentCaptor.getValue()[1]);
}
}

View File

@ -1,3 +1,19 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.listener; package org.deeplearning4j.rl4j.learning.async.listener;
import org.deeplearning4j.rl4j.learning.IEpochTrainer; import org.deeplearning4j.rl4j.learning.IEpochTrainer;

View File

@ -1,11 +1,30 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete; package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockDQN; import org.deeplearning4j.rl4j.support.MockDQN;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -14,17 +33,21 @@ import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@RunWith(MockitoJUnitRunner.class)
public class QLearningUpdateAlgorithmTest { public class QLearningUpdateAlgorithmTest {
@Mock
AsyncGlobal mockAsyncGlobal;
@Test @Test
public void when_isTerminal_expect_initRewardIs0() { public void when_isTerminal_expect_initRewardIs0() {
// Arrange // Arrange
MockDQN dqnMock = new MockDQN(); MockDQN dqnMock = new MockDQN();
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock); UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 1 }, 1, 1.0);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 1 }, 1, -1, 1.0); final Observation observation = new Observation(Nd4j.zeros(1));
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() { List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{ {
add(new StateActionPair<Integer>(new Observation(Nd4j.zeros(1)), 0, 0.0, true)); add(new StateActionPair<Integer>(observation, 0, 0.0, true));
} }
}; };
@ -38,12 +61,11 @@ public class QLearningUpdateAlgorithmTest {
@Test @Test
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() { public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
// Arrange // Arrange
MockDQN globalDQNMock = new MockDQN(); UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, 1.0);
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock); final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }));
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, -1, 1.0);
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() { List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{ {
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false)); add(new StateActionPair<Integer>(observation, 0, 0.0, false));
} }
}; };
MockDQN dqnMock = new MockDQN(); MockDQN dqnMock = new MockDQN();
@ -57,35 +79,11 @@ public class QLearningUpdateAlgorithmTest {
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
} }
@Test
public void when_terminalWithTargetUpdate_expect_initRewardWithMaxQFromGlobal() {
// Arrange
MockDQN globalDQNMock = new MockDQN();
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, 1.0);
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false));
}
};
MockDQN dqnMock = new MockDQN();
// Act
sut.computeGradients(dqnMock, experience);
// Assert
assertEquals(1, globalDQNMock.outputAllParams.size());
assertEquals(-123.0, globalDQNMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
}
@Test @Test
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() { public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
// Arrange // Arrange
double gamma = 0.9; double gamma = 0.9;
MockDQN globalDQNMock = new MockDQN(); UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, gamma);
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, gamma);
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() { List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{ {
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false)); add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false));

View File

@ -1,20 +1,56 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.listener; package org.deeplearning4j.rl4j.learning.listener;
import org.deeplearning4j.rl4j.support.MockTrainingListener; import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mock;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class TrainingListenerListTest { public class TrainingListenerListTest {
@Mock
IEpochTrainer mockTrainer;
@Mock
ILearning mockLearning;
@Mock
IDataManager.StatEntry mockStatEntry;
@Test @Test
public void when_listIsEmpty_expect_notifyReturnTrue() { public void when_listIsEmpty_expect_notifyReturnTrue() {
// Arrange // Arrange
TrainingListenerList sut = new TrainingListenerList(); TrainingListenerList trainingListenerList = new TrainingListenerList();
// Act // Act
boolean resultTrainingStarted = sut.notifyTrainingStarted(); boolean resultTrainingStarted = trainingListenerList.notifyTrainingStarted();
boolean resultNewEpoch = sut.notifyNewEpoch(null); boolean resultNewEpoch = trainingListenerList.notifyNewEpoch(null);
boolean resultEpochFinished = sut.notifyEpochTrainingResult(null, null); boolean resultEpochFinished = trainingListenerList.notifyEpochTrainingResult(null, null);
// Assert // Assert
assertTrue(resultTrainingStarted); assertTrue(resultTrainingStarted);
@ -25,54 +61,56 @@ public class TrainingListenerListTest {
@Test @Test
public void when_firstListerStops_expect_othersListnersNotCalled() { public void when_firstListerStops_expect_othersListnersNotCalled() {
// Arrange // Arrange
MockTrainingListener listener1 = new MockTrainingListener(); TrainingListener listener1 = mock(TrainingListener.class);
listener1.setRemainingTrainingStartCallCount(0); TrainingListener listener2 = mock(TrainingListener.class);
listener1.setRemainingOnNewEpochCallCount(0); TrainingListenerList trainingListenerList = new TrainingListenerList();
listener1.setRemainingonTrainingProgressCallCount(0); trainingListenerList.add(listener1);
listener1.setRemainingOnEpochTrainingResult(0); trainingListenerList.add(listener2);
MockTrainingListener listener2 = new MockTrainingListener();
TrainingListenerList sut = new TrainingListenerList(); when(listener1.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP);
sut.add(listener1); when(listener1.onNewEpoch(eq(mockTrainer))).thenReturn(TrainingListener.ListenerResponse.STOP);
sut.add(listener2); when(listener1.onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry))).thenReturn(TrainingListener.ListenerResponse.STOP);
when(listener1.onTrainingProgress(eq(mockLearning))).thenReturn(TrainingListener.ListenerResponse.STOP);
// Act // Act
sut.notifyTrainingStarted(); trainingListenerList.notifyTrainingStarted();
sut.notifyNewEpoch(null); trainingListenerList.notifyNewEpoch(mockTrainer);
sut.notifyEpochTrainingResult(null, null); trainingListenerList.notifyEpochTrainingResult(mockTrainer, null);
sut.notifyTrainingProgress(null); trainingListenerList.notifyTrainingProgress(mockLearning);
sut.notifyTrainingFinished(); trainingListenerList.notifyTrainingFinished();
// Assert // Assert
assertEquals(1, listener1.onTrainingStartCallCount);
assertEquals(0, listener2.onTrainingStartCallCount);
assertEquals(1, listener1.onNewEpochCallCount); verify(listener1, times(1)).onTrainingStart();
assertEquals(0, listener2.onNewEpochCallCount); verify(listener2, never()).onTrainingStart();
assertEquals(1, listener1.onEpochTrainingResultCallCount); verify(listener1, times(1)).onNewEpoch(eq(mockTrainer));
assertEquals(0, listener2.onEpochTrainingResultCallCount); verify(listener2, never()).onNewEpoch(eq(mockTrainer));
assertEquals(1, listener1.onTrainingProgressCallCount); verify(listener1, times(1)).onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry));
assertEquals(0, listener2.onTrainingProgressCallCount); verify(listener2, never()).onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry));
assertEquals(1, listener1.onTrainingEndCallCount); verify(listener1, times(1)).onTrainingProgress(eq(mockLearning));
assertEquals(1, listener2.onTrainingEndCallCount); verify(listener2, never()).onTrainingProgress(eq(mockLearning));
verify(listener1, times(1)).onTrainingEnd();
verify(listener2, times(1)).onTrainingEnd();
} }
@Test @Test
public void when_allListenersContinue_expect_listReturnsTrue() { public void when_allListenersContinue_expect_listReturnsTrue() {
// Arrange // Arrange
MockTrainingListener listener1 = new MockTrainingListener(); TrainingListener listener1 = mock(TrainingListener.class);
MockTrainingListener listener2 = new MockTrainingListener(); TrainingListener listener2 = mock(TrainingListener.class);
TrainingListenerList sut = new TrainingListenerList(); TrainingListenerList trainingListenerList = new TrainingListenerList();
sut.add(listener1); trainingListenerList.add(listener1);
sut.add(listener2); trainingListenerList.add(listener2);
// Act // Act
boolean resultTrainingStarted = sut.notifyTrainingStarted(); boolean resultTrainingStarted = trainingListenerList.notifyTrainingStarted();
boolean resultNewEpoch = sut.notifyNewEpoch(null); boolean resultNewEpoch = trainingListenerList.notifyNewEpoch(null);
boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null); boolean resultEpochTrainingResult = trainingListenerList.notifyEpochTrainingResult(null, null);
boolean resultProgress = sut.notifyTrainingProgress(null); boolean resultProgress = trainingListenerList.notifyTrainingProgress(null);
// Assert // Assert
assertTrue(resultTrainingStarted); assertTrue(resultTrainingStarted);

View File

@ -1,151 +1,117 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync; package org.deeplearning4j.rl4j.learning.sync;
import lombok.Getter; import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.LearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.support.MockTrainingListener; import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class SyncLearningTest { public class SyncLearningTest {
@Mock
TrainingListener mockTrainingListener;
SyncLearning<Box, INDArray, ActionSpace<INDArray>, NeuralNet> syncLearning;
@Mock
ILearningConfiguration mockLearningConfiguration;
@Before
public void setup() {
syncLearning = mock(SyncLearning.class, Mockito.withSettings()
.useConstructor()
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
syncLearning.addListener(mockTrainingListener);
when(syncLearning.trainEpoch()).thenAnswer(invocation -> {
//syncLearning.incrementEpoch();
syncLearning.incrementStep();
return new MockStatEntry(syncLearning.getEpochCount(), syncLearning.getStepCount(), 1.0);
});
when(syncLearning.getConfiguration()).thenReturn(mockLearningConfiguration);
when(mockLearningConfiguration.getMaxStep()).thenReturn(100);
}
@Test @Test
public void when_training_expect_listenersToBeCalled() { public void when_training_expect_listenersToBeCalled() {
// Arrange
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
// Act // Act
sut.train(); syncLearning.train();
verify(mockTrainingListener, times(1)).onTrainingStart();
verify(mockTrainingListener, times(100)).onNewEpoch(eq(syncLearning));
verify(mockTrainingListener, times(100)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
verify(mockTrainingListener, times(1)).onTrainingEnd();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(10, listener.onNewEpochCallCount);
assertEquals(10, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
} }
@Test @Test
public void when_trainingStartCanContinueFalse_expect_trainingStopped() { public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
// Arrange // Arrange
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); when(mockTrainingListener.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP);
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
listener.setRemainingTrainingStartCallCount(0);
// Act // Act
sut.train(); syncLearning.train();
assertEquals(1, listener.onTrainingStartCallCount); verify(mockTrainingListener, times(1)).onTrainingStart();
assertEquals(0, listener.onNewEpochCallCount); verify(mockTrainingListener, times(0)).onNewEpoch(eq(syncLearning));
assertEquals(0, listener.onEpochTrainingResultCallCount); verify(mockTrainingListener, times(0)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
assertEquals(1, listener.onTrainingEndCallCount); verify(mockTrainingListener, times(1)).onTrainingEnd();
} }
@Test @Test
public void when_newEpochCanContinueFalse_expect_trainingStopped() { public void when_newEpochCanContinueFalse_expect_trainingStopped() {
// Arrange // Arrange
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); when(mockTrainingListener.onNewEpoch(eq(syncLearning)))
MockTrainingListener listener = new MockTrainingListener(); .thenReturn(TrainingListener.ListenerResponse.CONTINUE)
MockSyncLearning sut = new MockSyncLearning(lconfig); .thenReturn(TrainingListener.ListenerResponse.CONTINUE)
sut.addListener(listener); .thenReturn(TrainingListener.ListenerResponse.STOP);
listener.setRemainingOnNewEpochCallCount(2);
// Act // Act
sut.train(); syncLearning.train();
verify(mockTrainingListener, times(1)).onTrainingStart();
verify(mockTrainingListener, times(3)).onNewEpoch(eq(syncLearning));
verify(mockTrainingListener, times(2)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
verify(mockTrainingListener, times(1)).onTrainingEnd();
assertEquals(1, listener.onTrainingStartCallCount);
assertEquals(3, listener.onNewEpochCallCount);
assertEquals(2, listener.onEpochTrainingResultCallCount);
assertEquals(1, listener.onTrainingEndCallCount);
} }
@Test @Test
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() { public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
// Arrange // Arrange
LearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); when(mockTrainingListener.onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class)))
MockTrainingListener listener = new MockTrainingListener(); .thenReturn(TrainingListener.ListenerResponse.CONTINUE)
MockSyncLearning sut = new MockSyncLearning(lconfig); .thenReturn(TrainingListener.ListenerResponse.CONTINUE)
sut.addListener(listener); .thenReturn(TrainingListener.ListenerResponse.STOP);
listener.setRemainingOnEpochTrainingResult(2);
// Act // Act
sut.train(); syncLearning.train();
assertEquals(1, listener.onTrainingStartCallCount); verify(mockTrainingListener, times(1)).onTrainingStart();
assertEquals(3, listener.onNewEpochCallCount); verify(mockTrainingListener, times(3)).onNewEpoch(eq(syncLearning));
assertEquals(3, listener.onEpochTrainingResultCallCount); verify(mockTrainingListener, times(3)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
assertEquals(1, listener.onTrainingEndCallCount); verify(mockTrainingListener, times(1)).onTrainingEnd();
}
public static class MockSyncLearning extends SyncLearning {
private final ILearningConfiguration conf;
@Getter
private int currentEpochStep = 0;
public MockSyncLearning(ILearningConfiguration conf) {
this.conf = conf;
}
@Override
protected void preEpoch() { currentEpochStep = 0; }
@Override
protected void postEpoch() { }
@Override
protected IDataManager.StatEntry trainEpoch() {
setStepCounter(getStepCounter() + 1);
return new MockStatEntry(getCurrentEpochStep(), getStepCounter(), 1.0);
}
@Override
public NeuralNet getNeuralNet() {
return null;
}
@Override
public IPolicy getPolicy() {
return null;
}
@Override
public ILearningConfiguration getConfiguration() {
return conf;
}
@Override
public MDP getMdp() {
return null;
}
} }
} }

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
@ -26,11 +27,21 @@ import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
@ -40,150 +51,146 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class QLearningDiscreteTest { public class QLearningDiscreteTest {
QLearningDiscrete<Encodable> qLearningDiscrete;
@Mock
IHistoryProcessor mockHistoryProcessor;
@Mock
IHistoryProcessor.Configuration mockHistoryConfiguration;
@Mock
MDP<Encodable, Integer, DiscreteSpace> mockMDP;
@Mock
DiscreteSpace mockActionSpace;
@Mock
ObservationSpace<Encodable> mockObservationSpace;
@Mock
IDQN mockDQN;
@Mock
QLearningConfiguration mockQlearningConfiguration;
int[] observationShape = new int[]{3, 10, 10};
int totalObservationSize = 1;
private void setupMDPMocks() {
when(mockObservationSpace.getShape()).thenReturn(observationShape);
when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace);
when(mockMDP.getActionSpace()).thenReturn(mockActionSpace);
int dataLength = 1;
for (int d : observationShape) {
dataLength *= d;
}
}
private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay) {
when(mockQlearningConfiguration.getBatchSize()).thenReturn(batchSize);
when(mockQlearningConfiguration.getRewardFactor()).thenReturn(rewardFactor);
when(mockQlearningConfiguration.getExpRepMaxSize()).thenReturn(maxExperienceReplay);
when(mockQlearningConfiguration.getSeed()).thenReturn(123L);
qLearningDiscrete = mock(
QLearningDiscrete.class,
Mockito.withSettings()
.useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0)
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
);
}
private void mockHistoryProcessor(int skipFrames) {
when(mockHistoryConfiguration.getRescaledHeight()).thenReturn(observationShape[1]);
when(mockHistoryConfiguration.getRescaledWidth()).thenReturn(observationShape[2]);
when(mockHistoryConfiguration.getOffsetX()).thenReturn(0);
when(mockHistoryConfiguration.getOffsetY()).thenReturn(0);
when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]);
when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]);
when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames);
when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration);
qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor);
}
@Before
public void setup() {
setupMDPMocks();
for (int i : observationShape) {
totalObservationSize *= i;
}
}
@Test @Test
public void refac_QLearningDiscrete_trainStep() { public void when_singleTrainStep_expect_correctValues() {
// Arrange // Arrange
MockObservationSpace observationSpace = new MockObservationSpace(); mockTestContext(100,0,2,1.0, 10);
MockDQN dqn = new MockDQN();
MockRandom random = new MockRandom(new double[]{
0.7309677600860596,
0.8314409852027893,
0.2405363917350769,
0.6063451766967773,
0.6374173760414124,
0.3090505599975586,
0.5504369735717773,
0.11700659990310669
},
new int[]{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4});
MockMDP mdp = new MockMDP(observationSpace, random);
int initStepCount = 8; // An example observation and 2 Q values output (2 actions)
Observation observation = new Observation(Nd4j.zeros(observationShape));
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
QLearningConfiguration conf = QLearningConfiguration.builder() when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
.seed(0L)
.maxEpochStep(24)
.maxStep(0)
.expRepMaxSize(5).batchSize(1).targetDqnUpdateFreq(1000)
.updateStart(initStepCount)
.rewardFactor(1.0)
.gamma(0)
.errorClamp(0)
.minEpsilon(0)
.epsilonNbStep(0)
.doubleDQN(true)
.build();
MockDataManager dataManager = new MockDataManager(false);
MockExperienceHandler experienceHandler = new MockExperienceHandler();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, experienceHandler, 10, random);
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hp);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
// Act // Act
IDataManager.StatEntry result = sut.trainEpoch(); QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
// Assert // Assert
// HistoryProcessor calls assertEquals(1.0, stepReturn.getMaxQ(), 1e-5);
double[] expectedRecords = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
assertEquals(expectedRecords.length, hp.recordCalls.size());
for (int i = 0; i < expectedRecords.length; ++i) {
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
}
assertEquals(0, hp.startMonitorCallCount);
assertEquals(0, hp.stopMonitorCallCount);
// DQN calls StepReply<Observation> stepReply = stepReturn.getStepReply();
assertEquals(1, dqn.fitParams.size());
assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001);
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
assertEquals(14, dqn.outputParams.size());
double[][] expectedDQNOutput = new double[][]{
new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
};
for (int i = 0; i < expectedDQNOutput.length; ++i) {
INDArray outputParam = dqn.outputParams.get(i);
assertEquals(5, outputParam.shape()[1]); assertEquals(0, stepReply.getReward(), 1e-5);
assertEquals(1, outputParam.shape()[2]); assertFalse(stepReply.isDone());
assertFalse(stepReply.getObservation().isSkipped());
assertEquals(observation.getData().reshape(observationShape), stepReply.getObservation().getData().reshape(observationShape));
double[] expectedRow = expectedDQNOutput[i];
for (int j = 0; j < expectedRow.length; ++j) {
assertEquals("row: " + i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
}
}
// MDP calls
assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray());
// ExperienceHandler calls
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
double[][] expectedTrObservations = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
};
assertEquals(expectedTrObservations.length, experienceHandler.addExperienceArgs.size());
for(int i = 0; i < expectedTrRewards.length; ++i) {
StateActionPair<Integer> stateActionPair = experienceHandler.addExperienceArgs.get(i);
assertEquals(expectedTrRewards[i], stateActionPair.getReward(), 0.0001);
assertEquals((int)expectedTrActions[i], (int)stateActionPair.getAction());
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * stateActionPair.getObservation().getData().getDouble(0, j, 0), 0.0001);
}
}
assertEquals(expectedTrNextObservation[expectedTrNextObservation.length - 1], 255.0 * experienceHandler.finalObservation.getData().getDouble(0), 0.0001);
// trainEpoch result
assertEquals(initStepCount + 16, result.getStepCounter());
assertEquals(300.0, result.getReward(), 0.00001);
assertTrue(dqn.hasBeenReset);
assertTrue(((MockDQN) sut.getTargetQNetwork()).hasBeenReset);
} }
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> { @Test
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn, public void when_singleTrainStepSkippedFrames_expect_correctValues() {
QLearningConfiguration conf, IDataManager dataManager, ExperienceHandler<Integer, Transition<Integer>> experienceHandler, // Arrange
int epsilonNbStep, Random rnd) { mockTestContext(100,0,2,1.0, 10);
super(mdp, dqn, conf, epsilonNbStep, rnd);
addListener(new DataManagerTrainingListener(dataManager));
setExperienceHandler(experienceHandler);
}
@Override mockHistoryProcessor(2);
protected DataSet setTarget(List<Transition<Integer>> transitions) {
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
}
@Override // An example observation and 2 Q values output (2 actions)
public IDataManager.StatEntry trainEpoch() { Observation observation = new Observation(Nd4j.zeros(observationShape));
return super.trainEpoch(); when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
}
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
// Act
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
// Assert
assertEquals(1.0, stepReturn.getMaxQ(), 1e-5);
StepReply<Observation> stepReply = stepReturn.getStepReply();
assertEquals(0, stepReply.getReward(), 1e-5);
assertFalse(stepReply.isDone());
assertTrue(stepReply.getObservation().isSkipped());
} }
//TODO: there are much more test cases here that can be improved upon
} }

View File

@ -1,79 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray;
public class MockMDP implements MDP<Object, Integer, DiscreteSpace> {
private final int maxSteps;
private final DiscreteSpace actionSpace = new DiscreteSpace(1);
private final MockObservationSpace observationSpace = new MockObservationSpace();
private int currentStep = 0;
public MockMDP(int maxSteps) {
this.maxSteps = maxSteps;
}
@Override
public ObservationSpace<Object> getObservationSpace() {
return observationSpace;
}
@Override
public DiscreteSpace getActionSpace() {
return actionSpace;
}
@Override
public Object reset() {
return null;
}
@Override
public void close() {
}
@Override
public StepReply<Object> step(Integer integer) {
return new StepReply<Object>(null, 1.0, isDone(), null);
}
@Override
public boolean isDone() {
return currentStep >= maxSteps;
}
@Override
public MDP<Object, Integer, DiscreteSpace> newInstance() {
return null;
}
private static class MockObservationSpace implements ObservationSpace {
@Override
public String getName() {
return null;
}
@Override
public int[] getShape() {
return new int[0];
}
@Override
public INDArray getLow() {
return null;
}
@Override
public INDArray getHigh() {
return null;
}
}
}

View File

@ -257,9 +257,9 @@ public class PolicyTest {
} }
@Override @Override
protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength)); mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
return super.refacInitMdp(mdpWrapper, hp, epochStepCounter); return super.refacInitMdp(mdpWrapper, hp);
} }
} }
} }

View File

@ -1,37 +0,0 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.support;
import lombok.AllArgsConstructor;
import lombok.Value;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
@Value
@AllArgsConstructor
public class MockAsyncConfiguration implements IAsyncLearningConfiguration {
private Long seed;
private int maxEpochStep;
private int maxStep;
private int updateStart;
private double rewardFactor;
private double gamma;
private double errorClamp;
private int numThreads;
private int nStep;
private int learnerUpdateFrequency;
}

View File

@ -1,75 +0,0 @@
package org.deeplearning4j.rl4j.support;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.network.NeuralNet;
import java.util.concurrent.atomic.AtomicInteger;
public class MockAsyncGlobal<NN extends NeuralNet> implements IAsyncGlobal<NN> {
@Getter
private final NN current;
public boolean hasBeenStarted = false;
public boolean hasBeenTerminated = false;
public int enqueueCallCount = 0;
@Setter
private int maxLoops;
@Setter
private int numLoopsStopRunning;
private int currentLoop = 0;
public MockAsyncGlobal() {
this(null);
}
public MockAsyncGlobal(NN current) {
maxLoops = Integer.MAX_VALUE;
numLoopsStopRunning = Integer.MAX_VALUE;
this.current = current;
}
@Override
public boolean isRunning() {
return currentLoop < numLoopsStopRunning;
}
@Override
public void terminate() {
hasBeenTerminated = true;
}
@Override
public boolean isTrainingComplete() {
return currentLoop >= maxLoops;
}
@Override
public void start() {
hasBeenStarted = true;
}
@Override
public AtomicInteger getT() {
return null;
}
@Override
public NN getTarget() {
return current;
}
@Override
public void enqueue(Gradient[] gradient, Integer nstep) {
++enqueueCallCount;
}
public void increaseCurrentLoop() {
++currentLoop;
}
}

View File

@ -1,46 +0,0 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.observation.Observation;
import java.util.ArrayList;
import java.util.List;
public class MockExperienceHandler implements ExperienceHandler<Integer, Transition<Integer>> {
public List<StateActionPair<Integer>> addExperienceArgs = new ArrayList<StateActionPair<Integer>>();
public Observation finalObservation;
public boolean isGenerateTrainingBatchCalled;
public boolean isResetCalled;
@Override
public void addExperience(Observation observation, Integer action, double reward, boolean isTerminal) {
addExperienceArgs.add(new StateActionPair<>(observation, action, reward, isTerminal));
}
@Override
public void setFinalObservation(Observation observation) {
finalObservation = observation;
}
@Override
public List<Transition<Integer>> generateTrainingBatch() {
isGenerateTrainingBatchCalled = true;
return new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(null, 0, 0.0, false));
}
};
}
@Override
public void reset() {
isResetCalled = true;
}
@Override
public int getTrainingBatchSize() {
return 1;
}
}

View File

@ -1,77 +0,0 @@
package org.deeplearning4j.rl4j.support;
import lombok.Setter;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.util.IDataManager;
import java.util.ArrayList;
import java.util.List;
public class MockTrainingListener implements TrainingListener {
private final MockAsyncGlobal asyncGlobal;
public int onTrainingStartCallCount = 0;
public int onTrainingEndCallCount = 0;
public int onNewEpochCallCount = 0;
public int onEpochTrainingResultCallCount = 0;
public int onTrainingProgressCallCount = 0;
@Setter
private int remainingTrainingStartCallCount = Integer.MAX_VALUE;
@Setter
private int remainingOnNewEpochCallCount = Integer.MAX_VALUE;
@Setter
private int remainingOnEpochTrainingResult = Integer.MAX_VALUE;
@Setter
private int remainingonTrainingProgressCallCount = Integer.MAX_VALUE;
public final List<IDataManager.StatEntry> statEntries = new ArrayList<>();
public MockTrainingListener() {
this(null);
}
public MockTrainingListener(MockAsyncGlobal asyncGlobal) {
this.asyncGlobal = asyncGlobal;
}
@Override
public ListenerResponse onTrainingStart() {
++onTrainingStartCallCount;
--remainingTrainingStartCallCount;
return remainingTrainingStartCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
++onNewEpochCallCount;
--remainingOnNewEpochCallCount;
return remainingOnNewEpochCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
++onEpochTrainingResultCallCount;
--remainingOnEpochTrainingResult;
statEntries.add(statEntry);
return remainingOnEpochTrainingResult < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse onTrainingProgress(ILearning learning) {
++onTrainingProgressCallCount;
--remainingonTrainingProgressCallCount;
if(asyncGlobal != null) {
asyncGlobal.increaseCurrentLoop();
}
return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}
@Override
public void onTrainingEnd() {
++onTrainingEndCallCount;
}
}

View File

@ -1,19 +0,0 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import java.util.ArrayList;
import java.util.List;
public class MockUpdateAlgorithm implements UpdateAlgorithm<MockNeuralNet> {
public final List<List<StateActionPair<Integer>>> experiences = new ArrayList<List<StateActionPair<Integer>>>();
@Override
public Gradient[] computeGradients(MockNeuralNet current, List<StateActionPair<Integer>> experience) {
experiences.add(experience);
return new Gradient[0];
}
}

View File

@ -151,20 +151,26 @@ public class DataManagerTrainingListenerTest {
private static class TestTrainer implements IEpochTrainer, ILearning private static class TestTrainer implements IEpochTrainer, ILearning
{ {
@Override @Override
public int getStepCounter() { public int getStepCount() {
return 0; return 0;
} }
@Override @Override
public int getEpochCounter() { public int getEpochCount() {
return 0; return 0;
} }
@Override @Override
public int getCurrentEpochStep() { public int getEpisodeCount() {
return 0; return 0;
} }
@Override
public int getCurrentEpisodeStepCount() {
return 0;
}
@Getter @Getter
@Setter @Setter
private IHistoryProcessor historyProcessor; private IHistoryProcessor historyProcessor;