RL4J: Make a few fixes (#8303)
* A few fixes Signed-off-by: unknown <aboulang2002@yahoo.com> * Reverted move of ObservationSpace, ActionSpace and others Signed-off-by: unknown <aboulang2002@yahoo.com> * Added unit tests Signed-off-by: unknown <aboulang2002@yahoo.com> * Changed ActionSpace of gym-java-client to use Nd4j's Random Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
ca881a987a
commit
a2b973d41b
|
@ -26,12 +26,10 @@ package org.deeplearning4j.rl4j.space;
|
|||
public interface ActionSpace<A> {
|
||||
|
||||
/**
|
||||
* @return A randomly uniformly sampled action,
|
||||
* @return A random action,
|
||||
*/
|
||||
A randomAction();
|
||||
|
||||
void setSeed(int seed);
|
||||
|
||||
Object encode(A action);
|
||||
|
||||
int getSize();
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
package org.deeplearning4j.rl4j.space;
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.Random;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
|
||||
|
@ -33,19 +33,19 @@ public class DiscreteSpace implements ActionSpace<Integer> {
|
|||
//size of the space also defined as the number of different actions
|
||||
@Getter
|
||||
final protected int size;
|
||||
protected Random rd;
|
||||
protected final Random rnd;
|
||||
|
||||
public DiscreteSpace(int size) {
|
||||
this(size, Nd4j.getRandom());
|
||||
}
|
||||
|
||||
public DiscreteSpace(int size, Random rnd) {
|
||||
this.size = size;
|
||||
rd = new Random();
|
||||
this.rnd = rnd;
|
||||
}
|
||||
|
||||
public Integer randomAction() {
|
||||
return rd.nextInt(size);
|
||||
}
|
||||
|
||||
public void setSeed(int seed) {
|
||||
rd = new Random(seed);
|
||||
return rnd.nextInt(size);
|
||||
}
|
||||
|
||||
public Object encode(Integer a) {
|
||||
|
|
|
@ -67,8 +67,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
|||
|
||||
O obs = mdp.reset();
|
||||
|
||||
O nextO = obs;
|
||||
|
||||
int step = 0;
|
||||
double reward = 0;
|
||||
|
||||
|
@ -77,25 +75,30 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
|||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
||||
|
||||
while (step < requiredFrame) {
|
||||
INDArray input = Learning.getInput(mdp, obs);
|
||||
|
||||
if (isHistoryProcessor)
|
||||
hp.record(input);
|
||||
|
||||
|
||||
while (step < requiredFrame && !mdp.isDone()) {
|
||||
|
||||
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||
if (step % skipFrame == 0 && isHistoryProcessor)
|
||||
hp.add(input);
|
||||
|
||||
StepReply<O> stepReply = mdp.step(action);
|
||||
reward += stepReply.getReward();
|
||||
nextO = stepReply.getObservation();
|
||||
obs = stepReply.getObservation();
|
||||
|
||||
input = Learning.getInput(mdp, obs);
|
||||
if (isHistoryProcessor)
|
||||
hp.record(input);
|
||||
|
||||
step++;
|
||||
|
||||
}
|
||||
|
||||
return new InitMdp(step, nextO, reward);
|
||||
return new InitMdp(step, obs, reward);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -22,10 +22,11 @@ import lombok.Setter;
|
|||
import lombok.Value;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.learning.*;
|
||||
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
@ -118,7 +119,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
|
||||
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
||||
handleTraining(context);
|
||||
if (context.length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||
if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||
canContinue = finishEpoch(context) && startNewEpoch(context);
|
||||
if (!canContinue) {
|
||||
break;
|
||||
|
@ -135,16 +136,16 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
|
||||
context.obs = initMdp.getLastObs();
|
||||
context.rewards = initMdp.getReward();
|
||||
context.length = initMdp.getSteps();
|
||||
context.epochElapsedSteps = initMdp.getSteps();
|
||||
}
|
||||
|
||||
private void handleTraining(RunContext<O> context) {
|
||||
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.length);
|
||||
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps);
|
||||
SubEpochReturn<O> subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
||||
|
||||
context.obs = subEpochReturn.getLastObs();
|
||||
stepCounter += subEpochReturn.getSteps();
|
||||
context.length += subEpochReturn.getSteps();
|
||||
context.epochElapsedSteps += subEpochReturn.getSteps();
|
||||
context.rewards += subEpochReturn.getReward();
|
||||
context.score = subEpochReturn.getScore();
|
||||
}
|
||||
|
@ -164,7 +165,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
|
||||
private boolean finishEpoch(RunContext context) {
|
||||
postEpoch();
|
||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.length, context.score);
|
||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score);
|
||||
|
||||
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
|
||||
|
||||
|
@ -182,7 +183,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
|
||||
protected abstract AsyncConfiguration getConf();
|
||||
|
||||
protected abstract Policy<O, A> getPolicy(NN net);
|
||||
protected abstract IPolicy<O, A> getPolicy(NN net);
|
||||
|
||||
protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep);
|
||||
|
||||
|
@ -208,7 +209,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
private static class RunContext<O extends Encodable> {
|
||||
private O obs;
|
||||
private double rewards;
|
||||
private int length;
|
||||
private int epochElapsedSteps;
|
||||
private double score;
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
|||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -69,7 +69,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
|||
Stack<MiniTrans<Integer>> rewards = new Stack<>();
|
||||
|
||||
O obs = sObs;
|
||||
Policy<O, Integer> policy = getPolicy(current);
|
||||
IPolicy<O, Integer> policy = getPolicy(current);
|
||||
|
||||
Integer action;
|
||||
Integer lastAction = null;
|
||||
|
|
|
@ -58,7 +58,6 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
|||
Integer seed = conf.getSeed();
|
||||
Random rnd = Nd4j.getRandom();
|
||||
if(seed != null) {
|
||||
mdp.getActionSpace().setSeed(seed);
|
||||
rnd.setSeed(seed);
|
||||
}
|
||||
|
||||
|
|
|
@ -63,8 +63,7 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
|||
Integer seed = conf.getSeed();
|
||||
rnd = Nd4j.getRandom();
|
||||
if(seed != null) {
|
||||
mdp.getActionSpace().setSeed(seed + threadNumber);
|
||||
rnd.setSeed(seed);
|
||||
rnd.setSeed(seed + threadNumber);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
|||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
|
@ -46,10 +47,6 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
|||
this.mdp = mdp;
|
||||
this.configuration = conf;
|
||||
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
||||
Integer seed = conf.getSeed();
|
||||
if(seed != null) {
|
||||
mdp.getActionSpace().setSeed(seed);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -61,7 +61,6 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
|||
|
||||
Integer seed = conf.getSeed();
|
||||
if(seed != null) {
|
||||
mdp.getActionSpace().setSeed(seed + threadNumber);
|
||||
rnd.setSeed(seed + threadNumber);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,7 +85,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
policy = new DQNPolicy(getQNetwork());
|
||||
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
|
||||
this);
|
||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||
|
||||
tdTargetAlgorithm = conf.isDoubleDQN()
|
||||
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
|
||||
|
@ -118,9 +117,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
||||
|
||||
|
||||
if (isHistoryProcessor)
|
||||
getHistoryProcessor().record(input);
|
||||
|
||||
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
||||
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
||||
int updateStart = getConfiguration().getUpdateStart()
|
||||
|
@ -160,12 +156,16 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
|
||||
StepReply<O> stepReply = getMdp().step(action);
|
||||
|
||||
INDArray ninput = getInput(stepReply.getObservation());
|
||||
|
||||
if (isHistoryProcessor)
|
||||
getHistoryProcessor().record(ninput);
|
||||
|
||||
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
||||
|
||||
//if it's not a skipped frame, you can do a step of training
|
||||
if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
|
||||
|
||||
INDArray ninput = getInput(stepReply.getObservation());
|
||||
if (isHistoryProcessor)
|
||||
getHistoryProcessor().add(ninput);
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
|||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.nd4j.linalg.ops.transforms.Transforms.exp;
|
||||
|
||||
|
|
|
@ -4,7 +4,9 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
|||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
public interface IPolicy<O extends Encodable, A> {
|
||||
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
||||
A nextAction(INDArray input);
|
||||
}
|
||||
|
|
|
@ -51,6 +51,9 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
|||
|
||||
@Override
|
||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
||||
boolean isHistoryProcessor = hp != null;
|
||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||
|
||||
getNeuralNet().reset();
|
||||
Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);
|
||||
O obs = initMdp.getLastObs();
|
||||
|
@ -62,16 +65,9 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
|||
int step = initMdp.getSteps();
|
||||
INDArray[] history = null;
|
||||
|
||||
while (!mdp.isDone()) {
|
||||
|
||||
INDArray input = Learning.getInput(mdp, obs);
|
||||
boolean isHistoryProcessor = hp != null;
|
||||
|
||||
if (isHistoryProcessor)
|
||||
hp.record(input);
|
||||
|
||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||
|
||||
while (!mdp.isDone()) {
|
||||
|
||||
if (step % skipFrame != 0) {
|
||||
action = lastAction;
|
||||
|
@ -102,8 +98,11 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
|||
StepReply<O> stepReply = mdp.step(action);
|
||||
reward += stepReply.getReward();
|
||||
|
||||
if (isHistoryProcessor)
|
||||
hp.add(Learning.getInput(mdp, stepReply.getObservation()));
|
||||
input = Learning.getInput(mdp, stepReply.getObservation());
|
||||
if (isHistoryProcessor) {
|
||||
hp.record(input);
|
||||
hp.add(input);
|
||||
}
|
||||
|
||||
history = isHistoryProcessor ? hp.getHistory()
|
||||
: new INDArray[] {Learning.getInput(mdp, stepReply.getObservation())};
|
||||
|
|
|
@ -68,10 +68,10 @@ public class AsyncLearningTest {
|
|||
|
||||
|
||||
public static class TestContext {
|
||||
public final MockAsyncConfiguration conf = new MockAsyncConfiguration(1, 1);
|
||||
MockAsyncConfiguration config = new MockAsyncConfiguration(1, 11, 0, 0, 0, 0,0, 0, 0, 0);
|
||||
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
||||
public final MockPolicy policy = new MockPolicy();
|
||||
public final TestAsyncLearning sut = new TestAsyncLearning(conf, asyncGlobal, policy);
|
||||
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
|
||||
public final MockTrainingListener listener = new MockTrainingListener();
|
||||
|
||||
public TestContext() {
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.support.*;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Stack;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class AsyncThreadDiscreteTest {
|
||||
|
||||
@Test
|
||||
public void refac_AsyncThreadDiscrete_trainSubEpoch() {
|
||||
// Arrange
|
||||
MockNeuralNet nnMock = new MockNeuralNet();
|
||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock);
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||
TrainingListenerList listeners = new TrainingListenerList();
|
||||
MockPolicy policyMock = new MockPolicy();
|
||||
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 0, 5,0, 0, 0, 0);
|
||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
|
||||
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
||||
MockEncodable obs = new MockEncodable(123);
|
||||
|
||||
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(1)));
|
||||
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(2)));
|
||||
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(3)));
|
||||
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(4)));
|
||||
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(5)));
|
||||
|
||||
// Act
|
||||
AsyncThread.SubEpochReturn<MockEncodable> result = sut.trainSubEpoch(obs, 2);
|
||||
|
||||
// Assert
|
||||
assertEquals(4, result.getSteps());
|
||||
assertEquals(6.0, result.getReward(), 0.00001);
|
||||
assertEquals(0.0, result.getScore(), 0.00001);
|
||||
assertEquals(3.0, result.getLastObs().toArray()[0], 0.00001);
|
||||
assertEquals(1, asyncGlobalMock.enqueueCallCount);
|
||||
|
||||
// HistoryProcessor
|
||||
assertEquals(10, hpMock.addCallCount);
|
||||
double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 };
|
||||
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
|
||||
for(int i = 0; i < expectedRecordValues.length; ++i) {
|
||||
assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001);
|
||||
}
|
||||
|
||||
// Policy
|
||||
double[][] expectedPolicyInputs = new double[][] {
|
||||
new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 },
|
||||
new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 },
|
||||
new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 },
|
||||
new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 },
|
||||
};
|
||||
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
|
||||
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
|
||||
double[] expectedRow = expectedPolicyInputs[i];
|
||||
INDArray input = policyMock.actionInputs.get(i);
|
||||
assertEquals(expectedRow.length, input.shape()[0]);
|
||||
for(int j = 0; j < expectedRow.length; ++j) {
|
||||
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
// NeuralNetwork
|
||||
assertEquals(1, nnMock.copyCallCount);
|
||||
double[][] expectedNNInputs = new double[][] {
|
||||
new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 },
|
||||
new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 },
|
||||
new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 },
|
||||
new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 },
|
||||
new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 },
|
||||
};
|
||||
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
|
||||
for(int i = 0; i < expectedNNInputs.length; ++i) {
|
||||
double[] expectedRow = expectedNNInputs[i];
|
||||
INDArray input = nnMock.outputAllInputs.get(i);
|
||||
assertEquals(expectedRow.length, input.shape()[0]);
|
||||
for(int j = 0; j < expectedRow.length; ++j) {
|
||||
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
|
||||
|
||||
private final IAsyncGlobal<MockNeuralNet> asyncGlobal;
|
||||
private final MockPolicy policy;
|
||||
private final MockAsyncConfiguration config;
|
||||
|
||||
public TestAsyncThreadDiscrete(IAsyncGlobal<MockNeuralNet> asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
|
||||
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
|
||||
MockAsyncConfiguration config, IHistoryProcessor hp) {
|
||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.policy = policy;
|
||||
this.config = config;
|
||||
setHistoryProcessor(hp);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack<MiniTrans<Integer>> rewards) {
|
||||
return new Gradient[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IAsyncGlobal<MockNeuralNet> getAsyncGlobal() {
|
||||
return asyncGlobal;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AsyncConfiguration getConf() {
|
||||
return config;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IPolicy<MockEncodable, Integer> getPolicy(MockNeuralNet net) {
|
||||
return policy;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,8 @@
|
|||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
|
@ -9,7 +12,11 @@ import org.deeplearning4j.rl4j.support.*;
|
|||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNull;
|
||||
|
||||
public class AsyncThreadTest {
|
||||
|
||||
|
@ -82,7 +89,42 @@ public class AsyncThreadTest {
|
|||
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
|
||||
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
|
||||
assertEquals(i, statEntry.getEpochCounter());
|
||||
assertEquals(2.0, statEntry.getReward(), 0.0001);
|
||||
assertEquals(38.0, statEntry.getReward(), 0.0001);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_run_expect_NeuralNetIsResetAtInitAndEveryEpoch() {
|
||||
// Arrange
|
||||
TestContext context = new TestContext();
|
||||
|
||||
// Act
|
||||
context.sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(6, context.neuralNet.resetCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_run_expect_trainSubEpochCalled() {
|
||||
// Arrange
|
||||
TestContext context = new TestContext();
|
||||
|
||||
// Act
|
||||
context.sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(10, context.sut.trainSubEpochParams.size());
|
||||
for(int i = 0; i < 10; ++i) {
|
||||
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
|
||||
if(i % 2 == 0) {
|
||||
assertEquals(2, params.nstep);
|
||||
assertEquals(8.0, params.obs.toArray()[0], 0.00001);
|
||||
}
|
||||
else {
|
||||
assertEquals(1, params.nstep);
|
||||
assertNull(params.obs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,14 +133,18 @@ public class AsyncThreadTest {
|
|||
public final MockNeuralNet neuralNet = new MockNeuralNet();
|
||||
public final MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
public final MockMDP mdp = new MockMDP(observationSpace);
|
||||
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 2);
|
||||
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0);
|
||||
public final TrainingListenerList listeners = new TrainingListenerList();
|
||||
public final MockTrainingListener listener = new MockTrainingListener();
|
||||
private final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf);
|
||||
|
||||
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
|
||||
|
||||
public TestContext() {
|
||||
asyncGlobal.setMaxLoops(10);
|
||||
listeners.add(listener);
|
||||
sut.setHistoryProcessor(historyProcessor);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -107,11 +153,12 @@ public class AsyncThreadTest {
|
|||
public int preEpochCallCount = 0;
|
||||
public int postEpochCallCount = 0;
|
||||
|
||||
|
||||
private final IAsyncGlobal asyncGlobal;
|
||||
private final MockNeuralNet neuralNet;
|
||||
private final AsyncConfiguration conf;
|
||||
|
||||
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
|
||||
|
||||
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
|
||||
super(asyncGlobal, mdp, listeners, threadNumber, 0);
|
||||
|
||||
|
@ -154,8 +201,16 @@ public class AsyncThreadTest {
|
|||
|
||||
@Override
|
||||
protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) {
|
||||
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
|
||||
return new SubEpochReturn(1, null, 1.0, 1.0);
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
@Getter
|
||||
public static class TrainSubEpochParams {
|
||||
Encodable obs;
|
||||
int nstep;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -26,9 +26,20 @@ public class QLearningDiscreteTest {
|
|||
public void refac_QLearningDiscrete_trainStep() {
|
||||
// Arrange
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdp = new MockMDP(observationSpace);
|
||||
MockDQN dqn = new MockDQN();
|
||||
MockRandom random = new MockRandom(new double[] { 0.7309677600860596, 0.8314409852027893, 0.2405363917350769, 0.6063451766967773, 0.6374173760414124, 0.3090505599975586, 0.5504369735717773, 0.11700659990310669 }, null);
|
||||
MockRandom random = new MockRandom(new double[] {
|
||||
0.7309677600860596,
|
||||
0.8314409852027893,
|
||||
0.2405363917350769,
|
||||
0.6063451766967773,
|
||||
0.6374173760414124,
|
||||
0.3090505599975586,
|
||||
0.5504369735717773,
|
||||
0.11700659990310669
|
||||
},
|
||||
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
||||
MockMDP mdp = new MockMDP(observationSpace, random);
|
||||
|
||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
||||
0, 1.0, 0, 0, 0, 0, true);
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
|
@ -37,7 +48,7 @@ public class QLearningDiscreteTest {
|
|||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||
sut.setHistoryProcessor(hp);
|
||||
MockEncodable obs = new MockEncodable(1);
|
||||
MockEncodable obs = new MockEncodable(-100);
|
||||
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
||||
|
||||
// Act
|
||||
|
@ -49,7 +60,11 @@ public class QLearningDiscreteTest {
|
|||
|
||||
// Assert
|
||||
// HistoryProcessor calls
|
||||
assertEquals(24, hp.recordCallCount);
|
||||
double[] expectedRecords = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0 };
|
||||
assertEquals(expectedRecords.length, hp.recordCalls.size());
|
||||
for(int i = 0; i < expectedRecords.length; ++i) {
|
||||
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||
}
|
||||
assertEquals(13, hp.addCallCount);
|
||||
assertEquals(0, hp.startMonitorCallCount);
|
||||
assertEquals(0, hp.stopMonitorCallCount);
|
||||
|
@ -60,21 +75,20 @@ public class QLearningDiscreteTest {
|
|||
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
|
||||
assertEquals(14, dqn.outputParams.size());
|
||||
double[][] expectedDQNOutput = new double[][] {
|
||||
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 },
|
||||
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
||||
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
||||
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 },
|
||||
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
|
||||
new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 },
|
||||
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||
|
||||
};
|
||||
for(int i = 0; i < expectedDQNOutput.length; ++i) {
|
||||
INDArray outputParam = dqn.outputParams.get(i);
|
||||
|
@ -84,23 +98,23 @@ public class QLearningDiscreteTest {
|
|||
|
||||
double[] expectedRow = expectedDQNOutput[i];
|
||||
for(int j = 0; j < expectedRow.length; ++j) {
|
||||
assertEquals(expectedRow[j] / 255.0, outputParam.getDouble(j), 0.00001);
|
||||
assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
// MDP calls
|
||||
assertArrayEquals(new Integer[] { 0, 0, 0, 0, 0, 0, 0, 0, 0 ,0, 4, 4, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
|
||||
assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
|
||||
|
||||
// ExpReplay calls
|
||||
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
|
||||
int[] expectedTrActions = new int[] { 0, 4, 3, 4, 4, 4, 4, 4 };
|
||||
double[] expectedTrNextObservation = new double[] { 0, 0, 0, 1.0, 9.0, 11.0, 13.0, 15.0 };
|
||||
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
||||
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, -100.0, 9.0, 11.0, 13.0, 15.0 };
|
||||
double[][] expectedTrObservations = new double[][] {
|
||||
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 },
|
||||
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
||||
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 },
|
||||
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
|
||||
new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 },
|
||||
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||
|
@ -111,18 +125,18 @@ public class QLearningDiscreteTest {
|
|||
assertEquals(expectedTrActions[i], tr.getAction());
|
||||
assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001);
|
||||
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
||||
assertEquals(expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001);
|
||||
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001);
|
||||
}
|
||||
}
|
||||
|
||||
// trainStep results
|
||||
assertEquals(16, results.size());
|
||||
double[] expectedMaxQ = new double[] { 1.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 };
|
||||
double[] expectedMaxQ = new double[] { 6.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 };
|
||||
double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
|
||||
for(int i=0; i < 16; ++i) {
|
||||
QLearning.QLStepReturn<MockEncodable> result = results.get(i);
|
||||
if(i % 2 == 0) {
|
||||
assertEquals(expectedMaxQ[i/2] / 255.0, result.getMaxQ(), 0.001);
|
||||
assertEquals(expectedMaxQ[i/2], 255.0 * result.getMaxQ(), 0.001);
|
||||
assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001);
|
||||
}
|
||||
else {
|
||||
|
|
|
@ -22,7 +22,15 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.support.*;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -31,6 +39,8 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
@ -153,4 +163,75 @@ public class PolicyTest {
|
|||
assertTrue(count[2] < 40);
|
||||
assertTrue(count[3] < 50);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void refacPolicyPlay() {
|
||||
// Arrange
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockDQN dqn = new MockDQN();
|
||||
MockRandom random = new MockRandom(new double[] {
|
||||
0.7309677600860596,
|
||||
0.8314409852027893,
|
||||
0.2405363917350769,
|
||||
0.6063451766967773,
|
||||
0.6374173760414124,
|
||||
0.3090505599975586,
|
||||
0.5504369735717773,
|
||||
0.11700659990310669
|
||||
},
|
||||
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
||||
MockMDP mdp = new MockMDP(observationSpace, 30, random);
|
||||
|
||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
||||
0, 1.0, 0, 0, 0, 0, true);
|
||||
MockNeuralNet nnMock = new MockNeuralNet();
|
||||
MockRefacPolicy sut = new MockRefacPolicy(nnMock);
|
||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||
|
||||
// Act
|
||||
double totalReward = sut.play(mdp, hp);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, nnMock.resetCallCount);
|
||||
assertEquals(465.0, totalReward, 0.0001);
|
||||
|
||||
// HistoryProcessor
|
||||
assertEquals(27, hp.addCallCount);
|
||||
assertEquals(31, hp.recordCalls.size());
|
||||
for(int i=0; i <= 30; ++i) {
|
||||
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||
}
|
||||
|
||||
// MDP
|
||||
assertEquals(1, mdp.resetCount);
|
||||
assertEquals(30, mdp.actions.size());
|
||||
for(int i = 0; i < mdp.actions.size(); ++i) {
|
||||
assertEquals(0, (int)mdp.actions.get(i));
|
||||
}
|
||||
|
||||
// DQN
|
||||
assertEquals(0, dqn.fitParams.size());
|
||||
assertEquals(0, dqn.outputParams.size());
|
||||
}
|
||||
|
||||
public static class MockRefacPolicy extends Policy<MockEncodable, Integer> {
|
||||
|
||||
private NeuralNet neuralNet;
|
||||
|
||||
public MockRefacPolicy(NeuralNet neuralNet) {
|
||||
|
||||
this.neuralNet = neuralNet;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet getNeuralNet() {
|
||||
return neuralNet;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer nextAction(INDArray input) {
|
||||
return (int)input.getDouble(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,65 +1,22 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.Value;
|
||||
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
|
||||
|
||||
@AllArgsConstructor
|
||||
@Value
|
||||
public class MockAsyncConfiguration implements AsyncConfiguration {
|
||||
|
||||
private final int nStep;
|
||||
private final int maxEpochStep;
|
||||
|
||||
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
|
||||
this.nStep = nStep;
|
||||
|
||||
this.maxEpochStep = maxEpochStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getSeed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxEpochStep() {
|
||||
return maxEpochStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxStep() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumThread() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNstep() {
|
||||
return nStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getTargetDqnUpdateFreq() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getUpdateStart() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getRewardFactor() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getGamma() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getErrorClamp() {
|
||||
return 0;
|
||||
}
|
||||
private Integer seed;
|
||||
private int maxEpochStep;
|
||||
private int maxStep;
|
||||
private int numThread;
|
||||
private int nstep;
|
||||
private int targetDqnUpdateFreq;
|
||||
private int updateStart;
|
||||
private double rewardFactor;
|
||||
private double gamma;
|
||||
private double errorClamp;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||
|
@ -9,9 +10,13 @@ import java.util.concurrent.atomic.AtomicInteger;
|
|||
|
||||
public class MockAsyncGlobal implements IAsyncGlobal {
|
||||
|
||||
private final NeuralNet current;
|
||||
|
||||
public boolean hasBeenStarted = false;
|
||||
public boolean hasBeenTerminated = false;
|
||||
|
||||
public int enqueueCallCount = 0;
|
||||
|
||||
@Setter
|
||||
private int maxLoops;
|
||||
@Setter
|
||||
|
@ -19,8 +24,13 @@ public class MockAsyncGlobal implements IAsyncGlobal {
|
|||
private int currentLoop = 0;
|
||||
|
||||
public MockAsyncGlobal() {
|
||||
this(null);
|
||||
}
|
||||
|
||||
public MockAsyncGlobal(NeuralNet current) {
|
||||
maxLoops = Integer.MAX_VALUE;
|
||||
numLoopsStopRunning = Integer.MAX_VALUE;
|
||||
this.current = current;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -50,16 +60,16 @@ public class MockAsyncGlobal implements IAsyncGlobal {
|
|||
|
||||
@Override
|
||||
public NeuralNet getCurrent() {
|
||||
return null;
|
||||
return current;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet getTarget() {
|
||||
return null;
|
||||
return current;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
||||
|
||||
++enqueueCallCount;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,9 +5,10 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
public class MockHistoryProcessor implements IHistoryProcessor {
|
||||
|
||||
public int recordCallCount = 0;
|
||||
public int addCallCount = 0;
|
||||
public int startMonitorCallCount = 0;
|
||||
public int stopMonitorCallCount = 0;
|
||||
|
@ -15,10 +16,13 @@ public class MockHistoryProcessor implements IHistoryProcessor {
|
|||
private final Configuration config;
|
||||
private final CircularFifoQueue<INDArray> history;
|
||||
|
||||
public final ArrayList<INDArray> recordCalls;
|
||||
|
||||
public MockHistoryProcessor(Configuration config) {
|
||||
|
||||
this.config = config;
|
||||
history = new CircularFifoQueue<>(config.getHistoryLength());
|
||||
recordCalls = new ArrayList<INDArray>();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -37,7 +41,7 @@ public class MockHistoryProcessor implements IHistoryProcessor {
|
|||
|
||||
@Override
|
||||
public void record(INDArray image) {
|
||||
++recordCallCount;
|
||||
recordCalls.add(image);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -2,8 +2,10 @@ package org.deeplearning4j.rl4j.support;
|
|||
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -11,14 +13,30 @@ import java.util.List;
|
|||
public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||
|
||||
private final DiscreteSpace actionSpace;
|
||||
private final int stepsUntilDone;
|
||||
private int currentObsValue = 0;
|
||||
private final ObservationSpace observationSpace;
|
||||
|
||||
public final List<Integer> actions = new ArrayList<>();
|
||||
private int step = 0;
|
||||
public int resetCount = 0;
|
||||
|
||||
public MockMDP(ObservationSpace observationSpace, int stepsUntilDone, DiscreteSpace actionSpace) {
|
||||
this.stepsUntilDone = stepsUntilDone;
|
||||
this.actionSpace = actionSpace;
|
||||
this.observationSpace = observationSpace;
|
||||
}
|
||||
|
||||
public MockMDP(ObservationSpace observationSpace, int stepsUntilDone, Random rnd) {
|
||||
this(observationSpace, stepsUntilDone, new DiscreteSpace(5, rnd));
|
||||
}
|
||||
|
||||
public MockMDP(ObservationSpace observationSpace) {
|
||||
actionSpace = new DiscreteSpace(5);
|
||||
this.observationSpace = observationSpace;
|
||||
this(observationSpace, Integer.MAX_VALUE, new DiscreteSpace(5));
|
||||
}
|
||||
|
||||
public MockMDP(ObservationSpace observationSpace, Random rnd) {
|
||||
this(observationSpace, Integer.MAX_VALUE, new DiscreteSpace(5, rnd));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -33,7 +51,9 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
|||
|
||||
@Override
|
||||
public MockEncodable reset() {
|
||||
++resetCount;
|
||||
currentObsValue = 0;
|
||||
step = 0;
|
||||
return new MockEncodable(currentObsValue++);
|
||||
}
|
||||
|
||||
|
@ -45,12 +65,13 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
|||
@Override
|
||||
public StepReply<MockEncodable> step(Integer action) {
|
||||
actions.add(action);
|
||||
++step;
|
||||
return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isDone() {
|
||||
return false;
|
||||
return step >= stepsUntilDone;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -4,13 +4,18 @@ import org.deeplearning4j.nn.api.NeuralNetwork;
|
|||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MockNeuralNet implements NeuralNet {
|
||||
|
||||
public int resetCallCount = 0;
|
||||
public int copyCallCount = 0;
|
||||
public List<INDArray> outputAllInputs = new ArrayList<INDArray>();
|
||||
|
||||
@Override
|
||||
public NeuralNetwork[] getNeuralNetworks() {
|
||||
|
@ -29,17 +34,18 @@ public class MockNeuralNet implements NeuralNet {
|
|||
|
||||
@Override
|
||||
public INDArray[] outputAll(INDArray batch) {
|
||||
return new INDArray[0];
|
||||
outputAllInputs.add(batch);
|
||||
return new INDArray[] { Nd4j.create(new double[] { 1.0 }) };
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet clone() {
|
||||
return null;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NeuralNet from) {
|
||||
|
||||
++copyCallCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -4,14 +4,25 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
|||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MockPolicy implements IPolicy<MockEncodable, Integer> {
|
||||
|
||||
public int playCallCount = 0;
|
||||
public List<INDArray> actionInputs = new ArrayList<INDArray>();
|
||||
|
||||
@Override
|
||||
public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) {
|
||||
++playCallCount;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer nextAction(INDArray input) {
|
||||
actionInputs.add(input);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,12 +45,4 @@ public abstract class MalmoActionSpace extends DiscreteSpace {
|
|||
public Integer noOp() {
|
||||
return -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the seed used for random generation of actions
|
||||
* @param seed random number generator seed
|
||||
*/
|
||||
public void setRandomSeed(long seed) {
|
||||
rd.setSeed(seed);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue