From 47c58cf69d1ae38506bbeda0c0948807be7a1cf4 Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Tue, 26 Nov 2019 09:05:11 -0500 Subject: [PATCH] RL4J: Add Observation and LegacyMDPWrapper (#8368) * Added Observable & LegacyMDPWrapper Signed-off-by: unknown * Moved observation processing to LegacyMDPWrapper Signed-off-by: unknown * Observation using DataSets, changes in Transition and BaseTDTargetAlgorithm Signed-off-by: Alexandre Boulanger * Added javadoc to Transition new methods Signed-off-by: unknown --- .../rl4j/learning/ILearning.java | 2 +- .../rl4j/learning/Learning.java | 17 +- .../rl4j/learning/sync/SyncLearning.java | 2 +- .../rl4j/learning/sync/Transition.java | 119 ++++++-- .../learning/sync/qlearning/QLearning.java | 45 +++- .../qlearning/discrete/QLearningDiscrete.java | 70 ++--- .../BaseTDTargetAlgorithm.java | 42 +-- .../deeplearning4j/rl4j/network/dqn/DQN.java | 5 + .../deeplearning4j/rl4j/network/dqn/IDQN.java | 2 + .../rl4j/observation/Observation.java | 51 ++++ .../deeplearning4j/rl4j/policy/EpsGreedy.java | 8 +- .../deeplearning4j/rl4j/policy/IPolicy.java | 2 +- .../deeplearning4j/rl4j/policy/Policy.java | 2 +- .../rl4j/util/LegacyMDPWrapper.java | 139 ++++++++++ .../async/AsyncThreadDiscreteTest.java | 2 +- .../rl4j/learning/sync/ExpReplayTest.java | 61 +++-- .../rl4j/learning/sync/TransitionTest.java | 255 ++++++++++++++++++ .../discrete/QLearningDiscreteTest.java | 68 +++-- .../TDTargetAlgorithm/DoubleDQNTest.java | 22 +- .../TDTargetAlgorithm/StandardDQNTest.java | 23 +- .../rl4j/learning/sync/support/MockDQN.java | 8 +- .../rl4j/policy/PolicyTest.java | 2 +- .../deeplearning4j/rl4j/support/MockDQN.java | 6 + .../rl4j/support/MockHistoryProcessor.java | 6 +- .../deeplearning4j/rl4j/support/MockMDP.java | 1 - 25 files changed, 742 insertions(+), 218 deletions(-) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java index e6c803bd2..3c4c94c6b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java @@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.space.Encodable; * * A common interface that any training method should implement */ -public interface ILearning> extends StepCountable { +public interface ILearning> extends StepCountable { IPolicy getPolicy(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java index 2a8c80608..780a73752 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java @@ -39,7 +39,7 @@ import org.nd4j.linalg.factory.Nd4j; * */ @Slf4j -public abstract class Learning, NN extends NeuralNet> +public abstract class Learning, NN extends NeuralNet> implements ILearning, NeuralNetFetchable { @Getter @Setter @@ -53,8 +53,8 @@ public abstract class Learning return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0); } - public static > INDArray getInput(MDP mdp, O obs) { - INDArray arr = Nd4j.create(obs.toArray()); + public static > INDArray getInput(MDP mdp, O obs) { + INDArray arr = Nd4j.create(((Encodable)obs).toArray()); int[] shape = mdp.getObservationSpace().getShape(); if (shape.length == 1) return arr.reshape(new long[] {1, arr.length()}); @@ -62,7 +62,7 @@ public abstract class Learning return arr.reshape(shape); } - public static > InitMdp initMdp(MDP mdp, + public static > InitMdp initMdp(MDP mdp, IHistoryProcessor hp) { O obs = mdp.reset(); @@ -138,15 +138,6 @@ public abstract class Learning this.historyProcessor = historyProcessor; } - public INDArray getInput(O obs) { - return getInput(getMdp(), obs); - } - - public InitMdp initMdp() { - getNeuralNet().reset(); - return initMdp(getMdp(), getHistoryProcessor()); - } - @AllArgsConstructor @Value public static class InitMdp { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java index 35cc9e136..ed5d73a75 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java @@ -36,7 +36,7 @@ import org.deeplearning4j.rl4j.util.IDataManager; * @author Alexandre Boulanger */ @Slf4j -public abstract class SyncLearning, NN extends NeuralNet> +public abstract class SyncLearning, NN extends NeuralNet> extends Learning implements IEpochTrainer { private final TrainingListenerList listeners = new TrainingListenerList(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java index 7242dd64b..509b28a88 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java @@ -16,27 +16,56 @@ package org.deeplearning4j.rl4j.learning.sync; -import lombok.AllArgsConstructor; import lombok.Value; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; + +import java.util.List; /** - * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16. * * A transition is a SARS tuple * State, Action, Reward, (isTerminal), State + * + * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16. + * @author Alexandre Boulanger + * */ @Value -@AllArgsConstructor public class Transition { - INDArray[] observation; + Observation observation; A action; double reward; boolean isTerminal; INDArray nextObservation; + public Transition(Observation observation, A action, double reward, boolean isTerminal, Observation nextObservation) { + this.observation = observation; + this.action = action; + this.reward = reward; + this.isTerminal = isTerminal; + + // To conserve memory, only the most recent frame of the next observation is kept (if history is used). + // The full nextObservation will be re-build from observation when needed. + long[] nextObservationShape = nextObservation.getData().shape().clone(); + nextObservationShape[0] = 1; + this.nextObservation = nextObservation.getData() + .get(new INDArrayIndex[] {NDArrayIndex.point(0)}) + .reshape(nextObservationShape); + } + + private Transition(Observation observation, A action, double reward, boolean isTerminal, INDArray nextObservation) { + this.observation = observation; + this.action = action; + this.reward = reward; + this.isTerminal = isTerminal; + this.nextObservation = nextObservation; + } + /** * concat an array history into a single INDArry of as many channel * as element in the history array @@ -53,36 +82,80 @@ public class Transition { * @return this transition duplicated */ public Transition dup() { - INDArray[] dupObservation = dup(observation); + Observation dupObservation = observation.dup(); INDArray nextObs = nextObservation.dup(); - return new Transition<>(dupObservation, action, reward, isTerminal, nextObs); + return new Transition(dupObservation, action, reward, isTerminal, nextObs); } /** - * Duplicate an history - * @param history the history to duplicate - * @return a duplicate of the history + * Stack along the 0-dimension all the observations of the batch in a INDArray. + * + * @param transitions A list of the transitions of the batch + * @param The type of the Action + * @return A INDArray of all of the batch's observations stacked along the 0-dimension. */ - public static INDArray[] dup(INDArray[] history) { - INDArray[] dupHistory = new INDArray[history.length]; - for (int i = 0; i < history.length; i++) { - dupHistory[i] = history[i].dup(); + public static INDArray buildStackedObservations(List> transitions) { + int size = transitions.size(); + long[] shape = getShape(transitions); + + INDArray[] array = new INDArray[size]; + for (int i = 0; i < size; i++) { + array[i] = transitions.get(i).getObservation().getData(); } - return dupHistory; + + return Nd4j.concat(0, array).reshape(shape); } /** - * append a pixel frame to an history (throwing the last frame) - * @param history the history on which to append - * @param append the pixel frame to append - * @return the appended history + * Stack along the 0-dimension all the next observations of the batch in a INDArray. + * + * @param transitions A list of the transitions of the batch + * @param The type of the Action + * @return A INDArray of all of the batch's next observations stacked along the 0-dimension. */ - public static INDArray[] append(INDArray[] history, INDArray append) { - INDArray[] appended = new INDArray[history.length]; - appended[0] = append; - System.arraycopy(history, 0, appended, 1, history.length - 1); - return appended; + public static INDArray buildStackedNextObservations(List> transitions) { + int size = transitions.size(); + long[] shape = getShape(transitions); + + INDArray[] array = new INDArray[size]; + + for (int i = 0; i < size; i++) { + Transition trans = transitions.get(i); + INDArray obs = trans.getObservation().getData(); + long historyLength = obs.shape()[0]; + + if(historyLength != 1) { + // To conserve memory, only the most recent frame of the next observation is kept (if history is used). + // We need to rebuild the frame-stack in addition to builing the batch-stack. + INDArray historyPart = obs.get(new INDArrayIndex[]{NDArrayIndex.interval(0, historyLength - 1)}); + array[i] = Nd4j.concat(0, trans.getNextObservation(), historyPart); + } + else { + array[i] = trans.getNextObservation(); + } + } + + return Nd4j.concat(0, array).reshape(shape); + } + + private static long[] getShape(List> transitions) { + INDArray observations = transitions.get(0).getObservation().getData(); + long[] observationShape = observations.shape(); + long[] stackedShape; + if(observationShape[0] == 1) { + // FIXME: Currently RL4J doesn't support 1D observations. So if we have a shape with 1 in the first dimension, we can use that dimension and don't need to add another one. + stackedShape = new long[observationShape.length]; + System.arraycopy(observationShape, 0, stackedShape, 0, observationShape.length); + } + else { + stackedShape = new long[observationShape.length + 1]; + System.arraycopy(observationShape, 1, stackedShape, 2, observationShape.length - 1); + stackedShape[1] = observationShape[1]; + } + stackedShape[0] = transitions.size(); + + return stackedShape; } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index 98c62e565..098aefeaa 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -21,15 +21,20 @@ import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; import lombok.*; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.sync.ExpReplay; import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.SyncLearning; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager.StatEntry; +import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.api.rng.Random; @@ -53,6 +58,8 @@ public abstract class QLearning expReplay; + protected abstract LegacyMDPWrapper getLegacyMDPWrapper(); + public QLearning(QLConfiguration conf) { this(conf, getSeededRandom(conf.getSeed())); } @@ -95,11 +102,11 @@ public abstract class QLearning trainStep(O obs); + protected abstract QLStepReturn trainStep(Observation obs); protected StatEntry trainEpoch() { - InitMdp initMdp = initMdp(); - O obs = initMdp.getLastObs(); + InitMdp initMdp = refacInitMdp(); + Observation obs = initMdp.getLastObs(); double reward = initMdp.getReward(); int step = initMdp.getSteps(); @@ -114,7 +121,7 @@ public abstract class QLearning stepR = trainStep(obs); + QLStepReturn stepR = trainStep(obs); if (!stepR.getMaxQ().isNaN()) { if (startQ.isNaN()) @@ -142,6 +149,36 @@ public abstract class QLearning refacInitMdp() { + LegacyMDPWrapper mdp = getLegacyMDPWrapper(); + IHistoryProcessor hp = getHistoryProcessor(); + + Observation observation = mdp.reset(); + + int step = 0; + double reward = 0; + + boolean isHistoryProcessor = hp != null; + + int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; + int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; + + while (step < requiredFrame && !mdp.isDone()) { + + A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP + + StepReply stepReply = mdp.step(action); + reward += stepReply.getReward(); + observation = stepReply.getObservation(); + + step++; + + } + + return new InitMdp(step, observation, reward); + + } + @AllArgsConstructor @Builder @Value diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index 796780fb9..4d089d1ee 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -26,10 +26,12 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.dataset.api.DataSet; @@ -51,8 +53,7 @@ public abstract class QLearningDiscrete extends QLearning mdp; + private final LegacyMDPWrapper mdp; @Getter private DQNPolicy policy; @Getter @@ -65,11 +66,14 @@ public abstract class QLearningDiscrete extends QLearning getLegacyMDPWrapper() { + return mdp; + } + public QLearningDiscrete(MDP mdp, IDQN dqn, QLConfiguration conf, int epsilonNbStep) { this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed())); @@ -79,7 +83,7 @@ public abstract class QLearningDiscrete extends QLearning(mdp, this); qNetwork = dqn; targetQNetwork = dqn.clone(); policy = new DQNPolicy(getQNetwork()); @@ -92,6 +96,10 @@ public abstract class QLearningDiscrete extends QLearning getMdp() { + return mdp.getWrappedMDP(); + } + public void postEpoch() { if (getHistoryProcessor() != null) @@ -100,7 +108,6 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning trainStep(O obs) { + protected QLStepReturn trainStep(Observation obs) { Integer action; - INDArray input = getInput(obs); boolean isHistoryProcessor = getHistoryProcessor() != null; @@ -128,50 +134,25 @@ public abstract class QLearningDiscrete extends QLearning 2) - hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()))); - - INDArray qs = getQNetwork().output(hstack); + INDArray qs = getQNetwork().output(obs); int maxAction = Learning.getMaxAction(qs); - maxQ = qs.getDouble(maxAction); - action = getEgPolicy().nextAction(hstack); + + action = getEgPolicy().nextAction(obs); } lastAction = action; - StepReply stepReply = getMdp().step(action); + StepReply stepReply = mdp.step(action); - INDArray ninput = getInput(stepReply.getObservation()); - - if (isHistoryProcessor) - getHistoryProcessor().record(ninput); + Observation nextObservation = stepReply.getObservation(); accuReward += stepReply.getReward() * configuration.getRewardFactor(); //if it's not a skipped frame, you can do a step of training if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) { - if (isHistoryProcessor) - getHistoryProcessor().add(ninput); - - INDArray[] nhistory = isHistoryProcessor ? getHistoryProcessor().getHistory() : new INDArray[] {ninput}; - - Transition trans = new Transition(history, action, accuReward, stepReply.isDone(), nhistory[0]); + Transition trans = new Transition(obs, action, accuReward, stepReply.isDone(), nextObservation); getExpReplay().store(trans); if (getStepCounter() > updateStart) { @@ -179,27 +160,16 @@ public abstract class QLearningDiscrete extends QLearning(maxQ, getQNetwork().getLatestScore(), stepReply); + return new QLStepReturn(maxQ, getQNetwork().getLatestScore(), stepReply); } protected DataSet setTarget(ArrayList> transitions) { if (transitions.size() == 0) throw new IllegalArgumentException("too few transitions"); - // TODO: Remove once we use DataSets in observations - int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() - : getHistoryProcessor().getConf().getShape(); - ((BaseTDTargetAlgorithm) tdTargetAlgorithm).setNShape(makeShape(transitions.size(), shape)); - - // TODO: Remove once we use DataSets in observations - if(getHistoryProcessor() != null) { - ((BaseTDTargetAlgorithm) tdTargetAlgorithm).setScale(getHistoryProcessor().getScale()); - } - return tdTargetAlgorithm.computeTDTargets(transitions); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java index f4f143ee9..ca4beb47e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java @@ -16,14 +16,10 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; -import lombok.Setter; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.INDArrayIndex; -import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.List; @@ -40,11 +36,6 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm trans = transitions.get(i); - - INDArray[] obsArray = trans.getObservation(); - if (observations.rank() == 2) { - observations.putRow(i, obsArray[0]); - } else { - for (int j = 0; j < obsArray.length; j++) { - observations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, obsArray[j]); - } - } - - INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation()); - if (nextObservations.rank() == 2) { - nextObservations.putRow(i, nextObsArray[0]); - } else { - for (int j = 0; j < nextObsArray.length; j++) { - nextObservations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, nextObsArray[j]); - } - } - } - - // TODO: Remove once we use DataSets in observations - if(scale != 1.0) { - observations.muli(1.0 / scale); - nextObservations.muli(1.0 / scale); - } + INDArray observations = Transition.buildStackedObservations(transitions); + INDArray nextObservations = Transition.buildStackedNextObservations(transitions); initComputation(observations, nextObservations); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java index f73be3b21..b3293c1b6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.TrainingListener; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.api.ndarray.INDArray; @@ -70,6 +71,10 @@ public class DQN implements IDQN { return mln.output(batch); } + public INDArray output(Observation observation) { + return this.output(observation.getData()); + } + public INDArray[] outputAll(INDArray batch) { return new INDArray[] {output(batch)}; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java index c6ae2f5ac..af295d202 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java @@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.network.dqn; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -37,6 +38,7 @@ public interface IDQN extends NeuralNet { void fit(INDArray input, INDArray[] labels); INDArray output(INDArray batch); + INDArray output(Observation observation); INDArray[] outputAll(INDArray batch); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java new file mode 100644 index 000000000..7ca63baaf --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Presently only a dummy container. Will contain observation channels when done. + */ +public class Observation { + // TODO: Presently only a dummy container. Will contain observation channels when done. + + private final DataSet data; + + public Observation(INDArray[] data) { + this(new org.nd4j.linalg.dataset.DataSet(Nd4j.concat(0, data), null)); + } + + // FIXME: Remove -- only used in unit tests + public Observation(INDArray data) { + this.data = new org.nd4j.linalg.dataset.DataSet(data, null); + } + + private Observation(DataSet data) { + this.data = data; + } + + public Observation dup() { + return new Observation(new org.nd4j.linalg.dataset.DataSet(data.getFeatures().dup(), null)); + } + + public INDArray getData() { + return data.getFeatures(); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java index a7be53596..732ce1a0e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java @@ -21,8 +21,8 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.rl4j.learning.StepCountable; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; @@ -38,7 +38,7 @@ import org.nd4j.linalg.api.rng.Random; */ @AllArgsConstructor @Slf4j -public class EpsGreedy> extends Policy { +public class EpsGreedy> extends Policy { final private Policy policy; final private MDP mdp; @@ -61,8 +61,10 @@ public class EpsGreedy> extend return policy.nextAction(input); else return mdp.getActionSpace().randomAction(); + } - + public A nextAction(Observation observation) { + return this.nextAction(observation.getData()); } public float getEpsilon() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java index 1b5ae67af..885fa36a2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java @@ -6,7 +6,7 @@ import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; -public interface IPolicy { +public interface IPolicy { > double play(MDP mdp, IHistoryProcessor hp); A nextAction(INDArray input); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index 946bdca7b..84ef26f25 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -35,7 +35,7 @@ import org.nd4j.linalg.util.ArrayUtil; * * A Policy responsability is to choose the next action given a state */ -public abstract class Policy implements IPolicy { +public abstract class Policy implements IPolicy { public abstract NeuralNet getNeuralNet(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java new file mode 100644 index 000000000..efbf29603 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java @@ -0,0 +1,139 @@ +package org.deeplearning4j.rl4j.util; + +import lombok.Getter; +import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.ObservationSpace; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +public class LegacyMDPWrapper> implements MDP { + + @Getter + private final MDP wrappedMDP; + @Getter + private final WrapperObservationSpace observationSpace; + private final ILearning learning; + private int skipFrame; + + private int step = 0; + + public LegacyMDPWrapper(MDP wrappedMDP, ILearning learning) { + this.wrappedMDP = wrappedMDP; + this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape()); + this.learning = learning; + } + + @Override + public AS getActionSpace() { + return wrappedMDP.getActionSpace(); + } + + @Override + public Observation reset() { + INDArray rawObservation = getInput(wrappedMDP.reset()); + + IHistoryProcessor historyProcessor = learning.getHistoryProcessor(); + if(historyProcessor != null) { + historyProcessor.record(rawObservation.dup()); + rawObservation.muli(1.0 / historyProcessor.getScale()); + } + + Observation observation = new Observation(new INDArray[] { rawObservation }); + + if(historyProcessor != null) { + skipFrame = historyProcessor.getConf().getSkipFrame(); + historyProcessor.add(rawObservation); + } + step = 0; + + return observation; + } + + @Override + public void close() { + wrappedMDP.close(); + } + + @Override + public StepReply step(A a) { + IHistoryProcessor historyProcessor = learning.getHistoryProcessor(); + + StepReply rawStepReply = wrappedMDP.step(a); + INDArray rawObservation = getInput(rawStepReply.getObservation()); + + ++step; + + int requiredFrame = 0; + if(historyProcessor != null) { + historyProcessor.record(rawObservation.dup()); + rawObservation.muli(1.0 / historyProcessor.getScale()); + + requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1); + if ((learning.getStepCounter() % skipFrame == 0 && step >= requiredFrame) + || (step % skipFrame == 0 && step < requiredFrame )){ + historyProcessor.add(rawObservation); + } + } + + Observation observation; + if(historyProcessor != null && step >= requiredFrame) { + observation = new Observation(historyProcessor.getHistory()); + } + else { + observation = new Observation(new INDArray[] { rawObservation }); + } + + return new StepReply(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); + } + + @Override + public boolean isDone() { + return wrappedMDP.isDone(); + } + + @Override + public MDP newInstance() { + return new LegacyMDPWrapper(wrappedMDP.newInstance(), learning); + } + + private INDArray getInput(O obs) { + INDArray arr = Nd4j.create(obs.toArray()); + int[] shape = observationSpace.getShape(); + if (shape.length == 1) + return arr.reshape(new long[] {1, arr.length()}); + else + return arr.reshape(shape); + } + + public static class WrapperObservationSpace implements ObservationSpace { + + @Getter + private final int[] shape; + + public WrapperObservationSpace(int[] shape) { + + this.shape = shape; + } + + @Override + public String getName() { + return null; + } + + @Override + public INDArray getLow() { + return null; + } + + @Override + public INDArray getHigh() { + return null; + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index 1d6a9e909..e3658b8dd 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -50,7 +50,7 @@ public class AsyncThreadDiscreteTest { assertEquals(1, asyncGlobalMock.enqueueCallCount); // HistoryProcessor - assertEquals(10, hpMock.addCallCount); + assertEquals(10, hpMock.addCalls.size()); double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 }; assertEquals(expectedRecordValues.length, hpMock.recordCalls.size()); for(int i = 0; i < expectedRecordValues.length; ++i) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java index 44271adde..373c4b189 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java @@ -1,5 +1,6 @@ package org.deeplearning4j.rl4j.learning.sync; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.support.MockRandom; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; @@ -17,7 +18,8 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(2, 1, randomMock); // Act - Transition transition = new Transition(new INDArray[] { Nd4j.create(1) }, 123, 234, false, Nd4j.create(1)); + Transition transition = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 123, 234, false, new Observation(Nd4j.create(1))); sut.store(transition); List> results = sut.getBatch(1); @@ -34,9 +36,12 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(2, 1, randomMock); // Act - Transition transition1 = new Transition(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); - Transition transition2 = new Transition(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); - Transition transition3 = new Transition(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); + Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 1, 2, false, new Observation(Nd4j.create(1))); + Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 3, 4, false, new Observation(Nd4j.create(1))); + Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 5, 6, false, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); sut.store(transition3); @@ -73,9 +78,12 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = new Transition(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); - Transition transition2 = new Transition(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); - Transition transition3 = new Transition(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); + Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 1, 2, false, new Observation(Nd4j.create(1))); + Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 3, 4, false, new Observation(Nd4j.create(1))); + Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 5, 6, false, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); sut.store(transition3); @@ -92,9 +100,12 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = new Transition(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); - Transition transition2 = new Transition(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); - Transition transition3 = new Transition(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); + Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 1, 2, false, new Observation(Nd4j.create(1))); + Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 3, 4, false, new Observation(Nd4j.create(1))); + Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 5, 6, false, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); sut.store(transition3); @@ -120,11 +131,16 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = new Transition(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); - Transition transition2 = new Transition(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); - Transition transition3 = new Transition(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); - Transition transition4 = new Transition(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1)); - Transition transition5 = new Transition(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1)); + Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 1, 2, false, new Observation(Nd4j.create(1))); + Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 3, 4, false, new Observation(Nd4j.create(1))); + Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 5, 6, false, new Observation(Nd4j.create(1))); + Transition transition4 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 7, 8, false, new Observation(Nd4j.create(1))); + Transition transition5 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 9, 10, false, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); sut.store(transition3); @@ -152,11 +168,16 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = new Transition(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); - Transition transition2 = new Transition(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); - Transition transition3 = new Transition(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); - Transition transition4 = new Transition(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1)); - Transition transition5 = new Transition(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1)); + Transition transition1 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 1, 2, false, new Observation(Nd4j.create(1))); + Transition transition2 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 3, 4, false, new Observation(Nd4j.create(1))); + Transition transition3 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 5, 6, false, new Observation(Nd4j.create(1))); + Transition transition4 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 7, 8, false, new Observation(Nd4j.create(1))); + Transition transition5 = new Transition(new Observation(new INDArray[] { Nd4j.create(1) }), + 9, 10, false, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); sut.store(transition3); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java new file mode 100644 index 000000000..b74ebac11 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java @@ -0,0 +1,255 @@ +package org.deeplearning4j.rl4j.learning.sync; + +import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class TransitionTest { + @Test + public void when_callingCtorWithoutHistory_expect_2DObservationAndNextObservation() { + // Arrange + double[] obs = new double[] { 1.0, 2.0, 3.0 }; + Observation observation = buildObservation(obs); + + double[] nextObs = new double[] { 10.0, 20.0, 30.0 }; + Observation nextObservation = buildObservation(nextObs); + + // Act + Transition transition = new Transition(observation, 123, 234.0, false, nextObservation); + + // Assert + double[][] expectedObservation = new double[][] { obs }; + assertExpected(expectedObservation, transition.getObservation().getData()); + + double[][] expectedNextObservation = new double[][] { nextObs }; + assertExpected(expectedNextObservation, transition.getNextObservation()); + + assertEquals(123, transition.getAction()); + assertEquals(234.0, transition.getReward(), 0.0001); + } + + @Test + public void when_callingCtorWithHistory_expect_ObservationWithHistoryAndNextObservationWithout() { + // Arrange + double[][] obs = new double[][] { + { 0.0, 1.0, 2.0 }, + { 3.0, 4.0, 5.0 }, + { 6.0, 7.0, 8.0 }, + }; + Observation observation = buildObservation(obs); + + double[][] nextObs = new double[][] { + { 10.0, 11.0, 12.0 }, + { 0.0, 1.0, 2.0 }, + { 3.0, 4.0, 5.0 }, + }; + Observation nextObservation = buildObservation(nextObs); + + // Act + Transition transition = new Transition(observation, 123, 234.0, false, nextObservation); + + // Assert + assertExpected(obs, transition.getObservation().getData()); + + assertExpected(nextObs[0], transition.getNextObservation()); + + assertEquals(123, transition.getAction()); + assertEquals(234.0, transition.getReward(), 0.0001); + } + + @Test + public void when_CallingBuildStackedObservationsAndShapeRankIs2_expect_2DResultWithObservationsStackedOnDimension0() { + // Arrange + List> transitions = new ArrayList>(); + + double[] obs1 = new double[] { 0.0, 1.0, 2.0 }; + Observation observation1 = buildObservation(obs1); + Observation nextObservation1 = buildObservation(new double[] { 100.0, 101.0, 102.0 }); + transitions.add(new Transition(observation1,0, 0.0, false, nextObservation1)); + + double[] obs2 = new double[] { 10.0, 11.0, 12.0 }; + Observation observation2 = buildObservation(obs2); + Observation nextObservation2 = buildObservation(new double[] { 110.0, 111.0, 112.0 }); + transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); + + // Act + INDArray result = Transition.buildStackedObservations(transitions); + + // Assert + double[][] expected = new double[][] { obs1, obs2 }; + assertExpected(expected, result); + } + + @Test + public void when_CallingBuildStackedObservationsAndShapeRankIsGreaterThan2_expect_ResultWithOneMoreDimensionAndObservationsStackedOnDimension0() { + // Arrange + List> transitions = new ArrayList>(); + + double[][] obs1 = new double[][] { + { 0.0, 1.0, 2.0 }, + { 3.0, 4.0, 5.0 }, + { 6.0, 7.0, 8.0 }, + }; + Observation observation1 = buildObservation(obs1); + + double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 }; + Observation nextObservation1 = buildNextObservation(obs1, nextObs1); + + transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1)); + + double[][] obs2 = new double[][] { + { 10.0, 11.0, 12.0 }, + { 13.0, 14.0, 15.0 }, + { 16.0, 17.0, 18.0 }, + }; + Observation observation2 = buildObservation(obs2); + + double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 }; + Observation nextObservation2 = buildNextObservation(obs2, nextObs2); + transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); + + // Act + INDArray result = Transition.buildStackedObservations(transitions); + + // Assert + double[][][] expected = new double[][][] { obs1, obs2 }; + assertExpected(expected, result); + } + + @Test + public void when_CallingBuildStackedNextObservationsAndShapeRankIs2_expect_2DResultWithObservationsStackedOnDimension0() { + // Arrange + List> transitions = new ArrayList>(); + + double[] obs1 = new double[] { 0.0, 1.0, 2.0 }; + double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 }; + Observation observation1 = buildObservation(obs1); + Observation nextObservation1 = buildObservation(nextObs1); + transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1)); + + double[] obs2 = new double[] { 10.0, 11.0, 12.0 }; + double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 }; + Observation observation2 = buildObservation(obs2); + Observation nextObservation2 = buildObservation(nextObs2); + transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); + + // Act + INDArray result = Transition.buildStackedNextObservations(transitions); + + // Assert + double[][] expected = new double[][] { nextObs1, nextObs2 }; + assertExpected(expected, result); + } + + @Test + public void when_CallingBuildStackedNextObservationsAndShapeRankIsGreaterThan2_expect_ResultWithOneMoreDimensionAndObservationsStackedOnDimension0() { + // Arrange + List> transitions = new ArrayList>(); + + double[][] obs1 = new double[][] { + { 0.0, 1.0, 2.0 }, + { 3.0, 4.0, 5.0 }, + { 6.0, 7.0, 8.0 }, + }; + Observation observation1 = buildObservation(obs1); + + double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 }; + Observation nextObservation1 = buildNextObservation(obs1, nextObs1); + + transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1)); + + double[][] obs2 = new double[][] { + { 10.0, 11.0, 12.0 }, + { 13.0, 14.0, 15.0 }, + { 16.0, 17.0, 18.0 }, + }; + Observation observation2 = buildObservation(obs2); + + double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 }; + Observation nextObservation2 = buildNextObservation(obs2, nextObs2); + + transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); + + // Act + INDArray result = Transition.buildStackedNextObservations(transitions); + + // Assert + double[][][] expected = new double[][][] { + new double[][] { nextObs1, obs1[0], obs1[1] }, + new double[][] { nextObs2, obs2[0], obs2[1] } + }; + assertExpected(expected, result); + } + + private Observation buildObservation(double[][] obs) { + INDArray[] history = new INDArray[] { + Nd4j.create(obs[0]).reshape(1, 3), + Nd4j.create(obs[1]).reshape(1, 3), + Nd4j.create(obs[2]).reshape(1, 3), + }; + return new Observation(history); + } + + private Observation buildObservation(double[] obs) { + return new Observation(new INDArray[] { Nd4j.create(obs).reshape(1, 3) }); + } + + private Observation buildNextObservation(double[][] obs, double[] nextObs) { + INDArray[] nextHistory = new INDArray[] { + Nd4j.create(nextObs).reshape(1, 3), + Nd4j.create(obs[0]).reshape(1, 3), + Nd4j.create(obs[1]).reshape(1, 3), + }; + return new Observation(nextHistory); + + } + + private void assertExpected(double[] expected, INDArray actual) { + long[] shape = actual.shape(); + assertEquals(2, shape.length); + assertEquals(1, shape[0]); + assertEquals(expected.length, shape[1]); + for(int i = 0; i < expected.length; ++i) { + assertEquals(expected[i], actual.getDouble(0, i), 0.0001); + } + } + + private void assertExpected(double[][] expected, INDArray actual) { + long[] shape = actual.shape(); + assertEquals(2, shape.length); + assertEquals(expected.length, shape[0]); + assertEquals(expected[0].length, shape[1]); + + for(int i = 0; i < expected.length; ++i) { + double[] expectedLine = expected[i]; + for(int j = 0; j < expectedLine.length; ++j) { + assertEquals(expectedLine[j], actual.getDouble(i, j), 0.0001); + } + } + } + + private void assertExpected(double[][][] expected, INDArray actual) { + long[] shape = actual.shape(); + assertEquals(3, shape.length); + assertEquals(expected.length, shape[0]); + assertEquals(expected[0].length, shape[1]); + assertEquals(expected[0][0].length, shape[2]); + + for(int i = 0; i < expected.length; ++i) { + double[][] expected2D = expected[i]; + for(int j = 0; j < expected2D.length; ++j) { + double[] expectedLine = expected2D[j]; + for (int k = 0; k < expectedLine.length; ++k) { + assertEquals(expectedLine[k], actual.getDouble(i, j, k), 0.0001); + } + } + } + + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index f36a5f197..59c28551b 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -40,7 +40,7 @@ public class QLearningDiscreteTest { new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); MockMDP mdp = new MockMDP(observationSpace, random); - QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, + QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000, 0, 1.0, 0, 0, 0, 0, true); MockDataManager dataManager = new MockDataManager(false); MockExpReplay expReplay = new MockExpReplay(); @@ -48,15 +48,10 @@ public class QLearningDiscreteTest { IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); sut.setHistoryProcessor(hp); - MockEncodable obs = new MockEncodable(-100); List> results = new ArrayList<>(); // Act - sut.initMdp(); - for(int step = 0; step < 16; ++step) { - results.add(sut.trainStep(obs)); - sut.incrementStep(); - } + IDataManager.StatEntry result = sut.trainEpoch(); // Assert // HistoryProcessor calls @@ -65,7 +60,11 @@ public class QLearningDiscreteTest { for(int i = 0; i < expectedRecords.length; ++i) { assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); } - assertEquals(13, hp.addCallCount); + double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 }; + assertEquals(expectedAdds.length, hp.addCalls.size()); + for(int i = 0; i < expectedAdds.length; ++i) { + assertEquals(expectedAdds[i], 255.0 * hp.addCalls.get(i).getDouble(0), 0.0001); + } assertEquals(0, hp.startMonitorCallCount); assertEquals(0, hp.stopMonitorCallCount); @@ -75,14 +74,14 @@ public class QLearningDiscreteTest { assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001); assertEquals(14, dqn.outputParams.size()); double[][] expectedDQNOutput = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 }, - new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 }, - new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 }, - new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 }, - new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 }, - new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 }, - new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 }, - new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 }, + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, + new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, + new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, + new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, + new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, @@ -108,13 +107,13 @@ public class QLearningDiscreteTest { // ExpReplay calls double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 }; int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 }; - double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, -100.0, 9.0, 11.0, 13.0, 15.0 }; + double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 }; double[][] expectedTrObservations = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 }, - new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 }, - new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 }, - new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 }, - new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 }, + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, + new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, + new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, @@ -123,26 +122,15 @@ public class QLearningDiscreteTest { Transition tr = expReplay.transitions.get(i); assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001); assertEquals(expectedTrActions[i], tr.getAction()); - assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001); + assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001); for(int j = 0; j < expectedTrObservations[i].length; ++j) { - assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001); + assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(j, 0), 0.0001); } } - // trainStep results - assertEquals(16, results.size()); - double[] expectedMaxQ = new double[] { 6.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 }; - double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 }; - for(int i=0; i < 16; ++i) { - QLearning.QLStepReturn result = results.get(i); - if(i % 2 == 0) { - assertEquals(expectedMaxQ[i/2], 255.0 * result.getMaxQ(), 0.001); - assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001); - } - else { - assertTrue(result.getMaxQ().isNaN()); - } - } + // trainEpoch result + assertEquals(16, result.getStepCounter()); + assertEquals(300.0, result.getReward(), 0.00001); } public static class TestQLearningDiscrete extends QLearningDiscrete { @@ -163,5 +151,9 @@ public class QLearningDiscreteTest { this.expReplay = exp; } + @Override + public IDataManager.StatEntry trainEpoch() { + return super.trainEpoch(); + } } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java index e598b66ca..bb8af1950 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java @@ -3,6 +3,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorit import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource; +import org.deeplearning4j.rl4j.observation.Observation; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; @@ -25,12 +26,12 @@ public class DoubleDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0}))); + add(new Transition(buildObservation(new double[]{1.1, 2.2}), + 0, 1.0, true, buildObservation(new double[]{11.0, 22.0}))); } }; DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); - sut.setNShape(new int[] { 1, 2 }); // Act DataSet result = sut.computeTDTargets(transitions); @@ -51,12 +52,12 @@ public class DoubleDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0}))); + add(new Transition(buildObservation(new double[]{1.1, 2.2}), + 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); } }; DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); - sut.setNShape(new int[] { 1, 2 }); // Act DataSet result = sut.computeTDTargets(transitions); @@ -77,14 +78,16 @@ public class DoubleDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0}))); - add(new Transition(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0}))); - add(new Transition(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0}))); + add(new Transition(buildObservation(new double[]{1.1, 2.2}), + 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); + add(new Transition(buildObservation(new double[]{3.3, 4.4}), + 1, 2.0, false, buildObservation(new double[]{33.0, 44.0}))); + add(new Transition(buildObservation(new double[]{5.5, 6.6}), + 0, 3.0, true, buildObservation(new double[]{55.0, 66.0}))); } }; DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); - sut.setNShape(new int[] { 3, 2 }); // Act DataSet result = sut.computeTDTargets(transitions); @@ -102,4 +105,7 @@ public class DoubleDQNTest { } + private Observation buildObservation(double[] data) { + return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); + } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java index 02dcdf6fd..d2608437d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java @@ -3,6 +3,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorit import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource; +import org.deeplearning4j.rl4j.observation.Observation; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; @@ -24,12 +25,12 @@ public class StandardDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0}))); + add(new Transition(buildObservation(new double[]{1.1, 2.2}), + 0, 1.0, true, buildObservation(new double[]{11.0, 22.0}))); } }; StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); - sut.setNShape(new int[] { 1, 2 }); // Act DataSet result = sut.computeTDTargets(transitions); @@ -50,12 +51,12 @@ public class StandardDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0}))); + add(new Transition(buildObservation(new double[]{1.1, 2.2}), + 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); } }; StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); - sut.setNShape(new int[] { 1, 2 }); // Act DataSet result = sut.computeTDTargets(transitions); @@ -76,14 +77,16 @@ public class StandardDQNTest { List> transitions = new ArrayList>() { { - add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0}))); - add(new Transition(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0}))); - add(new Transition(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0}))); + add(new Transition(buildObservation(new double[]{1.1, 2.2}), + 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); + add(new Transition(buildObservation(new double[]{3.3, 4.4}), + 1, 2.0, false, buildObservation(new double[]{33.0, 44.0}))); + add(new Transition(buildObservation(new double[]{5.5, 6.6}), + 0, 3.0, true, buildObservation(new double[]{55.0, 66.0}))); } }; StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); - sut.setNShape(new int[] { 3, 2 }); // Act DataSet result = sut.computeTDTargets(transitions); @@ -101,4 +104,8 @@ public class StandardDQNTest { } + private Observation buildObservation(double[] data) { + return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); + } + } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java index 08957fee5..e5a87cb93 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java @@ -1,12 +1,11 @@ package org.deeplearning4j.rl4j.learning.sync.support; -import lombok.Setter; import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; import java.io.IOException; import java.io.OutputStream; @@ -57,6 +56,11 @@ public class MockDQN implements IDQN { return batch; } + @Override + public INDArray output(Observation observation) { + return this.output(observation.getData()); + } + @Override public INDArray[] outputAll(INDArray batch) { return new INDArray[0]; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index 2ea50dd9d..ffb3680bb 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -197,7 +197,7 @@ public class PolicyTest { assertEquals(465.0, totalReward, 0.0001); // HistoryProcessor - assertEquals(27, hp.addCallCount); + assertEquals(27, hp.addCalls.size()); assertEquals(31, hp.recordCalls.size()); for(int i=0; i <= 30; ++i) { assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java index f4080c57f..680f9a653 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java @@ -4,6 +4,7 @@ import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -48,6 +49,11 @@ public class MockDQN implements IDQN { return batch; } + @Override + public INDArray output(Observation observation) { + return this.output(observation.getData()); + } + @Override public INDArray[] outputAll(INDArray batch) { return new INDArray[0]; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java index dbdb1f6a9..da435c0da 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java @@ -2,6 +2,7 @@ package org.deeplearning4j.rl4j.support; import org.apache.commons.collections4.queue.CircularFifoQueue; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -9,7 +10,6 @@ import java.util.ArrayList; public class MockHistoryProcessor implements IHistoryProcessor { - public int addCallCount = 0; public int startMonitorCallCount = 0; public int stopMonitorCallCount = 0; @@ -17,12 +17,14 @@ public class MockHistoryProcessor implements IHistoryProcessor { private final CircularFifoQueue history; public final ArrayList recordCalls; + public final ArrayList addCalls; public MockHistoryProcessor(Configuration config) { this.config = config; history = new CircularFifoQueue<>(config.getHistoryLength()); recordCalls = new ArrayList(); + addCalls = new ArrayList(); } @Override @@ -46,7 +48,7 @@ public class MockHistoryProcessor implements IHistoryProcessor { @Override public void add(INDArray image) { - ++addCallCount; + addCalls.add(image); history.add(image); } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java index 5deb72b2a..c0ac23a2d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java @@ -2,7 +2,6 @@ package org.deeplearning4j.rl4j.support; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.mdp.MDP; -import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.ObservationSpace; import org.nd4j.linalg.api.rng.Random;