From a2b973d41b17127cce6ffb12846c1476a50f7572 Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Thu, 31 Oct 2019 00:41:52 -0400 Subject: [PATCH] RL4J: Make a few fixes (#8303) * A few fixes Signed-off-by: unknown * Reverted move of ObservationSpace, ActionSpace and others Signed-off-by: unknown * Added unit tests Signed-off-by: unknown * Changed ActionSpace of gym-java-client to use Nd4j's Random Signed-off-by: Alexandre Boulanger --- .../rl4j/space/ActionSpace.java | 4 +- .../rl4j/space/DiscreteSpace.java | 18 +-- .../rl4j/learning/Learning.java | 19 +-- .../rl4j/learning/async/AsyncThread.java | 19 +-- .../learning/async/AsyncThreadDiscrete.java | 4 +- .../async/a3c/discrete/A3CDiscrete.java | 1 - .../async/a3c/discrete/A3CThreadDiscrete.java | 3 +- .../discrete/AsyncNStepQLearningDiscrete.java | 5 +- .../AsyncNStepQLearningThreadDiscrete.java | 1 - .../qlearning/discrete/QLearningDiscrete.java | 10 +- .../rl4j/policy/BoltzmannQ.java | 1 - .../deeplearning4j/rl4j/policy/IPolicy.java | 2 + .../deeplearning4j/rl4j/policy/Policy.java | 21 ++- .../learning/async/AsyncLearningTest.java | 4 +- .../async/AsyncThreadDiscreteTest.java | 134 ++++++++++++++++++ .../rl4j/learning/async/AsyncThreadTest.java | 61 +++++++- .../discrete/QLearningDiscreteTest.java | 64 +++++---- .../rl4j/policy/PolicyTest.java | 81 +++++++++++ .../rl4j/support/MockAsyncConfiguration.java | 73 ++-------- .../rl4j/support/MockAsyncGlobal.java | 16 ++- .../rl4j/support/MockHistoryProcessor.java | 8 +- .../deeplearning4j/rl4j/support/MockMDP.java | 27 +++- .../rl4j/support/MockNeuralNet.java | 12 +- .../rl4j/support/MockPolicy.java | 11 ++ .../malmo/MalmoActionSpace.java | 8 -- 25 files changed, 444 insertions(+), 163 deletions(-) create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java b/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java index a61aa788a..0b149d394 100644 --- a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java +++ b/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java @@ -26,12 +26,10 @@ package org.deeplearning4j.rl4j.space; public interface ActionSpace { /** - * @return A randomly uniformly sampled action, + * @return A random action, */ A randomAction(); - void setSeed(int seed); - Object encode(A action); int getSize(); diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java b/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java index 6478d4840..0a27c1fe8 100644 --- a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java +++ b/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java @@ -17,8 +17,8 @@ package org.deeplearning4j.rl4j.space; import lombok.Getter; - -import java.util.Random; +import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. @@ -33,19 +33,19 @@ public class DiscreteSpace implements ActionSpace { //size of the space also defined as the number of different actions @Getter final protected int size; - protected Random rd; + protected final Random rnd; public DiscreteSpace(int size) { + this(size, Nd4j.getRandom()); + } + + public DiscreteSpace(int size, Random rnd) { this.size = size; - rd = new Random(); + this.rnd = rnd; } public Integer randomAction() { - return rd.nextInt(size); - } - - public void setSeed(int seed) { - rd = new Random(seed); + return rnd.nextInt(size); } public Object encode(Integer a) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java index 04ff06bc6..2a8c80608 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java @@ -67,8 +67,6 @@ public abstract class Learning O obs = mdp.reset(); - O nextO = obs; - int step = 0; double reward = 0; @@ -77,11 +75,12 @@ public abstract class Learning int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; - while (step < requiredFrame) { - INDArray input = Learning.getInput(mdp, obs); + INDArray input = Learning.getInput(mdp, obs); + if (isHistoryProcessor) + hp.record(input); - if (isHistoryProcessor) - hp.record(input); + + while (step < requiredFrame && !mdp.isDone()) { A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP if (step % skipFrame == 0 && isHistoryProcessor) @@ -89,13 +88,17 @@ public abstract class Learning StepReply stepReply = mdp.step(action); reward += stepReply.getReward(); - nextO = stepReply.getObservation(); + obs = stepReply.getObservation(); + + input = Learning.getInput(mdp, obs); + if (isHistoryProcessor) + hp.record(input); step++; } - return new InitMdp(step, nextO, reward); + return new InitMdp(step, obs, reward); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java index 1d763be0b..34d9a9eaf 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java @@ -22,10 +22,11 @@ import lombok.Setter; import lombok.Value; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.rl4j.learning.*; -import org.deeplearning4j.rl4j.learning.listener.*; +import org.deeplearning4j.rl4j.learning.listener.TrainingListener; +import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; -import org.deeplearning4j.rl4j.policy.Policy; +import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager; @@ -118,7 +119,7 @@ public abstract class AsyncThread= getConf().getMaxEpochStep() || getMdp().isDone()) { + if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) { canContinue = finishEpoch(context) && startNewEpoch(context); if (!canContinue) { break; @@ -135,16 +136,16 @@ public abstract class AsyncThread context) { - int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.length); + int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps); SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); context.obs = subEpochReturn.getLastObs(); stepCounter += subEpochReturn.getSteps(); - context.length += subEpochReturn.getSteps(); + context.epochElapsedSteps += subEpochReturn.getSteps(); context.rewards += subEpochReturn.getReward(); context.score = subEpochReturn.getScore(); } @@ -164,7 +165,7 @@ public abstract class AsyncThread getPolicy(NN net); + protected abstract IPolicy getPolicy(NN net); protected abstract SubEpochReturn trainSubEpoch(O obs, int nstep); @@ -208,7 +209,7 @@ public abstract class AsyncThread { private O obs; private double rewards; - private int length; + private int epochElapsedSteps; private double score; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index 7458c0c06..8b8bc2861 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -25,7 +25,7 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; -import org.deeplearning4j.rl4j.policy.Policy; +import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; @@ -69,7 +69,7 @@ public abstract class AsyncThreadDiscrete> rewards = new Stack<>(); O obs = sObs; - Policy policy = getPolicy(current); + IPolicy policy = getPolicy(current); Integer action; Integer lastAction = null; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java index b420f5e83..52fa3932b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java @@ -58,7 +58,6 @@ public abstract class A3CDiscrete extends AsyncLearning extends AsyncThreadDiscrete< Integer seed = conf.getSeed(); rnd = Nd4j.getRandom(); if(seed != null) { - mdp.getActionSpace().setSeed(seed + threadNumber); - rnd.setSeed(seed); + rnd.setSeed(seed + threadNumber); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index 1b423d1a2..0c9ff057f 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -27,6 +27,7 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.factory.Nd4j; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. @@ -46,10 +47,6 @@ public abstract class AsyncNStepQLearningDiscrete this.mdp = mdp; this.configuration = conf; this.asyncGlobal = new AsyncGlobal<>(dqn, conf); - Integer seed = conf.getSeed(); - if(seed != null) { - mdp.getActionSpace().setSeed(seed); - } } @Override diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index 4a51c91d2..f8c470269 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -61,7 +61,6 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn Integer seed = conf.getSeed(); if(seed != null) { - mdp.getActionSpace().setSeed(seed + threadNumber); rnd.setSeed(seed + threadNumber); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index 6803e9521..ca5ddf0f2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -85,7 +85,6 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning extends QLearning stepReply = getMdp().step(action); + INDArray ninput = getInput(stepReply.getObservation()); + + if (isHistoryProcessor) + getHistoryProcessor().record(ninput); + accuReward += stepReply.getReward() * configuration.getRewardFactor(); //if it's not a skipped frame, you can do a step of training if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) { - INDArray ninput = getInput(stepReply.getObservation()); if (isHistoryProcessor) getHistoryProcessor().add(ninput); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java index 6ed7d4557..bff1a782c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java @@ -20,7 +20,6 @@ import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; -import org.nd4j.linalg.factory.Nd4j; import static org.nd4j.linalg.ops.transforms.Transforms.exp; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java index 5c9d54d45..1b5ae67af 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java @@ -4,7 +4,9 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; public interface IPolicy { > double play(MDP mdp, IHistoryProcessor hp); + A nextAction(INDArray input); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index 1be123f1d..946bdca7b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -51,6 +51,9 @@ public abstract class Policy implements IPolicy { @Override public > double play(MDP mdp, IHistoryProcessor hp) { + boolean isHistoryProcessor = hp != null; + int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; + getNeuralNet().reset(); Learning.InitMdp initMdp = Learning.initMdp(mdp, hp); O obs = initMdp.getLastObs(); @@ -62,17 +65,10 @@ public abstract class Policy implements IPolicy { int step = initMdp.getSteps(); INDArray[] history = null; + INDArray input = Learning.getInput(mdp, obs); + while (!mdp.isDone()) { - INDArray input = Learning.getInput(mdp, obs); - boolean isHistoryProcessor = hp != null; - - if (isHistoryProcessor) - hp.record(input); - - int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; - - if (step % skipFrame != 0) { action = lastAction; } else { @@ -102,8 +98,11 @@ public abstract class Policy implements IPolicy { StepReply stepReply = mdp.step(action); reward += stepReply.getReward(); - if (isHistoryProcessor) - hp.add(Learning.getInput(mdp, stepReply.getObservation())); + input = Learning.getInput(mdp, stepReply.getObservation()); + if (isHistoryProcessor) { + hp.record(input); + hp.add(input); + } history = isHistoryProcessor ? hp.getHistory() : new INDArray[] {Learning.getInput(mdp, stepReply.getObservation())}; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java index 536c6a8ad..8de9d864a 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java @@ -68,10 +68,10 @@ public class AsyncLearningTest { public static class TestContext { - public final MockAsyncConfiguration conf = new MockAsyncConfiguration(1, 1); + MockAsyncConfiguration config = new MockAsyncConfiguration(1, 11, 0, 0, 0, 0,0, 0, 0, 0); public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(); public final MockPolicy policy = new MockPolicy(); - public final TestAsyncLearning sut = new TestAsyncLearning(conf, asyncGlobal, policy); + public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy); public final MockTrainingListener listener = new MockTrainingListener(); public TestContext() { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java new file mode 100644 index 000000000..1d6a9e909 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -0,0 +1,134 @@ +package org.deeplearning4j.rl4j.learning.async; + +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.policy.IPolicy; +import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.deeplearning4j.rl4j.support.*; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Stack; + +import static org.junit.Assert.assertEquals; + +public class AsyncThreadDiscreteTest { + + @Test + public void refac_AsyncThreadDiscrete_trainSubEpoch() { + // Arrange + MockNeuralNet nnMock = new MockNeuralNet(); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock); + MockObservationSpace observationSpace = new MockObservationSpace(); + MockMDP mdpMock = new MockMDP(observationSpace); + TrainingListenerList listeners = new TrainingListenerList(); + MockPolicy policyMock = new MockPolicy(); + MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 0, 5,0, 0, 0, 0); + IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); + TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); + MockEncodable obs = new MockEncodable(123); + + hpMock.add(Learning.getInput(mdpMock, new MockEncodable(1))); + hpMock.add(Learning.getInput(mdpMock, new MockEncodable(2))); + hpMock.add(Learning.getInput(mdpMock, new MockEncodable(3))); + hpMock.add(Learning.getInput(mdpMock, new MockEncodable(4))); + hpMock.add(Learning.getInput(mdpMock, new MockEncodable(5))); + + // Act + AsyncThread.SubEpochReturn result = sut.trainSubEpoch(obs, 2); + + // Assert + assertEquals(4, result.getSteps()); + assertEquals(6.0, result.getReward(), 0.00001); + assertEquals(0.0, result.getScore(), 0.00001); + assertEquals(3.0, result.getLastObs().toArray()[0], 0.00001); + assertEquals(1, asyncGlobalMock.enqueueCallCount); + + // HistoryProcessor + assertEquals(10, hpMock.addCallCount); + double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 }; + assertEquals(expectedRecordValues.length, hpMock.recordCalls.size()); + for(int i = 0; i < expectedRecordValues.length; ++i) { + assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001); + } + + // Policy + double[][] expectedPolicyInputs = new double[][] { + new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 }, + new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 }, + new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 }, + new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 }, + }; + assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size()); + for(int i = 0; i < expectedPolicyInputs.length; ++i) { + double[] expectedRow = expectedPolicyInputs[i]; + INDArray input = policyMock.actionInputs.get(i); + assertEquals(expectedRow.length, input.shape()[0]); + for(int j = 0; j < expectedRow.length; ++j) { + assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001); + } + } + + // NeuralNetwork + assertEquals(1, nnMock.copyCallCount); + double[][] expectedNNInputs = new double[][] { + new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 }, + new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 }, + new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 }, + new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 }, + new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 }, + }; + assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size()); + for(int i = 0; i < expectedNNInputs.length; ++i) { + double[] expectedRow = expectedNNInputs[i]; + INDArray input = nnMock.outputAllInputs.get(i); + assertEquals(expectedRow.length, input.shape()[0]); + for(int j = 0; j < expectedRow.length; ++j) { + assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001); + } + } + + } + + public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete { + + private final IAsyncGlobal asyncGlobal; + private final MockPolicy policy; + private final MockAsyncConfiguration config; + + public TestAsyncThreadDiscrete(IAsyncGlobal asyncGlobal, MDP mdp, + TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy, + MockAsyncConfiguration config, IHistoryProcessor hp) { + super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); + this.asyncGlobal = asyncGlobal; + this.policy = policy; + this.config = config; + setHistoryProcessor(hp); + } + + @Override + public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack> rewards) { + return new Gradient[0]; + } + + @Override + protected IAsyncGlobal getAsyncGlobal() { + return asyncGlobal; + } + + @Override + protected AsyncConfiguration getConf() { + return config; + } + + @Override + protected IPolicy getPolicy(MockNeuralNet net) { + return policy; + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index 4d9e70b56..0a590a1e5 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -1,5 +1,8 @@ package org.deeplearning4j.rl4j.learning.async; +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; @@ -9,7 +12,11 @@ import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; +import java.util.ArrayList; +import java.util.List; + import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; public class AsyncThreadTest { @@ -82,7 +89,42 @@ public class AsyncThreadTest { IDataManager.StatEntry statEntry = context.listener.statEntries.get(i); assertEquals(expectedStepCounter[i], statEntry.getStepCounter()); assertEquals(i, statEntry.getEpochCounter()); - assertEquals(2.0, statEntry.getReward(), 0.0001); + assertEquals(38.0, statEntry.getReward(), 0.0001); + } + } + + @Test + public void when_run_expect_NeuralNetIsResetAtInitAndEveryEpoch() { + // Arrange + TestContext context = new TestContext(); + + // Act + context.sut.run(); + + // Assert + assertEquals(6, context.neuralNet.resetCallCount); + } + + @Test + public void when_run_expect_trainSubEpochCalled() { + // Arrange + TestContext context = new TestContext(); + + // Act + context.sut.run(); + + // Assert + assertEquals(10, context.sut.trainSubEpochParams.size()); + for(int i = 0; i < 10; ++i) { + MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i); + if(i % 2 == 0) { + assertEquals(2, params.nstep); + assertEquals(8.0, params.obs.toArray()[0], 0.00001); + } + else { + assertEquals(1, params.nstep); + assertNull(params.obs); + } } } @@ -91,14 +133,18 @@ public class AsyncThreadTest { public final MockNeuralNet neuralNet = new MockNeuralNet(); public final MockObservationSpace observationSpace = new MockObservationSpace(); public final MockMDP mdp = new MockMDP(observationSpace); - public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 2); + public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0); public final TrainingListenerList listeners = new TrainingListenerList(); public final MockTrainingListener listener = new MockTrainingListener(); + private final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf); + public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners); public TestContext() { asyncGlobal.setMaxLoops(10); listeners.add(listener); + sut.setHistoryProcessor(historyProcessor); } } @@ -107,11 +153,12 @@ public class AsyncThreadTest { public int preEpochCallCount = 0; public int postEpochCallCount = 0; - private final IAsyncGlobal asyncGlobal; private final MockNeuralNet neuralNet; private final AsyncConfiguration conf; + private final List trainSubEpochParams = new ArrayList(); + public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) { super(asyncGlobal, mdp, listeners, threadNumber, 0); @@ -154,8 +201,16 @@ public class AsyncThreadTest { @Override protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) { + trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep)); return new SubEpochReturn(1, null, 1.0, 1.0); } + + @AllArgsConstructor + @Getter + public static class TrainSubEpochParams { + Encodable obs; + int nstep; + } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index b5212566d..f36a5f197 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -26,9 +26,20 @@ public class QLearningDiscreteTest { public void refac_QLearningDiscrete_trainStep() { // Arrange MockObservationSpace observationSpace = new MockObservationSpace(); - MockMDP mdp = new MockMDP(observationSpace); MockDQN dqn = new MockDQN(); - MockRandom random = new MockRandom(new double[] { 0.7309677600860596, 0.8314409852027893, 0.2405363917350769, 0.6063451766967773, 0.6374173760414124, 0.3090505599975586, 0.5504369735717773, 0.11700659990310669 }, null); + MockRandom random = new MockRandom(new double[] { + 0.7309677600860596, + 0.8314409852027893, + 0.2405363917350769, + 0.6063451766967773, + 0.6374173760414124, + 0.3090505599975586, + 0.5504369735717773, + 0.11700659990310669 + }, + new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); + MockMDP mdp = new MockMDP(observationSpace, random); + QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, 0, 1.0, 0, 0, 0, 0, true); MockDataManager dataManager = new MockDataManager(false); @@ -37,7 +48,7 @@ public class QLearningDiscreteTest { IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); sut.setHistoryProcessor(hp); - MockEncodable obs = new MockEncodable(1); + MockEncodable obs = new MockEncodable(-100); List> results = new ArrayList<>(); // Act @@ -49,7 +60,11 @@ public class QLearningDiscreteTest { // Assert // HistoryProcessor calls - assertEquals(24, hp.recordCallCount); + double[] expectedRecords = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0 }; + assertEquals(expectedRecords.length, hp.recordCalls.size()); + for(int i = 0; i < expectedRecords.length; ++i) { + assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); + } assertEquals(13, hp.addCallCount); assertEquals(0, hp.startMonitorCallCount); assertEquals(0, hp.stopMonitorCallCount); @@ -60,21 +75,20 @@ public class QLearningDiscreteTest { assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001); assertEquals(14, dqn.outputParams.size()); double[][] expectedDQNOutput = new double[][] { - new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 }, - new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 }, - new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 }, - new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 }, - new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 }, - new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 }, - new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 }, - new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 }, + new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 }, + new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 }, + new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 }, + new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 }, + new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 }, + new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 }, + new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 }, + new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 }, new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, - }; for(int i = 0; i < expectedDQNOutput.length; ++i) { INDArray outputParam = dqn.outputParams.get(i); @@ -84,23 +98,23 @@ public class QLearningDiscreteTest { double[] expectedRow = expectedDQNOutput[i]; for(int j = 0; j < expectedRow.length; ++j) { - assertEquals(expectedRow[j] / 255.0, outputParam.getDouble(j), 0.00001); + assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001); } } // MDP calls - assertArrayEquals(new Integer[] { 0, 0, 0, 0, 0, 0, 0, 0, 0 ,0, 4, 4, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray()); + assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray()); // ExpReplay calls double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 }; - int[] expectedTrActions = new int[] { 0, 4, 3, 4, 4, 4, 4, 4 }; - double[] expectedTrNextObservation = new double[] { 0, 0, 0, 1.0, 9.0, 11.0, 13.0, 15.0 }; + int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 }; + double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, -100.0, 9.0, 11.0, 13.0, 15.0 }; double[][] expectedTrObservations = new double[][] { - new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 }, - new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 }, - new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 }, - new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 }, - new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 }, + new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 }, + new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 }, + new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 }, + new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 }, + new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 }, new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, @@ -111,18 +125,18 @@ public class QLearningDiscreteTest { assertEquals(expectedTrActions[i], tr.getAction()); assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001); for(int j = 0; j < expectedTrObservations[i].length; ++j) { - assertEquals(expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001); + assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001); } } // trainStep results assertEquals(16, results.size()); - double[] expectedMaxQ = new double[] { 1.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 }; + double[] expectedMaxQ = new double[] { 6.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 }; double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 }; for(int i=0; i < 16; ++i) { QLearning.QLStepReturn result = results.get(i); if(i % 2 == 0) { - assertEquals(expectedMaxQ[i/2] / 255.0, result.getMaxQ(), 0.001); + assertEquals(expectedMaxQ[i/2], 255.0 * result.getMaxQ(), 0.001); assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001); } else { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index 2dacd88e1..2ea50dd9d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -22,7 +22,15 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; +import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.support.*; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,6 +39,8 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.IOException; import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -153,4 +163,75 @@ public class PolicyTest { assertTrue(count[2] < 40); assertTrue(count[3] < 50); } + + @Test + public void refacPolicyPlay() { + // Arrange + MockObservationSpace observationSpace = new MockObservationSpace(); + MockDQN dqn = new MockDQN(); + MockRandom random = new MockRandom(new double[] { + 0.7309677600860596, + 0.8314409852027893, + 0.2405363917350769, + 0.6063451766967773, + 0.6374173760414124, + 0.3090505599975586, + 0.5504369735717773, + 0.11700659990310669 + }, + new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); + MockMDP mdp = new MockMDP(observationSpace, 30, random); + + QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, + 0, 1.0, 0, 0, 0, 0, true); + MockNeuralNet nnMock = new MockNeuralNet(); + MockRefacPolicy sut = new MockRefacPolicy(nnMock); + IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); + + // Act + double totalReward = sut.play(mdp, hp); + + // Assert + assertEquals(1, nnMock.resetCallCount); + assertEquals(465.0, totalReward, 0.0001); + + // HistoryProcessor + assertEquals(27, hp.addCallCount); + assertEquals(31, hp.recordCalls.size()); + for(int i=0; i <= 30; ++i) { + assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001); + } + + // MDP + assertEquals(1, mdp.resetCount); + assertEquals(30, mdp.actions.size()); + for(int i = 0; i < mdp.actions.size(); ++i) { + assertEquals(0, (int)mdp.actions.get(i)); + } + + // DQN + assertEquals(0, dqn.fitParams.size()); + assertEquals(0, dqn.outputParams.size()); + } + + public static class MockRefacPolicy extends Policy { + + private NeuralNet neuralNet; + + public MockRefacPolicy(NeuralNet neuralNet) { + + this.neuralNet = neuralNet; + } + + @Override + public NeuralNet getNeuralNet() { + return neuralNet; + } + + @Override + public Integer nextAction(INDArray input) { + return (int)input.getDouble(0); + } + } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java index 1706dc49e..56581cc0d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java @@ -1,65 +1,22 @@ package org.deeplearning4j.rl4j.support; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Value; import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; +@AllArgsConstructor +@Value public class MockAsyncConfiguration implements AsyncConfiguration { - private final int nStep; - private final int maxEpochStep; - - public MockAsyncConfiguration(int nStep, int maxEpochStep) { - this.nStep = nStep; - - this.maxEpochStep = maxEpochStep; - } - - @Override - public Integer getSeed() { - return 0; - } - - @Override - public int getMaxEpochStep() { - return maxEpochStep; - } - - @Override - public int getMaxStep() { - return 0; - } - - @Override - public int getNumThread() { - return 0; - } - - @Override - public int getNstep() { - return nStep; - } - - @Override - public int getTargetDqnUpdateFreq() { - return 0; - } - - @Override - public int getUpdateStart() { - return 0; - } - - @Override - public double getRewardFactor() { - return 0; - } - - @Override - public double getGamma() { - return 0; - } - - @Override - public double getErrorClamp() { - return 0; - } + private Integer seed; + private int maxEpochStep; + private int maxStep; + private int numThread; + private int nstep; + private int targetDqnUpdateFreq; + private int updateStart; + private double rewardFactor; + private double gamma; + private double errorClamp; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java index 0bc34d239..34a2078f0 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java @@ -1,5 +1,6 @@ package org.deeplearning4j.rl4j.support; +import lombok.Getter; import lombok.Setter; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; @@ -9,9 +10,13 @@ import java.util.concurrent.atomic.AtomicInteger; public class MockAsyncGlobal implements IAsyncGlobal { + private final NeuralNet current; + public boolean hasBeenStarted = false; public boolean hasBeenTerminated = false; + public int enqueueCallCount = 0; + @Setter private int maxLoops; @Setter @@ -19,8 +24,13 @@ public class MockAsyncGlobal implements IAsyncGlobal { private int currentLoop = 0; public MockAsyncGlobal() { + this(null); + } + + public MockAsyncGlobal(NeuralNet current) { maxLoops = Integer.MAX_VALUE; numLoopsStopRunning = Integer.MAX_VALUE; + this.current = current; } @Override @@ -50,16 +60,16 @@ public class MockAsyncGlobal implements IAsyncGlobal { @Override public NeuralNet getCurrent() { - return null; + return current; } @Override public NeuralNet getTarget() { - return null; + return current; } @Override public void enqueue(Gradient[] gradient, Integer nstep) { - + ++enqueueCallCount; } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java index 3235f21af..dbdb1f6a9 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java @@ -5,9 +5,10 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import java.util.ArrayList; + public class MockHistoryProcessor implements IHistoryProcessor { - public int recordCallCount = 0; public int addCallCount = 0; public int startMonitorCallCount = 0; public int stopMonitorCallCount = 0; @@ -15,10 +16,13 @@ public class MockHistoryProcessor implements IHistoryProcessor { private final Configuration config; private final CircularFifoQueue history; + public final ArrayList recordCalls; + public MockHistoryProcessor(Configuration config) { this.config = config; history = new CircularFifoQueue<>(config.getHistoryLength()); + recordCalls = new ArrayList(); } @Override @@ -37,7 +41,7 @@ public class MockHistoryProcessor implements IHistoryProcessor { @Override public void record(INDArray image) { - ++recordCallCount; + recordCalls.add(image); } @Override diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java index 8dce8edea..5deb72b2a 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java @@ -2,8 +2,10 @@ package org.deeplearning4j.rl4j.support; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.ObservationSpace; +import org.nd4j.linalg.api.rng.Random; import java.util.ArrayList; import java.util.List; @@ -11,14 +13,30 @@ import java.util.List; public class MockMDP implements MDP { private final DiscreteSpace actionSpace; + private final int stepsUntilDone; private int currentObsValue = 0; private final ObservationSpace observationSpace; public final List actions = new ArrayList<>(); + private int step = 0; + public int resetCount = 0; + + public MockMDP(ObservationSpace observationSpace, int stepsUntilDone, DiscreteSpace actionSpace) { + this.stepsUntilDone = stepsUntilDone; + this.actionSpace = actionSpace; + this.observationSpace = observationSpace; + } + + public MockMDP(ObservationSpace observationSpace, int stepsUntilDone, Random rnd) { + this(observationSpace, stepsUntilDone, new DiscreteSpace(5, rnd)); + } public MockMDP(ObservationSpace observationSpace) { - actionSpace = new DiscreteSpace(5); - this.observationSpace = observationSpace; + this(observationSpace, Integer.MAX_VALUE, new DiscreteSpace(5)); + } + + public MockMDP(ObservationSpace observationSpace, Random rnd) { + this(observationSpace, Integer.MAX_VALUE, new DiscreteSpace(5, rnd)); } @Override @@ -33,7 +51,9 @@ public class MockMDP implements MDP { @Override public MockEncodable reset() { + ++resetCount; currentObsValue = 0; + step = 0; return new MockEncodable(currentObsValue++); } @@ -45,12 +65,13 @@ public class MockMDP implements MDP { @Override public StepReply step(Integer action) { actions.add(action); + ++step; return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null); } @Override public boolean isDone() { - return false; + return step >= stepsUntilDone; } @Override diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java index bdffa59a8..6d542934b 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java @@ -4,13 +4,18 @@ import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.network.NeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import java.io.IOException; import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; public class MockNeuralNet implements NeuralNet { public int resetCallCount = 0; + public int copyCallCount = 0; + public List outputAllInputs = new ArrayList(); @Override public NeuralNetwork[] getNeuralNetworks() { @@ -29,17 +34,18 @@ public class MockNeuralNet implements NeuralNet { @Override public INDArray[] outputAll(INDArray batch) { - return new INDArray[0]; + outputAllInputs.add(batch); + return new INDArray[] { Nd4j.create(new double[] { 1.0 }) }; } @Override public NeuralNet clone() { - return null; + return this; } @Override public void copy(NeuralNet from) { - + ++copyCallCount; } @Override diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java index 28f812f33..82adc65b7 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java @@ -4,14 +4,25 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.ArrayList; +import java.util.List; public class MockPolicy implements IPolicy { public int playCallCount = 0; + public List actionInputs = new ArrayList(); @Override public > double play(MDP mdp, IHistoryProcessor hp) { ++playCallCount; return 0; } + + @Override + public Integer nextAction(INDArray input) { + actionInputs.add(input); + return null; + } } diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpace.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpace.java index b20714889..5961d3355 100644 --- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpace.java +++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpace.java @@ -45,12 +45,4 @@ public abstract class MalmoActionSpace extends DiscreteSpace { public Integer noOp() { return -1; } - - /** - * Sets the seed used for random generation of actions - * @param seed random number generator seed - */ - public void setRandomSeed(long seed) { - rd.setSeed(seed); - } }