From f1debe8c077e97d616a9e76893529f037c148473 Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Fri, 10 Apr 2020 19:50:40 -0400 Subject: [PATCH] RL4J: Add ExperienceHandler (#369) * Added ExperienceHandler Signed-off-by: Alexandre Boulanger * Added getTrainingBatchSize() Signed-off-by: Alexandre Boulanger --- .../rl4j/experience/ExperienceHandler.java | 54 +++++ .../ReplayMemoryExperienceHandler.java | 111 ++++++++++ .../StateActionExperienceHandler.java | 67 ++++++ .../rl4j/experience/StateActionPair.java | 49 +++++ .../learning/async/AsyncThreadDiscrete.java | 72 +++---- .../{MiniTrans.java => UpdateAlgorithm.java} | 66 +++--- .../async/a3c/discrete/A3CThreadDiscrete.java | 61 +----- .../a3c/discrete/A3CUpdateAlgorithm.java | 113 ++++++++++ .../AsyncNStepQLearningThreadDiscrete.java | 40 +--- .../discrete/QLearningUpdateAlgorithm.java | 88 ++++++++ .../rl4j/learning/sync/ExpReplay.java | 5 +- .../rl4j/learning/sync/IExpReplay.java | 6 +- .../learning/sync/qlearning/QLearning.java | 24 --- .../qlearning/discrete/QLearningDiscrete.java | 42 ++-- .../ReplayMemoryExperienceHandlerTest.java | 107 ++++++++++ .../StateActionExperienceHandlerTest.java | 82 ++++++++ .../async/AsyncThreadDiscreteTest.java | 82 +++----- .../a3c/discrete/A3CThreadDiscreteTest.java | 197 ------------------ .../a3c/discrete/A3CUpdateAlgorithmTest.java | 160 ++++++++++++++ ...AsyncNStepQLearningThreadDiscreteTest.java | 98 --------- .../QLearningUpdateAlgorithmTest.java | 115 ++++++++++ .../discrete/QLearningDiscreteTest.java | 62 +++--- .../deeplearning4j/rl4j/support/MockDQN.java | 6 +- .../rl4j/support/MockExpReplay.java | 22 -- .../rl4j/support/MockExperienceHandler.java | 46 ++++ .../rl4j/support/MockObservationSpace.java | 14 +- .../rl4j/support/MockUpdateAlgorithm.java | 19 ++ 27 files changed, 1183 insertions(+), 625 deletions(-) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionPair.java rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/{MiniTrans.java => UpdateAlgorithm.java} (57%) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithm.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithmTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExpReplay.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExperienceHandler.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockUpdateAlgorithm.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java new file mode 100644 index 000000000..1ec4f05c1 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.experience; + +import org.deeplearning4j.rl4j.observation.Observation; + +import java.util.List; + +/** + * A common interface to all classes capable of handling experience generated by the agents in a learning context. + * + * @param Action type + * @param Experience type + * + * @author Alexandre Boulanger + */ +public interface ExperienceHandler { + void addExperience(Observation observation, A action, double reward, boolean isTerminal); + + /** + * Called when the episode is done with the last observation + * @param observation + */ + void setFinalObservation(Observation observation); + + /** + * @return The size of the list that will be returned by generateTrainingBatch(). + */ + int getTrainingBatchSize(); + + /** + * The elements are returned in the historical order (i.e. in the order they happened) + * @return The list of experience elements + */ + List generateTrainingBatch(); + + /** + * Signal the experience handler that a new episode is starting + */ + void reset(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java new file mode 100644 index 000000000..74b7e3f05 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java @@ -0,0 +1,111 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.experience; + +import lombok.EqualsAndHashCode; +import org.deeplearning4j.rl4j.learning.sync.ExpReplay; +import org.deeplearning4j.rl4j.learning.sync.IExpReplay; +import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.deeplearning4j.rl4j.observation.Observation; +import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.List; + +/** + * A experience handler that stores the experience in a replay memory. See https://arxiv.org/abs/1312.5602 + * The experience container is a {@link Transition Transition} that stores the tuple observation-action-reward-nextObservation, + * as well as whether or the not the episode ended after the Transition + * + * @param Action type + */ +@EqualsAndHashCode +public class ReplayMemoryExperienceHandler implements ExperienceHandler> { + private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000; + private static final int DEFAULT_BATCH_SIZE = 32; + + private IExpReplay expReplay; + + private Transition pendingTransition; + + public ReplayMemoryExperienceHandler(IExpReplay expReplay) { + this.expReplay = expReplay; + } + + public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) { + this(new ExpReplay(maxReplayMemorySize, batchSize, random)); + } + + public void addExperience(Observation observation, A action, double reward, boolean isTerminal) { + setNextObservationOnPending(observation); + pendingTransition = new Transition<>(observation, action, reward, isTerminal); + } + + public void setFinalObservation(Observation observation) { + setNextObservationOnPending(observation); + pendingTransition = null; + } + + @Override + public int getTrainingBatchSize() { + return expReplay.getBatchSize(); + } + + /** + * @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call. + */ + @Override + public List> generateTrainingBatch() { + return expReplay.getBatch(); + } + + @Override + public void reset() { + pendingTransition = null; + } + + private void setNextObservationOnPending(Observation observation) { + if(pendingTransition != null) { + pendingTransition.setNextObservation(observation); + expReplay.store(pendingTransition); + } + } + + public class Builder { + private int maxReplayMemorySize = DEFAULT_MAX_REPLAY_MEMORY_SIZE; + private int batchSize = DEFAULT_BATCH_SIZE; + private Random random = Nd4j.getRandom(); + + public Builder maxReplayMemorySize(int value) { + maxReplayMemorySize = value; + return this; + } + + public Builder batchSize(int value) { + batchSize = value; + return this; + } + + public Builder random(Random value) { + random = value; + return this; + } + + public ReplayMemoryExperienceHandler build() { + return new ReplayMemoryExperienceHandler(maxReplayMemorySize, batchSize, random); + } + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java new file mode 100644 index 000000000..39338c6c0 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java @@ -0,0 +1,67 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.experience; + +import org.deeplearning4j.rl4j.observation.Observation; + +import java.util.ArrayList; +import java.util.List; + +/** + * A simple {@link ExperienceHandler experience handler} that stores the experiences. + * Note: Calling {@link StateActionExperienceHandler#generateTrainingBatch() generateTrainingBatch()} will clear the stored experiences + * + * @param Action type + * + * @author Alexandre Boulanger + */ +public class StateActionExperienceHandler implements ExperienceHandler> { + + private List> stateActionPairs; + + public void setFinalObservation(Observation observation) { + // Do nothing + } + + public void addExperience(Observation observation, A action, double reward, boolean isTerminal) { + stateActionPairs.add(new StateActionPair(observation, action, reward, isTerminal)); + } + + @Override + public int getTrainingBatchSize() { + return stateActionPairs.size(); + } + + /** + * The elements are returned in the historical order (i.e. in the order they happened) + * Note: the experience store is cleared after calling this method. + * + * @return The list of experience elements + */ + @Override + public List> generateTrainingBatch() { + List> result = stateActionPairs; + stateActionPairs = new ArrayList<>(); + + return result; + } + + @Override + public void reset() { + stateActionPairs = new ArrayList<>(); + } + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionPair.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionPair.java new file mode 100644 index 000000000..49e9ad3b5 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionPair.java @@ -0,0 +1,49 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.experience; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.deeplearning4j.rl4j.observation.Observation; + +/** + * A simple experience container. Used by {@link StateActionExperienceHandler StateActionExperienceHandler}. + * + * @param Action type + * + * @author Alexandre Boulanger + */ +@AllArgsConstructor +public class StateActionPair { + + /** + * The observation before the action is taken + */ + @Getter + private final Observation observation; + + @Getter + private final A action; + + @Getter + private final double reward; + + /** + * True if the episode ended after the action has been taken. + */ + @Getter + private final boolean terminal; +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index a72abfa62..ac9853045 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -18,9 +18,12 @@ package org.deeplearning4j.rl4j.learning.async; +import lombok.AccessLevel; import lombok.Getter; +import lombok.Setter; import org.deeplearning4j.gym.StepReply; -import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; @@ -28,10 +31,6 @@ import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.Stack; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. @@ -45,13 +44,39 @@ public abstract class AsyncThreadDiscrete @Getter private NN current; - public AsyncThreadDiscrete(IAsyncGlobal asyncGlobal, MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) { + @Setter(AccessLevel.PROTECTED) + private UpdateAlgorithm updateAlgorithm; + + // TODO: Make it configurable with a builder + @Setter(AccessLevel.PROTECTED) + private ExperienceHandler experienceHandler = new StateActionExperienceHandler(); + + public AsyncThreadDiscrete(IAsyncGlobal asyncGlobal, + MDP mdp, + TrainingListenerList listeners, + int threadNumber, + int deviceNum) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); synchronized (asyncGlobal) { current = (NN)asyncGlobal.getCurrent().clone(); } } + // TODO: Add an actor-learner class and be able to inject the update algorithm + protected abstract UpdateAlgorithm buildUpdateAlgorithm(); + + @Override + public void setHistoryProcessor(IHistoryProcessor historyProcessor) { + super.setHistoryProcessor(historyProcessor); + updateAlgorithm = buildUpdateAlgorithm(); + } + + @Override + protected void preEpoch() { + experienceHandler.reset(); + } + + /** * "Subepoch" correspond to the t_max-step iterations * that stack rewards with t_max MiniTrans @@ -65,13 +90,11 @@ public abstract class AsyncThreadDiscrete synchronized (getAsyncGlobal()) { current.copy(getAsyncGlobal().getCurrent()); } - Stack> rewards = new Stack<>(); Observation obs = sObs; IPolicy policy = getPolicy(current); - Integer action; - Integer lastAction = getMdp().getActionSpace().noOp(); + Integer action = getMdp().getActionSpace().noOp(); IHistoryProcessor hp = getHistoryProcessor(); int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1; @@ -82,21 +105,15 @@ public abstract class AsyncThreadDiscrete while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) { //if step of training, just repeat lastAction - if (obs.isSkipped()) { - action = lastAction; - } else { + if (!obs.isSkipped()) { action = policy.nextAction(obs); } StepReply stepReply = getLegacyMDPWrapper().step(action); accuReward += stepReply.getReward() * getConf().getRewardFactor(); - //if it's not a skipped frame, you can do a step of training if (!obs.isSkipped()) { - - INDArray[] output = current.outputAll(obs.getData()); - rewards.add(new MiniTrans(obs.getData(), action, output, accuReward)); - + experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone()); accuReward = 0; } @@ -104,29 +121,14 @@ public abstract class AsyncThreadDiscrete reward += stepReply.getReward(); incrementStep(); - lastAction = action; } - //a bit of a trick usable because of how the stack is treated to init R - // FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored. - - if (getMdp().isDone() && getCurrentEpochStep() < lastStep) - rewards.add(new MiniTrans(obs.getData(), null, null, 0)); - else { - INDArray[] output = null; - if (getConf().getLearnerUpdateFrequency() == -1) - output = current.outputAll(obs.getData()); - else synchronized (getAsyncGlobal()) { - output = getAsyncGlobal().getTarget().outputAll(obs.getData()); - } - double maxQ = Nd4j.max(output[0]).getDouble(0); - rewards.add(new MiniTrans(obs.getData(), null, output, maxQ)); + if (getMdp().isDone() && getCurrentEpochStep() < lastStep) { + experienceHandler.setFinalObservation(obs); } - getAsyncGlobal().enqueue(calcGradient(current, rewards), getCurrentEpochStep()); + getAsyncGlobal().enqueue(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), getCurrentEpochStep()); return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore()); } - - public abstract Gradient[] calcGradient(NN nn, Stack> rewards); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/MiniTrans.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java similarity index 57% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/MiniTrans.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java index 88bca6b0e..16ca1c3f8 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/MiniTrans.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java @@ -1,40 +1,26 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning.async; - -import lombok.AllArgsConstructor; -import lombok.Value; -import org.nd4j.linalg.api.ndarray.INDArray; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. - * - * Its called a MiniTrans because it is similar to a Transition - * but without a next observation - * - * It is stacked and then processed by AsyncNStepQL or A3C - * following the paper implementation https://arxiv.org/abs/1602.01783 paper. - * - */ -@AllArgsConstructor -@Value -public class MiniTrans { - INDArray obs; - A action; - INDArray[] output; - double reward; -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.learning.async; + +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.network.NeuralNet; + +import java.util.List; + +public interface UpdateAlgorithm { + Gradient[] computeGradients(NN current, List> experience); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index c2a16d6b4..d189edca1 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -18,11 +18,7 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete; import lombok.Getter; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; -import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; -import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.async.*; import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; @@ -34,9 +30,7 @@ import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.NDArrayIndex; - -import java.util.Stack; +import org.nd4j.linalg.api.rng.Random; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. @@ -67,6 +61,8 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< if(seed != null) { rnd.setSeed(seed + threadNumber); } + + setUpdateAlgorithm(buildUpdateAlgorithm()); } @Override @@ -74,52 +70,9 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< return new ACPolicy(net, rnd); } - /** - * calc the gradients based on the n-step rewards - */ @Override - public Gradient[] calcGradient(IActorCritic iac, Stack> rewards) { - MiniTrans minTrans = rewards.pop(); - - int size = rewards.size(); - - //if recurrent then train as a time serie with a batch size of 1 - boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent(); - - int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() - : getHistoryProcessor().getConf().getShape(); - int[] nshape = recurrent ? Learning.makeShape(1, shape, size) - : Learning.makeShape(size, shape); - - INDArray input = Nd4j.create(nshape); - INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1); - INDArray logSoftmax = recurrent ? Nd4j.zeros(1, getMdp().getActionSpace().getSize(), size) - : Nd4j.zeros(size, getMdp().getActionSpace().getSize()); - - double r = minTrans.getReward(); - for (int i = size - 1; i >= 0; i--) { - minTrans = rewards.pop(); - - r = minTrans.getReward() + conf.getGamma() * r; - if (recurrent) { - input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(minTrans.getObs()); - } else { - input.putRow(i, minTrans.getObs()); - } - - //the critic - targets.putScalar(i, r); - - //the actor - double expectedV = minTrans.getOutput()[0].getDouble(0); - double advantage = r - expectedV; - if (recurrent) { - logSoftmax.putScalar(0, minTrans.getAction(), i, advantage); - } else { - logSoftmax.putScalar(i, minTrans.getAction(), advantage); - } - } - - return iac.gradient(input, new INDArray[] {targets, logSoftmax}); + protected UpdateAlgorithm buildUpdateAlgorithm() { + int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); + return new A3CUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getLearnerUpdateFrequency(), conf.getGamma()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithm.java new file mode 100644 index 000000000..261cc788f --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithm.java @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.learning.async.a3c.discrete; + +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; +import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; +import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; + +import java.util.List; + +public class A3CUpdateAlgorithm implements UpdateAlgorithm { + + private final IAsyncGlobal asyncGlobal; + private final int[] shape; + private final int actionSpaceSize; + private final int targetDqnUpdateFreq; + private final double gamma; + private final boolean recurrent; + + public A3CUpdateAlgorithm(IAsyncGlobal asyncGlobal, + int[] shape, + int actionSpaceSize, + int targetDqnUpdateFreq, + double gamma) { + + this.asyncGlobal = asyncGlobal; + + //if recurrent then train as a time serie with a batch size of 1 + recurrent = asyncGlobal.getCurrent().isRecurrent(); + this.shape = shape; + this.actionSpaceSize = actionSpaceSize; + this.targetDqnUpdateFreq = targetDqnUpdateFreq; + this.gamma = gamma; + } + + @Override + public Gradient[] computeGradients(IActorCritic current, List> experience) { + int size = experience.size(); + + int[] nshape = recurrent ? Learning.makeShape(1, shape, size) + : Learning.makeShape(size, shape); + + INDArray input = Nd4j.create(nshape); + INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1); + INDArray logSoftmax = recurrent ? Nd4j.zeros(1, actionSpaceSize, size) + : Nd4j.zeros(size, actionSpaceSize); + + StateActionPair stateActionPair = experience.get(size - 1); + double r; + if(stateActionPair.isTerminal()) { + r = 0; + } + else { + INDArray[] output = null; + if (targetDqnUpdateFreq == -1) + output = current.outputAll(stateActionPair.getObservation().getData()); + else synchronized (asyncGlobal) { + output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData()); + } + r = output[0].getDouble(0); + } + + for (int i = size - 1; i >= 0; --i) { + stateActionPair = experience.get(i); + + INDArray observationData = stateActionPair.getObservation().getData(); + + INDArray[] output = current.outputAll(observationData); + + r = stateActionPair.getReward() + gamma * r; + if (recurrent) { + input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(observationData); + } else { + input.putRow(i, observationData); + } + + //the critic + targets.putScalar(i, r); + + //the actor + double expectedV = output[0].getDouble(0); + double advantage = r - expectedV; + if (recurrent) { + logSoftmax.putScalar(0, stateActionPair.getAction(), i, advantage); + } else { + logSoftmax.putScalar(i, stateActionPair.getAction(), advantage); + } + } + + // targets -> value, critic + // logSoftmax -> policy, actor + return current.gradient(input, new INDArray[] {targets, logSoftmax}); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index 71199efaf..bd4dc16e8 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -18,11 +18,9 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import lombok.Getter; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; -import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; @@ -32,12 +30,9 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.api.rng.Random; -import java.util.Stack; - /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. */ @@ -65,6 +60,8 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn if(seed != null) { rnd.setSeed(seed + threadNumber); } + + setUpdateAlgorithm(buildUpdateAlgorithm()); } public Policy getPolicy(IDQN nn) { @@ -72,32 +69,9 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn rnd, conf.getMinEpsilon(), this); } - - - //calc the gradient based on the n-step rewards - public Gradient[] calcGradient(IDQN current, Stack> rewards) { - - MiniTrans minTrans = rewards.pop(); - - int size = rewards.size(); - - int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() - : getHistoryProcessor().getConf().getShape(); - int[] nshape = Learning.makeShape(size, shape); - INDArray input = Nd4j.create(nshape); - INDArray targets = Nd4j.create(size, getMdp().getActionSpace().getSize()); - - double r = minTrans.getReward(); - for (int i = size - 1; i >= 0; i--) { - minTrans = rewards.pop(); - - r = minTrans.getReward() + conf.getGamma() * r; - input.putRow(i, minTrans.getObs()); - INDArray row = minTrans.getOutput()[0]; - row = row.putScalar(minTrans.getAction(), r); - targets.putRow(i, row); - } - - return current.gradient(input, targets); + @Override + protected UpdateAlgorithm buildUpdateAlgorithm() { + int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); + return new QLearningUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getTargetDqnUpdateFreq(), conf.getGamma()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java new file mode 100644 index 000000000..beae271b1 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java @@ -0,0 +1,88 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.learning.async.nstep.discrete; + +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; +import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; +import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.List; + +public class QLearningUpdateAlgorithm implements UpdateAlgorithm { + + private final IAsyncGlobal asyncGlobal; + private final int[] shape; + private final int actionSpaceSize; + private final int targetDqnUpdateFreq; + private final double gamma; + + public QLearningUpdateAlgorithm(IAsyncGlobal asyncGlobal, + int[] shape, + int actionSpaceSize, + int targetDqnUpdateFreq, + double gamma) { + + this.asyncGlobal = asyncGlobal; + this.shape = shape; + this.actionSpaceSize = actionSpaceSize; + this.targetDqnUpdateFreq = targetDqnUpdateFreq; + this.gamma = gamma; + } + + @Override + public Gradient[] computeGradients(IDQN current, List> experience) { + int size = experience.size(); + + int[] nshape = Learning.makeShape(size, shape); + INDArray input = Nd4j.create(nshape); + INDArray targets = Nd4j.create(size, actionSpaceSize); + + StateActionPair stateActionPair = experience.get(size - 1); + + double r; + if(stateActionPair.isTerminal()) { + r = 0; + } + else { + INDArray[] output = null; + if (targetDqnUpdateFreq == -1) + output = current.outputAll(stateActionPair.getObservation().getData()); + else synchronized (asyncGlobal) { + output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData()); + } + r = Nd4j.max(output[0]).getDouble(0); + } + + for (int i = size - 1; i >= 0; i--) { + stateActionPair = experience.get(i); + + input.putRow(i, stateActionPair.getObservation().getData()); + + r = stateActionPair.getReward() + gamma * r; + INDArray[] output = current.outputAll(stateActionPair.getObservation().getData()); + INDArray row = output[0]; + row = row.putScalar(stateActionPair.getAction(), r); + targets.putRow(i, row); + } + + return current.gradient(input, targets); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java index 2defc1d75..93b4d1bb5 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java @@ -80,6 +80,9 @@ public class ExpReplay implements IExpReplay { //log.info("size: "+storage.size()); } - + public int getBatchSize() { + int storageSize = storage.size(); + return Math.min(storageSize, batchSize); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java index 02a4c8af5..eaef5f0f8 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java @@ -32,6 +32,11 @@ import java.util.ArrayList; */ public interface IExpReplay { + /** + * @return The size of the batch that will be returned by getBatch() + */ + int getBatchSize(); + /** * @return a batch of uniformly sampled transitions */ @@ -42,5 +47,4 @@ public interface IExpReplay { * @param transition a new transition to store */ void store(Transition transition); - } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index 40704d4e9..7bef13e59 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -60,32 +60,8 @@ public abstract class QLearning implements TargetQNetworkSource, EpochStepCounter { - // FIXME Changed for refac - // @Getter - // final private IExpReplay expReplay; - @Getter - @Setter(AccessLevel.PROTECTED) - protected IExpReplay expReplay; - protected abstract LegacyMDPWrapper getLegacyMDPWrapper(); - public QLearning(QLearningConfiguration conf) { - this(conf, getSeededRandom(conf.getSeed())); - } - - public QLearning(QLearningConfiguration conf, Random random) { - expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random); - } - - private static Random getSeededRandom(Long seed) { - Random rnd = Nd4j.getRandom(); - if(seed != null) { - rnd.setSeed(seed); - } - - return rnd; - } - protected abstract EpsGreedy getEgPolicy(); public abstract MDP getMdp(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index e97415e29..1b9e667ae 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -21,6 +21,8 @@ import lombok.AccessLevel; import lombok.Getter; import lombok.Setter; import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; @@ -42,7 +44,7 @@ import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; -import java.util.ArrayList; +import java.util.List; /** @@ -71,10 +73,12 @@ public abstract class QLearningDiscrete extends QLearning> experienceHandler; + protected LegacyMDPWrapper getLegacyMDPWrapper() { return mdp; } @@ -85,7 +89,6 @@ public abstract class QLearningDiscrete extends QLearning mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep, Random random) { - super(conf); this.configuration = conf; this.mdp = new LegacyMDPWrapper<>(mdp, null, this); qNetwork = dqn; @@ -98,6 +101,7 @@ public abstract class QLearningDiscrete extends QLearning getMdp() { @@ -114,7 +118,7 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning trainStep(Observation obs) { - Integer action; - boolean isHistoryProcessor = getHistoryProcessor() != null; int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1; int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1; @@ -142,37 +144,28 @@ public abstract class QLearningDiscrete extends QLearning stepReply = mdp.step(action); - + StepReply stepReply = mdp.step(lastAction); accuReward += stepReply.getReward() * configuration.getRewardFactor(); //if it's not a skipped frame, you can do a step of training if (!obs.isSkipped()) { // Add experience - if (pendingTransition != null) { - pendingTransition.setNextObservation(obs); - getExpReplay().store(pendingTransition); - } - pendingTransition = new Transition(obs, action, accuReward, stepReply.isDone()); + experienceHandler.addExperience(obs, lastAction, accuReward, stepReply.isDone()); accuReward = 0; // Update NN // FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"? if (getStepCounter() > updateStart) { - DataSet targets = setTarget(getExpReplay().getBatch()); + DataSet targets = setTarget(experienceHandler.generateTrainingBatch()); getQNetwork().fit(targets.getFeatures(), targets.getLabels()); } } @@ -180,7 +173,7 @@ public abstract class QLearningDiscrete extends QLearning(maxQ, getQNetwork().getLatestScore(), stepReply); } - protected DataSet setTarget(ArrayList> transitions) { + protected DataSet setTarget(List> transitions) { if (transitions.size() == 0) throw new IllegalArgumentException("too few transitions"); @@ -189,9 +182,6 @@ public abstract class QLearningDiscrete extends QLearning { + + public final List> addedTransitions = new ArrayList<>(); + + @Override + public ArrayList> getBatch() { + return null; + } + + @Override + public void store(Transition transition) { + addedTransitions.add(transition); + } + + @Override + public int getBatchSize() { + return addedTransitions.size(); + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java new file mode 100644 index 000000000..7334ff87a --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java @@ -0,0 +1,82 @@ +package org.deeplearning4j.rl4j.experience; + +import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Test; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.List; + +import static org.junit.Assert.*; + +public class StateActionExperienceHandlerTest { + + @Test + public void when_addingExperience_expect_generateTrainingBatchReturnsIt() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(); + sut.reset(); + Observation observation = new Observation(Nd4j.zeros(1)); + sut.addExperience(observation, 123, 234.0, true); + + // Act + List> result = sut.generateTrainingBatch(); + + // Assert + assertEquals(1, result.size()); + assertSame(observation, result.get(0).getObservation()); + assertEquals(123, (int)result.get(0).getAction()); + assertEquals(234.0, result.get(0).getReward(), 0.00001); + assertTrue(result.get(0).isTerminal()); + } + + @Test + public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(); + sut.reset(); + sut.addExperience(null, 1, 1.0, false); + sut.addExperience(null, 2, 2.0, false); + sut.addExperience(null, 3, 3.0, false); + + // Act + List> result = sut.generateTrainingBatch(); + + // Assert + assertEquals(3, result.size()); + assertEquals(1, (int)result.get(0).getAction()); + assertEquals(2, (int)result.get(1).getAction()); + assertEquals(3, (int)result.get(2).getAction()); + } + + @Test + public void when_gettingExperience_expect_experienceStoreIsCleared() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(); + sut.reset(); + sut.addExperience(null, 1, 1.0, false); + + // Act + List> firstResult = sut.generateTrainingBatch(); + List> secondResult = sut.generateTrainingBatch(); + + // Assert + assertEquals(1, firstResult.size()); + assertEquals(0, secondResult.size()); + } + + @Test + public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(); + sut.reset(); + sut.addExperience(null, 1, 1.0, false); + sut.addExperience(null, 2, 2.0, false); + sut.addExperience(null, 3, 3.0, false); + + // Act + int size = sut.getTrainingBatchSize(); + + // Assert + assertEquals(3, size); + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index 72f374db5..320b53a0e 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -18,9 +18,12 @@ package org.deeplearning4j.rl4j.learning.async; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; +import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; @@ -31,7 +34,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; import java.util.List; -import java.util.Stack; import static org.junit.Assert.assertEquals; @@ -51,7 +53,9 @@ public class AsyncThreadDiscreteTest { TrainingListenerList listeners = new TrainingListenerList(); MockPolicy policyMock = new MockPolicy(); MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5); - TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); + MockExperienceHandler experienceHandlerMock = new MockExperienceHandler(); + MockUpdateAlgorithm updateAlgorithmMock = new MockUpdateAlgorithm(); + TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock, experienceHandlerMock, updateAlgorithmMock); sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); // Act @@ -60,8 +64,8 @@ public class AsyncThreadDiscreteTest { // Assert assertEquals(2, sut.trainSubEpochResults.size()); double[][] expectedLastObservations = new double[][] { - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, }; double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 }; for(int i = 0; i < 2; ++i) { @@ -102,62 +106,22 @@ public class AsyncThreadDiscreteTest { } } - // NeuralNetwork - assertEquals(2, nnMock.copyCallCount); - double[][] expectedNNInputs = new double[][] { + // ExperienceHandler + double[][] expectedExperienceHandlerInputs = new double[][] { new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: This one comes from the computation of output of the last minitrans new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: This one comes from the computation of output of the last minitrans }; - assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size()); - for(int i = 0; i < expectedNNInputs.length; ++i) { - double[] expectedRow = expectedNNInputs[i]; - INDArray input = nnMock.outputAllInputs.get(i); + assertEquals(expectedExperienceHandlerInputs.length, experienceHandlerMock.addExperienceArgs.size()); + for(int i = 0; i < expectedExperienceHandlerInputs.length; ++i) { + double[] expectedRow = expectedExperienceHandlerInputs[i]; + INDArray input = experienceHandlerMock.addExperienceArgs.get(i).getObservation().getData(); assertEquals(expectedRow.length, input.shape()[1]); for(int j = 0; j < expectedRow.length; ++j) { assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001); } } - - int arrayIdx = 0; - double[][][] expectedMinitransObs = new double[][][] { - new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: The last minitrans contains the next observation - }, - new double[][] { - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: The last minitrans contains the next observation - } - }; - double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 }; - double[] expectedRewards = new double[] { 0.0, 0.0, 3.0, 0.0, 0.0, 6.0 }; - - assertEquals(2, sut.rewards.size()); - for(int rewardIdx = 0; rewardIdx < 2; ++rewardIdx) { - Stack> miniTransStack = sut.rewards.get(rewardIdx); - - for (int i = 0; i < expectedMinitransObs[rewardIdx].length; ++i) { - MiniTrans minitrans = miniTransStack.get(i); - - // Observation - double[] expectedRow = expectedMinitransObs[rewardIdx][i]; - INDArray realRewards = minitrans.getObs(); - assertEquals(expectedRow.length, realRewards.shape()[1]); - for (int j = 0; j < expectedRow.length; ++j) { - assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * realRewards.getDouble(j), 0.00001); - } - - assertEquals(expectedOutputs[arrayIdx], minitrans.getOutput()[0].getDouble(0), 0.00001); - assertEquals(expectedRewards[arrayIdx], minitrans.getReward(), 0.00001); - ++arrayIdx; - } - } } public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete { @@ -167,22 +131,19 @@ public class AsyncThreadDiscreteTest { private final MockAsyncConfiguration config; public final List trainSubEpochResults = new ArrayList(); - public final List>> rewards = new ArrayList>>(); public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy, - MockAsyncConfiguration config, IHistoryProcessor hp) { + MockAsyncConfiguration config, IHistoryProcessor hp, + ExperienceHandler> experienceHandler, + UpdateAlgorithm updateAlgorithm) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); this.asyncGlobal = asyncGlobal; this.policy = policy; this.config = config; setHistoryProcessor(hp); - } - - @Override - public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack> rewards) { - this.rewards.add(rewards); - return new Gradient[0]; + setExperienceHandler(experienceHandler); + setUpdateAlgorithm(updateAlgorithm); } @Override @@ -200,6 +161,11 @@ public class AsyncThreadDiscreteTest { return policy; } + @Override + protected UpdateAlgorithm buildUpdateAlgorithm() { + return null; + } + @Override public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) { asyncGlobal.increaseCurrentLoop(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java deleted file mode 100644 index b812a5582..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java +++ /dev/null @@ -1,197 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning.async.a3c.discrete; - -import org.deeplearning4j.nn.api.NeuralNetwork; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.learning.IHistoryProcessor; -import org.deeplearning4j.rl4j.learning.async.MiniTrans; -import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; -import org.deeplearning4j.rl4j.network.NeuralNet; -import org.deeplearning4j.rl4j.network.ac.IActorCritic; -import org.deeplearning4j.rl4j.support.*; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.ArrayList; -import java.util.List; -import java.util.Stack; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; - -public class A3CThreadDiscreteTest { - - @Test - public void refac_calcGradient() { - // Arrange - double gamma = 0.9; - MockObservationSpace observationSpace = new MockObservationSpace(); - MockMDP mdpMock = new MockMDP(observationSpace); - A3CLearningConfiguration config = A3CLearningConfiguration.builder().gamma(0.9).build(); - MockActorCritic actorCriticMock = new MockActorCritic(); - IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); - MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(actorCriticMock); - A3CThreadDiscrete sut = new A3CThreadDiscrete(mdpMock, asyncGlobalMock, config, 0, null, 0); - MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); - sut.setHistoryProcessor(hpMock); - - double[][] minitransObs = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - }; - double[] outputs = new double[] { 1.0, 2.0, 3.0 }; - double[] rewards = new double[] { 0.0, 0.0, 3.0 }; - - Stack> minitransList = new Stack>(); - for(int i = 0; i < 3; ++i) { - INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1); - INDArray[] output = new INDArray[] { - Nd4j.zeros(5) - }; - output[0].putScalar(i, outputs[i]); - minitransList.push(new MiniTrans<>(obs, i, output, rewards[i])); - } - minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans - - // Act - sut.calcGradient(actorCriticMock, minitransList); - - // Assert - assertEquals(1, actorCriticMock.gradientParams.size()); - INDArray input = actorCriticMock.gradientParams.get(0).getFirst(); - INDArray[] labels = actorCriticMock.gradientParams.get(0).getSecond(); - - assertEquals(minitransObs.length, input.shape()[0]); - for(int i = 0; i < minitransObs.length; ++i) { - double[] expectedRow = minitransObs[i]; - assertEquals(expectedRow.length, input.shape()[1]); - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001); - } - } - - double latestReward = (gamma * 4.0) + 3.0; - double[] expectedLabels0 = new double[] { gamma * gamma * latestReward, gamma * latestReward, latestReward }; - for(int i = 0; i < expectedLabels0.length; ++i) { - assertEquals(expectedLabels0[i], labels[0].getDouble(i), 0.00001); - } - double[][] expectedLabels1 = new double[][] { - new double[] { 4.346, 0.0, 0.0, 0.0, 0.0 }, - new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 }, - new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 }, - }; - - assertArrayEquals(new long[] { expectedLabels0.length, 1 }, labels[0].shape()); - - for(int i = 0; i < expectedLabels1.length; ++i) { - double[] expectedRow = expectedLabels1[i]; - assertEquals(expectedRow.length, labels[1].shape()[1]); - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals(expectedRow[j], labels[1].getDouble(i, j), 0.00001); - } - } - - } - - public class MockActorCritic implements IActorCritic { - - public final List> gradientParams = new ArrayList<>(); - - @Override - public NeuralNetwork[] getNeuralNetworks() { - return new NeuralNetwork[0]; - } - - @Override - public boolean isRecurrent() { - return false; - } - - @Override - public void reset() { - - } - - @Override - public void fit(INDArray input, INDArray[] labels) { - - } - - @Override - public INDArray[] outputAll(INDArray batch) { - return new INDArray[0]; - } - - @Override - public IActorCritic clone() { - return this; - } - - @Override - public void copy(NeuralNet from) { - - } - - @Override - public void copy(IActorCritic from) { - - } - - @Override - public Gradient[] gradient(INDArray input, INDArray[] labels) { - gradientParams.add(new Pair(input, labels)); - return new Gradient[0]; - } - - @Override - public void applyGradient(Gradient[] gradient, int batchSize) { - - } - - @Override - public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException { - - } - - @Override - public void save(String pathValue, String pathPolicy) throws IOException { - - } - - @Override - public double getLatestScore() { - return 0; - } - - @Override - public void save(OutputStream os) throws IOException { - - } - - @Override - public void save(String filename) throws IOException { - - } - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithmTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithmTest.java new file mode 100644 index 000000000..1434796f3 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CUpdateAlgorithmTest.java @@ -0,0 +1,160 @@ +package org.deeplearning4j.rl4j.learning.async.a3c.discrete; + +import org.deeplearning4j.nn.api.NeuralNetwork; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.support.MockAsyncGlobal; +import org.deeplearning4j.rl4j.support.MockMDP; +import org.deeplearning4j.rl4j.support.MockObservationSpace; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class A3CUpdateAlgorithmTest { + + @Test + public void refac_calcGradient_non_terminal() { + // Arrange + double gamma = 0.9; + MockObservationSpace observationSpace = new MockObservationSpace(new int[] { 5 }); + MockMDP mdpMock = new MockMDP(observationSpace); + MockActorCritic actorCriticMock = new MockActorCritic(); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(actorCriticMock); + A3CUpdateAlgorithm sut = new A3CUpdateAlgorithm(asyncGlobalMock, observationSpace.getShape(), mdpMock.getActionSpace().getSize(), -1, gamma); + + + INDArray[] originalObservations = new INDArray[] { + Nd4j.create(new double[] { 0.0, 0.1, 0.2, 0.3, 0.4 }), + Nd4j.create(new double[] { 1.0, 1.1, 1.2, 1.3, 1.4 }), + Nd4j.create(new double[] { 2.0, 2.1, 2.2, 2.3, 2.4 }), + Nd4j.create(new double[] { 3.0, 3.1, 3.2, 3.3, 3.4 }), + }; + int[] actions = new int[] { 0, 1, 2, 1 }; + double[] rewards = new double[] { 0.1, 1.0, 10.0, 100.0 }; + + List> experience = new ArrayList>(); + for(int i = 0; i < originalObservations.length; ++i) { + experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false)); + } + + // Act + sut.computeGradients(actorCriticMock, experience); + + // Assert + assertEquals(1, actorCriticMock.gradientParams.size()); + + // Inputs + INDArray input = actorCriticMock.gradientParams.get(0).getLeft(); + for(int i = 0; i < 4; ++i) { + for(int j = 0; j < 5; ++j) { + assertEquals(i + j / 10.0, input.getDouble(i, j), 0.00001); + } + } + + INDArray targets = actorCriticMock.gradientParams.get(0).getRight()[0]; + INDArray logSoftmax = actorCriticMock.gradientParams.get(0).getRight()[1]; + + assertEquals(4, targets.shape()[0]); + assertEquals(1, targets.shape()[1]); + + // FIXME: check targets values once fixed + + assertEquals(4, logSoftmax.shape()[0]); + assertEquals(5, logSoftmax.shape()[1]); + + // FIXME: check logSoftmax values once fixed + + } + + public class MockActorCritic implements IActorCritic { + + public final List> gradientParams = new ArrayList<>(); + + @Override + public NeuralNetwork[] getNeuralNetworks() { + return new NeuralNetwork[0]; + } + + @Override + public boolean isRecurrent() { + return false; + } + + @Override + public void reset() { + + } + + @Override + public void fit(INDArray input, INDArray[] labels) { + + } + + @Override + public INDArray[] outputAll(INDArray batch) { + return new INDArray[] { batch.mul(-1.0) }; + } + + @Override + public IActorCritic clone() { + return this; + } + + @Override + public void copy(NeuralNet from) { + + } + + @Override + public void copy(IActorCritic from) { + + } + + @Override + public Gradient[] gradient(INDArray input, INDArray[] labels) { + gradientParams.add(new Pair(input, labels)); + return new Gradient[0]; + } + + @Override + public void applyGradient(Gradient[] gradient, int batchSize) { + + } + + @Override + public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException { + + } + + @Override + public void save(String pathValue, String pathPolicy) throws IOException { + + } + + @Override + public double getLatestScore() { + return 0; + } + + @Override + public void save(OutputStream os) throws IOException { + + } + + @Override + public void save(String filename) throws IOException { + + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java deleted file mode 100644 index 2a8c5b832..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java +++ /dev/null @@ -1,98 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2020 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning.async.nstep.discrete; - -import org.deeplearning4j.rl4j.learning.IHistoryProcessor; -import org.deeplearning4j.rl4j.learning.async.MiniTrans; -import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; -import org.deeplearning4j.rl4j.support.*; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.Stack; - -import static org.junit.Assert.assertEquals; - -public class AsyncNStepQLearningThreadDiscreteTest { - - @Test - public void refac_calcGradient() { - // Arrange - double gamma = 0.9; - MockObservationSpace observationSpace = new MockObservationSpace(); - MockMDP mdpMock = new MockMDP(observationSpace); - AsyncQLearningConfiguration config = AsyncQLearningConfiguration.builder().gamma(gamma).build(); - MockDQN dqnMock = new MockDQN(); - IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); - MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock); - AsyncNStepQLearningThreadDiscrete sut = new AsyncNStepQLearningThreadDiscrete(mdpMock, asyncGlobalMock, config, null, 0, 0); - MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); - sut.setHistoryProcessor(hpMock); - - double[][] minitransObs = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - }; - double[] outputs = new double[] { 1.0, 2.0, 3.0 }; - double[] rewards = new double[] { 0.0, 0.0, 3.0 }; - - Stack> minitransList = new Stack>(); - for(int i = 0; i < 3; ++i) { - INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1); - INDArray[] output = new INDArray[] { - Nd4j.zeros(5) - }; - output[0].putScalar(i, outputs[i]); - minitransList.push(new MiniTrans<>(obs, i, output, rewards[i])); - } - minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans - - // Act - sut.calcGradient(dqnMock, minitransList); - - // Assert - assertEquals(1, dqnMock.gradientParams.size()); - INDArray input = dqnMock.gradientParams.get(0).getFirst(); - INDArray labels = dqnMock.gradientParams.get(0).getSecond(); - - assertEquals(minitransObs.length, input.shape()[0]); - for(int i = 0; i < minitransObs.length; ++i) { - double[] expectedRow = minitransObs[i]; - assertEquals(expectedRow.length, input.shape()[1]); - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001); - } - } - - double latestReward = (gamma * 4.0) + 3.0; - double[][] expectedLabels = new double[][] { - new double[] { gamma * gamma * latestReward, 0.0, 0.0, 0.0, 0.0 }, - new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 }, - new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 }, - }; - assertEquals(minitransObs.length, labels.shape()[0]); - for(int i = 0; i < minitransObs.length; ++i) { - double[] expectedRow = expectedLabels[i]; - assertEquals(expectedRow.length, labels.shape()[1]); - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals(expectedRow[j], labels.getDouble(i, j), 0.00001); - } - } - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java new file mode 100644 index 000000000..35465d26a --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java @@ -0,0 +1,115 @@ +package org.deeplearning4j.rl4j.learning.async.nstep.discrete; + +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; +import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.support.MockAsyncGlobal; +import org.deeplearning4j.rl4j.support.MockDQN; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class QLearningUpdateAlgorithmTest { + + @Test + public void when_isTerminal_expect_initRewardIs0() { + // Arrange + MockDQN dqnMock = new MockDQN(); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock); + UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 1 }, 1, -1, 1.0); + List> experience = new ArrayList>() { + { + add(new StateActionPair(new Observation(Nd4j.zeros(1)), 0, 0.0, true)); + } + }; + + // Act + sut.computeGradients(dqnMock, experience); + + // Assert + assertEquals(0.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); + } + + @Test + public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() { + // Arrange + MockDQN globalDQNMock = new MockDQN(); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock); + UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, -1, 1.0); + List> experience = new ArrayList>() { + { + add(new StateActionPair(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false)); + } + }; + MockDQN dqnMock = new MockDQN(); + + // Act + sut.computeGradients(dqnMock, experience); + + // Assert + assertEquals(2, dqnMock.outputAllParams.size()); + assertEquals(-123.0, dqnMock.outputAllParams.get(0).getDouble(0, 0), 0.00001); + assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); + } + + @Test + public void when_terminalWithTargetUpdate_expect_initRewardWithMaxQFromGlobal() { + // Arrange + MockDQN globalDQNMock = new MockDQN(); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock); + UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, 1.0); + List> experience = new ArrayList>() { + { + add(new StateActionPair(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false)); + } + }; + MockDQN dqnMock = new MockDQN(); + + // Act + sut.computeGradients(dqnMock, experience); + + // Assert + assertEquals(1, globalDQNMock.outputAllParams.size()); + assertEquals(-123.0, globalDQNMock.outputAllParams.get(0).getDouble(0, 0), 0.00001); + assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); + } + + @Test + public void when_callingWithMultipleExperiences_expect_gradientsAreValid() { + // Arrange + double gamma = 0.9; + MockDQN globalDQNMock = new MockDQN(); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock); + UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, gamma); + List> experience = new ArrayList>() { + { + add(new StateActionPair(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false)); + add(new StateActionPair(new Observation(Nd4j.create(new double[] { -2.1, -2.2 })), 1, 2.0, true)); + } + }; + MockDQN dqnMock = new MockDQN(); + + // Act + sut.computeGradients(dqnMock, experience); + + // Assert + // input side -- should be a stack of observations + INDArray input = dqnMock.gradientParams.get(0).getLeft(); + assertEquals(-1.1, input.getDouble(0, 0), 0.00001); + assertEquals(-1.2, input.getDouble(0, 1), 0.00001); + assertEquals(-2.1, input.getDouble(1, 0), 0.00001); + assertEquals(-2.2, input.getDouble(1, 1), 0.00001); + + // target side + INDArray target = dqnMock.gradientParams.get(0).getRight(); + assertEquals(1.0 + gamma * 2.0, target.getDouble(0, 0), 0.00001); + assertEquals(1.2, target.getDouble(0, 1), 0.00001); + assertEquals(2.1, target.getDouble(1, 0), 0.00001); + assertEquals(2.0, target.getDouble(1, 1), 0.00001); + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index fe8dd6acc..9d77084d5 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -17,6 +17,8 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.IExpReplay; @@ -75,8 +77,8 @@ public class QLearningDiscreteTest { .build(); MockDataManager dataManager = new MockDataManager(false); - MockExpReplay expReplay = new MockExpReplay(); - TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random); + MockExperienceHandler experienceHandler = new MockExperienceHandler(); + TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, experienceHandler, 10, random); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); sut.setHistoryProcessor(hp); @@ -93,7 +95,6 @@ public class QLearningDiscreteTest { for (int i = 0; i < expectedRecords.length; ++i) { assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); } - assertEquals(0, hp.startMonitorCallCount); assertEquals(0, hp.stopMonitorCallCount); @@ -133,30 +134,31 @@ public class QLearningDiscreteTest { // MDP calls assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray()); - // ExpReplay calls - double[] expectedTrRewards = new double[]{9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0}; - int[] expectedTrActions = new int[]{1, 4, 2, 4, 4, 4, 4, 4}; - double[] expectedTrNextObservation = new double[]{2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0}; - double[][] expectedTrObservations = new double[][]{ - new double[]{0.0, 2.0, 4.0, 6.0, 8.0}, - new double[]{2.0, 4.0, 6.0, 8.0, 10.0}, - new double[]{4.0, 6.0, 8.0, 10.0, 12.0}, - new double[]{6.0, 8.0, 10.0, 12.0, 14.0}, - new double[]{8.0, 10.0, 12.0, 14.0, 16.0}, - new double[]{10.0, 12.0, 14.0, 16.0, 18.0}, - new double[]{12.0, 14.0, 16.0, 18.0, 20.0}, - new double[]{14.0, 16.0, 18.0, 20.0, 22.0}, + // ExperienceHandler calls + double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 }; + int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 }; + double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; + double[][] expectedTrObservations = new double[][] { + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, + new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, + new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, + new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, + new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, }; - assertEquals(expectedTrObservations.length, expReplay.transitions.size()); - for (int i = 0; i < expectedTrRewards.length; ++i) { - Transition tr = expReplay.transitions.get(i); - assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001); - assertEquals(expectedTrActions[i], tr.getAction()); - assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001); - for (int j = 0; j < expectedTrObservations[i].length; ++j) { - assertEquals("row: " + i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001); + + assertEquals(expectedTrObservations.length, experienceHandler.addExperienceArgs.size()); + for(int i = 0; i < expectedTrRewards.length; ++i) { + StateActionPair stateActionPair = experienceHandler.addExperienceArgs.get(i); + assertEquals(expectedTrRewards[i], stateActionPair.getReward(), 0.0001); + assertEquals((int)expectedTrActions[i], (int)stateActionPair.getAction()); + for(int j = 0; j < expectedTrObservations[i].length; ++j) { + assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * stateActionPair.getObservation().getData().getDouble(0, j, 0), 0.0001); } } + assertEquals(expectedTrNextObservation[expectedTrNextObservation.length - 1], 255.0 * experienceHandler.finalObservation.getData().getDouble(0), 0.0001); // trainEpoch result assertEquals(initStepCount + 16, result.getStepCounter()); @@ -167,20 +169,16 @@ public class QLearningDiscreteTest { public static class TestQLearningDiscrete extends QLearningDiscrete { public TestQLearningDiscrete(MDP mdp, IDQN dqn, - QLearningConfiguration conf, IDataManager dataManager, MockExpReplay expReplay, + QLearningConfiguration conf, IDataManager dataManager, ExperienceHandler> experienceHandler, int epsilonNbStep, Random rnd) { super(mdp, dqn, conf, epsilonNbStep, rnd); addListener(new DataManagerTrainingListener(dataManager)); - setExpReplay(expReplay); + setExperienceHandler(experienceHandler); } @Override - protected DataSet setTarget(ArrayList> transitions) { - return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[]{123.0}), Nd4j.create(new double[]{234.0})); - } - - public void setExpReplay(IExpReplay exp) { - this.expReplay = exp; + protected DataSet setTarget(List> transitions) { + return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); } @Override diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java index 28d7f3914..6f20d82ca 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java @@ -19,6 +19,7 @@ public class MockDQN implements IDQN { public final List outputParams = new ArrayList<>(); public final List> fitParams = new ArrayList<>(); public final List> gradientParams = new ArrayList<>(); + public final List outputAllParams = new ArrayList<>(); @Override public NeuralNetwork[] getNeuralNetworks() { @@ -58,7 +59,8 @@ public class MockDQN implements IDQN { @Override public INDArray[] outputAll(INDArray batch) { - return new INDArray[0]; + outputAllParams.add(batch); + return new INDArray[] { batch.mul(-1.0) }; } @Override @@ -109,4 +111,4 @@ public class MockDQN implements IDQN { public void save(String filename) throws IOException { } -} +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExpReplay.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExpReplay.java deleted file mode 100644 index d1fa84c04..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExpReplay.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.deeplearning4j.rl4j.support; - -import org.deeplearning4j.rl4j.learning.sync.IExpReplay; -import org.deeplearning4j.rl4j.learning.sync.Transition; - -import java.util.ArrayList; -import java.util.List; - -public class MockExpReplay implements IExpReplay { - - public List> transitions = new ArrayList<>(); - - @Override - public ArrayList> getBatch() { - return null; - } - - @Override - public void store(Transition transition) { - transitions.add(transition); - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExperienceHandler.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExperienceHandler.java new file mode 100644 index 000000000..13ea5d93a --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExperienceHandler.java @@ -0,0 +1,46 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.deeplearning4j.rl4j.observation.Observation; + +import java.util.ArrayList; +import java.util.List; + +public class MockExperienceHandler implements ExperienceHandler> { + public List> addExperienceArgs = new ArrayList>(); + public Observation finalObservation; + public boolean isGenerateTrainingBatchCalled; + public boolean isResetCalled; + + @Override + public void addExperience(Observation observation, Integer action, double reward, boolean isTerminal) { + addExperienceArgs.add(new StateActionPair<>(observation, action, reward, isTerminal)); + } + + @Override + public void setFinalObservation(Observation observation) { + finalObservation = observation; + } + + @Override + public List> generateTrainingBatch() { + isGenerateTrainingBatchCalled = true; + return new ArrayList>() { + { + add(new Transition(null, 0, 0.0, false)); + } + }; + } + + @Override + public void reset() { + isResetCalled = true; + } + + @Override + public int getTrainingBatchSize() { + return 1; + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java index 5395242b2..ffba71b5a 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java @@ -5,6 +5,16 @@ import org.nd4j.linalg.api.ndarray.INDArray; public class MockObservationSpace implements ObservationSpace { + private final int[] shape; + + public MockObservationSpace() { + this(new int[] { 1 }); + } + + public MockObservationSpace(int[] shape) { + this.shape = shape; + } + @Override public String getName() { return null; @@ -12,7 +22,7 @@ public class MockObservationSpace implements ObservationSpace { @Override public int[] getShape() { - return new int[] { 1 }; + return shape; } @Override @@ -24,4 +34,4 @@ public class MockObservationSpace implements ObservationSpace { public INDArray getHigh() { return null; } -} +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockUpdateAlgorithm.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockUpdateAlgorithm.java new file mode 100644 index 000000000..dbe2fe1fc --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockUpdateAlgorithm.java @@ -0,0 +1,19 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.experience.StateActionPair; +import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; + +import java.util.ArrayList; +import java.util.List; + +public class MockUpdateAlgorithm implements UpdateAlgorithm { + + public final List>> experiences = new ArrayList>>(); + + @Override + public Gradient[] computeGradients(MockNeuralNet current, List> experience) { + experiences.add(experience); + return new Gradient[0]; + } +}