RL4J: Make a few fixes (#8303)
* A few fixes Signed-off-by: unknown <aboulang2002@yahoo.com> * Reverted move of ObservationSpace, ActionSpace and others Signed-off-by: unknown <aboulang2002@yahoo.com> * Added unit tests Signed-off-by: unknown <aboulang2002@yahoo.com> * Changed ActionSpace of gym-java-client to use Nd4j's Random Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
ca881a987a
commit
a2b973d41b
|
@ -26,12 +26,10 @@ package org.deeplearning4j.rl4j.space;
|
||||||
public interface ActionSpace<A> {
|
public interface ActionSpace<A> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return A randomly uniformly sampled action,
|
* @return A random action,
|
||||||
*/
|
*/
|
||||||
A randomAction();
|
A randomAction();
|
||||||
|
|
||||||
void setSeed(int seed);
|
|
||||||
|
|
||||||
Object encode(A action);
|
Object encode(A action);
|
||||||
|
|
||||||
int getSize();
|
int getSize();
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
package org.deeplearning4j.rl4j.space;
|
package org.deeplearning4j.rl4j.space;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import java.util.Random;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
|
||||||
|
@ -33,19 +33,19 @@ public class DiscreteSpace implements ActionSpace<Integer> {
|
||||||
//size of the space also defined as the number of different actions
|
//size of the space also defined as the number of different actions
|
||||||
@Getter
|
@Getter
|
||||||
final protected int size;
|
final protected int size;
|
||||||
protected Random rd;
|
protected final Random rnd;
|
||||||
|
|
||||||
public DiscreteSpace(int size) {
|
public DiscreteSpace(int size) {
|
||||||
|
this(size, Nd4j.getRandom());
|
||||||
|
}
|
||||||
|
|
||||||
|
public DiscreteSpace(int size, Random rnd) {
|
||||||
this.size = size;
|
this.size = size;
|
||||||
rd = new Random();
|
this.rnd = rnd;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Integer randomAction() {
|
public Integer randomAction() {
|
||||||
return rd.nextInt(size);
|
return rnd.nextInt(size);
|
||||||
}
|
|
||||||
|
|
||||||
public void setSeed(int seed) {
|
|
||||||
rd = new Random(seed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Object encode(Integer a) {
|
public Object encode(Integer a) {
|
||||||
|
|
|
@ -67,8 +67,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
||||||
|
|
||||||
O obs = mdp.reset();
|
O obs = mdp.reset();
|
||||||
|
|
||||||
O nextO = obs;
|
|
||||||
|
|
||||||
int step = 0;
|
int step = 0;
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
|
|
||||||
|
@ -77,11 +75,12 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
||||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||||
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
||||||
|
|
||||||
while (step < requiredFrame) {
|
INDArray input = Learning.getInput(mdp, obs);
|
||||||
INDArray input = Learning.getInput(mdp, obs);
|
if (isHistoryProcessor)
|
||||||
|
hp.record(input);
|
||||||
|
|
||||||
if (isHistoryProcessor)
|
|
||||||
hp.record(input);
|
while (step < requiredFrame && !mdp.isDone()) {
|
||||||
|
|
||||||
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||||
if (step % skipFrame == 0 && isHistoryProcessor)
|
if (step % skipFrame == 0 && isHistoryProcessor)
|
||||||
|
@ -89,13 +88,17 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
||||||
|
|
||||||
StepReply<O> stepReply = mdp.step(action);
|
StepReply<O> stepReply = mdp.step(action);
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
nextO = stepReply.getObservation();
|
obs = stepReply.getObservation();
|
||||||
|
|
||||||
|
input = Learning.getInput(mdp, obs);
|
||||||
|
if (isHistoryProcessor)
|
||||||
|
hp.record(input);
|
||||||
|
|
||||||
step++;
|
step++;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new InitMdp(step, nextO, reward);
|
return new InitMdp(step, obs, reward);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,10 +22,11 @@ import lombok.Setter;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.rl4j.learning.*;
|
import org.deeplearning4j.rl4j.learning.*;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.*;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
@ -118,7 +119,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
|
|
||||||
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
||||||
handleTraining(context);
|
handleTraining(context);
|
||||||
if (context.length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||||
canContinue = finishEpoch(context) && startNewEpoch(context);
|
canContinue = finishEpoch(context) && startNewEpoch(context);
|
||||||
if (!canContinue) {
|
if (!canContinue) {
|
||||||
break;
|
break;
|
||||||
|
@ -135,16 +136,16 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
|
|
||||||
context.obs = initMdp.getLastObs();
|
context.obs = initMdp.getLastObs();
|
||||||
context.rewards = initMdp.getReward();
|
context.rewards = initMdp.getReward();
|
||||||
context.length = initMdp.getSteps();
|
context.epochElapsedSteps = initMdp.getSteps();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void handleTraining(RunContext<O> context) {
|
private void handleTraining(RunContext<O> context) {
|
||||||
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.length);
|
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps);
|
||||||
SubEpochReturn<O> subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
SubEpochReturn<O> subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
||||||
|
|
||||||
context.obs = subEpochReturn.getLastObs();
|
context.obs = subEpochReturn.getLastObs();
|
||||||
stepCounter += subEpochReturn.getSteps();
|
stepCounter += subEpochReturn.getSteps();
|
||||||
context.length += subEpochReturn.getSteps();
|
context.epochElapsedSteps += subEpochReturn.getSteps();
|
||||||
context.rewards += subEpochReturn.getReward();
|
context.rewards += subEpochReturn.getReward();
|
||||||
context.score = subEpochReturn.getScore();
|
context.score = subEpochReturn.getScore();
|
||||||
}
|
}
|
||||||
|
@ -164,7 +165,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
|
|
||||||
private boolean finishEpoch(RunContext context) {
|
private boolean finishEpoch(RunContext context) {
|
||||||
postEpoch();
|
postEpoch();
|
||||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.length, context.score);
|
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score);
|
||||||
|
|
||||||
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
|
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
|
||||||
|
|
||||||
|
@ -182,7 +183,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
|
|
||||||
protected abstract AsyncConfiguration getConf();
|
protected abstract AsyncConfiguration getConf();
|
||||||
|
|
||||||
protected abstract Policy<O, A> getPolicy(NN net);
|
protected abstract IPolicy<O, A> getPolicy(NN net);
|
||||||
|
|
||||||
protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep);
|
protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep);
|
||||||
|
|
||||||
|
@ -208,7 +209,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
private static class RunContext<O extends Encodable> {
|
private static class RunContext<O extends Encodable> {
|
||||||
private O obs;
|
private O obs;
|
||||||
private double rewards;
|
private double rewards;
|
||||||
private int length;
|
private int epochElapsedSteps;
|
||||||
private double score;
|
private double score;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -69,7 +69,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
||||||
Stack<MiniTrans<Integer>> rewards = new Stack<>();
|
Stack<MiniTrans<Integer>> rewards = new Stack<>();
|
||||||
|
|
||||||
O obs = sObs;
|
O obs = sObs;
|
||||||
Policy<O, Integer> policy = getPolicy(current);
|
IPolicy<O, Integer> policy = getPolicy(current);
|
||||||
|
|
||||||
Integer action;
|
Integer action;
|
||||||
Integer lastAction = null;
|
Integer lastAction = null;
|
||||||
|
|
|
@ -58,7 +58,6 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
||||||
Integer seed = conf.getSeed();
|
Integer seed = conf.getSeed();
|
||||||
Random rnd = Nd4j.getRandom();
|
Random rnd = Nd4j.getRandom();
|
||||||
if(seed != null) {
|
if(seed != null) {
|
||||||
mdp.getActionSpace().setSeed(seed);
|
|
||||||
rnd.setSeed(seed);
|
rnd.setSeed(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -63,8 +63,7 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
Integer seed = conf.getSeed();
|
Integer seed = conf.getSeed();
|
||||||
rnd = Nd4j.getRandom();
|
rnd = Nd4j.getRandom();
|
||||||
if(seed != null) {
|
if(seed != null) {
|
||||||
mdp.getActionSpace().setSeed(seed + threadNumber);
|
rnd.setSeed(seed + threadNumber);
|
||||||
rnd.setSeed(seed);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
|
@ -46,10 +47,6 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
||||||
Integer seed = conf.getSeed();
|
|
||||||
if(seed != null) {
|
|
||||||
mdp.getActionSpace().setSeed(seed);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -61,7 +61,6 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
|
|
||||||
Integer seed = conf.getSeed();
|
Integer seed = conf.getSeed();
|
||||||
if(seed != null) {
|
if(seed != null) {
|
||||||
mdp.getActionSpace().setSeed(seed + threadNumber);
|
|
||||||
rnd.setSeed(seed + threadNumber);
|
rnd.setSeed(seed + threadNumber);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,7 +85,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
policy = new DQNPolicy(getQNetwork());
|
policy = new DQNPolicy(getQNetwork());
|
||||||
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
|
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
|
||||||
this);
|
this);
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
|
||||||
|
|
||||||
tdTargetAlgorithm = conf.isDoubleDQN()
|
tdTargetAlgorithm = conf.isDoubleDQN()
|
||||||
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
|
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
|
||||||
|
@ -118,9 +117,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
||||||
|
|
||||||
|
|
||||||
if (isHistoryProcessor)
|
|
||||||
getHistoryProcessor().record(input);
|
|
||||||
|
|
||||||
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
||||||
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
||||||
int updateStart = getConfiguration().getUpdateStart()
|
int updateStart = getConfiguration().getUpdateStart()
|
||||||
|
@ -160,12 +156,16 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
StepReply<O> stepReply = getMdp().step(action);
|
StepReply<O> stepReply = getMdp().step(action);
|
||||||
|
|
||||||
|
INDArray ninput = getInput(stepReply.getObservation());
|
||||||
|
|
||||||
|
if (isHistoryProcessor)
|
||||||
|
getHistoryProcessor().record(ninput);
|
||||||
|
|
||||||
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
||||||
|
|
||||||
//if it's not a skipped frame, you can do a step of training
|
//if it's not a skipped frame, you can do a step of training
|
||||||
if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
|
if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
|
||||||
|
|
||||||
INDArray ninput = getInput(stepReply.getObservation());
|
|
||||||
if (isHistoryProcessor)
|
if (isHistoryProcessor)
|
||||||
getHistoryProcessor().add(ninput);
|
getHistoryProcessor().add(ninput);
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import static org.nd4j.linalg.ops.transforms.Transforms.exp;
|
import static org.nd4j.linalg.ops.transforms.Transforms.exp;
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,9 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public interface IPolicy<O extends Encodable, A> {
|
public interface IPolicy<O extends Encodable, A> {
|
||||||
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
||||||
|
A nextAction(INDArray input);
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,6 +51,9 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
||||||
|
boolean isHistoryProcessor = hp != null;
|
||||||
|
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||||
|
|
||||||
getNeuralNet().reset();
|
getNeuralNet().reset();
|
||||||
Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);
|
Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);
|
||||||
O obs = initMdp.getLastObs();
|
O obs = initMdp.getLastObs();
|
||||||
|
@ -62,17 +65,10 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
||||||
int step = initMdp.getSteps();
|
int step = initMdp.getSteps();
|
||||||
INDArray[] history = null;
|
INDArray[] history = null;
|
||||||
|
|
||||||
|
INDArray input = Learning.getInput(mdp, obs);
|
||||||
|
|
||||||
while (!mdp.isDone()) {
|
while (!mdp.isDone()) {
|
||||||
|
|
||||||
INDArray input = Learning.getInput(mdp, obs);
|
|
||||||
boolean isHistoryProcessor = hp != null;
|
|
||||||
|
|
||||||
if (isHistoryProcessor)
|
|
||||||
hp.record(input);
|
|
||||||
|
|
||||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
|
||||||
|
|
||||||
|
|
||||||
if (step % skipFrame != 0) {
|
if (step % skipFrame != 0) {
|
||||||
action = lastAction;
|
action = lastAction;
|
||||||
} else {
|
} else {
|
||||||
|
@ -102,8 +98,11 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
||||||
StepReply<O> stepReply = mdp.step(action);
|
StepReply<O> stepReply = mdp.step(action);
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
|
|
||||||
if (isHistoryProcessor)
|
input = Learning.getInput(mdp, stepReply.getObservation());
|
||||||
hp.add(Learning.getInput(mdp, stepReply.getObservation()));
|
if (isHistoryProcessor) {
|
||||||
|
hp.record(input);
|
||||||
|
hp.add(input);
|
||||||
|
}
|
||||||
|
|
||||||
history = isHistoryProcessor ? hp.getHistory()
|
history = isHistoryProcessor ? hp.getHistory()
|
||||||
: new INDArray[] {Learning.getInput(mdp, stepReply.getObservation())};
|
: new INDArray[] {Learning.getInput(mdp, stepReply.getObservation())};
|
||||||
|
|
|
@ -68,10 +68,10 @@ public class AsyncLearningTest {
|
||||||
|
|
||||||
|
|
||||||
public static class TestContext {
|
public static class TestContext {
|
||||||
public final MockAsyncConfiguration conf = new MockAsyncConfiguration(1, 1);
|
MockAsyncConfiguration config = new MockAsyncConfiguration(1, 11, 0, 0, 0, 0,0, 0, 0, 0);
|
||||||
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
||||||
public final MockPolicy policy = new MockPolicy();
|
public final MockPolicy policy = new MockPolicy();
|
||||||
public final TestAsyncLearning sut = new TestAsyncLearning(conf, asyncGlobal, policy);
|
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
|
||||||
public final MockTrainingListener listener = new MockTrainingListener();
|
public final MockTrainingListener listener = new MockTrainingListener();
|
||||||
|
|
||||||
public TestContext() {
|
public TestContext() {
|
||||||
|
|
|
@ -0,0 +1,134 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.Stack;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class AsyncThreadDiscreteTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void refac_AsyncThreadDiscrete_trainSubEpoch() {
|
||||||
|
// Arrange
|
||||||
|
MockNeuralNet nnMock = new MockNeuralNet();
|
||||||
|
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock);
|
||||||
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
|
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||||
|
TrainingListenerList listeners = new TrainingListenerList();
|
||||||
|
MockPolicy policyMock = new MockPolicy();
|
||||||
|
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 0, 5,0, 0, 0, 0);
|
||||||
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
|
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
|
||||||
|
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
||||||
|
MockEncodable obs = new MockEncodable(123);
|
||||||
|
|
||||||
|
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(1)));
|
||||||
|
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(2)));
|
||||||
|
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(3)));
|
||||||
|
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(4)));
|
||||||
|
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(5)));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
AsyncThread.SubEpochReturn<MockEncodable> result = sut.trainSubEpoch(obs, 2);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(4, result.getSteps());
|
||||||
|
assertEquals(6.0, result.getReward(), 0.00001);
|
||||||
|
assertEquals(0.0, result.getScore(), 0.00001);
|
||||||
|
assertEquals(3.0, result.getLastObs().toArray()[0], 0.00001);
|
||||||
|
assertEquals(1, asyncGlobalMock.enqueueCallCount);
|
||||||
|
|
||||||
|
// HistoryProcessor
|
||||||
|
assertEquals(10, hpMock.addCallCount);
|
||||||
|
double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 };
|
||||||
|
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
|
||||||
|
for(int i = 0; i < expectedRecordValues.length; ++i) {
|
||||||
|
assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Policy
|
||||||
|
double[][] expectedPolicyInputs = new double[][] {
|
||||||
|
new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 },
|
||||||
|
new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 },
|
||||||
|
new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 },
|
||||||
|
new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 },
|
||||||
|
};
|
||||||
|
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
|
||||||
|
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
|
||||||
|
double[] expectedRow = expectedPolicyInputs[i];
|
||||||
|
INDArray input = policyMock.actionInputs.get(i);
|
||||||
|
assertEquals(expectedRow.length, input.shape()[0]);
|
||||||
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
|
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeuralNetwork
|
||||||
|
assertEquals(1, nnMock.copyCallCount);
|
||||||
|
double[][] expectedNNInputs = new double[][] {
|
||||||
|
new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 },
|
||||||
|
new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 },
|
||||||
|
new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 },
|
||||||
|
new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 },
|
||||||
|
new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 },
|
||||||
|
};
|
||||||
|
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
|
||||||
|
for(int i = 0; i < expectedNNInputs.length; ++i) {
|
||||||
|
double[] expectedRow = expectedNNInputs[i];
|
||||||
|
INDArray input = nnMock.outputAllInputs.get(i);
|
||||||
|
assertEquals(expectedRow.length, input.shape()[0]);
|
||||||
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
|
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
|
||||||
|
|
||||||
|
private final IAsyncGlobal<MockNeuralNet> asyncGlobal;
|
||||||
|
private final MockPolicy policy;
|
||||||
|
private final MockAsyncConfiguration config;
|
||||||
|
|
||||||
|
public TestAsyncThreadDiscrete(IAsyncGlobal<MockNeuralNet> asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
|
||||||
|
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
|
||||||
|
MockAsyncConfiguration config, IHistoryProcessor hp) {
|
||||||
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
|
this.asyncGlobal = asyncGlobal;
|
||||||
|
this.policy = policy;
|
||||||
|
this.config = config;
|
||||||
|
setHistoryProcessor(hp);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack<MiniTrans<Integer>> rewards) {
|
||||||
|
return new Gradient[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected IAsyncGlobal<MockNeuralNet> getAsyncGlobal() {
|
||||||
|
return asyncGlobal;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected AsyncConfiguration getConf() {
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected IPolicy<MockEncodable, Integer> getPolicy(MockNeuralNet net) {
|
||||||
|
return policy;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,8 @@
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
@ -9,7 +12,11 @@ import org.deeplearning4j.rl4j.support.*;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertNull;
|
||||||
|
|
||||||
public class AsyncThreadTest {
|
public class AsyncThreadTest {
|
||||||
|
|
||||||
|
@ -82,7 +89,42 @@ public class AsyncThreadTest {
|
||||||
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
|
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
|
||||||
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
|
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
|
||||||
assertEquals(i, statEntry.getEpochCounter());
|
assertEquals(i, statEntry.getEpochCounter());
|
||||||
assertEquals(2.0, statEntry.getReward(), 0.0001);
|
assertEquals(38.0, statEntry.getReward(), 0.0001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_run_expect_NeuralNetIsResetAtInitAndEveryEpoch() {
|
||||||
|
// Arrange
|
||||||
|
TestContext context = new TestContext();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
context.sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(6, context.neuralNet.resetCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_run_expect_trainSubEpochCalled() {
|
||||||
|
// Arrange
|
||||||
|
TestContext context = new TestContext();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
context.sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(10, context.sut.trainSubEpochParams.size());
|
||||||
|
for(int i = 0; i < 10; ++i) {
|
||||||
|
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
|
||||||
|
if(i % 2 == 0) {
|
||||||
|
assertEquals(2, params.nstep);
|
||||||
|
assertEquals(8.0, params.obs.toArray()[0], 0.00001);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
assertEquals(1, params.nstep);
|
||||||
|
assertNull(params.obs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,14 +133,18 @@ public class AsyncThreadTest {
|
||||||
public final MockNeuralNet neuralNet = new MockNeuralNet();
|
public final MockNeuralNet neuralNet = new MockNeuralNet();
|
||||||
public final MockObservationSpace observationSpace = new MockObservationSpace();
|
public final MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
public final MockMDP mdp = new MockMDP(observationSpace);
|
public final MockMDP mdp = new MockMDP(observationSpace);
|
||||||
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 2);
|
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0);
|
||||||
public final TrainingListenerList listeners = new TrainingListenerList();
|
public final TrainingListenerList listeners = new TrainingListenerList();
|
||||||
public final MockTrainingListener listener = new MockTrainingListener();
|
public final MockTrainingListener listener = new MockTrainingListener();
|
||||||
|
private final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
|
public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf);
|
||||||
|
|
||||||
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
|
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
|
||||||
|
|
||||||
public TestContext() {
|
public TestContext() {
|
||||||
asyncGlobal.setMaxLoops(10);
|
asyncGlobal.setMaxLoops(10);
|
||||||
listeners.add(listener);
|
listeners.add(listener);
|
||||||
|
sut.setHistoryProcessor(historyProcessor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,11 +153,12 @@ public class AsyncThreadTest {
|
||||||
public int preEpochCallCount = 0;
|
public int preEpochCallCount = 0;
|
||||||
public int postEpochCallCount = 0;
|
public int postEpochCallCount = 0;
|
||||||
|
|
||||||
|
|
||||||
private final IAsyncGlobal asyncGlobal;
|
private final IAsyncGlobal asyncGlobal;
|
||||||
private final MockNeuralNet neuralNet;
|
private final MockNeuralNet neuralNet;
|
||||||
private final AsyncConfiguration conf;
|
private final AsyncConfiguration conf;
|
||||||
|
|
||||||
|
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
|
||||||
|
|
||||||
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
|
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, 0);
|
super(asyncGlobal, mdp, listeners, threadNumber, 0);
|
||||||
|
|
||||||
|
@ -154,8 +201,16 @@ public class AsyncThreadTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) {
|
protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) {
|
||||||
|
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
|
||||||
return new SubEpochReturn(1, null, 1.0, 1.0);
|
return new SubEpochReturn(1, null, 1.0, 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Getter
|
||||||
|
public static class TrainSubEpochParams {
|
||||||
|
Encodable obs;
|
||||||
|
int nstep;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,20 @@ public class QLearningDiscreteTest {
|
||||||
public void refac_QLearningDiscrete_trainStep() {
|
public void refac_QLearningDiscrete_trainStep() {
|
||||||
// Arrange
|
// Arrange
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
MockMDP mdp = new MockMDP(observationSpace);
|
|
||||||
MockDQN dqn = new MockDQN();
|
MockDQN dqn = new MockDQN();
|
||||||
MockRandom random = new MockRandom(new double[] { 0.7309677600860596, 0.8314409852027893, 0.2405363917350769, 0.6063451766967773, 0.6374173760414124, 0.3090505599975586, 0.5504369735717773, 0.11700659990310669 }, null);
|
MockRandom random = new MockRandom(new double[] {
|
||||||
|
0.7309677600860596,
|
||||||
|
0.8314409852027893,
|
||||||
|
0.2405363917350769,
|
||||||
|
0.6063451766967773,
|
||||||
|
0.6374173760414124,
|
||||||
|
0.3090505599975586,
|
||||||
|
0.5504369735717773,
|
||||||
|
0.11700659990310669
|
||||||
|
},
|
||||||
|
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
||||||
|
MockMDP mdp = new MockMDP(observationSpace, random);
|
||||||
|
|
||||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
||||||
0, 1.0, 0, 0, 0, 0, true);
|
0, 1.0, 0, 0, 0, 0, true);
|
||||||
MockDataManager dataManager = new MockDataManager(false);
|
MockDataManager dataManager = new MockDataManager(false);
|
||||||
|
@ -37,7 +48,7 @@ public class QLearningDiscreteTest {
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||||
sut.setHistoryProcessor(hp);
|
sut.setHistoryProcessor(hp);
|
||||||
MockEncodable obs = new MockEncodable(1);
|
MockEncodable obs = new MockEncodable(-100);
|
||||||
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
|
@ -49,7 +60,11 @@ public class QLearningDiscreteTest {
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
// HistoryProcessor calls
|
// HistoryProcessor calls
|
||||||
assertEquals(24, hp.recordCallCount);
|
double[] expectedRecords = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0 };
|
||||||
|
assertEquals(expectedRecords.length, hp.recordCalls.size());
|
||||||
|
for(int i = 0; i < expectedRecords.length; ++i) {
|
||||||
|
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||||
|
}
|
||||||
assertEquals(13, hp.addCallCount);
|
assertEquals(13, hp.addCallCount);
|
||||||
assertEquals(0, hp.startMonitorCallCount);
|
assertEquals(0, hp.startMonitorCallCount);
|
||||||
assertEquals(0, hp.stopMonitorCallCount);
|
assertEquals(0, hp.stopMonitorCallCount);
|
||||||
|
@ -60,21 +75,20 @@ public class QLearningDiscreteTest {
|
||||||
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
|
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
|
||||||
assertEquals(14, dqn.outputParams.size());
|
assertEquals(14, dqn.outputParams.size());
|
||||||
double[][] expectedDQNOutput = new double[][] {
|
double[][] expectedDQNOutput = new double[][] {
|
||||||
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 },
|
||||||
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
|
||||||
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
|
||||||
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 },
|
new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 },
|
||||||
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
|
||||||
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
|
||||||
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
|
||||||
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
|
||||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||||
|
|
||||||
};
|
};
|
||||||
for(int i = 0; i < expectedDQNOutput.length; ++i) {
|
for(int i = 0; i < expectedDQNOutput.length; ++i) {
|
||||||
INDArray outputParam = dqn.outputParams.get(i);
|
INDArray outputParam = dqn.outputParams.get(i);
|
||||||
|
@ -84,23 +98,23 @@ public class QLearningDiscreteTest {
|
||||||
|
|
||||||
double[] expectedRow = expectedDQNOutput[i];
|
double[] expectedRow = expectedDQNOutput[i];
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
assertEquals(expectedRow[j] / 255.0, outputParam.getDouble(j), 0.00001);
|
assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MDP calls
|
// MDP calls
|
||||||
assertArrayEquals(new Integer[] { 0, 0, 0, 0, 0, 0, 0, 0, 0 ,0, 4, 4, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
|
assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
|
||||||
|
|
||||||
// ExpReplay calls
|
// ExpReplay calls
|
||||||
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
|
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
|
||||||
int[] expectedTrActions = new int[] { 0, 4, 3, 4, 4, 4, 4, 4 };
|
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
||||||
double[] expectedTrNextObservation = new double[] { 0, 0, 0, 1.0, 9.0, 11.0, 13.0, 15.0 };
|
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, -100.0, 9.0, 11.0, 13.0, 15.0 };
|
||||||
double[][] expectedTrObservations = new double[][] {
|
double[][] expectedTrObservations = new double[][] {
|
||||||
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 },
|
||||||
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
|
||||||
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 },
|
new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 },
|
||||||
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
|
||||||
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
|
||||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||||
|
@ -111,18 +125,18 @@ public class QLearningDiscreteTest {
|
||||||
assertEquals(expectedTrActions[i], tr.getAction());
|
assertEquals(expectedTrActions[i], tr.getAction());
|
||||||
assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001);
|
assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001);
|
||||||
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
||||||
assertEquals(expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001);
|
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// trainStep results
|
// trainStep results
|
||||||
assertEquals(16, results.size());
|
assertEquals(16, results.size());
|
||||||
double[] expectedMaxQ = new double[] { 1.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 };
|
double[] expectedMaxQ = new double[] { 6.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 };
|
||||||
double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
|
double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
|
||||||
for(int i=0; i < 16; ++i) {
|
for(int i=0; i < 16; ++i) {
|
||||||
QLearning.QLStepReturn<MockEncodable> result = results.get(i);
|
QLearning.QLStepReturn<MockEncodable> result = results.get(i);
|
||||||
if(i % 2 == 0) {
|
if(i % 2 == 0) {
|
||||||
assertEquals(expectedMaxQ[i/2] / 255.0, result.getMaxQ(), 0.001);
|
assertEquals(expectedMaxQ[i/2], 255.0 * result.getMaxQ(), 0.001);
|
||||||
assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001);
|
assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
|
@ -22,7 +22,15 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -31,6 +39,8 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
@ -153,4 +163,75 @@ public class PolicyTest {
|
||||||
assertTrue(count[2] < 40);
|
assertTrue(count[2] < 40);
|
||||||
assertTrue(count[3] < 50);
|
assertTrue(count[3] < 50);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void refacPolicyPlay() {
|
||||||
|
// Arrange
|
||||||
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
|
MockDQN dqn = new MockDQN();
|
||||||
|
MockRandom random = new MockRandom(new double[] {
|
||||||
|
0.7309677600860596,
|
||||||
|
0.8314409852027893,
|
||||||
|
0.2405363917350769,
|
||||||
|
0.6063451766967773,
|
||||||
|
0.6374173760414124,
|
||||||
|
0.3090505599975586,
|
||||||
|
0.5504369735717773,
|
||||||
|
0.11700659990310669
|
||||||
|
},
|
||||||
|
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
||||||
|
MockMDP mdp = new MockMDP(observationSpace, 30, random);
|
||||||
|
|
||||||
|
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
||||||
|
0, 1.0, 0, 0, 0, 0, true);
|
||||||
|
MockNeuralNet nnMock = new MockNeuralNet();
|
||||||
|
MockRefacPolicy sut = new MockRefacPolicy(nnMock);
|
||||||
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
|
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
double totalReward = sut.play(mdp, hp);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, nnMock.resetCallCount);
|
||||||
|
assertEquals(465.0, totalReward, 0.0001);
|
||||||
|
|
||||||
|
// HistoryProcessor
|
||||||
|
assertEquals(27, hp.addCallCount);
|
||||||
|
assertEquals(31, hp.recordCalls.size());
|
||||||
|
for(int i=0; i <= 30; ++i) {
|
||||||
|
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MDP
|
||||||
|
assertEquals(1, mdp.resetCount);
|
||||||
|
assertEquals(30, mdp.actions.size());
|
||||||
|
for(int i = 0; i < mdp.actions.size(); ++i) {
|
||||||
|
assertEquals(0, (int)mdp.actions.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// DQN
|
||||||
|
assertEquals(0, dqn.fitParams.size());
|
||||||
|
assertEquals(0, dqn.outputParams.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MockRefacPolicy extends Policy<MockEncodable, Integer> {
|
||||||
|
|
||||||
|
private NeuralNet neuralNet;
|
||||||
|
|
||||||
|
public MockRefacPolicy(NeuralNet neuralNet) {
|
||||||
|
|
||||||
|
this.neuralNet = neuralNet;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNet getNeuralNet() {
|
||||||
|
return neuralNet;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer nextAction(INDArray input) {
|
||||||
|
return (int)input.getDouble(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,65 +1,22 @@
|
||||||
package org.deeplearning4j.rl4j.support;
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Value;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
|
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Value
|
||||||
public class MockAsyncConfiguration implements AsyncConfiguration {
|
public class MockAsyncConfiguration implements AsyncConfiguration {
|
||||||
|
|
||||||
private final int nStep;
|
private Integer seed;
|
||||||
private final int maxEpochStep;
|
private int maxEpochStep;
|
||||||
|
private int maxStep;
|
||||||
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
|
private int numThread;
|
||||||
this.nStep = nStep;
|
private int nstep;
|
||||||
|
private int targetDqnUpdateFreq;
|
||||||
this.maxEpochStep = maxEpochStep;
|
private int updateStart;
|
||||||
}
|
private double rewardFactor;
|
||||||
|
private double gamma;
|
||||||
@Override
|
private double errorClamp;
|
||||||
public Integer getSeed() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getMaxEpochStep() {
|
|
||||||
return maxEpochStep;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getMaxStep() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getNumThread() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getNstep() {
|
|
||||||
return nStep;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getTargetDqnUpdateFreq() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getUpdateStart() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getRewardFactor() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getGamma() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getErrorClamp() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package org.deeplearning4j.rl4j.support;
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||||
|
@ -9,9 +10,13 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
public class MockAsyncGlobal implements IAsyncGlobal {
|
public class MockAsyncGlobal implements IAsyncGlobal {
|
||||||
|
|
||||||
|
private final NeuralNet current;
|
||||||
|
|
||||||
public boolean hasBeenStarted = false;
|
public boolean hasBeenStarted = false;
|
||||||
public boolean hasBeenTerminated = false;
|
public boolean hasBeenTerminated = false;
|
||||||
|
|
||||||
|
public int enqueueCallCount = 0;
|
||||||
|
|
||||||
@Setter
|
@Setter
|
||||||
private int maxLoops;
|
private int maxLoops;
|
||||||
@Setter
|
@Setter
|
||||||
|
@ -19,8 +24,13 @@ public class MockAsyncGlobal implements IAsyncGlobal {
|
||||||
private int currentLoop = 0;
|
private int currentLoop = 0;
|
||||||
|
|
||||||
public MockAsyncGlobal() {
|
public MockAsyncGlobal() {
|
||||||
|
this(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public MockAsyncGlobal(NeuralNet current) {
|
||||||
maxLoops = Integer.MAX_VALUE;
|
maxLoops = Integer.MAX_VALUE;
|
||||||
numLoopsStopRunning = Integer.MAX_VALUE;
|
numLoopsStopRunning = Integer.MAX_VALUE;
|
||||||
|
this.current = current;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -50,16 +60,16 @@ public class MockAsyncGlobal implements IAsyncGlobal {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNet getCurrent() {
|
public NeuralNet getCurrent() {
|
||||||
return null;
|
return current;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNet getTarget() {
|
public NeuralNet getTarget() {
|
||||||
return null;
|
return current;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
public void enqueue(Gradient[] gradient, Integer nstep) {
|
||||||
|
++enqueueCallCount;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,9 +5,10 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
|
||||||
public class MockHistoryProcessor implements IHistoryProcessor {
|
public class MockHistoryProcessor implements IHistoryProcessor {
|
||||||
|
|
||||||
public int recordCallCount = 0;
|
|
||||||
public int addCallCount = 0;
|
public int addCallCount = 0;
|
||||||
public int startMonitorCallCount = 0;
|
public int startMonitorCallCount = 0;
|
||||||
public int stopMonitorCallCount = 0;
|
public int stopMonitorCallCount = 0;
|
||||||
|
@ -15,10 +16,13 @@ public class MockHistoryProcessor implements IHistoryProcessor {
|
||||||
private final Configuration config;
|
private final Configuration config;
|
||||||
private final CircularFifoQueue<INDArray> history;
|
private final CircularFifoQueue<INDArray> history;
|
||||||
|
|
||||||
|
public final ArrayList<INDArray> recordCalls;
|
||||||
|
|
||||||
public MockHistoryProcessor(Configuration config) {
|
public MockHistoryProcessor(Configuration config) {
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
history = new CircularFifoQueue<>(config.getHistoryLength());
|
history = new CircularFifoQueue<>(config.getHistoryLength());
|
||||||
|
recordCalls = new ArrayList<INDArray>();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -37,7 +41,7 @@ public class MockHistoryProcessor implements IHistoryProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void record(INDArray image) {
|
public void record(INDArray image) {
|
||||||
++recordCallCount;
|
recordCalls.add(image);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -2,8 +2,10 @@ package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -11,14 +13,30 @@ import java.util.List;
|
||||||
public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||||
|
|
||||||
private final DiscreteSpace actionSpace;
|
private final DiscreteSpace actionSpace;
|
||||||
|
private final int stepsUntilDone;
|
||||||
private int currentObsValue = 0;
|
private int currentObsValue = 0;
|
||||||
private final ObservationSpace observationSpace;
|
private final ObservationSpace observationSpace;
|
||||||
|
|
||||||
public final List<Integer> actions = new ArrayList<>();
|
public final List<Integer> actions = new ArrayList<>();
|
||||||
|
private int step = 0;
|
||||||
|
public int resetCount = 0;
|
||||||
|
|
||||||
|
public MockMDP(ObservationSpace observationSpace, int stepsUntilDone, DiscreteSpace actionSpace) {
|
||||||
|
this.stepsUntilDone = stepsUntilDone;
|
||||||
|
this.actionSpace = actionSpace;
|
||||||
|
this.observationSpace = observationSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MockMDP(ObservationSpace observationSpace, int stepsUntilDone, Random rnd) {
|
||||||
|
this(observationSpace, stepsUntilDone, new DiscreteSpace(5, rnd));
|
||||||
|
}
|
||||||
|
|
||||||
public MockMDP(ObservationSpace observationSpace) {
|
public MockMDP(ObservationSpace observationSpace) {
|
||||||
actionSpace = new DiscreteSpace(5);
|
this(observationSpace, Integer.MAX_VALUE, new DiscreteSpace(5));
|
||||||
this.observationSpace = observationSpace;
|
}
|
||||||
|
|
||||||
|
public MockMDP(ObservationSpace observationSpace, Random rnd) {
|
||||||
|
this(observationSpace, Integer.MAX_VALUE, new DiscreteSpace(5, rnd));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -33,7 +51,9 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MockEncodable reset() {
|
public MockEncodable reset() {
|
||||||
|
++resetCount;
|
||||||
currentObsValue = 0;
|
currentObsValue = 0;
|
||||||
|
step = 0;
|
||||||
return new MockEncodable(currentObsValue++);
|
return new MockEncodable(currentObsValue++);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,12 +65,13 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||||
@Override
|
@Override
|
||||||
public StepReply<MockEncodable> step(Integer action) {
|
public StepReply<MockEncodable> step(Integer action) {
|
||||||
actions.add(action);
|
actions.add(action);
|
||||||
|
++step;
|
||||||
return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null);
|
return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isDone() {
|
public boolean isDone() {
|
||||||
return false;
|
return step >= stepsUntilDone;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -4,13 +4,18 @@ import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class MockNeuralNet implements NeuralNet {
|
public class MockNeuralNet implements NeuralNet {
|
||||||
|
|
||||||
public int resetCallCount = 0;
|
public int resetCallCount = 0;
|
||||||
|
public int copyCallCount = 0;
|
||||||
|
public List<INDArray> outputAllInputs = new ArrayList<INDArray>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNetwork[] getNeuralNetworks() {
|
public NeuralNetwork[] getNeuralNetworks() {
|
||||||
|
@ -29,17 +34,18 @@ public class MockNeuralNet implements NeuralNet {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] outputAll(INDArray batch) {
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
return new INDArray[0];
|
outputAllInputs.add(batch);
|
||||||
|
return new INDArray[] { Nd4j.create(new double[] { 1.0 }) };
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNet clone() {
|
public NeuralNet clone() {
|
||||||
return null;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void copy(NeuralNet from) {
|
public void copy(NeuralNet from) {
|
||||||
|
++copyCallCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -4,14 +4,25 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class MockPolicy implements IPolicy<MockEncodable, Integer> {
|
public class MockPolicy implements IPolicy<MockEncodable, Integer> {
|
||||||
|
|
||||||
public int playCallCount = 0;
|
public int playCallCount = 0;
|
||||||
|
public List<INDArray> actionInputs = new ArrayList<INDArray>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) {
|
public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) {
|
||||||
++playCallCount;
|
++playCallCount;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer nextAction(INDArray input) {
|
||||||
|
actionInputs.add(input);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,12 +45,4 @@ public abstract class MalmoActionSpace extends DiscreteSpace {
|
||||||
public Integer noOp() {
|
public Integer noOp() {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets the seed used for random generation of actions
|
|
||||||
* @param seed random number generator seed
|
|
||||||
*/
|
|
||||||
public void setRandomSeed(long seed) {
|
|
||||||
rd.setSeed(seed);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue