RL4J - Added a unit test to help refac QLearningDiscrete.trainStep() (#8065)
* Added a unit test to help refac QLearningDiscrete.trainStep() Signed-off-by: unknown <aboulang2002@yahoo.com> * Changed expReplay setter to package private Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
b2145ca780
commit
b083c22de5
|
@ -45,8 +45,12 @@ import java.util.List;
|
||||||
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||||
extends SyncLearning<O, A, AS, IDQN> {
|
extends SyncLearning<O, A, AS, IDQN> {
|
||||||
|
|
||||||
|
// FIXME Changed for refac
|
||||||
|
// @Getter
|
||||||
|
// final private IExpReplay<A> expReplay;
|
||||||
@Getter
|
@Getter
|
||||||
final private IExpReplay<A> expReplay;
|
@Setter(AccessLevel.PACKAGE)
|
||||||
|
private IExpReplay<A> expReplay;
|
||||||
|
|
||||||
public QLearning(QLConfiguration conf) {
|
public QLearning(QLConfiguration conf) {
|
||||||
super(conf);
|
super(conf);
|
||||||
|
|
|
@ -1,17 +1,16 @@
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import org.deeplearning4j.gym.StepReply;
|
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockMDP;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -93,7 +92,7 @@ public class AsyncThreadTest {
|
||||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||||
assertEquals(i + 1, entry.getStepCounter());
|
assertEquals(i + 1, entry.getStepCounter());
|
||||||
assertEquals(i, entry.getEpochCounter());
|
assertEquals(i, entry.getEpochCounter());
|
||||||
assertEquals(1.0, entry.getReward(), 0.0);
|
assertEquals(79.0, entry.getReward(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(10, dataManager.isSaveDataCallCount);
|
assertEquals(10, dataManager.isSaveDataCallCount);
|
||||||
|
@ -128,7 +127,7 @@ public class AsyncThreadTest {
|
||||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||||
assertEquals(i + 1, entry.getStepCounter());
|
assertEquals(i + 1, entry.getStepCounter());
|
||||||
assertEquals(i, entry.getEpochCounter());
|
assertEquals(i, entry.getEpochCounter());
|
||||||
assertEquals(1.0, entry.getReward(), 0.0);
|
assertEquals(79.0, entry.getReward(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(1, dataManager.isSaveDataCallCount);
|
assertEquals(1, dataManager.isSaveDataCallCount);
|
||||||
|
@ -308,91 +307,6 @@ public class AsyncThreadTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class MockEncodable implements Encodable {
|
|
||||||
|
|
||||||
private final int value;
|
|
||||||
|
|
||||||
public MockEncodable(int value) {
|
|
||||||
|
|
||||||
this.value = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] toArray() {
|
|
||||||
return new double[] { value };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class MockObservationSpace implements ObservationSpace {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getName() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int[] getShape() {
|
|
||||||
return new int[] { 1 };
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getLow() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getHigh() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
|
||||||
|
|
||||||
private final DiscreteSpace actionSpace;
|
|
||||||
private int currentObsValue = 0;
|
|
||||||
private final ObservationSpace observationSpace;
|
|
||||||
|
|
||||||
public MockMDP(ObservationSpace observationSpace) {
|
|
||||||
actionSpace = new DiscreteSpace(5);
|
|
||||||
this.observationSpace = observationSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ObservationSpace getObservationSpace() {
|
|
||||||
return observationSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DiscreteSpace getActionSpace() {
|
|
||||||
return actionSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MockEncodable reset() {
|
|
||||||
return new MockEncodable(++currentObsValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public StepReply<MockEncodable> step(Integer obs) {
|
|
||||||
return new StepReply<MockEncodable>(new MockEncodable(obs), (double)obs, isDone(), null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isDone() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MDP newInstance() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class MockAsyncConfiguration implements AsyncConfiguration {
|
public static class MockAsyncConfiguration implements AsyncConfiguration {
|
||||||
|
|
||||||
private final int nStep;
|
private final int nStep;
|
||||||
|
|
|
@ -0,0 +1,142 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
public class QLearningDiscreteTest {
|
||||||
|
@Test
|
||||||
|
public void refac_QLearningDiscrete_trainStep() {
|
||||||
|
// Arrange
|
||||||
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
|
MockMDP mdp = new MockMDP(observationSpace);
|
||||||
|
MockDQN dqn = new MockDQN();
|
||||||
|
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
||||||
|
0, 1.0, 0, 0, 0, 0, true);
|
||||||
|
MockDataManager dataManager = new MockDataManager(false);
|
||||||
|
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, 10);
|
||||||
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
|
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||||
|
sut.setHistoryProcessor(hp);
|
||||||
|
MockExpReplay expReplay = new MockExpReplay();
|
||||||
|
sut.setExpReplay(expReplay);
|
||||||
|
MockEncodable obs = new MockEncodable(1);
|
||||||
|
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.initMdp();
|
||||||
|
for(int step = 0; step < 16; ++step) {
|
||||||
|
results.add(sut.trainStep(obs));
|
||||||
|
sut.incrementStep();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
// HistoryProcessor calls
|
||||||
|
assertEquals(24, hp.recordCallCount);
|
||||||
|
assertEquals(13, hp.addCallCount);
|
||||||
|
assertEquals(0, hp.startMonitorCallCount);
|
||||||
|
assertEquals(0, hp.stopMonitorCallCount);
|
||||||
|
|
||||||
|
// DQN calls
|
||||||
|
assertEquals(1, dqn.fitParams.size());
|
||||||
|
assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001);
|
||||||
|
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
|
||||||
|
assertEquals(14, dqn.outputParams.size());
|
||||||
|
double[][] expectedDQNOutput = new double[][] {
|
||||||
|
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 },
|
||||||
|
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
||||||
|
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
||||||
|
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 },
|
||||||
|
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
||||||
|
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
||||||
|
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
||||||
|
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
||||||
|
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||||
|
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||||
|
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||||
|
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||||
|
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||||
|
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||||
|
|
||||||
|
};
|
||||||
|
for(int i = 0; i < expectedDQNOutput.length; ++i) {
|
||||||
|
INDArray outputParam = dqn.outputParams.get(i);
|
||||||
|
|
||||||
|
assertEquals(5, outputParam.shape()[0]);
|
||||||
|
assertEquals(1, outputParam.shape()[1]);
|
||||||
|
|
||||||
|
double[] expectedRow = expectedDQNOutput[i];
|
||||||
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
|
assertEquals(expectedRow[j] / 255.0, outputParam.getDouble(j), 0.00001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MDP calls
|
||||||
|
assertArrayEquals(new Integer[] { 0, 0, 0, 0, 0, 0, 0, 0, 0 ,0, 4, 4, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
|
||||||
|
|
||||||
|
// ExpReplay calls
|
||||||
|
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
|
||||||
|
int[] expectedTrActions = new int[] { 0, 4, 3, 4, 4, 4, 4, 4 };
|
||||||
|
double[] expectedTrNextObservation = new double[] { 0, 0, 0, 1.0, 9.0, 11.0, 13.0, 15.0 };
|
||||||
|
double[][] expectedTrObservations = new double[][] {
|
||||||
|
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 },
|
||||||
|
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
||||||
|
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 },
|
||||||
|
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
||||||
|
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
||||||
|
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||||
|
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||||
|
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||||
|
};
|
||||||
|
for(int i = 0; i < expectedTrRewards.length; ++i) {
|
||||||
|
Transition tr = expReplay.transitions.get(i);
|
||||||
|
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
|
||||||
|
assertEquals(expectedTrActions[i], tr.getAction());
|
||||||
|
assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001);
|
||||||
|
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
||||||
|
assertEquals(expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// trainStep results
|
||||||
|
assertEquals(16, results.size());
|
||||||
|
double[] expectedMaxQ = new double[] { 1.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 };
|
||||||
|
double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
|
||||||
|
for(int i=0; i < 16; ++i) {
|
||||||
|
QLearning.QLStepReturn<MockEncodable> result = results.get(i);
|
||||||
|
if(i % 2 == 0) {
|
||||||
|
assertEquals(expectedMaxQ[i/2] / 255.0, result.getMaxQ(), 0.001);
|
||||||
|
assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
assertTrue(result.getMaxQ().isNaN());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
||||||
|
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn,
|
||||||
|
QLConfiguration conf, IDataManager dataManager, int epsilonNbStep) {
|
||||||
|
super(mdp, dqn, conf, dataManager, epsilonNbStep);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
|
||||||
|
return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,100 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class MockDQN implements IDQN {
|
||||||
|
|
||||||
|
public final List<INDArray> outputParams = new ArrayList<>();
|
||||||
|
public final List<Pair<INDArray, INDArray>> fitParams = new ArrayList<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNetwork[] getNeuralNetworks() {
|
||||||
|
return new NeuralNetwork[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isRecurrent() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(INDArray input, INDArray labels) {
|
||||||
|
fitParams.add(new Pair<>(input, labels));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(INDArray input, INDArray[] labels) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(INDArray batch){
|
||||||
|
outputParams.add(batch);
|
||||||
|
return batch;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
|
return new INDArray[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IDQN clone() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(NeuralNet from) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(IDQN from) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] gradient(INDArray input, INDArray label) {
|
||||||
|
return new Gradient[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] gradient(INDArray input, INDArray[] label) {
|
||||||
|
return new Gradient[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void applyGradient(Gradient[] gradient, int batchSize) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getLatestScore() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(OutputStream os) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(String filename) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,18 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
|
||||||
|
public class MockEncodable implements Encodable {
|
||||||
|
|
||||||
|
private final int value;
|
||||||
|
|
||||||
|
public MockEncodable(int value) {
|
||||||
|
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double[] toArray() {
|
||||||
|
return new double[] { value };
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class MockExpReplay implements IExpReplay<Integer> {
|
||||||
|
|
||||||
|
public List<Transition<Integer>> transitions = new ArrayList<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ArrayList<Transition<Integer>> getBatch() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void store(Transition<Integer> transition) {
|
||||||
|
transitions.add(transition);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,15 +1,24 @@
|
||||||
package org.deeplearning4j.rl4j.support;
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
public class MockHistoryProcessor implements IHistoryProcessor {
|
public class MockHistoryProcessor implements IHistoryProcessor {
|
||||||
|
|
||||||
|
public int recordCallCount = 0;
|
||||||
|
public int addCallCount = 0;
|
||||||
|
public int startMonitorCallCount = 0;
|
||||||
|
public int stopMonitorCallCount = 0;
|
||||||
|
|
||||||
private final Configuration config;
|
private final Configuration config;
|
||||||
|
private final CircularFifoQueue<INDArray> history;
|
||||||
|
|
||||||
public MockHistoryProcessor(Configuration config) {
|
public MockHistoryProcessor(Configuration config) {
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
history = new CircularFifoQueue<>(config.getHistoryLength());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -19,27 +28,32 @@ public class MockHistoryProcessor implements IHistoryProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] getHistory() {
|
public INDArray[] getHistory() {
|
||||||
return new INDArray[0];
|
INDArray[] array = new INDArray[getConf().getHistoryLength()];
|
||||||
|
for (int i = 0; i < config.getHistoryLength(); i++) {
|
||||||
|
array[i] = history.get(i).castTo(Nd4j.dataType());
|
||||||
|
}
|
||||||
|
return array;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void record(INDArray image) {
|
public void record(INDArray image) {
|
||||||
|
++recordCallCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void add(INDArray image) {
|
public void add(INDArray image) {
|
||||||
|
++addCallCount;
|
||||||
|
history.add(image);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void startMonitor(String filename, int[] shape) {
|
public void startMonitor(String filename, int[] shape) {
|
||||||
|
++startMonitorCallCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void stopMonitor() {
|
public void stopMonitor() {
|
||||||
|
++stopMonitorCallCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -49,6 +63,6 @@ public class MockHistoryProcessor implements IHistoryProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double getScale() {
|
public double getScale() {
|
||||||
return 0;
|
return 255.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.gym.StepReply;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||||
|
|
||||||
|
private final DiscreteSpace actionSpace;
|
||||||
|
private int currentObsValue = 0;
|
||||||
|
private final ObservationSpace observationSpace;
|
||||||
|
|
||||||
|
public final List<Integer> actions = new ArrayList<>();
|
||||||
|
|
||||||
|
public MockMDP(ObservationSpace observationSpace) {
|
||||||
|
actionSpace = new DiscreteSpace(5);
|
||||||
|
this.observationSpace = observationSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ObservationSpace getObservationSpace() {
|
||||||
|
return observationSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DiscreteSpace getActionSpace() {
|
||||||
|
return actionSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MockEncodable reset() {
|
||||||
|
currentObsValue = 0;
|
||||||
|
return new MockEncodable(currentObsValue++);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public StepReply<MockEncodable> step(Integer action) {
|
||||||
|
actions.add(action);
|
||||||
|
return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isDone() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MDP newInstance() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
public class MockObservationSpace implements ObservationSpace {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int[] getShape() {
|
||||||
|
return new int[] { 1 };
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getLow() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getHigh() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue