RL4J: Make a few fixes (#8303)

* A few fixes

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Reverted move of ObservationSpace, ActionSpace and others

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Added unit tests

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Changed ActionSpace of gym-java-client to use Nd4j's Random

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2019-10-31 00:41:52 -04:00 committed by Samuel Audet
parent ca881a987a
commit a2b973d41b
25 changed files with 444 additions and 163 deletions

View File

@ -26,12 +26,10 @@ package org.deeplearning4j.rl4j.space;
public interface ActionSpace<A> { public interface ActionSpace<A> {
/** /**
* @return A randomly uniformly sampled action, * @return A random action,
*/ */
A randomAction(); A randomAction();
void setSeed(int seed);
Object encode(A action); Object encode(A action);
int getSize(); int getSize();

View File

@ -17,8 +17,8 @@
package org.deeplearning4j.rl4j.space; package org.deeplearning4j.rl4j.space;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.rng.Random;
import java.util.Random; import org.nd4j.linalg.factory.Nd4j;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
@ -33,19 +33,19 @@ public class DiscreteSpace implements ActionSpace<Integer> {
//size of the space also defined as the number of different actions //size of the space also defined as the number of different actions
@Getter @Getter
final protected int size; final protected int size;
protected Random rd; protected final Random rnd;
public DiscreteSpace(int size) { public DiscreteSpace(int size) {
this(size, Nd4j.getRandom());
}
public DiscreteSpace(int size, Random rnd) {
this.size = size; this.size = size;
rd = new Random(); this.rnd = rnd;
} }
public Integer randomAction() { public Integer randomAction() {
return rd.nextInt(size); return rnd.nextInt(size);
}
public void setSeed(int seed) {
rd = new Random(seed);
} }
public Object encode(Integer a) { public Object encode(Integer a) {

View File

@ -67,8 +67,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
O obs = mdp.reset(); O obs = mdp.reset();
O nextO = obs;
int step = 0; int step = 0;
double reward = 0; double reward = 0;
@ -77,25 +75,30 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
while (step < requiredFrame) {
INDArray input = Learning.getInput(mdp, obs); INDArray input = Learning.getInput(mdp, obs);
if (isHistoryProcessor) if (isHistoryProcessor)
hp.record(input); hp.record(input);
while (step < requiredFrame && !mdp.isDone()) {
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
if (step % skipFrame == 0 && isHistoryProcessor) if (step % skipFrame == 0 && isHistoryProcessor)
hp.add(input); hp.add(input);
StepReply<O> stepReply = mdp.step(action); StepReply<O> stepReply = mdp.step(action);
reward += stepReply.getReward(); reward += stepReply.getReward();
nextO = stepReply.getObservation(); obs = stepReply.getObservation();
input = Learning.getInput(mdp, obs);
if (isHistoryProcessor)
hp.record(input);
step++; step++;
} }
return new InitMdp(step, nextO, reward); return new InitMdp(step, obs, reward);
} }

View File

@ -22,10 +22,11 @@ import lombok.Setter;
import lombok.Value; import lombok.Value;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.*; import org.deeplearning4j.rl4j.learning.*;
import org.deeplearning4j.rl4j.learning.listener.*; import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
@ -118,7 +119,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) { while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
handleTraining(context); handleTraining(context);
if (context.length >= getConf().getMaxEpochStep() || getMdp().isDone()) { if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
canContinue = finishEpoch(context) && startNewEpoch(context); canContinue = finishEpoch(context) && startNewEpoch(context);
if (!canContinue) { if (!canContinue) {
break; break;
@ -135,16 +136,16 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
context.obs = initMdp.getLastObs(); context.obs = initMdp.getLastObs();
context.rewards = initMdp.getReward(); context.rewards = initMdp.getReward();
context.length = initMdp.getSteps(); context.epochElapsedSteps = initMdp.getSteps();
} }
private void handleTraining(RunContext<O> context) { private void handleTraining(RunContext<O> context) {
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.length); int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps);
SubEpochReturn<O> subEpochReturn = trainSubEpoch(context.obs, maxSteps); SubEpochReturn<O> subEpochReturn = trainSubEpoch(context.obs, maxSteps);
context.obs = subEpochReturn.getLastObs(); context.obs = subEpochReturn.getLastObs();
stepCounter += subEpochReturn.getSteps(); stepCounter += subEpochReturn.getSteps();
context.length += subEpochReturn.getSteps(); context.epochElapsedSteps += subEpochReturn.getSteps();
context.rewards += subEpochReturn.getReward(); context.rewards += subEpochReturn.getReward();
context.score = subEpochReturn.getScore(); context.score = subEpochReturn.getScore();
} }
@ -164,7 +165,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
private boolean finishEpoch(RunContext context) { private boolean finishEpoch(RunContext context) {
postEpoch(); postEpoch();
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.length, context.score); IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score);
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards); log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
@ -182,7 +183,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
protected abstract AsyncConfiguration getConf(); protected abstract AsyncConfiguration getConf();
protected abstract Policy<O, A> getPolicy(NN net); protected abstract IPolicy<O, A> getPolicy(NN net);
protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep); protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep);
@ -208,7 +209,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
private static class RunContext<O extends Encodable> { private static class RunContext<O extends Encodable> {
private O obs; private O obs;
private double rewards; private double rewards;
private int length; private int epochElapsedSteps;
private double score; private double score;
} }

View File

@ -25,7 +25,7 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -69,7 +69,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
Stack<MiniTrans<Integer>> rewards = new Stack<>(); Stack<MiniTrans<Integer>> rewards = new Stack<>();
O obs = sObs; O obs = sObs;
Policy<O, Integer> policy = getPolicy(current); IPolicy<O, Integer> policy = getPolicy(current);
Integer action; Integer action;
Integer lastAction = null; Integer lastAction = null;

View File

@ -58,7 +58,6 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
Integer seed = conf.getSeed(); Integer seed = conf.getSeed();
Random rnd = Nd4j.getRandom(); Random rnd = Nd4j.getRandom();
if(seed != null) { if(seed != null) {
mdp.getActionSpace().setSeed(seed);
rnd.setSeed(seed); rnd.setSeed(seed);
} }

View File

@ -63,8 +63,7 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
Integer seed = conf.getSeed(); Integer seed = conf.getSeed();
rnd = Nd4j.getRandom(); rnd = Nd4j.getRandom();
if(seed != null) { if(seed != null) {
mdp.getActionSpace().setSeed(seed + threadNumber); rnd.setSeed(seed + threadNumber);
rnd.setSeed(seed);
} }
} }

View File

@ -27,6 +27,7 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.factory.Nd4j;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -46,10 +47,6 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
this.mdp = mdp; this.mdp = mdp;
this.configuration = conf; this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf); this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
Integer seed = conf.getSeed();
if(seed != null) {
mdp.getActionSpace().setSeed(seed);
}
} }
@Override @Override

View File

@ -61,7 +61,6 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
Integer seed = conf.getSeed(); Integer seed = conf.getSeed();
if(seed != null) { if(seed != null) {
mdp.getActionSpace().setSeed(seed + threadNumber);
rnd.setSeed(seed + threadNumber); rnd.setSeed(seed + threadNumber);
} }
} }

View File

@ -85,7 +85,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
policy = new DQNPolicy(getQNetwork()); policy = new DQNPolicy(getQNetwork());
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(), egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
this); this);
mdp.getActionSpace().setSeed(conf.getSeed());
tdTargetAlgorithm = conf.isDoubleDQN() tdTargetAlgorithm = conf.isDoubleDQN()
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp()) ? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
@ -118,9 +117,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
boolean isHistoryProcessor = getHistoryProcessor() != null; boolean isHistoryProcessor = getHistoryProcessor() != null;
if (isHistoryProcessor)
getHistoryProcessor().record(input);
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1; int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1; int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
int updateStart = getConfiguration().getUpdateStart() int updateStart = getConfiguration().getUpdateStart()
@ -160,12 +156,16 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
StepReply<O> stepReply = getMdp().step(action); StepReply<O> stepReply = getMdp().step(action);
INDArray ninput = getInput(stepReply.getObservation());
if (isHistoryProcessor)
getHistoryProcessor().record(ninput);
accuReward += stepReply.getReward() * configuration.getRewardFactor(); accuReward += stepReply.getReward() * configuration.getRewardFactor();
//if it's not a skipped frame, you can do a step of training //if it's not a skipped frame, you can do a step of training
if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) { if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
INDArray ninput = getInput(stepReply.getObservation());
if (isHistoryProcessor) if (isHistoryProcessor)
getHistoryProcessor().add(ninput); getHistoryProcessor().add(ninput);

View File

@ -20,7 +20,6 @@ import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import static org.nd4j.linalg.ops.transforms.Transforms.exp; import static org.nd4j.linalg.ops.transforms.Transforms.exp;

View File

@ -4,7 +4,9 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
public interface IPolicy<O extends Encodable, A> { public interface IPolicy<O extends Encodable, A> {
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp); <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
A nextAction(INDArray input);
} }

View File

@ -51,6 +51,9 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
@Override @Override
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) { public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
boolean isHistoryProcessor = hp != null;
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
getNeuralNet().reset(); getNeuralNet().reset();
Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp); Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);
O obs = initMdp.getLastObs(); O obs = initMdp.getLastObs();
@ -62,16 +65,9 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
int step = initMdp.getSteps(); int step = initMdp.getSteps();
INDArray[] history = null; INDArray[] history = null;
while (!mdp.isDone()) {
INDArray input = Learning.getInput(mdp, obs); INDArray input = Learning.getInput(mdp, obs);
boolean isHistoryProcessor = hp != null;
if (isHistoryProcessor)
hp.record(input);
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
while (!mdp.isDone()) {
if (step % skipFrame != 0) { if (step % skipFrame != 0) {
action = lastAction; action = lastAction;
@ -102,8 +98,11 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
StepReply<O> stepReply = mdp.step(action); StepReply<O> stepReply = mdp.step(action);
reward += stepReply.getReward(); reward += stepReply.getReward();
if (isHistoryProcessor) input = Learning.getInput(mdp, stepReply.getObservation());
hp.add(Learning.getInput(mdp, stepReply.getObservation())); if (isHistoryProcessor) {
hp.record(input);
hp.add(input);
}
history = isHistoryProcessor ? hp.getHistory() history = isHistoryProcessor ? hp.getHistory()
: new INDArray[] {Learning.getInput(mdp, stepReply.getObservation())}; : new INDArray[] {Learning.getInput(mdp, stepReply.getObservation())};

View File

@ -68,10 +68,10 @@ public class AsyncLearningTest {
public static class TestContext { public static class TestContext {
public final MockAsyncConfiguration conf = new MockAsyncConfiguration(1, 1); MockAsyncConfiguration config = new MockAsyncConfiguration(1, 11, 0, 0, 0, 0,0, 0, 0, 0);
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(); public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockPolicy policy = new MockPolicy(); public final MockPolicy policy = new MockPolicy();
public final TestAsyncLearning sut = new TestAsyncLearning(conf, asyncGlobal, policy); public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
public final MockTrainingListener listener = new MockTrainingListener(); public final MockTrainingListener listener = new MockTrainingListener();
public TestContext() { public TestContext() {

View File

@ -0,0 +1,134 @@
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Stack;
import static org.junit.Assert.assertEquals;
public class AsyncThreadDiscreteTest {
@Test
public void refac_AsyncThreadDiscrete_trainSubEpoch() {
// Arrange
MockNeuralNet nnMock = new MockNeuralNet();
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock);
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace);
TrainingListenerList listeners = new TrainingListenerList();
MockPolicy policyMock = new MockPolicy();
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 0, 5,0, 0, 0, 0);
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
MockEncodable obs = new MockEncodable(123);
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(1)));
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(2)));
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(3)));
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(4)));
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(5)));
// Act
AsyncThread.SubEpochReturn<MockEncodable> result = sut.trainSubEpoch(obs, 2);
// Assert
assertEquals(4, result.getSteps());
assertEquals(6.0, result.getReward(), 0.00001);
assertEquals(0.0, result.getScore(), 0.00001);
assertEquals(3.0, result.getLastObs().toArray()[0], 0.00001);
assertEquals(1, asyncGlobalMock.enqueueCallCount);
// HistoryProcessor
assertEquals(10, hpMock.addCallCount);
double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 };
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
for(int i = 0; i < expectedRecordValues.length; ++i) {
assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001);
}
// Policy
double[][] expectedPolicyInputs = new double[][] {
new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 },
new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 },
new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 },
new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 },
};
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
double[] expectedRow = expectedPolicyInputs[i];
INDArray input = policyMock.actionInputs.get(i);
assertEquals(expectedRow.length, input.shape()[0]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
}
}
// NeuralNetwork
assertEquals(1, nnMock.copyCallCount);
double[][] expectedNNInputs = new double[][] {
new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 },
new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 },
new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 },
new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 },
new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 },
};
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
for(int i = 0; i < expectedNNInputs.length; ++i) {
double[] expectedRow = expectedNNInputs[i];
INDArray input = nnMock.outputAllInputs.get(i);
assertEquals(expectedRow.length, input.shape()[0]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
}
}
}
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
private final IAsyncGlobal<MockNeuralNet> asyncGlobal;
private final MockPolicy policy;
private final MockAsyncConfiguration config;
public TestAsyncThreadDiscrete(IAsyncGlobal<MockNeuralNet> asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
MockAsyncConfiguration config, IHistoryProcessor hp) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.asyncGlobal = asyncGlobal;
this.policy = policy;
this.config = config;
setHistoryProcessor(hp);
}
@Override
public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack<MiniTrans<Integer>> rewards) {
return new Gradient[0];
}
@Override
protected IAsyncGlobal<MockNeuralNet> getAsyncGlobal() {
return asyncGlobal;
}
@Override
protected AsyncConfiguration getConf() {
return config;
}
@Override
protected IPolicy<MockEncodable, Integer> getPolicy(MockNeuralNet net) {
return policy;
}
}
}

View File

@ -1,5 +1,8 @@
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.async;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
@ -9,7 +12,11 @@ import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
public class AsyncThreadTest { public class AsyncThreadTest {
@ -82,7 +89,42 @@ public class AsyncThreadTest {
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i); IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
assertEquals(expectedStepCounter[i], statEntry.getStepCounter()); assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
assertEquals(i, statEntry.getEpochCounter()); assertEquals(i, statEntry.getEpochCounter());
assertEquals(2.0, statEntry.getReward(), 0.0001); assertEquals(38.0, statEntry.getReward(), 0.0001);
}
}
@Test
public void when_run_expect_NeuralNetIsResetAtInitAndEveryEpoch() {
// Arrange
TestContext context = new TestContext();
// Act
context.sut.run();
// Assert
assertEquals(6, context.neuralNet.resetCallCount);
}
@Test
public void when_run_expect_trainSubEpochCalled() {
// Arrange
TestContext context = new TestContext();
// Act
context.sut.run();
// Assert
assertEquals(10, context.sut.trainSubEpochParams.size());
for(int i = 0; i < 10; ++i) {
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
if(i % 2 == 0) {
assertEquals(2, params.nstep);
assertEquals(8.0, params.obs.toArray()[0], 0.00001);
}
else {
assertEquals(1, params.nstep);
assertNull(params.obs);
}
} }
} }
@ -91,14 +133,18 @@ public class AsyncThreadTest {
public final MockNeuralNet neuralNet = new MockNeuralNet(); public final MockNeuralNet neuralNet = new MockNeuralNet();
public final MockObservationSpace observationSpace = new MockObservationSpace(); public final MockObservationSpace observationSpace = new MockObservationSpace();
public final MockMDP mdp = new MockMDP(observationSpace); public final MockMDP mdp = new MockMDP(observationSpace);
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 2); public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0);
public final TrainingListenerList listeners = new TrainingListenerList(); public final TrainingListenerList listeners = new TrainingListenerList();
public final MockTrainingListener listener = new MockTrainingListener(); public final MockTrainingListener listener = new MockTrainingListener();
private final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf);
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners); public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
public TestContext() { public TestContext() {
asyncGlobal.setMaxLoops(10); asyncGlobal.setMaxLoops(10);
listeners.add(listener); listeners.add(listener);
sut.setHistoryProcessor(historyProcessor);
} }
} }
@ -107,11 +153,12 @@ public class AsyncThreadTest {
public int preEpochCallCount = 0; public int preEpochCallCount = 0;
public int postEpochCallCount = 0; public int postEpochCallCount = 0;
private final IAsyncGlobal asyncGlobal; private final IAsyncGlobal asyncGlobal;
private final MockNeuralNet neuralNet; private final MockNeuralNet neuralNet;
private final AsyncConfiguration conf; private final AsyncConfiguration conf;
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) { public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
super(asyncGlobal, mdp, listeners, threadNumber, 0); super(asyncGlobal, mdp, listeners, threadNumber, 0);
@ -154,8 +201,16 @@ public class AsyncThreadTest {
@Override @Override
protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) { protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) {
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
return new SubEpochReturn(1, null, 1.0, 1.0); return new SubEpochReturn(1, null, 1.0, 1.0);
} }
@AllArgsConstructor
@Getter
public static class TrainSubEpochParams {
Encodable obs;
int nstep;
}
} }

View File

@ -26,9 +26,20 @@ public class QLearningDiscreteTest {
public void refac_QLearningDiscrete_trainStep() { public void refac_QLearningDiscrete_trainStep() {
// Arrange // Arrange
MockObservationSpace observationSpace = new MockObservationSpace(); MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockDQN dqn = new MockDQN(); MockDQN dqn = new MockDQN();
MockRandom random = new MockRandom(new double[] { 0.7309677600860596, 0.8314409852027893, 0.2405363917350769, 0.6063451766967773, 0.6374173760414124, 0.3090505599975586, 0.5504369735717773, 0.11700659990310669 }, null); MockRandom random = new MockRandom(new double[] {
0.7309677600860596,
0.8314409852027893,
0.2405363917350769,
0.6063451766967773,
0.6374173760414124,
0.3090505599975586,
0.5504369735717773,
0.11700659990310669
},
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
MockMDP mdp = new MockMDP(observationSpace, random);
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
0, 1.0, 0, 0, 0, 0, true); 0, 1.0, 0, 0, 0, 0, true);
MockDataManager dataManager = new MockDataManager(false); MockDataManager dataManager = new MockDataManager(false);
@ -37,7 +48,7 @@ public class QLearningDiscreteTest {
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hp); sut.setHistoryProcessor(hp);
MockEncodable obs = new MockEncodable(1); MockEncodable obs = new MockEncodable(-100);
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>(); List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
// Act // Act
@ -49,7 +60,11 @@ public class QLearningDiscreteTest {
// Assert // Assert
// HistoryProcessor calls // HistoryProcessor calls
assertEquals(24, hp.recordCallCount); double[] expectedRecords = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0 };
assertEquals(expectedRecords.length, hp.recordCalls.size());
for(int i = 0; i < expectedRecords.length; ++i) {
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
}
assertEquals(13, hp.addCallCount); assertEquals(13, hp.addCallCount);
assertEquals(0, hp.startMonitorCallCount); assertEquals(0, hp.startMonitorCallCount);
assertEquals(0, hp.stopMonitorCallCount); assertEquals(0, hp.stopMonitorCallCount);
@ -60,21 +75,20 @@ public class QLearningDiscreteTest {
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001); assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
assertEquals(14, dqn.outputParams.size()); assertEquals(14, dqn.outputParams.size());
double[][] expectedDQNOutput = new double[][] { double[][] expectedDQNOutput = new double[][] {
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 }, new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 },
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 }, new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 }, new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 }, new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 },
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 }, new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 }, new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 }, new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 }, new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
}; };
for(int i = 0; i < expectedDQNOutput.length; ++i) { for(int i = 0; i < expectedDQNOutput.length; ++i) {
INDArray outputParam = dqn.outputParams.get(i); INDArray outputParam = dqn.outputParams.get(i);
@ -84,23 +98,23 @@ public class QLearningDiscreteTest {
double[] expectedRow = expectedDQNOutput[i]; double[] expectedRow = expectedDQNOutput[i];
for(int j = 0; j < expectedRow.length; ++j) { for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j] / 255.0, outputParam.getDouble(j), 0.00001); assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
} }
} }
// MDP calls // MDP calls
assertArrayEquals(new Integer[] { 0, 0, 0, 0, 0, 0, 0, 0, 0 ,0, 4, 4, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray()); assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
// ExpReplay calls // ExpReplay calls
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 }; double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
int[] expectedTrActions = new int[] { 0, 4, 3, 4, 4, 4, 4, 4 }; int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
double[] expectedTrNextObservation = new double[] { 0, 0, 0, 1.0, 9.0, 11.0, 13.0, 15.0 }; double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, -100.0, 9.0, 11.0, 13.0, 15.0 };
double[][] expectedTrObservations = new double[][] { double[][] expectedTrObservations = new double[][] {
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 }, new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 },
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 }, new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 }, new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 },
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 }, new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 }, new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
@ -111,18 +125,18 @@ public class QLearningDiscreteTest {
assertEquals(expectedTrActions[i], tr.getAction()); assertEquals(expectedTrActions[i], tr.getAction());
assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001); assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001);
for(int j = 0; j < expectedTrObservations[i].length; ++j) { for(int j = 0; j < expectedTrObservations[i].length; ++j) {
assertEquals(expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001); assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001);
} }
} }
// trainStep results // trainStep results
assertEquals(16, results.size()); assertEquals(16, results.size());
double[] expectedMaxQ = new double[] { 1.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 }; double[] expectedMaxQ = new double[] { 6.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 };
double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 }; double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
for(int i=0; i < 16; ++i) { for(int i=0; i < 16; ++i) {
QLearning.QLStepReturn<MockEncodable> result = results.get(i); QLearning.QLStepReturn<MockEncodable> result = results.get(i);
if(i % 2 == 0) { if(i % 2 == 0) {
assertEquals(expectedMaxQ[i/2] / 255.0, result.getMaxQ(), 0.001); assertEquals(expectedMaxQ[i/2], 255.0 * result.getMaxQ(), 0.001);
assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001); assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001);
} }
else { else {

View File

@ -22,7 +22,15 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.support.*;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -31,6 +39,8 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
@ -153,4 +163,75 @@ public class PolicyTest {
assertTrue(count[2] < 40); assertTrue(count[2] < 40);
assertTrue(count[3] < 50); assertTrue(count[3] < 50);
} }
@Test
public void refacPolicyPlay() {
// Arrange
MockObservationSpace observationSpace = new MockObservationSpace();
MockDQN dqn = new MockDQN();
MockRandom random = new MockRandom(new double[] {
0.7309677600860596,
0.8314409852027893,
0.2405363917350769,
0.6063451766967773,
0.6374173760414124,
0.3090505599975586,
0.5504369735717773,
0.11700659990310669
},
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
MockMDP mdp = new MockMDP(observationSpace, 30, random);
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
0, 1.0, 0, 0, 0, 0, true);
MockNeuralNet nnMock = new MockNeuralNet();
MockRefacPolicy sut = new MockRefacPolicy(nnMock);
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
// Act
double totalReward = sut.play(mdp, hp);
// Assert
assertEquals(1, nnMock.resetCallCount);
assertEquals(465.0, totalReward, 0.0001);
// HistoryProcessor
assertEquals(27, hp.addCallCount);
assertEquals(31, hp.recordCalls.size());
for(int i=0; i <= 30; ++i) {
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
}
// MDP
assertEquals(1, mdp.resetCount);
assertEquals(30, mdp.actions.size());
for(int i = 0; i < mdp.actions.size(); ++i) {
assertEquals(0, (int)mdp.actions.get(i));
}
// DQN
assertEquals(0, dqn.fitParams.size());
assertEquals(0, dqn.outputParams.size());
}
public static class MockRefacPolicy extends Policy<MockEncodable, Integer> {
private NeuralNet neuralNet;
public MockRefacPolicy(NeuralNet neuralNet) {
this.neuralNet = neuralNet;
}
@Override
public NeuralNet getNeuralNet() {
return neuralNet;
}
@Override
public Integer nextAction(INDArray input) {
return (int)input.getDouble(0);
}
}
} }

