diff --git a/.gitignore b/.gitignore index ad2e28e6f..fd33cb142 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,6 @@ venv2/ # Ignore the nd4j files that are created by javacpp at build to stop merge conflicts nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java + +# Ignore meld temp files +*.orig diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java index 330c06887..37b097dbf 100644 --- a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java @@ -19,6 +19,7 @@ package org.deeplearning4j.rl4j.mdp; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.ObservationSpace; /** @@ -31,20 +32,20 @@ import org.deeplearning4j.rl4j.space.ObservationSpace; * in a "functionnal manner" if step return a mdp * */ -public interface MDP> { +public interface MDP> { - ObservationSpace getObservationSpace(); + ObservationSpace getObservationSpace(); - AS getActionSpace(); + ACTION_SPACE getActionSpace(); - O reset(); + OBSERVATION reset(); void close(); - StepReply step(A action); + StepReply step(ACTION action); boolean isDone(); - MDP newInstance(); + MDP newInstance(); } diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java index 0b149d394..d20b3e159 100644 --- a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java @@ -17,24 +17,24 @@ package org.deeplearning4j.rl4j.space; /** - * @param the type of Action + * @param the type of Action * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. *

* Should contain contextual information about the Action space, which is the space of all the actions that could be available. * Also must know how to return a randomly uniformly sampled action. */ -public interface ActionSpace { +public interface ActionSpace { /** * @return A random action, */ - A randomAction(); + ACTION randomAction(); - Object encode(A action); + Object encode(ACTION action); int getSize(); - A noOp(); + ACTION noOp(); } diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml index a93ea6345..bbb66a9e9 100644 --- a/rl4j/rl4j-core/pom.xml +++ b/rl4j/rl4j-core/pom.xml @@ -121,6 +121,13 @@ ${datavec.version} + + org.mockito + mockito-core + 3.3.3 + test + + diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java index 1ec4f05c1..0017925df 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java @@ -1,54 +1,54 @@ -/******************************************************************************* - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ -package org.deeplearning4j.rl4j.experience; - -import org.deeplearning4j.rl4j.observation.Observation; - -import java.util.List; - -/** - * A common interface to all classes capable of handling experience generated by the agents in a learning context. - * - * @param Action type - * @param Experience type - * - * @author Alexandre Boulanger - */ -public interface ExperienceHandler { - void addExperience(Observation observation, A action, double reward, boolean isTerminal); - - /** - * Called when the episode is done with the last observation - * @param observation - */ - void setFinalObservation(Observation observation); - - /** - * @return The size of the list that will be returned by generateTrainingBatch(). - */ - int getTrainingBatchSize(); - - /** - * The elements are returned in the historical order (i.e. in the order they happened) - * @return The list of experience elements - */ - List generateTrainingBatch(); - - /** - * Signal the experience handler that a new episode is starting - */ - void reset(); -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.experience; + +import org.deeplearning4j.rl4j.observation.Observation; + +import java.util.List; + +/** + * A common interface to all classes capable of handling experience generated by the agents in a learning context. + * + * @param Action type + * @param Experience type + * + * @author Alexandre Boulanger + */ +public interface ExperienceHandler { + void addExperience(Observation observation, A action, double reward, boolean isTerminal); + + /** + * Called when the episode is done with the last observation + * @param observation + */ + void setFinalObservation(Observation observation); + + /** + * @return The size of the list that will be returned by generateTrainingBatch(). + */ + int getTrainingBatchSize(); + + /** + * The elements are returned in the historical order (i.e. in the order they happened) + * @return The list of experience elements + */ + List generateTrainingBatch(); + + /** + * Signal the experience handler that a new episode is starting + */ + void reset(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java index 39338c6c0..4c6b95c89 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java @@ -30,7 +30,7 @@ import java.util.List; */ public class StateActionExperienceHandler implements ExperienceHandler> { - private List> stateActionPairs; + private List> stateActionPairs = new ArrayList<>(); public void setFinalObservation(Observation observation) { // Do nothing diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java deleted file mode 100644 index 533209ed7..000000000 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java +++ /dev/null @@ -1,21 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning; - -public interface EpochStepCounter { - int getCurrentEpochStep(); -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java index f113ce157..082357b9a 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,9 +29,11 @@ import org.deeplearning4j.rl4j.mdp.MDP; * @author Alexandre Boulanger * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. */ -public interface IEpochTrainer extends EpochStepCounter { - int getStepCounter(); - int getEpochCounter(); +public interface IEpochTrainer { + int getStepCount(); + int getEpochCount(); + int getEpisodeCount(); + int getCurrentEpisodeStepCount(); IHistoryProcessor getHistoryProcessor(); MDP getMdp(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java index 2041953ff..a8a09bc0b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java @@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.learning; import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.Data; import lombok.Value; import org.nd4j.linalg.api.ndarray.INDArray; @@ -51,7 +52,7 @@ public interface IHistoryProcessor { @AllArgsConstructor @Builder - @Value + @Data public static class Configuration { @Builder.Default int historyLength = 4; @Builder.Default int rescaledWidth = 84; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java index 43ed508b0..0d1f0ae20 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java @@ -21,19 +21,20 @@ import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Encodable; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16. * * A common interface that any training method should implement */ -public interface ILearning> { +public interface ILearning> { IPolicy getPolicy(); void train(); - int getStepCounter(); + int getStepCount(); ILearningConfiguration getConfiguration(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java index 833094929..ca9451ea2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java @@ -38,13 +38,13 @@ import org.nd4j.linalg.factory.Nd4j; * */ @Slf4j -public abstract class Learning, NN extends NeuralNet> +public abstract class Learning, NN extends NeuralNet> implements ILearning, NeuralNetFetchable { @Getter @Setter - private int stepCounter = 0; + protected int stepCount = 0; @Getter @Setter - private int epochCounter = 0; + private int epochCount = 0; @Getter @Setter private IHistoryProcessor historyProcessor = null; @@ -73,11 +73,11 @@ public abstract class Learning, NN extends Neura public abstract NN getNeuralNet(); public void incrementStep() { - stepCounter++; + stepCount++; } public void incrementEpoch() { - epochCounter++; + epochCount++; } public void setHistoryProcessor(HistoryProcessor.Configuration conf) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java index 01c519b57..75388dd9b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java @@ -20,13 +20,11 @@ package org.deeplearning4j.rl4j.learning.async; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.network.NeuralNet; -import org.nd4j.linalg.primitives.Pair; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. @@ -52,69 +50,75 @@ import java.util.concurrent.atomic.AtomicInteger; * structure */ @Slf4j -public class AsyncGlobal extends Thread implements IAsyncGlobal { +public class AsyncGlobal implements IAsyncGlobal { - @Getter final private NN current; - final private ConcurrentLinkedQueue> queue; - final private IAsyncLearningConfiguration configuration; - private final IAsyncLearning learning; - @Getter - private AtomicInteger T = new AtomicInteger(0); - @Getter - private NN target; - @Getter - private boolean running = true; - public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration, IAsyncLearning learning) { + private NN target; + + final private IAsyncLearningConfiguration configuration; + + @Getter + private final Lock updateLock; + + /** + * The number of times the gradient has been updated by worker threads + */ + @Getter + private int workerUpdateCount; + + @Getter + private int stepCount; + + public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration) { this.current = initial; target = (NN) initial.clone(); this.configuration = configuration; - this.learning = learning; - queue = new ConcurrentLinkedQueue<>(); + + // This is used to sync between + updateLock = new ReentrantLock(); } public boolean isTrainingComplete() { - return T.get() >= configuration.getMaxStep(); + return stepCount >= configuration.getMaxStep(); } - public void enqueue(Gradient[] gradient, Integer nstep) { - if (running && !isTrainingComplete()) { - queue.add(new Pair<>(gradient, nstep)); + public void applyGradient(Gradient[] gradient, int nstep) { + + if (isTrainingComplete()) { + return; } + + try { + updateLock.lock(); + + current.applyGradient(gradient, nstep); + + stepCount += nstep; + workerUpdateCount++; + + int targetUpdateFrequency = configuration.getLearnerUpdateFrequency(); + + // If we have a target update frequency, this means we only want to update the workers after a certain number of async updates + // This can lead to more stable training + if (targetUpdateFrequency != -1 && workerUpdateCount % targetUpdateFrequency == 0) { + log.info("Updating target network at updates={} steps={}", workerUpdateCount, stepCount); + } else { + target.copy(current); + } + } finally { + updateLock.unlock(); + } + } @Override - public void run() { - - while (!isTrainingComplete() && running) { - if (!queue.isEmpty()) { - Pair pair = queue.poll(); - T.addAndGet(pair.getSecond()); - Gradient[] gradient = pair.getFirst(); - synchronized (this) { - current.applyGradient(gradient, pair.getSecond()); - } - if (configuration.getLearnerUpdateFrequency() != -1 && T.get() / configuration.getLearnerUpdateFrequency() > (T.get() - pair.getSecond()) - / configuration.getLearnerUpdateFrequency()) { - log.info("TARGET UPDATE at T = " + T.get()); - synchronized (this) { - target.copy(current); - } - } - } - } - - } - - /** - * Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too. - */ - public void terminate() { - if (running) { - running = false; - queue.clear(); - learning.terminate(); + public NN getTarget() { + try { + updateLock.lock(); + return target; + } finally { + updateLock.unlock(); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java index 1c3c83972..ab6284396 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java @@ -40,9 +40,9 @@ import org.nd4j.linalg.factory.Nd4j; * @author Alexandre Boulanger */ @Slf4j -public abstract class AsyncLearning, NN extends NeuralNet> - extends Learning - implements IAsyncLearning { +public abstract class AsyncLearning, NN extends NeuralNet> + extends Learning + implements IAsyncLearning { private Thread monitorThread = null; @@ -69,10 +69,6 @@ public abstract class AsyncLearning getAsyncGlobal(); - protected void startGlobalThread() { - getAsyncGlobal().start(); - } - protected boolean isTrainingComplete() { return getAsyncGlobal().isTrainingComplete(); } @@ -87,7 +83,6 @@ public abstract class AsyncLearning, NN extends NeuralNet> +public abstract class AsyncThread, NN extends NeuralNet> extends Thread implements IEpochTrainer { @Getter private int threadNumber; + @Getter protected final int deviceNum; + + /** + * The number of steps that this async thread has produced + */ @Getter @Setter - private int stepCounter = 0; + protected int stepCount = 0; + + /** + * The number of epochs (updates) that this thread has sent to the global learner + */ @Getter @Setter - private int epochCounter = 0; + protected int epochCount = 0; + + /** + * The number of environment episodes that have been played out + */ + @Getter @Setter + protected int episodeCount = 0; + + /** + * The number of steps in the current episode + */ + @Getter + protected int currentEpisodeStepCount = 0; + + /** + * If the current episode needs to be reset + */ + boolean episodeComplete = true; + @Getter @Setter private IHistoryProcessor historyProcessor; - @Getter - private int currentEpochStep = 0; - - private boolean isEpochStarted = false; - private final LegacyMDPWrapper mdp; + private boolean isEpisodeStarted = false; + private final LegacyMDPWrapper mdp; private final TrainingListenerList listeners; - public AsyncThread(IAsyncGlobal asyncGlobal, MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) { - this.mdp = new LegacyMDPWrapper(mdp, null, this); + public AsyncThread(MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) { + this.mdp = new LegacyMDPWrapper(mdp, null); this.listeners = listeners; this.threadNumber = threadNumber; this.deviceNum = deviceNum; } - public MDP getMdp() { + public MDP getMdp() { return mdp.getWrappedMDP(); } - protected LegacyMDPWrapper getLegacyMDPWrapper() { + protected LegacyMDPWrapper getLegacyMDPWrapper() { return mdp; } @@ -92,13 +117,13 @@ public abstract class AsyncThread, NN extends Ne mdp.setHistoryProcessor(historyProcessor); } - protected void postEpoch() { + protected void postEpisode() { if (getHistoryProcessor() != null) getHistoryProcessor().stopMonitor(); } - protected void preEpoch() { + protected void preEpisode() { // Do nothing } @@ -125,74 +150,69 @@ public abstract class AsyncThread, NN extends Ne */ @Override public void run() { - try { - RunContext context = new RunContext(); - Nd4j.getAffinityManager().unsafeSetDevice(deviceNum); + RunContext context = new RunContext(); + Nd4j.getAffinityManager().unsafeSetDevice(deviceNum); - log.info("ThreadNum-" + threadNumber + " Started!"); + log.info("ThreadNum-" + threadNumber + " Started!"); - while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) { - if (!isEpochStarted) { - boolean canContinue = startNewEpoch(context); - if (!canContinue) { - break; - } - } + while (!getAsyncGlobal().isTrainingComplete()) { - handleTraining(context); - - if (currentEpochStep >= getConf().getMaxEpochStep() || getMdp().isDone()) { - boolean canContinue = finishEpoch(context); - if (!canContinue) { - break; - } - - ++epochCounter; - } + if (episodeComplete) { + startEpisode(context); + } + + if(!startEpoch(context)) { + break; + } + + episodeComplete = handleTraining(context); + + if(!finishEpoch(context)) { + break; + } + + if(episodeComplete) { + finishEpisode(context); } - } - finally { - terminateWork(); } } - private void handleTraining(RunContext context) { - int maxSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpochStep); - SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); + private boolean finishEpoch(RunContext context) { + epochCount++; + IDataManager.StatEntry statEntry = new AsyncStatEntry(stepCount, epochCount, context.rewards, currentEpisodeStepCount, context.score); + return listeners.notifyEpochTrainingResult(this, statEntry); + } + + private boolean startEpoch(RunContext context) { + return listeners.notifyNewEpoch(this); + } + + private boolean handleTraining(RunContext context) { + int maxTrainSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpisodeStepCount); + SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxTrainSteps); context.obs = subEpochReturn.getLastObs(); context.rewards += subEpochReturn.getReward(); context.score = subEpochReturn.getScore(); + + return subEpochReturn.isEpisodeComplete(); } - private boolean startNewEpoch(RunContext context) { + private void startEpisode(RunContext context) { getCurrent().reset(); Learning.InitMdp initMdp = refacInitMdp(); context.obs = initMdp.getLastObs(); context.rewards = initMdp.getReward(); - isEpochStarted = true; - preEpoch(); - - return listeners.notifyNewEpoch(this); + preEpisode(); + episodeCount++; } - private boolean finishEpoch(RunContext context) { - isEpochStarted = false; - postEpoch(); - IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, currentEpochStep, context.score); + private void finishEpisode(RunContext context) { + postEpisode(); - log.info("ThreadNum-" + threadNumber + " Epoch: " + getCurrentEpochStep() + ", reward: " + context.rewards); - - return listeners.notifyEpochTrainingResult(this, statEntry); - } - - private void terminateWork() { - getAsyncGlobal().terminate(); - if(isEpochStarted) { - postEpoch(); - } + log.info("ThreadNum-{} Episode step: {}, Episode: {}, Epoch: {}, reward: {}", threadNumber, currentEpisodeStepCount, episodeCount, epochCount, context.rewards); } protected abstract NN getCurrent(); @@ -201,35 +221,35 @@ public abstract class AsyncThread, NN extends Ne protected abstract IAsyncLearningConfiguration getConf(); - protected abstract IPolicy getPolicy(NN net); + protected abstract IPolicy getPolicy(NN net); protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep); private Learning.InitMdp refacInitMdp() { - currentEpochStep = 0; + currentEpisodeStepCount = 0; double reward = 0; - LegacyMDPWrapper mdp = getLegacyMDPWrapper(); + LegacyMDPWrapper mdp = getLegacyMDPWrapper(); Observation observation = mdp.reset(); - A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP + ACTION action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP while (observation.isSkipped() && !mdp.isDone()) { StepReply stepReply = mdp.step(action); reward += stepReply.getReward(); observation = stepReply.getObservation(); - incrementStep(); + incrementSteps(); } return new Learning.InitMdp(0, observation, reward); } - public void incrementStep() { - ++stepCounter; - ++currentEpochStep; + public void incrementSteps() { + stepCount++; + currentEpisodeStepCount++; } @AllArgsConstructor @@ -239,6 +259,7 @@ public abstract class AsyncThread, NN extends Ne Observation lastObs; double reward; double score; + boolean episodeComplete; } @AllArgsConstructor diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index ac9853045..fcce92a4a 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -24,6 +24,9 @@ import lombok.Setter; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; +import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; @@ -31,15 +34,19 @@ import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Stack; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. - * + *

* Async Learning specialized for the Discrete Domain - * */ -public abstract class AsyncThreadDiscrete - extends AsyncThread { +public abstract class AsyncThreadDiscrete + extends AsyncThread { @Getter private NN current; @@ -48,7 +55,7 @@ public abstract class AsyncThreadDiscrete private UpdateAlgorithm updateAlgorithm; // TODO: Make it configurable with a builder - @Setter(AccessLevel.PROTECTED) + @Setter(AccessLevel.PROTECTED) @Getter private ExperienceHandler experienceHandler = new StateActionExperienceHandler(); public AsyncThreadDiscrete(IAsyncGlobal asyncGlobal, @@ -56,9 +63,9 @@ public abstract class AsyncThreadDiscrete TrainingListenerList listeners, int threadNumber, int deviceNum) { - super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); + super(mdp, listeners, threadNumber, deviceNum); synchronized (asyncGlobal) { - current = (NN)asyncGlobal.getCurrent().clone(); + current = (NN) asyncGlobal.getTarget().clone(); } } @@ -72,7 +79,7 @@ public abstract class AsyncThreadDiscrete } @Override - protected void preEpoch() { + protected void preEpisode() { experienceHandler.reset(); } @@ -81,28 +88,23 @@ public abstract class AsyncThreadDiscrete * "Subepoch" correspond to the t_max-step iterations * that stack rewards with t_max MiniTrans * - * @param sObs the obs to start from - * @param nstep the number of max nstep (step until t_max or state is terminal) + * @param sObs the obs to start from + * @param trainingSteps the number of training steps * @return subepoch training informations */ - public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) { + public SubEpochReturn trainSubEpoch(Observation sObs, int trainingSteps) { - synchronized (getAsyncGlobal()) { - current.copy(getAsyncGlobal().getCurrent()); - } + current.copy(getAsyncGlobal().getTarget()); Observation obs = sObs; IPolicy policy = getPolicy(current); Integer action = getMdp().getActionSpace().noOp(); - IHistoryProcessor hp = getHistoryProcessor(); - int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1; double reward = 0; double accuReward = 0; - int stepAtStart = getCurrentEpochStep(); - int lastStep = nstep * skipFrame + stepAtStart; - while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) { + + while (!getMdp().isDone() && experienceHandler.getTrainingBatchSize() != trainingSteps) { //if step of training, just repeat lastAction if (!obs.isSkipped()) { @@ -115,20 +117,26 @@ public abstract class AsyncThreadDiscrete if (!obs.isSkipped()) { experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone()); accuReward = 0; + + incrementSteps(); } obs = stepReply.getObservation(); reward += stepReply.getReward(); - incrementStep(); } - if (getMdp().isDone() && getCurrentEpochStep() < lastStep) { + boolean episodeComplete = getMdp().isDone() || getConf().getMaxEpochStep() == currentEpisodeStepCount; + + if (episodeComplete && experienceHandler.getTrainingBatchSize() != trainingSteps) { experienceHandler.setFinalObservation(obs); } - getAsyncGlobal().enqueue(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), getCurrentEpochStep()); + int experienceSize = experienceHandler.getTrainingBatchSize(); - return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore()); + getAsyncGlobal().applyGradient(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), experienceSize); + + return new SubEpochReturn(experienceSize, obs, reward, current.getLatestScore(), episodeComplete); } + } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java index df3d476f9..b9725499a 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -22,17 +23,29 @@ import org.deeplearning4j.rl4j.network.NeuralNet; import java.util.concurrent.atomic.AtomicInteger; public interface IAsyncGlobal { - boolean isRunning(); + boolean isTrainingComplete(); - void start(); /** - * Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded. + * The number of updates that have been applied by worker threads. */ - void terminate(); + int getWorkerUpdateCount(); - AtomicInteger getT(); - NN getCurrent(); + /** + * The total number of environment steps that have been processed. + */ + int getStepCount(); + + /** + * A copy of the global network that is updated after a certain number of worker episodes. + */ NN getTarget(); - void enqueue(Gradient[] gradient, Integer nstep); + + /** + * Apply gradients to the global network + * @param gradient + * @param batchSize + */ + void applyGradient(Gradient[] gradient, int batchSize); + } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java index 16ca1c3f8..c5bb7c84c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java @@ -1,26 +1,26 @@ -/******************************************************************************* - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ -package org.deeplearning4j.rl4j.learning.async; - -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.experience.StateActionPair; -import org.deeplearning4j.rl4j.network.NeuralNet; - -import java.util.List; - -public interface UpdateAlgorithm { - Gradient[] computeGradients(NN current, List> experience); -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.learning.async; + +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.network.NeuralNet; + +import java.util.List; + +public interface UpdateAlgorithm { + Gradient[] computeGradients(NN current, List> experience); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java index 0608ec5cc..8c5b07903 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java @@ -57,7 +57,7 @@ public abstract class A3CDiscrete extends AsyncLearning(iActorCritic, conf, this); + asyncGlobal = new AsyncGlobal<>(iActorCritic, conf); Long seed = conf.getSeed(); Random rnd = Nd4j.getRandom(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index d189edca1..adf68489e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -27,7 +27,6 @@ import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.api.rng.Random; @@ -73,6 +72,6 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< @Override protected UpdateAlgorithm buildUpdateAlgorithm() { int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); - return new A3CUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getLearnerUpdateFrequency(), conf.getGamma()); + return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), conf.getGamma()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithm.java similarity index 68% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithm.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithm.java index 261cc788f..658d2bf61 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithm.java @@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.nd4j.linalg.api.ndarray.INDArray; @@ -27,28 +26,25 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.List; -public class A3CUpdateAlgorithm implements UpdateAlgorithm { +/** + * The Advantage Actor-Critic update algorithm can be used by A2C and A3C algorithms alike + */ +public class AdvantageActorCriticUpdateAlgorithm implements UpdateAlgorithm { - private final IAsyncGlobal asyncGlobal; private final int[] shape; private final int actionSpaceSize; - private final int targetDqnUpdateFreq; private final double gamma; private final boolean recurrent; - public A3CUpdateAlgorithm(IAsyncGlobal asyncGlobal, - int[] shape, - int actionSpaceSize, - int targetDqnUpdateFreq, - double gamma) { - - this.asyncGlobal = asyncGlobal; + public AdvantageActorCriticUpdateAlgorithm(boolean recurrent, + int[] shape, + int actionSpaceSize, + double gamma) { //if recurrent then train as a time serie with a batch size of 1 - recurrent = asyncGlobal.getCurrent().isRecurrent(); + this.recurrent = recurrent; this.shape = shape; this.actionSpaceSize = actionSpaceSize; - this.targetDqnUpdateFreq = targetDqnUpdateFreq; this.gamma = gamma; } @@ -65,18 +61,12 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm { : Nd4j.zeros(size, actionSpaceSize); StateActionPair stateActionPair = experience.get(size - 1); - double r; - if(stateActionPair.isTerminal()) { - r = 0; - } - else { - INDArray[] output = null; - if (targetDqnUpdateFreq == -1) - output = current.outputAll(stateActionPair.getObservation().getData()); - else synchronized (asyncGlobal) { - output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData()); - } - r = output[0].getDouble(0); + double value; + if (stateActionPair.isTerminal()) { + value = 0; + } else { + INDArray[] output = current.outputAll(stateActionPair.getObservation().getData()); + value = output[0].getDouble(0); } for (int i = size - 1; i >= 0; --i) { @@ -86,7 +76,7 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm { INDArray[] output = current.outputAll(observationData); - r = stateActionPair.getReward() + gamma * r; + value = stateActionPair.getReward() + gamma * value; if (recurrent) { input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(observationData); } else { @@ -94,11 +84,11 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm { } //the critic - targets.putScalar(i, r); + targets.putScalar(i, value); //the actor double expectedV = output[0].getDouble(0); - double advantage = r - expectedV; + double advantage = value - expectedV; if (recurrent) { logSoftmax.putScalar(0, stateActionPair.getAction(), i, advantage); } else { @@ -108,6 +98,6 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm { // targets -> value, critic // logSoftmax -> policy, actor - return current.gradient(input, new INDArray[] {targets, logSoftmax}); + return current.gradient(input, new INDArray[]{targets, logSoftmax}); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index 9a8049f6f..a4c0b643b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -50,7 +50,7 @@ public abstract class AsyncNStepQLearningDiscrete public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncQLearningConfiguration conf) { this.mdp = mdp; this.configuration = conf; - this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this); + this.asyncGlobal = new AsyncGlobal<>(dqn, conf); } @Override @@ -59,7 +59,7 @@ public abstract class AsyncNStepQLearningDiscrete } public IDQN getNeuralNet() { - return asyncGlobal.getCurrent(); + return asyncGlobal.getTarget(); } public IPolicy getPolicy() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index bd4dc16e8..34a2c07a4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -30,8 +30,8 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. @@ -57,7 +57,7 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn rnd = Nd4j.getRandom(); Long seed = conf.getSeed(); - if(seed != null) { + if (seed != null) { rnd.setSeed(seed + threadNumber); } @@ -72,6 +72,6 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn @Override protected UpdateAlgorithm buildUpdateAlgorithm() { int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); - return new QLearningUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getTargetDqnUpdateFreq(), conf.getGamma()); + return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), conf.getGamma()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java index beae271b1..79c9666a2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java @@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.nd4j.linalg.api.ndarray.INDArray; @@ -28,22 +27,16 @@ import java.util.List; public class QLearningUpdateAlgorithm implements UpdateAlgorithm { - private final IAsyncGlobal asyncGlobal; private final int[] shape; private final int actionSpaceSize; - private final int targetDqnUpdateFreq; private final double gamma; - public QLearningUpdateAlgorithm(IAsyncGlobal asyncGlobal, - int[] shape, + public QLearningUpdateAlgorithm(int[] shape, int actionSpaceSize, - int targetDqnUpdateFreq, double gamma) { - this.asyncGlobal = asyncGlobal; this.shape = shape; this.actionSpaceSize = actionSpaceSize; - this.targetDqnUpdateFreq = targetDqnUpdateFreq; this.gamma = gamma; } @@ -58,16 +51,11 @@ public class QLearningUpdateAlgorithm implements UpdateAlgorithm { StateActionPair stateActionPair = experience.get(size - 1); double r; - if(stateActionPair.isTerminal()) { + if (stateActionPair.isTerminal()) { r = 0; - } - else { + } else { INDArray[] output = null; - if (targetDqnUpdateFreq == -1) - output = current.outputAll(stateActionPair.getObservation().getData()); - else synchronized (asyncGlobal) { - output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData()); - } + output = current.outputAll(stateActionPair.getObservation().getData()); r = Nd4j.max(output[0]).getDouble(0); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java index 1e7cf3f2e..1639597ae 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java @@ -20,8 +20,14 @@ public interface IAsyncLearningConfiguration extends ILearningConfiguration { int getNumThreads(); + /** + * The number of steps to collect for each worker thread between each global update + */ int getNStep(); + /** + * The frequency of worker thread gradient updates to perform a copy of the current working network to the target network + */ int getLearnerUpdateFrequency(); int getMaxStep(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java index c42756145..e50e50114 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java @@ -25,6 +25,7 @@ import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.listener.*; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager; /** @@ -35,8 +36,8 @@ import org.deeplearning4j.rl4j.util.IDataManager; * @author Alexandre Boulanger */ @Slf4j -public abstract class SyncLearning, NN extends NeuralNet> - extends Learning implements IEpochTrainer { +public abstract class SyncLearning, NN extends NeuralNet> + extends Learning implements IEpochTrainer { private final TrainingListenerList listeners = new TrainingListenerList(); @@ -85,7 +86,7 @@ public abstract class SyncLearning, NN extends N boolean canContinue = listeners.notifyTrainingStarted(); if (canContinue) { - while (getStepCounter() < getConfiguration().getMaxStep()) { + while (this.getStepCount() < getConfiguration().getMaxStep()) { preEpoch(); canContinue = listeners.notifyNewEpoch(this); if (!canContinue) { @@ -100,14 +101,14 @@ public abstract class SyncLearning, NN extends N postEpoch(); - if(getEpochCounter() % progressMonitorFrequency == 0) { + if(getEpochCount() % progressMonitorFrequency == 0) { canContinue = listeners.notifyTrainingProgress(this); if (!canContinue) { break; } } - log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward()); + log.info("Epoch: " + getEpochCount() + ", reward: " + statEntry.getReward()); incrementEpoch(); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index 7bef13e59..d12db5d67 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -19,21 +19,16 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; -import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.Getter; -import lombok.Setter; import lombok.Value; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.gym.StepReply; -import org.deeplearning4j.rl4j.learning.EpochStepCounter; -import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; +import org.deeplearning4j.rl4j.learning.IEpochTrainer; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; -import org.deeplearning4j.rl4j.learning.sync.ExpReplay; -import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.SyncLearning; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; @@ -43,8 +38,6 @@ import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager.StatEntry; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; -import org.nd4j.linalg.api.rng.Random; -import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; @@ -58,7 +51,7 @@ import java.util.List; @Slf4j public abstract class QLearning> extends SyncLearning - implements TargetQNetworkSource, EpochStepCounter { + implements TargetQNetworkSource, IEpochTrainer { protected abstract LegacyMDPWrapper getLegacyMDPWrapper(); @@ -90,7 +83,10 @@ public abstract class QLearning trainStep(Observation obs); @Getter - private int currentEpochStep = 0; + private int episodeCount; + + @Getter + private int currentEpisodeStepCount = 0; protected StatEntry trainEpoch() { resetNetworks(); @@ -104,9 +100,9 @@ public abstract class QLearning scores = new ArrayList<>(); - while (currentEpochStep < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) { + while (currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) { - if (getStepCounter() % getConfiguration().getTargetDqnUpdateFreq() == 0) { + if (this.getStepCount() % getConfiguration().getTargetDqnUpdateFreq() == 0) { updateTargetNetwork(); } @@ -132,20 +128,20 @@ public abstract class QLearning refacInitMdp() { - currentEpochStep = 0; + currentEpisodeStepCount = 0; double reward = 0; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index 1b9e667ae..b2ad597d0 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -47,6 +47,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.List; + /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. *

@@ -90,7 +91,7 @@ public abstract class QLearningDiscrete extends QLearning mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep, Random random) { this.configuration = conf; - this.mdp = new LegacyMDPWrapper<>(mdp, null, this); + this.mdp = new LegacyMDPWrapper<>(mdp, null); qNetwork = dqn; targetQNetwork = dqn.clone(); policy = new DQNPolicy(getQNetwork()); @@ -164,13 +165,13 @@ public abstract class QLearningDiscrete extends QLearning updateStart) { + if (this.getStepCount() > updateStart) { DataSet targets = setTarget(experienceHandler.generateTrainingBatch()); getQNetwork().fit(targets.getFeatures(), targets.getLabels()); } } - return new QLStepReturn(maxQ, getQNetwork().getLatestScore(), stepReply); + return new QLStepReturn<>(maxQ, getQNetwork().getLatestScore(), stepReply); } protected DataSet setTarget(List> transitions) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java index c8a0c38d9..0444aa32d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java @@ -17,6 +17,7 @@ package org.deeplearning4j.rl4j.observation; import lombok.Getter; +import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; /** diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java index 3eaeec4dc..133fbdb61 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java @@ -40,7 +40,7 @@ public class EncodableToImageWritableTransform implements Operation> extends Policy { +public class EpsGreedy> extends Policy { final private Policy policy; final private MDP mdp; @@ -57,8 +58,8 @@ public class EpsGreedy> extends Policy { public A nextAction(INDArray input) { double ep = getEpsilon(); - if (learning.getStepCounter() % 500 == 1) - log.info("EP: " + ep + " " + learning.getStepCounter()); + if (learning.getStepCount() % 500 == 1) + log.info("EP: " + ep + " " + learning.getStepCount()); if (rnd.nextDouble() > ep) return policy.nextAction(input); else @@ -70,6 +71,6 @@ public class EpsGreedy> extends Policy { } public double getEpsilon() { - return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCounter() - updateStart) * 1.0 / epsilonNbStep)); + return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCount() - updateStart) * 1.0 / epsilonNbStep)); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java index 0bef2a757..f87971a89 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java @@ -7,7 +7,7 @@ import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; -public interface IPolicy { +public interface IPolicy { > double play(MDP mdp, IHistoryProcessor hp); A nextAction(INDArray input); A nextAction(Observation observation); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index 7719df612..d5fa59766 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -16,10 +16,7 @@ package org.deeplearning4j.rl4j.policy; -import lombok.Getter; -import lombok.Setter; import org.deeplearning4j.gym.StepReply; -import org.deeplearning4j.rl4j.learning.EpochStepCounter; import org.deeplearning4j.rl4j.learning.HistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; @@ -27,6 +24,7 @@ import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; /** @@ -36,7 +34,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; * * A Policy responsability is to choose the next action given a state */ -public abstract class Policy implements IPolicy { +public abstract class Policy implements IPolicy { public abstract NeuralNet getNeuralNet(); @@ -54,10 +52,9 @@ public abstract class Policy implements IPolicy { public > double play(MDP mdp, IHistoryProcessor hp) { resetNetworks(); - RefacEpochStepCounter epochStepCounter = new RefacEpochStepCounter(); - LegacyMDPWrapper mdpWrapper = new LegacyMDPWrapper(mdp, hp, epochStepCounter); + LegacyMDPWrapper mdpWrapper = new LegacyMDPWrapper(mdp, hp); - Learning.InitMdp initMdp = refacInitMdp(mdpWrapper, hp, epochStepCounter); + Learning.InitMdp initMdp = refacInitMdp(mdpWrapper, hp); Observation obs = initMdp.getLastObs(); double reward = initMdp.getReward(); @@ -79,7 +76,6 @@ public abstract class Policy implements IPolicy { reward += stepReply.getReward(); obs = stepReply.getObservation(); - epochStepCounter.incrementEpochStep(); } return reward; @@ -89,8 +85,7 @@ public abstract class Policy implements IPolicy { getNeuralNet().reset(); } - protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { - epochStepCounter.setCurrentEpochStep(0); + protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) { double reward = 0; @@ -104,21 +99,9 @@ public abstract class Policy implements IPolicy { reward += stepReply.getReward(); observation = stepReply.getObservation(); - epochStepCounter.incrementEpochStep(); } return new Learning.InitMdp(0, observation, reward); } - public class RefacEpochStepCounter implements EpochStepCounter { - - @Getter - @Setter - private int currentEpochStep = 0; - - public void incrementEpochStep() { - ++currentEpochStep; - } - - } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java index bffafdb76..0caa29453 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java @@ -278,7 +278,7 @@ public class DataManager implements IDataManager { Path infoPath = Paths.get(getInfo()); Info info = new Info(iLearning.getClass().getSimpleName(), iLearning.getMdp().getClass().getSimpleName(), - iLearning.getConfiguration(), iLearning.getStepCounter(), System.currentTimeMillis()); + iLearning.getConfiguration(), iLearning.getStepCount(), System.currentTimeMillis()); String toWrite = toJson(info); Files.write(infoPath, toWrite.getBytes(), StandardOpenOption.TRUNCATE_EXISTING); @@ -300,12 +300,12 @@ public class DataManager implements IDataManager { if (!saveData) return; - save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning); + save(getModelDir() + "/" + learning.getStepCount() + ".training", learning); if(learning instanceof NeuralNetFetchable) { try { - ((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model"); + ((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCount() + ".model"); } catch (UnsupportedOperationException e) { - String path = getModelDir() + "/" + learning.getStepCounter(); + String path = getModelDir() + "/" + learning.getStepCount(); ((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(path + "_value.model", path + "_policy.model"); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java index 83b8d71da..aa4ec4d17 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java @@ -40,7 +40,7 @@ public class DataManagerTrainingListener implements TrainingListener { if (trainer instanceof AsyncThread) { filename += ((AsyncThread) trainer).getThreadNumber() + "-"; } - filename += trainer.getEpochCounter() + "-" + trainer.getStepCounter() + ".mp4"; + filename += trainer.getEpochCount() + "-" + trainer.getStepCount() + ".mp4"; hp.startMonitor(filename, shape); } @@ -66,7 +66,7 @@ public class DataManagerTrainingListener implements TrainingListener { @Override public ListenerResponse onTrainingProgress(ILearning learning) { try { - int stepCounter = learning.getStepCounter(); + int stepCounter = learning.getStepCount(); if (stepCounter - lastSave >= Constants.MODEL_SAVE_FREQ) { dataManager.save(learning); lastSave = stepCounter; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java index b0f46ef57..981f35379 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java @@ -8,7 +8,6 @@ import org.datavec.image.transform.CropImageTransform; import org.datavec.image.transform.MultiImageTransform; import org.datavec.image.transform.ResizeImageTransform; import org.deeplearning4j.gym.StepReply; -import org.deeplearning4j.rl4j.learning.EpochStepCounter; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; @@ -30,10 +29,10 @@ import java.util.Map; import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2GRAY; -public class LegacyMDPWrapper> implements MDP { +public class LegacyMDPWrapper> implements MDP { @Getter - private final MDP wrappedMDP; + private final MDP wrappedMDP; @Getter private final WrapperObservationSpace observationSpace; private final int[] shape; @@ -44,16 +43,14 @@ public class LegacyMDPWrapper> implements MDP wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) { + public LegacyMDPWrapper(MDP wrappedMDP, IHistoryProcessor historyProcessor) { this.wrappedMDP = wrappedMDP; this.shape = wrappedMDP.getObservationSpace().getShape(); this.observationSpace = new WrapperObservationSpace(shape); this.historyProcessor = historyProcessor; - this.epochStepCounter = epochStepCounter; setHistoryProcessor(historyProcessor); } @@ -63,6 +60,7 @@ public class LegacyMDPWrapper> implements MDP> implements MDP> implements MDP step(A a) { IHistoryProcessor historyProcessor = getHistoryProcessor(); - StepReply rawStepReply = wrappedMDP.step(a); + StepReply rawStepReply = wrappedMDP.step(a); INDArray rawObservation = getInput(rawStepReply.getObservation()); if(historyProcessor != null) { historyProcessor.record(rawObservation); } - int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1; + int stepOfObservation = steps++; Map channelsData = buildChannelsData(rawStepReply.getObservation()); Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone()); return new StepReply(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); } - private void record(O obs) { + private void record(OBSERVATION obs) { INDArray rawObservation = getInput(obs); IHistoryProcessor historyProcessor = getHistoryProcessor(); @@ -141,7 +139,7 @@ public class LegacyMDPWrapper> implements MDP buildChannelsData(final O obs) { + private Map buildChannelsData(final OBSERVATION obs) { return new HashMap() {{ put("data", obs); }}; @@ -159,11 +157,11 @@ public class LegacyMDPWrapper> implements MDP newInstance() { - return new LegacyMDPWrapper(wrappedMDP.newInstance(), historyProcessor, epochStepCounter); + return new LegacyMDPWrapper<>(wrappedMDP.newInstance(), historyProcessor); } - private INDArray getInput(O obs) { - INDArray arr = Nd4j.create(((Encodable)obs).toArray()); + private INDArray getInput(OBSERVATION obs) { + INDArray arr = Nd4j.create(obs.toArray()); int[] shape = observationSpace.getShape(); if (shape.length == 1) return arr.reshape(new long[] {1, arr.length()}); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java index 619fd813f..765a14c8f 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java @@ -1,107 +1,107 @@ -package org.deeplearning4j.rl4j.experience; - -import org.deeplearning4j.rl4j.learning.sync.IExpReplay; -import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.ArrayList; -import java.util.List; - -import static org.junit.Assert.assertEquals; - -public class ReplayMemoryExperienceHandlerTest { - @Test - public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() { - // Arrange - TestExpReplay expReplayMock = new TestExpReplay(); - ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); - - // Act - sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); - int numStoredTransitions = expReplayMock.addedTransitions.size(); - sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); - - // Assert - assertEquals(0, numStoredTransitions); - assertEquals(1, expReplayMock.addedTransitions.size()); - } - - @Test - public void when_addingExperience_expect_transitionsAreCorrect() { - // Arrange - TestExpReplay expReplayMock = new TestExpReplay(); - ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); - - // Act - sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); - sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); - sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 }))); - - // Assert - assertEquals(2, expReplayMock.addedTransitions.size()); - - assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001); - assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction()); - assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001); - assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001); - - assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001); - assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction()); - assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001); - assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001); - - } - - @Test - public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() { - // Arrange - TestExpReplay expReplayMock = new TestExpReplay(); - ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); - - // Act - sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); - sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 2.0 }))); - sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false); - - // Assert - assertEquals(1, expReplayMock.addedTransitions.size()); - assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction()); - } - - @Test - public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() { - // Arrange - TestExpReplay expReplayMock = new TestExpReplay(); - ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); - sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); - sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); - sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 }))); - - // Act - int size = sut.getTrainingBatchSize(); - // Assert - assertEquals(2, size); - } - - private static class TestExpReplay implements IExpReplay { - - public final List> addedTransitions = new ArrayList<>(); - - @Override - public ArrayList> getBatch() { - return null; - } - - @Override - public void store(Transition transition) { - addedTransitions.add(transition); - } - - @Override - public int getBatchSize() { - return addedTransitions.size(); - } - } -} +package org.deeplearning4j.rl4j.experience; + +import org.deeplearning4j.rl4j.learning.sync.IExpReplay; +import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Test; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class ReplayMemoryExperienceHandlerTest { + @Test + public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() { + // Arrange + TestExpReplay expReplayMock = new TestExpReplay(); + ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); + + // Act + sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); + int numStoredTransitions = expReplayMock.addedTransitions.size(); + sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); + + // Assert + assertEquals(0, numStoredTransitions); + assertEquals(1, expReplayMock.addedTransitions.size()); + } + + @Test + public void when_addingExperience_expect_transitionsAreCorrect() { + // Arrange + TestExpReplay expReplayMock = new TestExpReplay(); + ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); + + // Act + sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); + sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); + sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 }))); + + // Assert + assertEquals(2, expReplayMock.addedTransitions.size()); + + assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001); + assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction()); + assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001); + assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001); + + assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001); + assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction()); + assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001); + assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001); + + } + + @Test + public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() { + // Arrange + TestExpReplay expReplayMock = new TestExpReplay(); + ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); + + // Act + sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); + sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 2.0 }))); + sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false); + + // Assert + assertEquals(1, expReplayMock.addedTransitions.size()); + assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction()); + } + + @Test + public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() { + // Arrange + TestExpReplay expReplayMock = new TestExpReplay(); + ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); + sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); + sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); + sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 }))); + + // Act + int size = sut.getTrainingBatchSize(); + // Assert + assertEquals(2, size); + } + + private static class TestExpReplay implements IExpReplay { + + public final List> addedTransitions = new ArrayList<>(); + + @Override + public ArrayList> getBatch() { + return null; + } + + @Override + public void store(Transition transition) { + addedTransitions.add(transition); + } + + @Override + public int getBatchSize() { + return addedTransitions.size(); + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java index f2941feef..b1faa70bf 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java @@ -18,132 +18,93 @@ package org.deeplearning4j.rl4j.learning.async; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; -import org.deeplearning4j.rl4j.mdp.MDP; -import org.deeplearning4j.rl4j.policy.IPolicy; -import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.support.MockAsyncConfiguration; -import org.deeplearning4j.rl4j.support.MockAsyncGlobal; -import org.deeplearning4j.rl4j.support.MockEncodable; -import org.deeplearning4j.rl4j.support.MockNeuralNet; -import org.deeplearning4j.rl4j.support.MockPolicy; -import org.deeplearning4j.rl4j.support.MockTrainingListener; +import org.deeplearning4j.rl4j.learning.listener.TrainingListener; +import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Box; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) public class AsyncLearningTest { - @Test - public void when_training_expect_AsyncGlobalStarted() { - // Arrange - TestContext context = new TestContext(); - context.asyncGlobal.setMaxLoops(1); + AsyncLearning, NeuralNet> asyncLearning; - // Act - context.sut.train(); + @Mock + TrainingListener mockTrainingListener; - // Assert - assertTrue(context.asyncGlobal.hasBeenStarted); - assertTrue(context.asyncGlobal.hasBeenTerminated); + @Mock + AsyncGlobal mockAsyncGlobal; + + @Mock + IAsyncLearningConfiguration mockConfiguration; + + @Before + public void setup() { + asyncLearning = mock(AsyncLearning.class, Mockito.withSettings() + .useConstructor() + .defaultAnswer(Mockito.CALLS_REAL_METHODS)); + + asyncLearning.addListener(mockTrainingListener); + + when(asyncLearning.getAsyncGlobal()).thenReturn(mockAsyncGlobal); + when(asyncLearning.getConfiguration()).thenReturn(mockConfiguration); + + // Don't actually start any threads in any of these tests + when(mockConfiguration.getNumThreads()).thenReturn(0); } @Test public void when_trainStartReturnsStop_expect_noTraining() { // Arrange - TestContext context = new TestContext(); - context.listener.setRemainingTrainingStartCallCount(0); + when(mockTrainingListener.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP); + // Act - context.sut.train(); + asyncLearning.train(); // Assert - assertEquals(1, context.listener.onTrainingStartCallCount); - assertEquals(1, context.listener.onTrainingEndCallCount); - assertEquals(0, context.policy.playCallCount); - assertTrue(context.asyncGlobal.hasBeenTerminated); + verify(mockTrainingListener, times(1)).onTrainingStart(); + verify(mockTrainingListener, times(1)).onTrainingEnd(); } @Test public void when_trainingIsComplete_expect_trainingStop() { // Arrange - TestContext context = new TestContext(); + when(mockAsyncGlobal.isTrainingComplete()).thenReturn(true); // Act - context.sut.train(); + asyncLearning.train(); // Assert - assertEquals(1, context.listener.onTrainingStartCallCount); - assertEquals(1, context.listener.onTrainingEndCallCount); - assertTrue(context.asyncGlobal.hasBeenTerminated); + verify(mockTrainingListener, times(1)).onTrainingStart(); + verify(mockTrainingListener, times(1)).onTrainingEnd(); } @Test public void when_training_expect_onTrainingProgressCalled() { // Arrange - TestContext context = new TestContext(); + asyncLearning.setProgressMonitorFrequency(100); + when(mockTrainingListener.onTrainingProgress(eq(asyncLearning))).thenReturn(TrainingListener.ListenerResponse.STOP); // Act - context.sut.train(); + asyncLearning.train(); // Assert - assertEquals(1, context.listener.onTrainingProgressCallCount); + verify(mockTrainingListener, times(1)).onTrainingStart(); + verify(mockTrainingListener, times(1)).onTrainingEnd(); + verify(mockTrainingListener, times(1)).onTrainingProgress(eq(asyncLearning)); } - - - public static class TestContext { - MockAsyncConfiguration config = new MockAsyncConfiguration(1L, 11, 0, 0, 0, 0,0, 0, 0, 0); - public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(); - public final MockPolicy policy = new MockPolicy(); - public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy); - public final MockTrainingListener listener = new MockTrainingListener(asyncGlobal); - - public TestContext() { - sut.addListener(listener); - asyncGlobal.setMaxLoops(1); - sut.setProgressMonitorFrequency(1); - } - } - - public static class TestAsyncLearning extends AsyncLearning { - private final IAsyncLearningConfiguration conf; - private final IAsyncGlobal asyncGlobal; - private final IPolicy policy; - - public TestAsyncLearning(IAsyncLearningConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy policy) { - this.conf = conf; - this.asyncGlobal = asyncGlobal; - this.policy = policy; - } - - @Override - public IPolicy getPolicy() { - return policy; - } - - @Override - public IAsyncLearningConfiguration getConfiguration() { - return conf; - } - - @Override - protected AsyncThread newThread(int i, int deviceAffinity) { - return null; - } - - @Override - public MDP getMdp() { - return null; - } - - @Override - protected IAsyncGlobal getAsyncGlobal() { - return asyncGlobal; - } - - @Override - public MockNeuralNet getNeuralNet() { - return null; - } - } - } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index 320b53a0e..5f2a8ab31 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -1,6 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. + * Copyright (c) 2020 Konduit K. K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -17,161 +16,230 @@ package org.deeplearning4j.rl4j.learning.async; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.experience.ExperienceHandler; -import org.deeplearning4j.rl4j.experience.StateActionPair; -import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; -import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; -import org.deeplearning4j.rl4j.policy.IPolicy; +import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.support.*; +import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.ObservationSpace; +import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; +import org.junit.Before; import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; +import org.nd4j.linalg.factory.Nd4j; -import java.util.ArrayList; -import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +@RunWith(MockitoJUnitRunner.class) public class AsyncThreadDiscreteTest { + + AsyncThreadDiscrete asyncThreadDiscrete; + + @Mock + IAsyncLearningConfiguration mockAsyncConfiguration; + + @Mock + UpdateAlgorithm mockUpdateAlgorithm; + + @Mock + IAsyncGlobal mockAsyncGlobal; + + @Mock + Policy mockGlobalCurrentPolicy; + + @Mock + NeuralNet mockGlobalTargetNetwork; + + @Mock + MDP mockMDP; + + @Mock + LegacyMDPWrapper mockLegacyMDPWrapper; + + @Mock + DiscreteSpace mockActionSpace; + + @Mock + ObservationSpace mockObservationSpace; + + @Mock + TrainingListenerList mockTrainingListenerList; + + @Mock + Observation mockObservation; + + int[] observationShape = new int[]{3, 10, 10}; + int actionSize = 4; + + private void setupMDPMocks() { + + when(mockActionSpace.noOp()).thenReturn(0); + when(mockMDP.getActionSpace()).thenReturn(mockActionSpace); + + when(mockObservationSpace.getShape()).thenReturn(observationShape); + when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace); + + } + + private void setupNNMocks() { + when(mockAsyncGlobal.getTarget()).thenReturn(mockGlobalTargetNetwork); + when(mockGlobalTargetNetwork.clone()).thenReturn(mockGlobalTargetNetwork); + } + + @Before + public void setup() { + + setupMDPMocks(); + setupNNMocks(); + + asyncThreadDiscrete = mock(AsyncThreadDiscrete.class, Mockito.withSettings() + .useConstructor(mockAsyncGlobal, mockMDP, mockTrainingListenerList, 0, 0) + .defaultAnswer(Mockito.CALLS_REAL_METHODS)); + + asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm); + + when(asyncThreadDiscrete.getConf()).thenReturn(mockAsyncConfiguration); + when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0); + when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal); + when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy); + + when(mockGlobalCurrentPolicy.nextAction(any(Observation.class))).thenReturn(0); + + when(asyncThreadDiscrete.getLegacyMDPWrapper()).thenReturn(mockLegacyMDPWrapper); + + } + @Test - public void refac_AsyncThreadDiscrete_trainSubEpoch() { + public void when_episodeCompletes_expect_stepsToBeInLineWithEpisodeLenth() { + // Arrange - int numEpochs = 1; - MockNeuralNet nnMock = new MockNeuralNet(); - IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); - MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); - MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock); - asyncGlobalMock.setMaxLoops(hpConf.getSkipFrame() * numEpochs); - MockObservationSpace observationSpace = new MockObservationSpace(); - MockMDP mdpMock = new MockMDP(observationSpace); - TrainingListenerList listeners = new TrainingListenerList(); - MockPolicy policyMock = new MockPolicy(); - MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5); - MockExperienceHandler experienceHandlerMock = new MockExperienceHandler(); - MockUpdateAlgorithm updateAlgorithmMock = new MockUpdateAlgorithm(); - TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock, experienceHandlerMock, updateAlgorithmMock); - sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); + int episodeRemaining = 5; + int remainingTrainingSteps = 10; + + // return done after 4 steps (the episode finishes before nsteps) + when(mockMDP.isDone()).thenAnswer(invocation -> + asyncThreadDiscrete.getStepCount() == episodeRemaining + ); + + when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null)); // Act - sut.run(); + AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps); // Assert - assertEquals(2, sut.trainSubEpochResults.size()); - double[][] expectedLastObservations = new double[][] { - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, - }; - double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 }; - for(int i = 0; i < 2; ++i) { - AsyncThread.SubEpochReturn result = sut.trainSubEpochResults.get(i); - assertEquals(4, result.getSteps()); - assertEquals(expectedSubEpochReturnRewards[i], result.getReward(), 0.00001); - assertEquals(0.0, result.getScore(), 0.00001); - - double[] expectedLastObservation = expectedLastObservations[i]; - assertEquals(expectedLastObservation.length, result.getLastObs().getData().shape()[1]); - for(int j = 0; j < expectedLastObservation.length; ++j) { - assertEquals(expectedLastObservation[j], 255.0 * result.getLastObs().getData().getDouble(j), 0.00001); - } - } - assertEquals(2, asyncGlobalMock.enqueueCallCount); - - // HistoryProcessor - double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, }; - assertEquals(expectedRecordValues.length, hpMock.recordCalls.size()); - for(int i = 0; i < expectedRecordValues.length; ++i) { - assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001); - } - - // Policy - double[][] expectedPolicyInputs = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - }; - assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size()); - for(int i = 0; i < expectedPolicyInputs.length; ++i) { - double[] expectedRow = expectedPolicyInputs[i]; - INDArray input = policyMock.actionInputs.get(i); - assertEquals(expectedRow.length, input.shape()[1]); - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001); - } - } - - // ExperienceHandler - double[][] expectedExperienceHandlerInputs = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - }; - assertEquals(expectedExperienceHandlerInputs.length, experienceHandlerMock.addExperienceArgs.size()); - for(int i = 0; i < expectedExperienceHandlerInputs.length; ++i) { - double[] expectedRow = expectedExperienceHandlerInputs[i]; - INDArray input = experienceHandlerMock.addExperienceArgs.get(i).getObservation().getData(); - assertEquals(expectedRow.length, input.shape()[1]); - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001); - } - } + assertTrue(subEpochReturn.isEpisodeComplete()); + assertEquals(5, subEpochReturn.getSteps()); } - public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete { + @Test + public void when_episodeCompletesDueToMaxStepsReached_expect_isEpisodeComplete() { - private final MockAsyncGlobal asyncGlobal; - private final MockPolicy policy; - private final MockAsyncConfiguration config; + // Arrange + int remainingTrainingSteps = 50; - public final List trainSubEpochResults = new ArrayList(); + // Episode does not complete due to MDP + when(mockMDP.isDone()).thenReturn(false); - public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP mdp, - TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy, - MockAsyncConfiguration config, IHistoryProcessor hp, - ExperienceHandler> experienceHandler, - UpdateAlgorithm updateAlgorithm) { - super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); - this.asyncGlobal = asyncGlobal; - this.policy = policy; - this.config = config; - setHistoryProcessor(hp); - setExperienceHandler(experienceHandler); - setUpdateAlgorithm(updateAlgorithm); - } + when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null)); - @Override - protected IAsyncGlobal getAsyncGlobal() { - return asyncGlobal; - } + when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(50); - @Override - protected IAsyncLearningConfiguration getConf() { - return config; - } + // Act + AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps); - @Override - protected IPolicy getPolicy(MockNeuralNet net) { - return policy; - } + // Assert + assertTrue(subEpochReturn.isEpisodeComplete()); + assertEquals(50, subEpochReturn.getSteps()); - @Override - protected UpdateAlgorithm buildUpdateAlgorithm() { - return null; - } - - @Override - public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) { - asyncGlobal.increaseCurrentLoop(); - SubEpochReturn result = super.trainSubEpoch(sObs, nstep); - trainSubEpochResults.add(result); - return result; - } } + + @Test + public void when_episodeLongerThanNsteps_expect_returnNStepLength() { + + // Arrange + int episodeRemaining = 5; + int remainingTrainingSteps = 4; + + // return done after 4 steps (the episode finishes before nsteps) + when(mockMDP.isDone()).thenAnswer(invocation -> + asyncThreadDiscrete.getStepCount() == episodeRemaining + ); + + when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null)); + + // Act + AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps); + + // Assert + assertFalse(subEpochReturn.isEpisodeComplete()); + assertEquals(remainingTrainingSteps, subEpochReturn.getSteps()); + } + + @Test + public void when_framesAreSkipped_expect_proportionateStepCounterUpdates() { + int skipFrames = 2; + int remainingTrainingSteps = 10; + + // Episode does not complete due to MDP + when(mockMDP.isDone()).thenReturn(false); + + AtomicInteger stepCount = new AtomicInteger(); + + // Use skipFrames to return if observations are skipped or not + when(mockLegacyMDPWrapper.step(anyInt())).thenAnswer(invocationOnMock -> { + + boolean isSkipped = stepCount.incrementAndGet() % skipFrames != 0; + + Observation mockObs = new Observation(isSkipped ? null : Nd4j.create(observationShape)); + return new StepReply<>(mockObs, 0.0, false, null); + }); + + + // Act + AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps); + + // Assert + assertFalse(subEpochReturn.isEpisodeComplete()); + assertEquals(remainingTrainingSteps, subEpochReturn.getSteps()); + assertEquals((remainingTrainingSteps - 1) * skipFrames + 1, stepCount.get()); + } + + @Test + public void when_preEpisodeCalled_expect_experienceHandlerReset() { + + // Arrange + int trainingSteps = 100; + for (int i = 0; i < trainingSteps; i++) { + asyncThreadDiscrete.getExperienceHandler().addExperience(mockObservation, 0, 0.0, false); + } + + int experienceHandlerSizeBeforeReset = asyncThreadDiscrete.getExperienceHandler().getTrainingBatchSize(); + + // Act + asyncThreadDiscrete.preEpisode(); + + // Assert + assertEquals(100, experienceHandlerSizeBeforeReset); + assertEquals(0, asyncThreadDiscrete.getExperienceHandler().getTrainingBatchSize()); + + + } + } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index ff29960f1..117465de3 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -1,220 +1,277 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.async; -import lombok.AllArgsConstructor; -import lombok.Getter; -import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; -import org.deeplearning4j.rl4j.policy.Policy; -import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.support.MockAsyncConfiguration; -import org.deeplearning4j.rl4j.support.MockAsyncGlobal; -import org.deeplearning4j.rl4j.support.MockEncodable; -import org.deeplearning4j.rl4j.support.MockHistoryProcessor; -import org.deeplearning4j.rl4j.support.MockMDP; -import org.deeplearning4j.rl4j.support.MockNeuralNet; -import org.deeplearning4j.rl4j.support.MockObservationSpace; -import org.deeplearning4j.rl4j.support.MockTrainingListener; +import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Box; +import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.util.IDataManager; +import org.junit.Before; import org.junit.Test; - -import java.util.ArrayList; -import java.util.List; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.shade.guava.base.Preconditions; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +@RunWith(MockitoJUnitRunner.class) public class AsyncThreadTest { - @Test - public void when_newEpochStarted_expect_neuralNetworkReset() { - // Arrange - int numberOfEpochs = 5; - TestContext context = new TestContext(numberOfEpochs); + @Mock + ActionSpace mockActionSpace; - // Act - context.sut.run(); + @Mock + ObservationSpace mockObservationSpace; - // Assert - assertEquals(numberOfEpochs, context.neuralNet.resetCallCount); + @Mock + IAsyncLearningConfiguration mockAsyncConfiguration; + + @Mock + NeuralNet mockNeuralNet; + + @Mock + IAsyncGlobal mockAsyncGlobal; + + @Mock + MDP> mockMDP; + + @Mock + TrainingListenerList mockTrainingListeners; + + int[] observationShape = new int[]{3, 10, 10}; + int actionSize = 4; + + AsyncThread, NeuralNet> thread; + + @Before + public void setup() { + setupMDPMocks(); + setupThreadMocks(); + } + + private void setupThreadMocks() { + + thread = mock(AsyncThread.class, Mockito.withSettings() + .useConstructor(mockMDP, mockTrainingListeners, 0, 0) + .defaultAnswer(Mockito.CALLS_REAL_METHODS)); + + when(thread.getAsyncGlobal()).thenReturn(mockAsyncGlobal); + when(thread.getCurrent()).thenReturn(mockNeuralNet); + } + + private void setupMDPMocks() { + + when(mockObservationSpace.getShape()).thenReturn(observationShape); + when(mockActionSpace.noOp()).thenReturn(Nd4j.zeros(actionSize)); + + when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace); + when(mockMDP.getActionSpace()).thenReturn(mockActionSpace); + + int dataLength = 1; + for (int d : observationShape) { + dataLength *= d; + } + + when(mockMDP.reset()).thenReturn(new Box(new double[dataLength])); + } + + private void mockTrainingListeners() { + mockTrainingListeners(false, false); + } + + private void mockTrainingListeners(boolean stopOnNotifyNewEpoch, boolean stopOnNotifyEpochTrainingResult) { + when(mockTrainingListeners.notifyNewEpoch(eq(thread))).thenReturn(!stopOnNotifyNewEpoch); + when(mockTrainingListeners.notifyEpochTrainingResult(eq(thread), any(IDataManager.StatEntry.class))).thenReturn(!stopOnNotifyEpochTrainingResult); + } + + private void mockTrainingContext() { + mockTrainingContext(1000, 100, 10); + } + + private void mockTrainingContext(int maxSteps, int maxStepsPerEpisode, int nstep) { + + // Some conditions of this test harness + Preconditions.checkArgument(maxStepsPerEpisode >= nstep, "episodeLength must be greater than or equal to nstep"); + Preconditions.checkArgument(maxStepsPerEpisode % nstep == 0, "episodeLength must be a multiple of nstep"); + + Observation mockObs = new Observation(Nd4j.zeros(observationShape)); + + when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode); + when(mockAsyncConfiguration.getNStep()).thenReturn(nstep); + when(thread.getConf()).thenReturn(mockAsyncConfiguration); + + // if we hit the max step count + when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps); + + when(thread.trainSubEpoch(any(Observation.class), anyInt())).thenAnswer(invocationOnMock -> { + int steps = invocationOnMock.getArgument(1); + thread.stepCount += steps; + thread.currentEpisodeStepCount += steps; + boolean isEpisodeComplete = thread.getCurrentEpisodeStepCount() % maxStepsPerEpisode == 0; + return new AsyncThread.SubEpochReturn(steps, mockObs, 0.0, 0.0, isEpisodeComplete); + }); } @Test - public void when_onNewEpochReturnsStop_expect_threadStopped() { + public void when_episodeComplete_expect_neuralNetworkReset() { + // Arrange - int stopAfterNumCalls = 1; - TestContext context = new TestContext(100000); - context.listener.setRemainingOnNewEpochCallCount(stopAfterNumCalls); + mockTrainingContext(100, 10, 10); + mockTrainingListeners(); // Act - context.sut.run(); + thread.run(); // Assert - assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: The call that returns stop is counted - assertEquals(stopAfterNumCalls, context.listener.onEpochTrainingResultCallCount); + verify(mockNeuralNet, times(10)).reset(); // there are 10 episodes so the network should be reset between each + assertEquals(10, thread.getEpochCount()); // We are performing a training iteration every 10 steps, so there should be 10 epochs + assertEquals(10, thread.getEpisodeCount()); // There should be 10 completed episodes + assertEquals(100, thread.getStepCount()); // 100 steps overall } @Test - public void when_epochTrainingResultReturnsStop_expect_threadStopped() { + public void when_notifyNewEpochReturnsStop_expect_threadStopped() { // Arrange - int stopAfterNumCalls = 1; - TestContext context = new TestContext(100000); - context.listener.setRemainingOnEpochTrainingResult(stopAfterNumCalls); + mockTrainingContext(); + mockTrainingListeners(true, false); // Act - context.sut.run(); + thread.run(); // Assert - assertEquals(stopAfterNumCalls + 1, context.listener.onEpochTrainingResultCallCount); // +1: The call that returns stop is counted - assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: onNewEpoch is called on the epoch that onEpochTrainingResult() will stop + assertEquals(0, thread.getEpochCount()); + assertEquals(1, thread.getEpisodeCount()); + assertEquals(0, thread.getStepCount()); } @Test - public void when_run_expect_preAndPostEpochCalled() { + public void when_notifyEpochTrainingResultReturnsStop_expect_threadStopped() { // Arrange - int numberOfEpochs = 5; - TestContext context = new TestContext(numberOfEpochs); + mockTrainingContext(); + mockTrainingListeners(false, true); // Act - context.sut.run(); + thread.run(); // Assert - assertEquals(numberOfEpochs, context.sut.preEpochCallCount); - assertEquals(numberOfEpochs, context.sut.postEpochCallCount); + assertEquals(1, thread.getEpochCount()); + assertEquals(1, thread.getEpisodeCount()); + assertEquals(10, thread.getStepCount()); // one epoch is by default 10 steps + } + + @Test + public void when_run_expect_preAndPostEpisodeCalled() { + // Arrange + mockTrainingContext(100, 10, 5); + mockTrainingListeners(false, false); + + // Act + thread.run(); + + // Assert + assertEquals(20, thread.getEpochCount()); + assertEquals(10, thread.getEpisodeCount()); + assertEquals(100, thread.getStepCount()); + + verify(thread, times(10)).preEpisode(); // over 100 steps there will be 10 episodes + verify(thread, times(10)).postEpisode(); } @Test public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() { // Arrange - int numberOfEpochs = 5; - TestContext context = new TestContext(numberOfEpochs); + mockTrainingContext(100, 10, 5); + mockTrainingListeners(false, false); // Act - context.sut.run(); + thread.run(); // Assert - assertEquals(numberOfEpochs, context.listener.statEntries.size()); - int[] expectedStepCounter = new int[] { 10, 20, 30, 40, 50 }; - double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init - + 1.0; // Reward from trainSubEpoch() - for(int i = 0; i < numberOfEpochs; ++i) { - IDataManager.StatEntry statEntry = context.listener.statEntries.get(i); - assertEquals(expectedStepCounter[i], statEntry.getStepCounter()); - assertEquals(i, statEntry.getEpochCounter()); - assertEquals(expectedReward, statEntry.getReward(), 0.0001); - } + assertEquals(20, thread.getEpochCount()); + assertEquals(10, thread.getEpisodeCount()); + assertEquals(100, thread.getStepCount()); + + // Over 100 steps there will be 20 training iterations, so there will be 20 calls to notifyEpochTrainingResult + verify(mockTrainingListeners, times(20)).notifyEpochTrainingResult(eq(thread), any(IDataManager.StatEntry.class)); } @Test public void when_run_expect_trainSubEpochCalled() { // Arrange - int numberOfEpochs = 5; - TestContext context = new TestContext(numberOfEpochs); + mockTrainingContext(100, 10, 5); + mockTrainingListeners(false, false); // Act - context.sut.run(); + thread.run(); // Assert - assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size()); - double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }; - for(int i = 0; i < context.sut.trainSubEpochParams.size(); ++i) { - MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i); - assertEquals(2, params.nstep); - assertEquals(expectedObservation.length, params.obs.getData().shape()[1]); - for(int j = 0; j < expectedObservation.length; ++j){ - assertEquals(expectedObservation[j], 255.0 * params.obs.getData().getDouble(j), 0.00001); + assertEquals(20, thread.getEpochCount()); + assertEquals(10, thread.getEpisodeCount()); + assertEquals(100, thread.getStepCount()); + + // There should be 20 calls to trainsubepoch with 5 steps per epoch + verify(thread, times(20)).trainSubEpoch(any(Observation.class), eq(5)); + } + + @Test + public void when_remainingEpisodeLengthSmallerThanNSteps_expect_trainSubEpochCalledWithMinimumValue() { + + int currentEpisodeSteps = 95; + mockTrainingContext(1000, 100, 10); + mockTrainingListeners(false, true); + + // want to mock that we are 95 steps into the episode + doAnswer(invocationOnMock -> { + for (int i = 0; i < currentEpisodeSteps; i++) { + thread.incrementSteps(); } - } - } - - private static class TestContext { - public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(); - public final MockNeuralNet neuralNet = new MockNeuralNet(); - public final MockObservationSpace observationSpace = new MockObservationSpace(); - public final MockMDP mdp = new MockMDP(observationSpace); - public final MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 10, 0, 0, 0, 0, 0, 0, 10, 0); - public final TrainingListenerList listeners = new TrainingListenerList(); - public final MockTrainingListener listener = new MockTrainingListener(); - public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); - public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf); - - public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners); - - public TestContext(int numEpochs) { - asyncGlobal.setMaxLoops(numEpochs); - listeners.add(listener); - sut.setHistoryProcessor(historyProcessor); - sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); - } - } - - public static class MockAsyncThread extends AsyncThread { - - public int preEpochCallCount = 0; - public int postEpochCallCount = 0; - - private final MockAsyncGlobal asyncGlobal; - private final MockNeuralNet neuralNet; - private final IAsyncLearningConfiguration conf; - - private final List trainSubEpochParams = new ArrayList(); - - public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, IAsyncLearningConfiguration conf, TrainingListenerList listeners) { - super(asyncGlobal, mdp, listeners, threadNumber, 0); - - this.asyncGlobal = asyncGlobal; - this.neuralNet = neuralNet; - this.conf = conf; - } - - @Override - protected void preEpoch() { - ++preEpochCallCount; - super.preEpoch(); - } - - @Override - protected void postEpoch() { - ++postEpochCallCount; - super.postEpoch(); - } - - @Override - protected MockNeuralNet getCurrent() { - return neuralNet; - } - - @Override - protected IAsyncGlobal getAsyncGlobal() { - return asyncGlobal; - } - - @Override - protected IAsyncLearningConfiguration getConf() { - return conf; - } - - @Override - protected Policy getPolicy(MockNeuralNet net) { return null; - } + }).when(thread).preEpisode(); - @Override - protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) { - asyncGlobal.increaseCurrentLoop(); - trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep)); - for(int i = 0; i < nstep; ++i) { - incrementStep(); - } - return new SubEpochReturn(nstep, null, 1.0, 1.0); - } + mockTrainingListeners(false, true); - @AllArgsConstructor - @Getter - public static class TrainSubEpochParams { - Observation obs; - int nstep; - } + // Act + thread.run(); + + // Assert + assertEquals(1, thread.getEpochCount()); + assertEquals(1, thread.getEpisodeCount()); + assertEquals(100, thread.getStepCount()); + + // There should be 1 call to trainsubepoch with 5 steps as this is the remaining episode steps + verify(thread, times(1)).trainSubEpoch(any(Observation.class), eq(5)); } + } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithmTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithmTest.java deleted file mode 100644 index 1434796f3..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithmTest.java +++ /dev/null @@ -1,160 +0,0 @@ -package org.deeplearning4j.rl4j.learning.async.a3c.discrete; - -import org.deeplearning4j.nn.api.NeuralNetwork; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.experience.StateActionPair; -import org.deeplearning4j.rl4j.network.NeuralNet; -import org.deeplearning4j.rl4j.network.ac.IActorCritic; -import org.deeplearning4j.rl4j.observation.Observation; -import org.deeplearning4j.rl4j.support.MockAsyncGlobal; -import org.deeplearning4j.rl4j.support.MockMDP; -import org.deeplearning4j.rl4j.support.MockObservationSpace; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.Assert.assertEquals; - -public class A3CUpdateAlgorithmTest { - - @Test - public void refac_calcGradient_non_terminal() { - // Arrange - double gamma = 0.9; - MockObservationSpace observationSpace = new MockObservationSpace(new int[] { 5 }); - MockMDP mdpMock = new MockMDP(observationSpace); - MockActorCritic actorCriticMock = new MockActorCritic(); - MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(actorCriticMock); - A3CUpdateAlgorithm sut = new A3CUpdateAlgorithm(asyncGlobalMock, observationSpace.getShape(), mdpMock.getActionSpace().getSize(), -1, gamma); - - - INDArray[] originalObservations = new INDArray[] { - Nd4j.create(new double[] { 0.0, 0.1, 0.2, 0.3, 0.4 }), - Nd4j.create(new double[] { 1.0, 1.1, 1.2, 1.3, 1.4 }), - Nd4j.create(new double[] { 2.0, 2.1, 2.2, 2.3, 2.4 }), - Nd4j.create(new double[] { 3.0, 3.1, 3.2, 3.3, 3.4 }), - }; - int[] actions = new int[] { 0, 1, 2, 1 }; - double[] rewards = new double[] { 0.1, 1.0, 10.0, 100.0 }; - - List> experience = new ArrayList>(); - for(int i = 0; i < originalObservations.length; ++i) { - experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false)); - } - - // Act - sut.computeGradients(actorCriticMock, experience); - - // Assert - assertEquals(1, actorCriticMock.gradientParams.size()); - - // Inputs - INDArray input = actorCriticMock.gradientParams.get(0).getLeft(); - for(int i = 0; i < 4; ++i) { - for(int j = 0; j < 5; ++j) { - assertEquals(i + j / 10.0, input.getDouble(i, j), 0.00001); - } - } - - INDArray targets = actorCriticMock.gradientParams.get(0).getRight()[0]; - INDArray logSoftmax = actorCriticMock.gradientParams.get(0).getRight()[1]; - - assertEquals(4, targets.shape()[0]); - assertEquals(1, targets.shape()[1]); - - // FIXME: check targets values once fixed - - assertEquals(4, logSoftmax.shape()[0]); - assertEquals(5, logSoftmax.shape()[1]); - - // FIXME: check logSoftmax values once fixed - - } - - public class MockActorCritic implements IActorCritic { - - public final List> gradientParams = new ArrayList<>(); - - @Override - public NeuralNetwork[] getNeuralNetworks() { - return new NeuralNetwork[0]; - } - - @Override - public boolean isRecurrent() { - return false; - } - - @Override - public void reset() { - - } - - @Override - public void fit(INDArray input, INDArray[] labels) { - - } - - @Override - public INDArray[] outputAll(INDArray batch) { - return new INDArray[] { batch.mul(-1.0) }; - } - - @Override - public IActorCritic clone() { - return this; - } - - @Override - public void copy(NeuralNet from) { - - } - - @Override - public void copy(IActorCritic from) { - - } - - @Override - public Gradient[] gradient(INDArray input, INDArray[] labels) { - gradientParams.add(new Pair(input, labels)); - return new Gradient[0]; - } - - @Override - public void applyGradient(Gradient[] gradient, int batchSize) { - - } - - @Override - public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException { - - } - - @Override - public void save(String pathValue, String pathPolicy) throws IOException { - - } - - @Override - public double getLatestScore() { - return 0; - } - - @Override - public void save(OutputStream os) throws IOException { - - } - - @Override - public void save(String filename) throws IOException { - - } - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java new file mode 100644 index 000000000..131dbce63 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java @@ -0,0 +1,93 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.async.a3c.discrete; + +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; +import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) +public class AdvantageActorCriticUpdateAlgorithmTest { + + @Mock + AsyncGlobal mockAsyncGlobal; + + @Mock + IActorCritic mockActorCritic; + + @Test + public void refac_calcGradient_non_terminal() { + // Arrange + int[] observationShape = new int[]{5}; + double gamma = 0.9; + AdvantageActorCriticUpdateAlgorithm algorithm = new AdvantageActorCriticUpdateAlgorithm(false, observationShape, 1, gamma); + + INDArray[] originalObservations = new INDArray[]{ + Nd4j.create(new double[]{0.0, 0.1, 0.2, 0.3, 0.4}), + Nd4j.create(new double[]{1.0, 1.1, 1.2, 1.3, 1.4}), + Nd4j.create(new double[]{2.0, 2.1, 2.2, 2.3, 2.4}), + Nd4j.create(new double[]{3.0, 3.1, 3.2, 3.3, 3.4}), + }; + + int[] actions = new int[]{0, 1, 2, 1}; + double[] rewards = new double[]{0.1, 1.0, 10.0, 100.0}; + + List> experience = new ArrayList>(); + for (int i = 0; i < originalObservations.length; ++i) { + experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false)); + } + + when(mockActorCritic.outputAll(any(INDArray.class))).thenAnswer(invocation -> { + INDArray batch = invocation.getArgument(0); + return new INDArray[]{batch.mul(-1.0)}; + }); + + ArgumentCaptor inputArgumentCaptor = ArgumentCaptor.forClass(INDArray.class); + ArgumentCaptor criticActorArgumentCaptor = ArgumentCaptor.forClass(INDArray[].class); + + // Act + algorithm.computeGradients(mockActorCritic, experience); + + verify(mockActorCritic, times(1)).gradient(inputArgumentCaptor.capture(), criticActorArgumentCaptor.capture()); + + assertEquals(Nd4j.stack(0, originalObservations), inputArgumentCaptor.getValue()); + + //TODO: the actual AdvantageActorCritic Algo is not implemented correctly, so needs to be fixed, then we can test these +// assertEquals(Nd4j.zeros(1), criticActorArgumentCaptor.getValue()[0]); +// assertEquals(Nd4j.zeros(1), criticActorArgumentCaptor.getValue()[1]); + + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java index 56b8494a0..d73a2a9f7 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java @@ -1,3 +1,19 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.async.listener; import org.deeplearning4j.rl4j.learning.IEpochTrainer; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java index 35465d26a..f44437d67 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java @@ -1,11 +1,30 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.observation.Observation; -import org.deeplearning4j.rl4j.support.MockAsyncGlobal; import org.deeplearning4j.rl4j.support.MockDQN; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -14,17 +33,21 @@ import java.util.List; import static org.junit.Assert.assertEquals; +@RunWith(MockitoJUnitRunner.class) public class QLearningUpdateAlgorithmTest { + @Mock + AsyncGlobal mockAsyncGlobal; + @Test public void when_isTerminal_expect_initRewardIs0() { // Arrange MockDQN dqnMock = new MockDQN(); - MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock); - UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 1 }, 1, -1, 1.0); + UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 1 }, 1, 1.0); + final Observation observation = new Observation(Nd4j.zeros(1)); List> experience = new ArrayList>() { { - add(new StateActionPair(new Observation(Nd4j.zeros(1)), 0, 0.0, true)); + add(new StateActionPair(observation, 0, 0.0, true)); } }; @@ -38,12 +61,11 @@ public class QLearningUpdateAlgorithmTest { @Test public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() { // Arrange - MockDQN globalDQNMock = new MockDQN(); - MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock); - UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, -1, 1.0); + UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, 1.0); + final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 })); List> experience = new ArrayList>() { { - add(new StateActionPair(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false)); + add(new StateActionPair(observation, 0, 0.0, false)); } }; MockDQN dqnMock = new MockDQN(); @@ -57,35 +79,11 @@ public class QLearningUpdateAlgorithmTest { assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); } - @Test - public void when_terminalWithTargetUpdate_expect_initRewardWithMaxQFromGlobal() { - // Arrange - MockDQN globalDQNMock = new MockDQN(); - MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock); - UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, 1.0); - List> experience = new ArrayList>() { - { - add(new StateActionPair(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false)); - } - }; - MockDQN dqnMock = new MockDQN(); - - // Act - sut.computeGradients(dqnMock, experience); - - // Assert - assertEquals(1, globalDQNMock.outputAllParams.size()); - assertEquals(-123.0, globalDQNMock.outputAllParams.get(0).getDouble(0, 0), 0.00001); - assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); - } - @Test public void when_callingWithMultipleExperiences_expect_gradientsAreValid() { // Arrange double gamma = 0.9; - MockDQN globalDQNMock = new MockDQN(); - MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock); - UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, gamma); + UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, gamma); List> experience = new ArrayList>() { { add(new StateActionPair(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false)); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java index a926864e4..fc67ac151 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java @@ -1,20 +1,56 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.listener; -import org.deeplearning4j.rl4j.support.MockTrainingListener; +import org.deeplearning4j.rl4j.learning.IEpochTrainer; +import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; +import org.mockito.Mock; import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class TrainingListenerListTest { + + @Mock + IEpochTrainer mockTrainer; + + @Mock + ILearning mockLearning; + + @Mock + IDataManager.StatEntry mockStatEntry; + @Test public void when_listIsEmpty_expect_notifyReturnTrue() { // Arrange - TrainingListenerList sut = new TrainingListenerList(); + TrainingListenerList trainingListenerList = new TrainingListenerList(); // Act - boolean resultTrainingStarted = sut.notifyTrainingStarted(); - boolean resultNewEpoch = sut.notifyNewEpoch(null); - boolean resultEpochFinished = sut.notifyEpochTrainingResult(null, null); + boolean resultTrainingStarted = trainingListenerList.notifyTrainingStarted(); + boolean resultNewEpoch = trainingListenerList.notifyNewEpoch(null); + boolean resultEpochFinished = trainingListenerList.notifyEpochTrainingResult(null, null); // Assert assertTrue(resultTrainingStarted); @@ -25,54 +61,56 @@ public class TrainingListenerListTest { @Test public void when_firstListerStops_expect_othersListnersNotCalled() { // Arrange - MockTrainingListener listener1 = new MockTrainingListener(); - listener1.setRemainingTrainingStartCallCount(0); - listener1.setRemainingOnNewEpochCallCount(0); - listener1.setRemainingonTrainingProgressCallCount(0); - listener1.setRemainingOnEpochTrainingResult(0); - MockTrainingListener listener2 = new MockTrainingListener(); - TrainingListenerList sut = new TrainingListenerList(); - sut.add(listener1); - sut.add(listener2); + TrainingListener listener1 = mock(TrainingListener.class); + TrainingListener listener2 = mock(TrainingListener.class); + TrainingListenerList trainingListenerList = new TrainingListenerList(); + trainingListenerList.add(listener1); + trainingListenerList.add(listener2); + + when(listener1.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP); + when(listener1.onNewEpoch(eq(mockTrainer))).thenReturn(TrainingListener.ListenerResponse.STOP); + when(listener1.onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry))).thenReturn(TrainingListener.ListenerResponse.STOP); + when(listener1.onTrainingProgress(eq(mockLearning))).thenReturn(TrainingListener.ListenerResponse.STOP); // Act - sut.notifyTrainingStarted(); - sut.notifyNewEpoch(null); - sut.notifyEpochTrainingResult(null, null); - sut.notifyTrainingProgress(null); - sut.notifyTrainingFinished(); + trainingListenerList.notifyTrainingStarted(); + trainingListenerList.notifyNewEpoch(mockTrainer); + trainingListenerList.notifyEpochTrainingResult(mockTrainer, null); + trainingListenerList.notifyTrainingProgress(mockLearning); + trainingListenerList.notifyTrainingFinished(); // Assert - assertEquals(1, listener1.onTrainingStartCallCount); - assertEquals(0, listener2.onTrainingStartCallCount); - assertEquals(1, listener1.onNewEpochCallCount); - assertEquals(0, listener2.onNewEpochCallCount); + verify(listener1, times(1)).onTrainingStart(); + verify(listener2, never()).onTrainingStart(); - assertEquals(1, listener1.onEpochTrainingResultCallCount); - assertEquals(0, listener2.onEpochTrainingResultCallCount); + verify(listener1, times(1)).onNewEpoch(eq(mockTrainer)); + verify(listener2, never()).onNewEpoch(eq(mockTrainer)); - assertEquals(1, listener1.onTrainingProgressCallCount); - assertEquals(0, listener2.onTrainingProgressCallCount); + verify(listener1, times(1)).onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry)); + verify(listener2, never()).onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry)); - assertEquals(1, listener1.onTrainingEndCallCount); - assertEquals(1, listener2.onTrainingEndCallCount); + verify(listener1, times(1)).onTrainingProgress(eq(mockLearning)); + verify(listener2, never()).onTrainingProgress(eq(mockLearning)); + + verify(listener1, times(1)).onTrainingEnd(); + verify(listener2, times(1)).onTrainingEnd(); } @Test public void when_allListenersContinue_expect_listReturnsTrue() { // Arrange - MockTrainingListener listener1 = new MockTrainingListener(); - MockTrainingListener listener2 = new MockTrainingListener(); - TrainingListenerList sut = new TrainingListenerList(); - sut.add(listener1); - sut.add(listener2); + TrainingListener listener1 = mock(TrainingListener.class); + TrainingListener listener2 = mock(TrainingListener.class); + TrainingListenerList trainingListenerList = new TrainingListenerList(); + trainingListenerList.add(listener1); + trainingListenerList.add(listener2); // Act - boolean resultTrainingStarted = sut.notifyTrainingStarted(); - boolean resultNewEpoch = sut.notifyNewEpoch(null); - boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null); - boolean resultProgress = sut.notifyTrainingProgress(null); + boolean resultTrainingStarted = trainingListenerList.notifyTrainingStarted(); + boolean resultNewEpoch = trainingListenerList.notifyNewEpoch(null); + boolean resultEpochTrainingResult = trainingListenerList.notifyEpochTrainingResult(null, null); + boolean resultProgress = trainingListenerList.notifyTrainingProgress(null); // Assert assertTrue(resultTrainingStarted); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java index 22e4be3f6..c17bc93d8 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java @@ -1,151 +1,117 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - package org.deeplearning4j.rl4j.learning.sync; -import lombok.Getter; +import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; -import org.deeplearning4j.rl4j.learning.configuration.LearningConfiguration; -import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; -import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; +import org.deeplearning4j.rl4j.learning.listener.TrainingListener; import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; -import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; -import org.deeplearning4j.rl4j.policy.IPolicy; -import org.deeplearning4j.rl4j.support.MockTrainingListener; +import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.util.IDataManager; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; +import org.nd4j.linalg.api.ndarray.INDArray; -import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +@RunWith(MockitoJUnitRunner.class) public class SyncLearningTest { + @Mock + TrainingListener mockTrainingListener; + + SyncLearning, NeuralNet> syncLearning; + + @Mock + ILearningConfiguration mockLearningConfiguration; + + @Before + public void setup() { + + syncLearning = mock(SyncLearning.class, Mockito.withSettings() + .useConstructor() + .defaultAnswer(Mockito.CALLS_REAL_METHODS)); + + syncLearning.addListener(mockTrainingListener); + + when(syncLearning.trainEpoch()).thenAnswer(invocation -> { + //syncLearning.incrementEpoch(); + syncLearning.incrementStep(); + return new MockStatEntry(syncLearning.getEpochCount(), syncLearning.getStepCount(), 1.0); + }); + + when(syncLearning.getConfiguration()).thenReturn(mockLearningConfiguration); + when(mockLearningConfiguration.getMaxStep()).thenReturn(100); + } + @Test public void when_training_expect_listenersToBeCalled() { - // Arrange - QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); - MockTrainingListener listener = new MockTrainingListener(); - MockSyncLearning sut = new MockSyncLearning(lconfig); - sut.addListener(listener); // Act - sut.train(); + syncLearning.train(); + + verify(mockTrainingListener, times(1)).onTrainingStart(); + verify(mockTrainingListener, times(100)).onNewEpoch(eq(syncLearning)); + verify(mockTrainingListener, times(100)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class)); + verify(mockTrainingListener, times(1)).onTrainingEnd(); - assertEquals(1, listener.onTrainingStartCallCount); - assertEquals(10, listener.onNewEpochCallCount); - assertEquals(10, listener.onEpochTrainingResultCallCount); - assertEquals(1, listener.onTrainingEndCallCount); } @Test public void when_trainingStartCanContinueFalse_expect_trainingStopped() { // Arrange - QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); - MockTrainingListener listener = new MockTrainingListener(); - MockSyncLearning sut = new MockSyncLearning(lconfig); - sut.addListener(listener); - listener.setRemainingTrainingStartCallCount(0); + when(mockTrainingListener.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP); // Act - sut.train(); + syncLearning.train(); - assertEquals(1, listener.onTrainingStartCallCount); - assertEquals(0, listener.onNewEpochCallCount); - assertEquals(0, listener.onEpochTrainingResultCallCount); - assertEquals(1, listener.onTrainingEndCallCount); + verify(mockTrainingListener, times(1)).onTrainingStart(); + verify(mockTrainingListener, times(0)).onNewEpoch(eq(syncLearning)); + verify(mockTrainingListener, times(0)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class)); + verify(mockTrainingListener, times(1)).onTrainingEnd(); } @Test public void when_newEpochCanContinueFalse_expect_trainingStopped() { // Arrange - QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); - MockTrainingListener listener = new MockTrainingListener(); - MockSyncLearning sut = new MockSyncLearning(lconfig); - sut.addListener(listener); - listener.setRemainingOnNewEpochCallCount(2); + when(mockTrainingListener.onNewEpoch(eq(syncLearning))) + .thenReturn(TrainingListener.ListenerResponse.CONTINUE) + .thenReturn(TrainingListener.ListenerResponse.CONTINUE) + .thenReturn(TrainingListener.ListenerResponse.STOP); // Act - sut.train(); + syncLearning.train(); + + verify(mockTrainingListener, times(1)).onTrainingStart(); + verify(mockTrainingListener, times(3)).onNewEpoch(eq(syncLearning)); + verify(mockTrainingListener, times(2)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class)); + verify(mockTrainingListener, times(1)).onTrainingEnd(); - assertEquals(1, listener.onTrainingStartCallCount); - assertEquals(3, listener.onNewEpochCallCount); - assertEquals(2, listener.onEpochTrainingResultCallCount); - assertEquals(1, listener.onTrainingEndCallCount); } @Test public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() { // Arrange - LearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); - MockTrainingListener listener = new MockTrainingListener(); - MockSyncLearning sut = new MockSyncLearning(lconfig); - sut.addListener(listener); - listener.setRemainingOnEpochTrainingResult(2); + when(mockTrainingListener.onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class))) + .thenReturn(TrainingListener.ListenerResponse.CONTINUE) + .thenReturn(TrainingListener.ListenerResponse.CONTINUE) + .thenReturn(TrainingListener.ListenerResponse.STOP); // Act - sut.train(); + syncLearning.train(); - assertEquals(1, listener.onTrainingStartCallCount); - assertEquals(3, listener.onNewEpochCallCount); - assertEquals(3, listener.onEpochTrainingResultCallCount); - assertEquals(1, listener.onTrainingEndCallCount); - } - - public static class MockSyncLearning extends SyncLearning { - - private final ILearningConfiguration conf; - - @Getter - private int currentEpochStep = 0; - - public MockSyncLearning(ILearningConfiguration conf) { - this.conf = conf; - } - - @Override - protected void preEpoch() { currentEpochStep = 0; } - - @Override - protected void postEpoch() { } - - @Override - protected IDataManager.StatEntry trainEpoch() { - setStepCounter(getStepCounter() + 1); - return new MockStatEntry(getCurrentEpochStep(), getStepCounter(), 1.0); - } - - @Override - public NeuralNet getNeuralNet() { - return null; - } - - @Override - public IPolicy getPolicy() { - return null; - } - - @Override - public ILearningConfiguration getConfiguration() { - return conf; - } - - @Override - public MDP getMdp() { - return null; - } + verify(mockTrainingListener, times(1)).onTrainingStart(); + verify(mockTrainingListener, times(3)).onNewEpoch(eq(syncLearning)); + verify(mockTrainingListener, times(3)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class)); + verify(mockTrainingListener, times(1)).onTrainingEnd(); } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 9d77084d5..82129e0df 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -17,6 +17,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; +import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; @@ -26,11 +27,21 @@ import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; +import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.api.rng.Random; @@ -40,150 +51,146 @@ import java.util.ArrayList; import java.util.List; import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +@RunWith(MockitoJUnitRunner.class) public class QLearningDiscreteTest { + + QLearningDiscrete qLearningDiscrete; + + @Mock + IHistoryProcessor mockHistoryProcessor; + + @Mock + IHistoryProcessor.Configuration mockHistoryConfiguration; + + @Mock + MDP mockMDP; + + @Mock + DiscreteSpace mockActionSpace; + + @Mock + ObservationSpace mockObservationSpace; + + @Mock + IDQN mockDQN; + + @Mock + QLearningConfiguration mockQlearningConfiguration; + + int[] observationShape = new int[]{3, 10, 10}; + int totalObservationSize = 1; + + private void setupMDPMocks() { + + when(mockObservationSpace.getShape()).thenReturn(observationShape); + + when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace); + when(mockMDP.getActionSpace()).thenReturn(mockActionSpace); + + int dataLength = 1; + for (int d : observationShape) { + dataLength *= d; + } + } + + + private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay) { + when(mockQlearningConfiguration.getBatchSize()).thenReturn(batchSize); + when(mockQlearningConfiguration.getRewardFactor()).thenReturn(rewardFactor); + when(mockQlearningConfiguration.getExpRepMaxSize()).thenReturn(maxExperienceReplay); + when(mockQlearningConfiguration.getSeed()).thenReturn(123L); + + qLearningDiscrete = mock( + QLearningDiscrete.class, + Mockito.withSettings() + .useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0) + .defaultAnswer(Mockito.CALLS_REAL_METHODS) + ); + } + + private void mockHistoryProcessor(int skipFrames) { + when(mockHistoryConfiguration.getRescaledHeight()).thenReturn(observationShape[1]); + when(mockHistoryConfiguration.getRescaledWidth()).thenReturn(observationShape[2]); + + when(mockHistoryConfiguration.getOffsetX()).thenReturn(0); + when(mockHistoryConfiguration.getOffsetY()).thenReturn(0); + + when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]); + when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]); + when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames); + when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration); + + qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor); + } + + @Before + public void setup() { + setupMDPMocks(); + + for (int i : observationShape) { + totalObservationSize *= i; + } + + } + @Test - public void refac_QLearningDiscrete_trainStep() { + public void when_singleTrainStep_expect_correctValues() { + // Arrange - MockObservationSpace observationSpace = new MockObservationSpace(); - MockDQN dqn = new MockDQN(); - MockRandom random = new MockRandom(new double[]{ - 0.7309677600860596, - 0.8314409852027893, - 0.2405363917350769, - 0.6063451766967773, - 0.6374173760414124, - 0.3090505599975586, - 0.5504369735717773, - 0.11700659990310669 - }, - new int[]{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}); - MockMDP mdp = new MockMDP(observationSpace, random); + mockTestContext(100,0,2,1.0, 10); - int initStepCount = 8; + // An example observation and 2 Q values output (2 actions) + Observation observation = new Observation(Nd4j.zeros(observationShape)); + when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f})); - QLearningConfiguration conf = QLearningConfiguration.builder() - .seed(0L) - .maxEpochStep(24) - .maxStep(0) - .expRepMaxSize(5).batchSize(1).targetDqnUpdateFreq(1000) - .updateStart(initStepCount) - .rewardFactor(1.0) - .gamma(0) - .errorClamp(0) - .minEpsilon(0) - .epsilonNbStep(0) - .doubleDQN(true) - .build(); - - MockDataManager dataManager = new MockDataManager(false); - MockExperienceHandler experienceHandler = new MockExperienceHandler(); - TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, experienceHandler, 10, random); - IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); - MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); - sut.setHistoryProcessor(hp); - sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); - List> results = new ArrayList<>(); + when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null)); // Act - IDataManager.StatEntry result = sut.trainEpoch(); + QLearning.QLStepReturn stepReturn = qLearningDiscrete.trainStep(observation); // Assert - // HistoryProcessor calls - double[] expectedRecords = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; - assertEquals(expectedRecords.length, hp.recordCalls.size()); - for (int i = 0; i < expectedRecords.length; ++i) { - assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); - } - assertEquals(0, hp.startMonitorCallCount); - assertEquals(0, hp.stopMonitorCallCount); + assertEquals(1.0, stepReturn.getMaxQ(), 1e-5); - // DQN calls - assertEquals(1, dqn.fitParams.size()); - assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001); - assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001); - assertEquals(14, dqn.outputParams.size()); - double[][] expectedDQNOutput = new double[][]{ - new double[]{0.0, 2.0, 4.0, 6.0, 8.0}, - new double[]{2.0, 4.0, 6.0, 8.0, 10.0}, - new double[]{2.0, 4.0, 6.0, 8.0, 10.0}, - new double[]{4.0, 6.0, 8.0, 10.0, 12.0}, - new double[]{6.0, 8.0, 10.0, 12.0, 14.0}, - new double[]{6.0, 8.0, 10.0, 12.0, 14.0}, - new double[]{8.0, 10.0, 12.0, 14.0, 16.0}, - new double[]{8.0, 10.0, 12.0, 14.0, 16.0}, - new double[]{10.0, 12.0, 14.0, 16.0, 18.0}, - new double[]{10.0, 12.0, 14.0, 16.0, 18.0}, - new double[]{12.0, 14.0, 16.0, 18.0, 20.0}, - new double[]{12.0, 14.0, 16.0, 18.0, 20.0}, - new double[]{14.0, 16.0, 18.0, 20.0, 22.0}, - new double[]{14.0, 16.0, 18.0, 20.0, 22.0}, - }; - for (int i = 0; i < expectedDQNOutput.length; ++i) { - INDArray outputParam = dqn.outputParams.get(i); + StepReply stepReply = stepReturn.getStepReply(); - assertEquals(5, outputParam.shape()[1]); - assertEquals(1, outputParam.shape()[2]); + assertEquals(0, stepReply.getReward(), 1e-5); + assertFalse(stepReply.isDone()); + assertFalse(stepReply.getObservation().isSkipped()); + assertEquals(observation.getData().reshape(observationShape), stepReply.getObservation().getData().reshape(observationShape)); - double[] expectedRow = expectedDQNOutput[i]; - for (int j = 0; j < expectedRow.length; ++j) { - assertEquals("row: " + i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001); - } - } - - // MDP calls - assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray()); - - // ExperienceHandler calls - double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 }; - int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 }; - double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; - double[][] expectedTrObservations = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, - new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, - new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, - new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, - }; - - assertEquals(expectedTrObservations.length, experienceHandler.addExperienceArgs.size()); - for(int i = 0; i < expectedTrRewards.length; ++i) { - StateActionPair stateActionPair = experienceHandler.addExperienceArgs.get(i); - assertEquals(expectedTrRewards[i], stateActionPair.getReward(), 0.0001); - assertEquals((int)expectedTrActions[i], (int)stateActionPair.getAction()); - for(int j = 0; j < expectedTrObservations[i].length; ++j) { - assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * stateActionPair.getObservation().getData().getDouble(0, j, 0), 0.0001); - } - } - assertEquals(expectedTrNextObservation[expectedTrNextObservation.length - 1], 255.0 * experienceHandler.finalObservation.getData().getDouble(0), 0.0001); - - // trainEpoch result - assertEquals(initStepCount + 16, result.getStepCounter()); - assertEquals(300.0, result.getReward(), 0.00001); - assertTrue(dqn.hasBeenReset); - assertTrue(((MockDQN) sut.getTargetQNetwork()).hasBeenReset); } - public static class TestQLearningDiscrete extends QLearningDiscrete { - public TestQLearningDiscrete(MDP mdp, IDQN dqn, - QLearningConfiguration conf, IDataManager dataManager, ExperienceHandler> experienceHandler, - int epsilonNbStep, Random rnd) { - super(mdp, dqn, conf, epsilonNbStep, rnd); - addListener(new DataManagerTrainingListener(dataManager)); - setExperienceHandler(experienceHandler); - } + @Test + public void when_singleTrainStepSkippedFrames_expect_correctValues() { + // Arrange + mockTestContext(100,0,2,1.0, 10); - @Override - protected DataSet setTarget(List> transitions) { - return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); - } + mockHistoryProcessor(2); - @Override - public IDataManager.StatEntry trainEpoch() { - return super.trainEpoch(); - } + // An example observation and 2 Q values output (2 actions) + Observation observation = new Observation(Nd4j.zeros(observationShape)); + when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f})); + + when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null)); + + // Act + QLearning.QLStepReturn stepReturn = qLearningDiscrete.trainStep(observation); + + // Assert + assertEquals(1.0, stepReturn.getMaxQ(), 1e-5); + + StepReply stepReply = stepReturn.getStepReply(); + + assertEquals(0, stepReply.getReward(), 1e-5); + assertFalse(stepReply.isDone()); + assertTrue(stepReply.getObservation().isSkipped()); } + + //TODO: there are much more test cases here that can be improved upon + } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockMDP.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockMDP.java deleted file mode 100644 index 4352b9bee..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockMDP.java +++ /dev/null @@ -1,79 +0,0 @@ -package org.deeplearning4j.rl4j.learning.sync.support; - -import org.deeplearning4j.gym.StepReply; -import org.deeplearning4j.rl4j.mdp.MDP; -import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.ObservationSpace; -import org.nd4j.linalg.api.ndarray.INDArray; - -public class MockMDP implements MDP { - - private final int maxSteps; - private final DiscreteSpace actionSpace = new DiscreteSpace(1); - private final MockObservationSpace observationSpace = new MockObservationSpace(); - - private int currentStep = 0; - - public MockMDP(int maxSteps) { - - this.maxSteps = maxSteps; - } - - @Override - public ObservationSpace getObservationSpace() { - return observationSpace; - } - - @Override - public DiscreteSpace getActionSpace() { - return actionSpace; - } - - @Override - public Object reset() { - return null; - } - - @Override - public void close() { - - } - - @Override - public StepReply step(Integer integer) { - return new StepReply(null, 1.0, isDone(), null); - } - - @Override - public boolean isDone() { - return currentStep >= maxSteps; - } - - @Override - public MDP newInstance() { - return null; - } - - private static class MockObservationSpace implements ObservationSpace { - - @Override - public String getName() { - return null; - } - - @Override - public int[] getShape() { - return new int[0]; - } - - @Override - public INDArray getLow() { - return null; - } - - @Override - public INDArray getHigh() { - return null; - } - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index 0dc16df09..7db92a599 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -257,9 +257,9 @@ public class PolicyTest { } @Override - protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { + protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) { mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength)); - return super.refacInitMdp(mdpWrapper, hp, epochStepCounter); + return super.refacInitMdp(mdpWrapper, hp); } } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java deleted file mode 100644 index 08689b032..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java +++ /dev/null @@ -1,37 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.support; - -import lombok.AllArgsConstructor; -import lombok.Value; -import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; - -@Value -@AllArgsConstructor -public class MockAsyncConfiguration implements IAsyncLearningConfiguration { - - private Long seed; - private int maxEpochStep; - private int maxStep; - private int updateStart; - private double rewardFactor; - private double gamma; - private double errorClamp; - private int numThreads; - private int nStep; - private int learnerUpdateFrequency; -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java deleted file mode 100644 index 33dc82314..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java +++ /dev/null @@ -1,75 +0,0 @@ -package org.deeplearning4j.rl4j.support; - -import lombok.Getter; -import lombok.Setter; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; -import org.deeplearning4j.rl4j.network.NeuralNet; - -import java.util.concurrent.atomic.AtomicInteger; - -public class MockAsyncGlobal implements IAsyncGlobal { - - @Getter - private final NN current; - - public boolean hasBeenStarted = false; - public boolean hasBeenTerminated = false; - - public int enqueueCallCount = 0; - - @Setter - private int maxLoops; - @Setter - private int numLoopsStopRunning; - private int currentLoop = 0; - - public MockAsyncGlobal() { - this(null); - } - - public MockAsyncGlobal(NN current) { - maxLoops = Integer.MAX_VALUE; - numLoopsStopRunning = Integer.MAX_VALUE; - this.current = current; - } - - @Override - public boolean isRunning() { - return currentLoop < numLoopsStopRunning; - } - - @Override - public void terminate() { - hasBeenTerminated = true; - } - - @Override - public boolean isTrainingComplete() { - return currentLoop >= maxLoops; - } - - @Override - public void start() { - hasBeenStarted = true; - } - - @Override - public AtomicInteger getT() { - return null; - } - - @Override - public NN getTarget() { - return current; - } - - @Override - public void enqueue(Gradient[] gradient, Integer nstep) { - ++enqueueCallCount; - } - - public void increaseCurrentLoop() { - ++currentLoop; - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExperienceHandler.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExperienceHandler.java deleted file mode 100644 index 13ea5d93a..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExperienceHandler.java +++ /dev/null @@ -1,46 +0,0 @@ -package org.deeplearning4j.rl4j.support; - -import org.deeplearning4j.rl4j.experience.ExperienceHandler; -import org.deeplearning4j.rl4j.experience.StateActionPair; -import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.observation.Observation; - -import java.util.ArrayList; -import java.util.List; - -public class MockExperienceHandler implements ExperienceHandler> { - public List> addExperienceArgs = new ArrayList>(); - public Observation finalObservation; - public boolean isGenerateTrainingBatchCalled; - public boolean isResetCalled; - - @Override - public void addExperience(Observation observation, Integer action, double reward, boolean isTerminal) { - addExperienceArgs.add(new StateActionPair<>(observation, action, reward, isTerminal)); - } - - @Override - public void setFinalObservation(Observation observation) { - finalObservation = observation; - } - - @Override - public List> generateTrainingBatch() { - isGenerateTrainingBatchCalled = true; - return new ArrayList>() { - { - add(new Transition(null, 0, 0.0, false)); - } - }; - } - - @Override - public void reset() { - isResetCalled = true; - } - - @Override - public int getTrainingBatchSize() { - return 1; - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java deleted file mode 100644 index d4e696248..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java +++ /dev/null @@ -1,77 +0,0 @@ -package org.deeplearning4j.rl4j.support; - -import lombok.Setter; -import org.deeplearning4j.rl4j.learning.IEpochTrainer; -import org.deeplearning4j.rl4j.learning.ILearning; -import org.deeplearning4j.rl4j.learning.listener.*; -import org.deeplearning4j.rl4j.util.IDataManager; - -import java.util.ArrayList; -import java.util.List; - -public class MockTrainingListener implements TrainingListener { - - private final MockAsyncGlobal asyncGlobal; - public int onTrainingStartCallCount = 0; - public int onTrainingEndCallCount = 0; - public int onNewEpochCallCount = 0; - public int onEpochTrainingResultCallCount = 0; - public int onTrainingProgressCallCount = 0; - - @Setter - private int remainingTrainingStartCallCount = Integer.MAX_VALUE; - @Setter - private int remainingOnNewEpochCallCount = Integer.MAX_VALUE; - @Setter - private int remainingOnEpochTrainingResult = Integer.MAX_VALUE; - @Setter - private int remainingonTrainingProgressCallCount = Integer.MAX_VALUE; - - public final List statEntries = new ArrayList<>(); - - public MockTrainingListener() { - this(null); - } - - public MockTrainingListener(MockAsyncGlobal asyncGlobal) { - this.asyncGlobal = asyncGlobal; - } - - - @Override - public ListenerResponse onTrainingStart() { - ++onTrainingStartCallCount; - --remainingTrainingStartCallCount; - return remainingTrainingStartCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE; - } - - @Override - public ListenerResponse onNewEpoch(IEpochTrainer trainer) { - ++onNewEpochCallCount; - --remainingOnNewEpochCallCount; - return remainingOnNewEpochCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE; - } - - @Override - public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) { - ++onEpochTrainingResultCallCount; - --remainingOnEpochTrainingResult; - statEntries.add(statEntry); - return remainingOnEpochTrainingResult < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE; - } - - @Override - public ListenerResponse onTrainingProgress(ILearning learning) { - ++onTrainingProgressCallCount; - --remainingonTrainingProgressCallCount; - if(asyncGlobal != null) { - asyncGlobal.increaseCurrentLoop(); - } - return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE; - } - - @Override - public void onTrainingEnd() { - ++onTrainingEndCallCount; - } -} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockUpdateAlgorithm.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockUpdateAlgorithm.java deleted file mode 100644 index dbe2fe1fc..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockUpdateAlgorithm.java +++ /dev/null @@ -1,19 +0,0 @@ -package org.deeplearning4j.rl4j.support; - -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.experience.StateActionPair; -import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; - -import java.util.ArrayList; -import java.util.List; - -public class MockUpdateAlgorithm implements UpdateAlgorithm { - - public final List>> experiences = new ArrayList>>(); - - @Override - public Gradient[] computeGradients(MockNeuralNet current, List> experience) { - experiences.add(experience); - return new Gradient[0]; - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java index 3a2d5230a..0818889e5 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java @@ -151,20 +151,26 @@ public class DataManagerTrainingListenerTest { private static class TestTrainer implements IEpochTrainer, ILearning { @Override - public int getStepCounter() { + public int getStepCount() { return 0; } @Override - public int getEpochCounter() { + public int getEpochCount() { return 0; } @Override - public int getCurrentEpochStep() { + public int getEpisodeCount() { return 0; } + @Override + public int getCurrentEpisodeStepCount() { + return 0; + } + + @Getter @Setter private IHistoryProcessor historyProcessor;