diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index 72dde9214..a2c25a43c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -45,8 +45,12 @@ import java.util.List; public abstract class QLearning> extends SyncLearning { + // FIXME Changed for refac + // @Getter + // final private IExpReplay expReplay; @Getter - final private IExpReplay expReplay; + @Setter(AccessLevel.PACKAGE) + private IExpReplay expReplay; public QLearning(QLConfiguration conf) { super(conf); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index f8d1a219c..0e53103ef 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -1,17 +1,16 @@ package org.deeplearning4j.rl4j.learning.async; -import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.policy.Policy; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.support.MockDataManager; import org.deeplearning4j.rl4j.support.MockHistoryProcessor; +import org.deeplearning4j.rl4j.support.MockMDP; +import org.deeplearning4j.rl4j.support.MockObservationSpace; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; @@ -93,7 +92,7 @@ public class AsyncThreadTest { IDataManager.StatEntry entry = dataManager.statEntries.get(i); assertEquals(i + 1, entry.getStepCounter()); assertEquals(i, entry.getEpochCounter()); - assertEquals(1.0, entry.getReward(), 0.0); + assertEquals(79.0, entry.getReward(), 0.0); } assertEquals(10, dataManager.isSaveDataCallCount); @@ -128,7 +127,7 @@ public class AsyncThreadTest { IDataManager.StatEntry entry = dataManager.statEntries.get(i); assertEquals(i + 1, entry.getStepCounter()); assertEquals(i, entry.getEpochCounter()); - assertEquals(1.0, entry.getReward(), 0.0); + assertEquals(79.0, entry.getReward(), 0.0); } assertEquals(1, dataManager.isSaveDataCallCount); @@ -308,91 +307,6 @@ public class AsyncThreadTest { } } - public static class MockEncodable implements Encodable { - - private final int value; - - public MockEncodable(int value) { - - this.value = value; - } - - @Override - public double[] toArray() { - return new double[] { value }; - } - } - - public static class MockObservationSpace implements ObservationSpace { - - @Override - public String getName() { - return null; - } - - @Override - public int[] getShape() { - return new int[] { 1 }; - } - - @Override - public INDArray getLow() { - return null; - } - - @Override - public INDArray getHigh() { - return null; - } - } - - public static class MockMDP implements MDP { - - private final DiscreteSpace actionSpace; - private int currentObsValue = 0; - private final ObservationSpace observationSpace; - - public MockMDP(ObservationSpace observationSpace) { - actionSpace = new DiscreteSpace(5); - this.observationSpace = observationSpace; - } - - @Override - public ObservationSpace getObservationSpace() { - return observationSpace; - } - - @Override - public DiscreteSpace getActionSpace() { - return actionSpace; - } - - @Override - public MockEncodable reset() { - return new MockEncodable(++currentObsValue); - } - - @Override - public void close() { - - } - - @Override - public StepReply step(Integer obs) { - return new StepReply(new MockEncodable(obs), (double)obs, isDone(), null); - } - - @Override - public boolean isDone() { - return false; - } - - @Override - public MDP newInstance() { - return null; - } - } - public static class MockAsyncConfiguration implements AsyncConfiguration { private final int nStep; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java new file mode 100644 index 000000000..51bdeaf41 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -0,0 +1,142 @@ +package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; + +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.deeplearning4j.rl4j.support.*; +import org.deeplearning4j.rl4j.util.IDataManager; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.*; + +public class QLearningDiscreteTest { + @Test + public void refac_QLearningDiscrete_trainStep() { + // Arrange + MockObservationSpace observationSpace = new MockObservationSpace(); + MockMDP mdp = new MockMDP(observationSpace); + MockDQN dqn = new MockDQN(); + QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, + 0, 1.0, 0, 0, 0, 0, true); + MockDataManager dataManager = new MockDataManager(false); + TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, 10); + IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); + sut.setHistoryProcessor(hp); + MockExpReplay expReplay = new MockExpReplay(); + sut.setExpReplay(expReplay); + MockEncodable obs = new MockEncodable(1); + List> results = new ArrayList<>(); + + // Act + sut.initMdp(); + for(int step = 0; step < 16; ++step) { + results.add(sut.trainStep(obs)); + sut.incrementStep(); + } + + // Assert + // HistoryProcessor calls + assertEquals(24, hp.recordCallCount); + assertEquals(13, hp.addCallCount); + assertEquals(0, hp.startMonitorCallCount); + assertEquals(0, hp.stopMonitorCallCount); + + // DQN calls + assertEquals(1, dqn.fitParams.size()); + assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001); + assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001); + assertEquals(14, dqn.outputParams.size()); + double[][] expectedDQNOutput = new double[][] { + new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 }, + new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 }, + new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 }, + new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 }, + new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 }, + new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 }, + new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 }, + new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 }, + new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, + new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, + new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, + new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, + new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, + new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, + + }; + for(int i = 0; i < expectedDQNOutput.length; ++i) { + INDArray outputParam = dqn.outputParams.get(i); + + assertEquals(5, outputParam.shape()[0]); + assertEquals(1, outputParam.shape()[1]); + + double[] expectedRow = expectedDQNOutput[i]; + for(int j = 0; j < expectedRow.length; ++j) { + assertEquals(expectedRow[j] / 255.0, outputParam.getDouble(j), 0.00001); + } + } + + // MDP calls + assertArrayEquals(new Integer[] { 0, 0, 0, 0, 0, 0, 0, 0, 0 ,0, 4, 4, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray()); + + // ExpReplay calls + double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 }; + int[] expectedTrActions = new int[] { 0, 4, 3, 4, 4, 4, 4, 4 }; + double[] expectedTrNextObservation = new double[] { 0, 0, 0, 1.0, 9.0, 11.0, 13.0, 15.0 }; + double[][] expectedTrObservations = new double[][] { + new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 }, + new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 }, + new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 }, + new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 }, + new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 }, + new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, + new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, + new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, + }; + for(int i = 0; i < expectedTrRewards.length; ++i) { + Transition tr = expReplay.transitions.get(i); + assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001); + assertEquals(expectedTrActions[i], tr.getAction()); + assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001); + for(int j = 0; j < expectedTrObservations[i].length; ++j) { + assertEquals(expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001); + } + } + + // trainStep results + assertEquals(16, results.size()); + double[] expectedMaxQ = new double[] { 1.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 }; + double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 }; + for(int i=0; i < 16; ++i) { + QLearning.QLStepReturn result = results.get(i); + if(i % 2 == 0) { + assertEquals(expectedMaxQ[i/2] / 255.0, result.getMaxQ(), 0.001); + assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001); + } + else { + assertTrue(result.getMaxQ().isNaN()); + } + } + } + + public static class TestQLearningDiscrete extends QLearningDiscrete { + public TestQLearningDiscrete(MDP mdp,IDQN dqn, + QLConfiguration conf, IDataManager dataManager, int epsilonNbStep) { + super(mdp, dqn, conf, dataManager, epsilonNbStep); + } + + @Override + protected Pair setTarget(ArrayList> transitions) { + return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java new file mode 100644 index 000000000..f4080c57f --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java @@ -0,0 +1,100 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.nn.api.NeuralNetwork; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; + +public class MockDQN implements IDQN { + + public final List outputParams = new ArrayList<>(); + public final List> fitParams = new ArrayList<>(); + + @Override + public NeuralNetwork[] getNeuralNetworks() { + return new NeuralNetwork[0]; + } + + @Override + public boolean isRecurrent() { + return false; + } + + @Override + public void reset() { + + } + + @Override + public void fit(INDArray input, INDArray labels) { + fitParams.add(new Pair<>(input, labels)); + } + + @Override + public void fit(INDArray input, INDArray[] labels) { + + } + + @Override + public INDArray output(INDArray batch){ + outputParams.add(batch); + return batch; + } + + @Override + public INDArray[] outputAll(INDArray batch) { + return new INDArray[0]; + } + + @Override + public IDQN clone() { + return null; + } + + @Override + public void copy(NeuralNet from) { + + } + + @Override + public void copy(IDQN from) { + + } + + @Override + public Gradient[] gradient(INDArray input, INDArray label) { + return new Gradient[0]; + } + + @Override + public Gradient[] gradient(INDArray input, INDArray[] label) { + return new Gradient[0]; + } + + @Override + public void applyGradient(Gradient[] gradient, int batchSize) { + + } + + @Override + public double getLatestScore() { + return 0; + } + + @Override + public void save(OutputStream os) throws IOException { + + } + + @Override + public void save(String filename) throws IOException { + + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockEncodable.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockEncodable.java new file mode 100644 index 000000000..436205b42 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockEncodable.java @@ -0,0 +1,18 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.rl4j.space.Encodable; + +public class MockEncodable implements Encodable { + + private final int value; + + public MockEncodable(int value) { + + this.value = value; + } + + @Override + public double[] toArray() { + return new double[] { value }; + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExpReplay.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExpReplay.java new file mode 100644 index 000000000..d1fa84c04 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockExpReplay.java @@ -0,0 +1,22 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.rl4j.learning.sync.IExpReplay; +import org.deeplearning4j.rl4j.learning.sync.Transition; + +import java.util.ArrayList; +import java.util.List; + +public class MockExpReplay implements IExpReplay { + + public List> transitions = new ArrayList<>(); + + @Override + public ArrayList> getBatch() { + return null; + } + + @Override + public void store(Transition transition) { + transitions.add(transition); + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java index 9d24161b4..3235f21af 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java @@ -1,15 +1,24 @@ package org.deeplearning4j.rl4j.support; +import org.apache.commons.collections4.queue.CircularFifoQueue; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; public class MockHistoryProcessor implements IHistoryProcessor { + public int recordCallCount = 0; + public int addCallCount = 0; + public int startMonitorCallCount = 0; + public int stopMonitorCallCount = 0; + private final Configuration config; + private final CircularFifoQueue history; public MockHistoryProcessor(Configuration config) { this.config = config; + history = new CircularFifoQueue<>(config.getHistoryLength()); } @Override @@ -19,27 +28,32 @@ public class MockHistoryProcessor implements IHistoryProcessor { @Override public INDArray[] getHistory() { - return new INDArray[0]; + INDArray[] array = new INDArray[getConf().getHistoryLength()]; + for (int i = 0; i < config.getHistoryLength(); i++) { + array[i] = history.get(i).castTo(Nd4j.dataType()); + } + return array; } @Override public void record(INDArray image) { - + ++recordCallCount; } @Override public void add(INDArray image) { - + ++addCallCount; + history.add(image); } @Override public void startMonitor(String filename, int[] shape) { - + ++startMonitorCallCount; } @Override public void stopMonitor() { - + ++stopMonitorCallCount; } @Override @@ -49,6 +63,6 @@ public class MockHistoryProcessor implements IHistoryProcessor { @Override public double getScale() { - return 0; + return 255.0; } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java new file mode 100644 index 000000000..8dce8edea --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java @@ -0,0 +1,60 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.deeplearning4j.rl4j.space.ObservationSpace; + +import java.util.ArrayList; +import java.util.List; + +public class MockMDP implements MDP { + + private final DiscreteSpace actionSpace; + private int currentObsValue = 0; + private final ObservationSpace observationSpace; + + public final List actions = new ArrayList<>(); + + public MockMDP(ObservationSpace observationSpace) { + actionSpace = new DiscreteSpace(5); + this.observationSpace = observationSpace; + } + + @Override + public ObservationSpace getObservationSpace() { + return observationSpace; + } + + @Override + public DiscreteSpace getActionSpace() { + return actionSpace; + } + + @Override + public MockEncodable reset() { + currentObsValue = 0; + return new MockEncodable(currentObsValue++); + } + + @Override + public void close() { + + } + + @Override + public StepReply step(Integer action) { + actions.add(action); + return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null); + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public MDP newInstance() { + return null; + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java new file mode 100644 index 000000000..5395242b2 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java @@ -0,0 +1,27 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.rl4j.space.ObservationSpace; +import org.nd4j.linalg.api.ndarray.INDArray; + +public class MockObservationSpace implements ObservationSpace { + + @Override + public String getName() { + return null; + } + + @Override + public int[] getShape() { + return new int[] { 1 }; + } + + @Override + public INDArray getLow() { + return null; + } + + @Override + public INDArray getHigh() { + return null; + } +}