View File

@ -1,65 +1,22 @@
package org.deeplearning4j.rl4j.support; package org.deeplearning4j.rl4j.support;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Value;
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
@AllArgsConstructor
@Value
public class MockAsyncConfiguration implements AsyncConfiguration { public class MockAsyncConfiguration implements AsyncConfiguration {
private final int nStep; private Integer seed;
private final int maxEpochStep; private int maxEpochStep;
private int maxStep;
public MockAsyncConfiguration(int nStep, int maxEpochStep) { private int numThread;
this.nStep = nStep; private int nstep;
private int targetDqnUpdateFreq;
this.maxEpochStep = maxEpochStep; private int updateStart;
} private double rewardFactor;
private double gamma;
@Override private double errorClamp;
public Integer getSeed() {
return 0;
}
@Override
public int getMaxEpochStep() {
return maxEpochStep;
}
@Override
public int getMaxStep() {
return 0;
}
@Override
public int getNumThread() {
return 0;
}
@Override
public int getNstep() {
return nStep;
}
@Override
public int getTargetDqnUpdateFreq() {
return 0;
}
@Override
public int getUpdateStart() {
return 0;
}
@Override
public double getRewardFactor() {
return 0;
}
@Override
public double getGamma() {
return 0;
}
@Override
public double getErrorClamp() {
return 0;
}
} }

View File

@ -1,5 +1,6 @@
package org.deeplearning4j.rl4j.support; package org.deeplearning4j.rl4j.support;
import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
@ -9,9 +10,13 @@ import java.util.concurrent.atomic.AtomicInteger;
public class MockAsyncGlobal implements IAsyncGlobal { public class MockAsyncGlobal implements IAsyncGlobal {
private final NeuralNet current;
public boolean hasBeenStarted = false; public boolean hasBeenStarted = false;
public boolean hasBeenTerminated = false; public boolean hasBeenTerminated = false;
public int enqueueCallCount = 0;
@Setter @Setter
private int maxLoops; private int maxLoops;
@Setter @Setter
@ -19,8 +24,13 @@ public class MockAsyncGlobal implements IAsyncGlobal {
private int currentLoop = 0; private int currentLoop = 0;
public MockAsyncGlobal() { public MockAsyncGlobal() {
this(null);
}
public MockAsyncGlobal(NeuralNet current) {
maxLoops = Integer.MAX_VALUE; maxLoops = Integer.MAX_VALUE;
numLoopsStopRunning = Integer.MAX_VALUE; numLoopsStopRunning = Integer.MAX_VALUE;
this.current = current;
} }
@Override @Override
@ -50,16 +60,16 @@ public class MockAsyncGlobal implements IAsyncGlobal {
@Override @Override
public NeuralNet getCurrent() { public NeuralNet getCurrent() {
return null; return current;
} }
@Override @Override
public NeuralNet getTarget() { public NeuralNet getTarget() {
return null; return current;
} }
@Override @Override
public void enqueue(Gradient[] gradient, Integer nstep) { public void enqueue(Gradient[] gradient, Integer nstep) {
++enqueueCallCount;
} }
} }

View File

@ -5,9 +5,10 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
public class MockHistoryProcessor implements IHistoryProcessor { public class MockHistoryProcessor implements IHistoryProcessor {
public int recordCallCount = 0;
public int addCallCount = 0; public int addCallCount = 0;
public int startMonitorCallCount = 0; public int startMonitorCallCount = 0;
public int stopMonitorCallCount = 0; public int stopMonitorCallCount = 0;
@ -15,10 +16,13 @@ public class MockHistoryProcessor implements IHistoryProcessor {
private final Configuration config; private final Configuration config;
private final CircularFifoQueue<INDArray> history; private final CircularFifoQueue<INDArray> history;
public final ArrayList<INDArray> recordCalls;
public MockHistoryProcessor(Configuration config) { public MockHistoryProcessor(Configuration config) {
this.config = config; this.config = config;
history = new CircularFifoQueue<>(config.getHistoryLength()); history = new CircularFifoQueue<>(config.getHistoryLength());
recordCalls = new ArrayList<INDArray>();
} }
@Override @Override
@ -37,7 +41,7 @@ public class MockHistoryProcessor implements IHistoryProcessor {
@Override @Override
public void record(INDArray image) { public void record(INDArray image) {
++recordCallCount; recordCalls.add(image);
} }
@Override @Override

View File

@ -2,8 +2,10 @@ package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.rng.Random;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -11,14 +13,30 @@ import java.util.List;
public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> { public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
private final DiscreteSpace actionSpace; private final DiscreteSpace actionSpace;
private final int stepsUntilDone;
private int currentObsValue = 0; private int currentObsValue = 0;
private final ObservationSpace observationSpace; private final ObservationSpace observationSpace;
public final List<Integer> actions = new ArrayList<>(); public final List<Integer> actions = new ArrayList<>();
private int step = 0;
public int resetCount = 0;
public MockMDP(ObservationSpace observationSpace, int stepsUntilDone, DiscreteSpace actionSpace) {
this.stepsUntilDone = stepsUntilDone;
this.actionSpace = actionSpace;
this.observationSpace = observationSpace;
}
public MockMDP(ObservationSpace observationSpace, int stepsUntilDone, Random rnd) {
this(observationSpace, stepsUntilDone, new DiscreteSpace(5, rnd));
}
public MockMDP(ObservationSpace observationSpace) { public MockMDP(ObservationSpace observationSpace) {
actionSpace = new DiscreteSpace(5); this(observationSpace, Integer.MAX_VALUE, new DiscreteSpace(5));
this.observationSpace = observationSpace; }
public MockMDP(ObservationSpace observationSpace, Random rnd) {
this(observationSpace, Integer.MAX_VALUE, new DiscreteSpace(5, rnd));
} }
@Override @Override
@ -33,7 +51,9 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
@Override @Override
public MockEncodable reset() { public MockEncodable reset() {
++resetCount;
currentObsValue = 0; currentObsValue = 0;
step = 0;
return new MockEncodable(currentObsValue++); return new MockEncodable(currentObsValue++);
} }
@ -45,12 +65,13 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
@Override @Override
public StepReply<MockEncodable> step(Integer action) { public StepReply<MockEncodable> step(Integer action) {
actions.add(action); actions.add(action);
++step;
return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null); return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null);
} }
@Override @Override
public boolean isDone() { public boolean isDone() {
return false; return step >= stepsUntilDone;
} }
@Override @Override

View File

@ -4,13 +4,18 @@ import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
public class MockNeuralNet implements NeuralNet { public class MockNeuralNet implements NeuralNet {
public int resetCallCount = 0; public int resetCallCount = 0;
public int copyCallCount = 0;
public List<INDArray> outputAllInputs = new ArrayList<INDArray>();
@Override @Override
public NeuralNetwork[] getNeuralNetworks() { public NeuralNetwork[] getNeuralNetworks() {
@ -29,17 +34,18 @@ public class MockNeuralNet implements NeuralNet {
@Override @Override
public INDArray[] outputAll(INDArray batch) { public INDArray[] outputAll(INDArray batch) {
return new INDArray[0]; outputAllInputs.add(batch);
return new INDArray[] { Nd4j.create(new double[] { 1.0 }) };
} }
@Override @Override
public NeuralNet clone() { public NeuralNet clone() {
return null; return this;
} }
@Override @Override
public void copy(NeuralNet from) { public void copy(NeuralNet from) {
++copyCallCount;
} }
@Override @Override

View File

@ -4,14 +4,25 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.List;
public class MockPolicy implements IPolicy<MockEncodable, Integer> { public class MockPolicy implements IPolicy<MockEncodable, Integer> {
public int playCallCount = 0; public int playCallCount = 0;
public List<INDArray> actionInputs = new ArrayList<INDArray>();
@Override @Override
public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) { public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) {
++playCallCount; ++playCallCount;
return 0; return 0;
} }
@Override
public Integer nextAction(INDArray input) {
actionInputs.add(input);
return null;
}
} }

View File

@ -45,12 +45,4 @@ public abstract class MalmoActionSpace extends DiscreteSpace {
public Integer noOp() { public Integer noOp() {
return -1; return -1;
} }
/**
* Sets the seed used for random generation of actions
* @param seed random number generator seed
*/
public void setRandomSeed(long seed) {
rd.setSeed(seed);
}
} }