RL4J: Remove processing done on observations in Policy & Async (#8471)
* Removed processing from Policy.play() and fixed missing resets Signed-off-by: unknown <aboulang2002@yahoo.com> * Adjusted unit test to check if DQNs have been reset Signed-off-by: unknown <aboulang2002@yahoo.com> * Fixed a couple of problems, added and updated unit tests Signed-off-by: unknown <aboulang2002@yahoo.com> * Removed processing from AsyncThreadDiscrete Signed-off-by: unknown <aboulang2002@yahoo.com> * Fixed a few problems Signed-off-by: unknown <aboulang2002@yahoo.com>master
parent
65ef0622ff
commit
de3975f088
|
@ -57,6 +57,7 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
||||||
final private NN current;
|
final private NN current;
|
||||||
final private ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue;
|
final private ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue;
|
||||||
final private AsyncConfiguration a3cc;
|
final private AsyncConfiguration a3cc;
|
||||||
|
private final IAsyncLearning learning;
|
||||||
@Getter
|
@Getter
|
||||||
private AtomicInteger T = new AtomicInteger(0);
|
private AtomicInteger T = new AtomicInteger(0);
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -64,10 +65,11 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
||||||
@Getter
|
@Getter
|
||||||
private boolean running = true;
|
private boolean running = true;
|
||||||
|
|
||||||
public AsyncGlobal(NN initial, AsyncConfiguration a3cc) {
|
public AsyncGlobal(NN initial, AsyncConfiguration a3cc, IAsyncLearning learning) {
|
||||||
this.current = initial;
|
this.current = initial;
|
||||||
target = (NN) initial.clone();
|
target = (NN) initial.clone();
|
||||||
this.a3cc = a3cc;
|
this.a3cc = a3cc;
|
||||||
|
this.learning = learning;
|
||||||
queue = new ConcurrentLinkedQueue<>();
|
queue = new ConcurrentLinkedQueue<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,11 +108,14 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
|
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too.
|
||||||
*/
|
*/
|
||||||
public void terminate() {
|
public void terminate() {
|
||||||
running = false;
|
if(running) {
|
||||||
queue.clear();
|
running = false;
|
||||||
|
queue.clear();
|
||||||
|
learning.terminate();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,10 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
extends Learning<O, A, AS, NN> {
|
extends Learning<O, A, AS, NN>
|
||||||
|
implements IAsyncLearning {
|
||||||
|
|
||||||
|
private Thread monitorThread = null;
|
||||||
|
|
||||||
@Getter(AccessLevel.PROTECTED)
|
@Getter(AccessLevel.PROTECTED)
|
||||||
private final TrainingListenerList listeners = new TrainingListenerList();
|
private final TrainingListenerList listeners = new TrainingListenerList();
|
||||||
|
@ -126,6 +129,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
|
|
||||||
protected void monitorTraining() {
|
protected void monitorTraining() {
|
||||||
try {
|
try {
|
||||||
|
monitorThread = Thread.currentThread();
|
||||||
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
||||||
canContinue = listeners.notifyTrainingProgress(this);
|
canContinue = listeners.notifyTrainingProgress(this);
|
||||||
if(!canContinue) {
|
if(!canContinue) {
|
||||||
|
@ -139,10 +143,25 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
log.error("Training interrupted.", e);
|
log.error("Training interrupted.", e);
|
||||||
}
|
}
|
||||||
|
monitorThread = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void cleanupPostTraining() {
|
protected void cleanupPostTraining() {
|
||||||
// Worker threads stops automatically when the global thread stops
|
// Worker threads stops automatically when the global thread stops
|
||||||
getAsyncGlobal().terminate();
|
getAsyncGlobal().terminate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated.
|
||||||
|
*/
|
||||||
|
public void terminate() {
|
||||||
|
if(canContinue) {
|
||||||
|
canContinue = false;
|
||||||
|
|
||||||
|
Thread safeMonitorThread = monitorThread;
|
||||||
|
if(safeMonitorThread != null) {
|
||||||
|
safeMonitorThread.interrupt();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,15 +21,19 @@ import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.*;
|
import org.deeplearning4j.rl4j.learning.*;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -43,7 +47,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
extends Thread implements StepCountable, IEpochTrainer {
|
extends Thread implements StepCountable, IEpochTrainer {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -54,26 +58,35 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
private int stepCounter = 0;
|
private int stepCounter = 0;
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private int epochCounter = 0;
|
private int epochCounter = 0;
|
||||||
@Getter
|
|
||||||
private MDP<O, A, AS> mdp;
|
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private IHistoryProcessor historyProcessor;
|
private IHistoryProcessor historyProcessor;
|
||||||
|
|
||||||
|
private boolean isEpochStarted = false;
|
||||||
|
private final LegacyMDPWrapper<O, A, AS> mdp;
|
||||||
|
|
||||||
private final TrainingListenerList listeners;
|
private final TrainingListenerList listeners;
|
||||||
|
|
||||||
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, MDP<O, A, AS> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, MDP<O, A, AS> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
||||||
this.mdp = mdp;
|
this.mdp = new LegacyMDPWrapper<O, A, AS>(mdp, null, this);
|
||||||
this.listeners = listeners;
|
this.listeners = listeners;
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
this.deviceNum = deviceNum;
|
this.deviceNum = deviceNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MDP<O, A, AS> getMdp() {
|
||||||
|
return mdp.getWrappedMDP();
|
||||||
|
}
|
||||||
|
protected LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper() {
|
||||||
|
return mdp;
|
||||||
|
}
|
||||||
|
|
||||||
public void setHistoryProcessor(IHistoryProcessor.Configuration conf) {
|
public void setHistoryProcessor(IHistoryProcessor.Configuration conf) {
|
||||||
historyProcessor = new HistoryProcessor(conf);
|
setHistoryProcessor(new HistoryProcessor(conf));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
||||||
this.historyProcessor = historyProcessor;
|
this.historyProcessor = historyProcessor;
|
||||||
|
mdp.setHistoryProcessor(historyProcessor);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void postEpoch() {
|
protected void postEpoch() {
|
||||||
|
@ -109,61 +122,63 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
RunContext<O> context = new RunContext<>();
|
try {
|
||||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
|
RunContext context = new RunContext();
|
||||||
|
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
|
||||||
|
|
||||||
log.info("ThreadNum-" + threadNumber + " Started!");
|
log.info("ThreadNum-" + threadNumber + " Started!");
|
||||||
|
|
||||||
boolean canContinue = initWork(context);
|
|
||||||
if (canContinue) {
|
|
||||||
|
|
||||||
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
||||||
handleTraining(context);
|
if (!isEpochStarted) {
|
||||||
if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
boolean canContinue = startNewEpoch(context);
|
||||||
canContinue = finishEpoch(context) && startNewEpoch(context);
|
|
||||||
if (!canContinue) {
|
if (!canContinue) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handleTraining(context);
|
||||||
|
|
||||||
|
if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||||
|
boolean canContinue = finishEpoch(context);
|
||||||
|
if (!canContinue) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
++epochCounter;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
terminateWork();
|
finally {
|
||||||
|
terminateWork();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initNewEpoch(RunContext context) {
|
private void handleTraining(RunContext context) {
|
||||||
getCurrent().reset();
|
|
||||||
Learning.InitMdp<O> initMdp = Learning.initMdp(getMdp(), historyProcessor);
|
|
||||||
|
|
||||||
context.obs = initMdp.getLastObs();
|
|
||||||
context.rewards = initMdp.getReward();
|
|
||||||
context.epochElapsedSteps = initMdp.getSteps();
|
|
||||||
}
|
|
||||||
|
|
||||||
private void handleTraining(RunContext<O> context) {
|
|
||||||
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps);
|
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps);
|
||||||
SubEpochReturn<O> subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
||||||
|
|
||||||
context.obs = subEpochReturn.getLastObs();
|
context.obs = subEpochReturn.getLastObs();
|
||||||
stepCounter += subEpochReturn.getSteps();
|
|
||||||
context.epochElapsedSteps += subEpochReturn.getSteps();
|
context.epochElapsedSteps += subEpochReturn.getSteps();
|
||||||
context.rewards += subEpochReturn.getReward();
|
context.rewards += subEpochReturn.getReward();
|
||||||
context.score = subEpochReturn.getScore();
|
context.score = subEpochReturn.getScore();
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean initWork(RunContext context) {
|
|
||||||
initNewEpoch(context);
|
|
||||||
preEpoch();
|
|
||||||
return listeners.notifyNewEpoch(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean startNewEpoch(RunContext context) {
|
private boolean startNewEpoch(RunContext context) {
|
||||||
initNewEpoch(context);
|
getCurrent().reset();
|
||||||
epochCounter++;
|
Learning.InitMdp<Observation> initMdp = refacInitMdp();
|
||||||
|
|
||||||
|
context.obs = initMdp.getLastObs();
|
||||||
|
context.rewards = initMdp.getReward();
|
||||||
|
context.epochElapsedSteps = initMdp.getSteps();
|
||||||
|
|
||||||
|
isEpochStarted = true;
|
||||||
preEpoch();
|
preEpoch();
|
||||||
|
|
||||||
return listeners.notifyNewEpoch(this);
|
return listeners.notifyNewEpoch(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean finishEpoch(RunContext context) {
|
private boolean finishEpoch(RunContext context) {
|
||||||
|
isEpochStarted = false;
|
||||||
postEpoch();
|
postEpoch();
|
||||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score);
|
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score);
|
||||||
|
|
||||||
|
@ -173,8 +188,10 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
}
|
}
|
||||||
|
|
||||||
private void terminateWork() {
|
private void terminateWork() {
|
||||||
postEpoch();
|
|
||||||
getAsyncGlobal().terminate();
|
getAsyncGlobal().terminate();
|
||||||
|
if(isEpochStarted) {
|
||||||
|
postEpoch();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract NN getCurrent();
|
protected abstract NN getCurrent();
|
||||||
|
@ -185,13 +202,47 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
|
|
||||||
protected abstract IPolicy<O, A> getPolicy(NN net);
|
protected abstract IPolicy<O, A> getPolicy(NN net);
|
||||||
|
|
||||||
protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep);
|
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
|
||||||
|
|
||||||
|
private Learning.InitMdp<Observation> refacInitMdp() {
|
||||||
|
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
||||||
|
IHistoryProcessor hp = getHistoryProcessor();
|
||||||
|
|
||||||
|
Observation observation = mdp.reset();
|
||||||
|
|
||||||
|
int step = 0;
|
||||||
|
double reward = 0;
|
||||||
|
|
||||||
|
boolean isHistoryProcessor = hp != null;
|
||||||
|
|
||||||
|
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||||
|
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
||||||
|
|
||||||
|
while (step < requiredFrame && !mdp.isDone()) {
|
||||||
|
|
||||||
|
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||||
|
|
||||||
|
StepReply<Observation> stepReply = mdp.step(action);
|
||||||
|
reward += stepReply.getReward();
|
||||||
|
observation = stepReply.getObservation();
|
||||||
|
|
||||||
|
step++;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Learning.InitMdp(step, observation, reward);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public void incrementStep() {
|
||||||
|
++stepCounter;
|
||||||
|
}
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Value
|
@Value
|
||||||
public static class SubEpochReturn<O> {
|
public static class SubEpochReturn {
|
||||||
int steps;
|
int steps;
|
||||||
O lastObs;
|
Observation lastObs;
|
||||||
double reward;
|
double reward;
|
||||||
double score;
|
double score;
|
||||||
}
|
}
|
||||||
|
@ -206,8 +257,8 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
double score;
|
double score;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class RunContext<O extends Encodable> {
|
private static class RunContext {
|
||||||
private O obs;
|
private Observation obs;
|
||||||
private double rewards;
|
private double rewards;
|
||||||
private int epochElapsedSteps;
|
private int epochElapsedSteps;
|
||||||
private double score;
|
private double score;
|
||||||
|
|
|
@ -25,9 +25,11 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
@ -40,7 +42,7 @@ import java.util.Stack;
|
||||||
* Async Learning specialized for the Discrete Domain
|
* Async Learning specialized for the Discrete Domain
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet>
|
public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
extends AsyncThread<O, Integer, DiscreteSpace, NN> {
|
extends AsyncThread<O, Integer, DiscreteSpace, NN> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -61,14 +63,14 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
||||||
* @param nstep the number of max nstep (step until t_max or state is terminal)
|
* @param nstep the number of max nstep (step until t_max or state is terminal)
|
||||||
* @return subepoch training informations
|
* @return subepoch training informations
|
||||||
*/
|
*/
|
||||||
public SubEpochReturn<O> trainSubEpoch(O sObs, int nstep) {
|
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
|
||||||
|
|
||||||
synchronized (getAsyncGlobal()) {
|
synchronized (getAsyncGlobal()) {
|
||||||
current.copy(getAsyncGlobal().getCurrent());
|
current.copy(getAsyncGlobal().getCurrent());
|
||||||
}
|
}
|
||||||
Stack<MiniTrans<Integer>> rewards = new Stack<>();
|
Stack<MiniTrans<Integer>> rewards = new Stack<>();
|
||||||
|
|
||||||
O obs = sObs;
|
Observation obs = sObs;
|
||||||
IPolicy<O, Integer> policy = getPolicy(current);
|
IPolicy<O, Integer> policy = getPolicy(current);
|
||||||
|
|
||||||
Integer action;
|
Integer action;
|
||||||
|
@ -81,93 +83,53 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (!getMdp().isDone() && i < nstep * skipFrame) {
|
while (!getMdp().isDone() && i < nstep * skipFrame) {
|
||||||
|
|
||||||
INDArray input = Learning.getInput(getMdp(), obs);
|
|
||||||
INDArray hstack = null;
|
|
||||||
|
|
||||||
if (hp != null) {
|
|
||||||
hp.record(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
//if step of training, just repeat lastAction
|
//if step of training, just repeat lastAction
|
||||||
if (i % skipFrame != 0 && lastAction != null) {
|
if (i % skipFrame != 0 && lastAction != null) {
|
||||||
action = lastAction;
|
action = lastAction;
|
||||||
} else {
|
} else {
|
||||||
hstack = processHistory(input);
|
action = policy.nextAction(obs);
|
||||||
action = policy.nextAction(hstack);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StepReply<O> stepReply = getMdp().step(action);
|
StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action);
|
||||||
accuReward += stepReply.getReward() * getConf().getRewardFactor();
|
accuReward += stepReply.getReward() * getConf().getRewardFactor();
|
||||||
|
|
||||||
//if it's not a skipped frame, you can do a step of training
|
//if it's not a skipped frame, you can do a step of training
|
||||||
if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) {
|
if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) {
|
||||||
obs = stepReply.getObservation();
|
|
||||||
|
|
||||||
if (hstack == null) {
|
INDArray[] output = current.outputAll(obs.getData());
|
||||||
hstack = processHistory(input);
|
rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
|
||||||
}
|
|
||||||
INDArray[] output = current.outputAll(hstack);
|
|
||||||
rewards.add(new MiniTrans(hstack, action, output, accuReward));
|
|
||||||
|
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
obs = stepReply.getObservation();
|
||||||
|
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
|
|
||||||
i++;
|
i++;
|
||||||
|
incrementStep();
|
||||||
lastAction = action;
|
lastAction = action;
|
||||||
}
|
}
|
||||||
|
|
||||||
//a bit of a trick usable because of how the stack is treated to init R
|
//a bit of a trick usable because of how the stack is treated to init R
|
||||||
INDArray input = Learning.getInput(getMdp(), obs);
|
// FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored.
|
||||||
INDArray hstack = processHistory(input);
|
|
||||||
|
|
||||||
if (hp != null) {
|
|
||||||
hp.record(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (getMdp().isDone() && i < nstep * skipFrame)
|
if (getMdp().isDone() && i < nstep * skipFrame)
|
||||||
rewards.add(new MiniTrans(hstack, null, null, 0));
|
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
|
||||||
else {
|
else {
|
||||||
INDArray[] output = null;
|
INDArray[] output = null;
|
||||||
if (getConf().getTargetDqnUpdateFreq() == -1)
|
if (getConf().getTargetDqnUpdateFreq() == -1)
|
||||||
output = current.outputAll(hstack);
|
output = current.outputAll(obs.getData());
|
||||||
else synchronized (getAsyncGlobal()) {
|
else synchronized (getAsyncGlobal()) {
|
||||||
output = getAsyncGlobal().getTarget().outputAll(hstack);
|
output = getAsyncGlobal().getTarget().outputAll(obs.getData());
|
||||||
}
|
}
|
||||||
double maxQ = Nd4j.max(output[0]).getDouble(0);
|
double maxQ = Nd4j.max(output[0]).getDouble(0);
|
||||||
rewards.add(new MiniTrans(hstack, null, output, maxQ));
|
rewards.add(new MiniTrans(obs.getData(), null, output, maxQ));
|
||||||
}
|
}
|
||||||
|
|
||||||
getAsyncGlobal().enqueue(calcGradient(current, rewards), i);
|
getAsyncGlobal().enqueue(calcGradient(current, rewards), i);
|
||||||
|
|
||||||
return new SubEpochReturn<O>(i, obs, reward, current.getLatestScore());
|
return new SubEpochReturn(i, obs, reward, current.getLatestScore());
|
||||||
}
|
|
||||||
|
|
||||||
protected INDArray processHistory(INDArray input) {
|
|
||||||
IHistoryProcessor hp = getHistoryProcessor();
|
|
||||||
INDArray[] history;
|
|
||||||
if (hp != null) {
|
|
||||||
hp.add(input);
|
|
||||||
history = hp.getHistory();
|
|
||||||
} else
|
|
||||||
history = new INDArray[] {input};
|
|
||||||
//concat the history into a single INDArray input
|
|
||||||
INDArray hstack = Transition.concat(history);
|
|
||||||
if (hp != null) {
|
|
||||||
hstack.muli(1.0 / hp.getScale());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (getCurrent().isRecurrent()) {
|
|
||||||
//flatten everything for the RNN
|
|
||||||
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()), 1));
|
|
||||||
} else {
|
|
||||||
//if input is not 2d, you have to append that the batch is 1 length high
|
|
||||||
if (hstack.shape().length > 2)
|
|
||||||
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape())));
|
|
||||||
}
|
|
||||||
|
|
||||||
return hstack;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> rewards);
|
public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> rewards);
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
|
public interface IAsyncLearning {
|
||||||
|
void terminate();
|
||||||
|
}
|
|
@ -53,7 +53,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
||||||
this.iActorCritic = iActorCritic;
|
this.iActorCritic = iActorCritic;
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
|
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this);
|
||||||
|
|
||||||
Integer seed = conf.getSeed();
|
Integer seed = conf.getSeed();
|
||||||
Random rnd = Nd4j.getRandom();
|
Random rnd = Nd4j.getRandom();
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
@ -46,13 +47,13 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
@Getter
|
@Getter
|
||||||
final protected A3CDiscrete.A3CConfiguration conf;
|
final protected A3CDiscrete.A3CConfiguration conf;
|
||||||
@Getter
|
@Getter
|
||||||
final protected AsyncGlobal<IActorCritic> asyncGlobal;
|
final protected IAsyncGlobal<IActorCritic> asyncGlobal;
|
||||||
@Getter
|
@Getter
|
||||||
final protected int threadNumber;
|
final protected int threadNumber;
|
||||||
|
|
||||||
final private Random rnd;
|
final private Random rnd;
|
||||||
|
|
||||||
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
|
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal,
|
||||||
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
|
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
|
||||||
int threadNumber) {
|
int threadNumber) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
|
|
|
@ -46,7 +46,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
|
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -150,6 +150,9 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
}
|
}
|
||||||
|
|
||||||
private InitMdp<Observation> refacInitMdp() {
|
private InitMdp<Observation> refacInitMdp() {
|
||||||
|
getQNetwork().reset();
|
||||||
|
getTargetQNetwork().reset();
|
||||||
|
|
||||||
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
||||||
IHistoryProcessor hp = getHistoryProcessor();
|
IHistoryProcessor hp = getHistoryProcessor();
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ import java.util.ArrayList;
|
||||||
*
|
*
|
||||||
* DQN or Deep Q-Learning in the Discrete domain
|
* DQN or Deep Q-Learning in the Discrete domain
|
||||||
*
|
*
|
||||||
* https://arxiv.org/abs/1312.5602
|
* http://arxiv.org/abs/1312.5602
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> {
|
public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> {
|
||||||
|
|
|
@ -29,7 +29,15 @@ public class Observation {
|
||||||
private final DataSet data;
|
private final DataSet data;
|
||||||
|
|
||||||
public Observation(INDArray[] data) {
|
public Observation(INDArray[] data) {
|
||||||
this(new org.nd4j.linalg.dataset.DataSet(Nd4j.concat(0, data), null));
|
this(data, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Observation(INDArray[] data, boolean shouldReshape) {
|
||||||
|
INDArray features = Nd4j.concat(0, data);
|
||||||
|
if(shouldReshape) {
|
||||||
|
features = reshape(features);
|
||||||
|
}
|
||||||
|
this.data = new org.nd4j.linalg.dataset.DataSet(features, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: Remove -- only used in unit tests
|
// FIXME: Remove -- only used in unit tests
|
||||||
|
@ -37,6 +45,15 @@ public class Observation {
|
||||||
this.data = new org.nd4j.linalg.dataset.DataSet(data, null);
|
this.data = new org.nd4j.linalg.dataset.DataSet(data, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private INDArray reshape(INDArray source) {
|
||||||
|
long[] shape = source.shape();
|
||||||
|
long[] nshape = new long[shape.length + 1];
|
||||||
|
nshape[0] = 1;
|
||||||
|
System.arraycopy(shape, 0, nshape, 1, shape.length);
|
||||||
|
|
||||||
|
return source.reshape(nshape);
|
||||||
|
}
|
||||||
|
|
||||||
private Observation(DataSet data) {
|
private Observation(DataSet data) {
|
||||||
this.data = data;
|
this.data = data;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
|
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
|
||||||
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
|
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
@ -65,6 +66,11 @@ public class ACPolicy<O extends Encodable> extends Policy<O, Integer> {
|
||||||
return actorCritic;
|
return actorCritic;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer nextAction(Observation obs) {
|
||||||
|
return nextAction(obs.getData());
|
||||||
|
}
|
||||||
|
|
||||||
public Integer nextAction(INDArray input) {
|
public Integer nextAction(INDArray input) {
|
||||||
INDArray output = actorCritic.outputAll(input)[1];
|
INDArray output = actorCritic.outputAll(input)[1];
|
||||||
if (rnd == null) {
|
if (rnd == null) {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.rl4j.policy;
|
package org.deeplearning4j.rl4j.policy;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
@ -43,6 +44,11 @@ public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
|
||||||
return dqn;
|
return dqn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer nextAction(Observation obs) {
|
||||||
|
return nextAction(obs.getData());
|
||||||
|
}
|
||||||
|
|
||||||
public Integer nextAction(INDArray input) {
|
public Integer nextAction(INDArray input) {
|
||||||
|
|
||||||
INDArray output = dqn.output(input);
|
INDArray output = dqn.output(input);
|
||||||
|
|
|
@ -20,6 +20,7 @@ import lombok.AllArgsConstructor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQN;
|
import org.deeplearning4j.rl4j.network.dqn.DQN;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
@ -44,6 +45,11 @@ public class DQNPolicy<O extends Encodable> extends Policy<O, Integer> {
|
||||||
return dqn;
|
return dqn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer nextAction(Observation obs) {
|
||||||
|
return nextAction(obs.getData());
|
||||||
|
}
|
||||||
|
|
||||||
public Integer nextAction(INDArray input) {
|
public Integer nextAction(INDArray input) {
|
||||||
INDArray output = dqn.output(input);
|
INDArray output = dqn.output(input);
|
||||||
return Learning.getMaxAction(output);
|
return Learning.getMaxAction(output);
|
||||||
|
|
|
@ -2,6 +2,7 @@ package org.deeplearning4j.rl4j.policy;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -9,4 +10,5 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
public interface IPolicy<O, A> {
|
public interface IPolicy<O, A> {
|
||||||
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
||||||
A nextAction(INDArray input);
|
A nextAction(INDArray input);
|
||||||
|
A nextAction(Observation observation);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,15 +16,21 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.policy;
|
package org.deeplearning4j.rl4j.policy;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.StepCountable;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
@ -39,7 +45,7 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
|
||||||
|
|
||||||
public abstract NeuralNet getNeuralNet();
|
public abstract NeuralNet getNeuralNet();
|
||||||
|
|
||||||
public abstract A nextAction(INDArray input);
|
public abstract A nextAction(Observation obs);
|
||||||
|
|
||||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) {
|
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) {
|
||||||
return play(mdp, (IHistoryProcessor)null);
|
return play(mdp, (IHistoryProcessor)null);
|
||||||
|
@ -51,66 +57,81 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
||||||
|
RefacStepCountable stepCountable = new RefacStepCountable();
|
||||||
|
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, stepCountable);
|
||||||
|
|
||||||
boolean isHistoryProcessor = hp != null;
|
boolean isHistoryProcessor = hp != null;
|
||||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||||
|
|
||||||
getNeuralNet().reset();
|
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp);
|
||||||
Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);
|
Observation obs = initMdp.getLastObs();
|
||||||
O obs = initMdp.getLastObs();
|
|
||||||
|
|
||||||
double reward = initMdp.getReward();
|
double reward = initMdp.getReward();
|
||||||
|
|
||||||
A lastAction = mdp.getActionSpace().noOp();
|
A lastAction = mdpWrapper.getActionSpace().noOp();
|
||||||
A action;
|
A action;
|
||||||
int step = initMdp.getSteps();
|
stepCountable.setStepCounter(initMdp.getSteps());
|
||||||
INDArray[] history = null;
|
|
||||||
|
|
||||||
INDArray input = Learning.getInput(mdp, obs);
|
while (!mdpWrapper.isDone()) {
|
||||||
|
|
||||||
while (!mdp.isDone()) {
|
if (stepCountable.getStepCounter() % skipFrame != 0) {
|
||||||
|
|
||||||
if (step % skipFrame != 0) {
|
|
||||||
action = lastAction;
|
action = lastAction;
|
||||||
} else {
|
} else {
|
||||||
|
action = nextAction(obs);
|
||||||
if (history == null) {
|
|
||||||
if (isHistoryProcessor) {
|
|
||||||
hp.add(input);
|
|
||||||
history = hp.getHistory();
|
|
||||||
} else
|
|
||||||
history = new INDArray[] {input};
|
|
||||||
}
|
|
||||||
INDArray hstack = Transition.concat(history);
|
|
||||||
if (isHistoryProcessor) {
|
|
||||||
hstack.muli(1.0 / hp.getScale());
|
|
||||||
}
|
|
||||||
if (getNeuralNet().isRecurrent()) {
|
|
||||||
//flatten everything for the RNN
|
|
||||||
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()), 1));
|
|
||||||
} else {
|
|
||||||
if (hstack.shape().length > 2)
|
|
||||||
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape())));
|
|
||||||
}
|
|
||||||
action = nextAction(hstack);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lastAction = action;
|
lastAction = action;
|
||||||
|
|
||||||
StepReply<O> stepReply = mdp.step(action);
|
StepReply<Observation> stepReply = mdpWrapper.step(action);
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
|
|
||||||
input = Learning.getInput(mdp, stepReply.getObservation());
|
obs = stepReply.getObservation();
|
||||||
if (isHistoryProcessor) {
|
stepCountable.increment();
|
||||||
hp.record(input);
|
|
||||||
hp.add(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
history = isHistoryProcessor ? hp.getHistory()
|
|
||||||
: new INDArray[] {Learning.getInput(mdp, stepReply.getObservation())};
|
|
||||||
step++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return reward;
|
return reward;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
|
||||||
|
getNeuralNet().reset();
|
||||||
|
Observation observation = mdpWrapper.reset();
|
||||||
|
|
||||||
|
int step = 0;
|
||||||
|
double reward = 0;
|
||||||
|
|
||||||
|
boolean isHistoryProcessor = hp != null;
|
||||||
|
|
||||||
|
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||||
|
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
||||||
|
|
||||||
|
while (step < requiredFrame && !mdpWrapper.isDone()) {
|
||||||
|
|
||||||
|
A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||||
|
|
||||||
|
StepReply<Observation> stepReply = mdpWrapper.step(action);
|
||||||
|
reward += stepReply.getReward();
|
||||||
|
observation = stepReply.getObservation();
|
||||||
|
|
||||||
|
step++;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Learning.InitMdp(step, observation, reward);
|
||||||
|
}
|
||||||
|
|
||||||
|
private class RefacStepCountable implements StepCountable {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
private int stepCounter = 0;
|
||||||
|
|
||||||
|
public void increment() {
|
||||||
|
++stepCounter;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getStepCounter() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import lombok.Getter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.StepCountable;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
@ -12,21 +13,53 @@ import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
|
public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final MDP<O, A, AS> wrappedMDP;
|
private final MDP<O, A, AS> wrappedMDP;
|
||||||
@Getter
|
@Getter
|
||||||
private final WrapperObservationSpace observationSpace;
|
private final WrapperObservationSpace observationSpace;
|
||||||
private final ILearning learning;
|
private final ILearning learning;
|
||||||
|
private IHistoryProcessor historyProcessor;
|
||||||
|
private final StepCountable stepCountable;
|
||||||
private int skipFrame;
|
private int skipFrame;
|
||||||
|
|
||||||
private int step = 0;
|
private int step = 0;
|
||||||
|
|
||||||
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning) {
|
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning) {
|
||||||
|
this(wrappedMDP, learning, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
|
||||||
|
this(wrappedMDP, null, historyProcessor, stepCountable);
|
||||||
|
}
|
||||||
|
|
||||||
|
private LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
|
||||||
this.wrappedMDP = wrappedMDP;
|
this.wrappedMDP = wrappedMDP;
|
||||||
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
|
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
|
||||||
this.learning = learning;
|
this.learning = learning;
|
||||||
|
this.historyProcessor = historyProcessor;
|
||||||
|
this.stepCountable = stepCountable;
|
||||||
|
}
|
||||||
|
|
||||||
|
private IHistoryProcessor getHistoryProcessor() {
|
||||||
|
if(historyProcessor != null) {
|
||||||
|
return historyProcessor;
|
||||||
|
}
|
||||||
|
|
||||||
|
return learning.getHistoryProcessor();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
||||||
|
this.historyProcessor = historyProcessor;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getStep() {
|
||||||
|
if(stepCountable != null) {
|
||||||
|
return stepCountable.getStepCounter();
|
||||||
|
}
|
||||||
|
|
||||||
|
return learning.getStepCounter();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -38,13 +71,12 @@ public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||||
public Observation reset() {
|
public Observation reset() {
|
||||||
INDArray rawObservation = getInput(wrappedMDP.reset());
|
INDArray rawObservation = getInput(wrappedMDP.reset());
|
||||||
|
|
||||||
IHistoryProcessor historyProcessor = learning.getHistoryProcessor();
|
IHistoryProcessor historyProcessor = getHistoryProcessor();
|
||||||
if(historyProcessor != null) {
|
if(historyProcessor != null) {
|
||||||
historyProcessor.record(rawObservation.dup());
|
historyProcessor.record(rawObservation);
|
||||||
rawObservation.muli(1.0 / historyProcessor.getScale());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Observation observation = new Observation(new INDArray[] { rawObservation });
|
Observation observation = new Observation(new INDArray[] { rawObservation }, false);
|
||||||
|
|
||||||
if(historyProcessor != null) {
|
if(historyProcessor != null) {
|
||||||
skipFrame = historyProcessor.getConf().getSkipFrame();
|
skipFrame = historyProcessor.getConf().getSkipFrame();
|
||||||
|
@ -55,14 +87,9 @@ public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||||
return observation;
|
return observation;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() {
|
|
||||||
wrappedMDP.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public StepReply<Observation> step(A a) {
|
public StepReply<Observation> step(A a) {
|
||||||
IHistoryProcessor historyProcessor = learning.getHistoryProcessor();
|
IHistoryProcessor historyProcessor = getHistoryProcessor();
|
||||||
|
|
||||||
StepReply<O> rawStepReply = wrappedMDP.step(a);
|
StepReply<O> rawStepReply = wrappedMDP.step(a);
|
||||||
INDArray rawObservation = getInput(rawStepReply.getObservation());
|
INDArray rawObservation = getInput(rawStepReply.getObservation());
|
||||||
|
@ -71,11 +98,10 @@ public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||||
|
|
||||||
int requiredFrame = 0;
|
int requiredFrame = 0;
|
||||||
if(historyProcessor != null) {
|
if(historyProcessor != null) {
|
||||||
historyProcessor.record(rawObservation.dup());
|
historyProcessor.record(rawObservation);
|
||||||
rawObservation.muli(1.0 / historyProcessor.getScale());
|
|
||||||
|
|
||||||
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
|
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
|
||||||
if ((learning.getStepCounter() % skipFrame == 0 && step >= requiredFrame)
|
if ((getStep() % skipFrame == 0 && step >= requiredFrame)
|
||||||
|| (step % skipFrame == 0 && step < requiredFrame )){
|
|| (step % skipFrame == 0 && step < requiredFrame )){
|
||||||
historyProcessor.add(rawObservation);
|
historyProcessor.add(rawObservation);
|
||||||
}
|
}
|
||||||
|
@ -83,15 +109,21 @@ public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||||
|
|
||||||
Observation observation;
|
Observation observation;
|
||||||
if(historyProcessor != null && step >= requiredFrame) {
|
if(historyProcessor != null && step >= requiredFrame) {
|
||||||
observation = new Observation(historyProcessor.getHistory());
|
observation = new Observation(historyProcessor.getHistory(), true);
|
||||||
|
observation.getData().muli(1.0 / historyProcessor.getScale());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
observation = new Observation(new INDArray[] { rawObservation });
|
observation = new Observation(new INDArray[] { rawObservation }, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
|
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
wrappedMDP.close();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isDone() {
|
public boolean isDone() {
|
||||||
return wrappedMDP.isDone();
|
return wrappedMDP.isDone();
|
||||||
|
@ -103,7 +135,7 @@ public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||||
}
|
}
|
||||||
|
|
||||||
private INDArray getInput(O obs) {
|
private INDArray getInput(O obs) {
|
||||||
INDArray arr = Nd4j.create(obs.toArray());
|
INDArray arr = Nd4j.create(((Encodable)obs).toArray());
|
||||||
int[] shape = observationSpace.getShape();
|
int[] shape = observationSpace.getShape();
|
||||||
if (shape.length == 1)
|
if (shape.length == 1)
|
||||||
return arr.reshape(new long[] {1, arr.length()});
|
return arr.reshape(new long[] {1, arr.length()});
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class AsyncLearningTest {
|
||||||
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
||||||
public final MockPolicy policy = new MockPolicy();
|
public final MockPolicy policy = new MockPolicy();
|
||||||
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
|
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
|
||||||
public final MockTrainingListener listener = new MockTrainingListener();
|
public final MockTrainingListener listener = new MockTrainingListener(asyncGlobal);
|
||||||
|
|
||||||
public TestContext() {
|
public TestContext() {
|
||||||
sut.addListener(listener);
|
sut.addListener(listener);
|
||||||
|
|
|
@ -2,16 +2,17 @@ package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Stack;
|
import java.util.Stack;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
@ -21,37 +22,51 @@ public class AsyncThreadDiscreteTest {
|
||||||
@Test
|
@Test
|
||||||
public void refac_AsyncThreadDiscrete_trainSubEpoch() {
|
public void refac_AsyncThreadDiscrete_trainSubEpoch() {
|
||||||
// Arrange
|
// Arrange
|
||||||
|
int numEpochs = 1;
|
||||||
MockNeuralNet nnMock = new MockNeuralNet();
|
MockNeuralNet nnMock = new MockNeuralNet();
|
||||||
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
|
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
|
||||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock);
|
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock);
|
||||||
|
asyncGlobalMock.setMaxLoops(hpConf.getSkipFrame() * numEpochs);
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||||
TrainingListenerList listeners = new TrainingListenerList();
|
TrainingListenerList listeners = new TrainingListenerList();
|
||||||
MockPolicy policyMock = new MockPolicy();
|
MockPolicy policyMock = new MockPolicy();
|
||||||
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 0, 5,0, 0, 0, 0);
|
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 16, 0, 0, 2, 5,0, 0, 0, 0);
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
|
||||||
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
|
|
||||||
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
||||||
MockEncodable obs = new MockEncodable(123);
|
|
||||||
|
|
||||||
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(1)));
|
|
||||||
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(2)));
|
|
||||||
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(3)));
|
|
||||||
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(4)));
|
|
||||||
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(5)));
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
AsyncThread.SubEpochReturn<MockEncodable> result = sut.trainSubEpoch(obs, 2);
|
sut.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(4, result.getSteps());
|
assertEquals(2, sut.trainSubEpochResults.size());
|
||||||
assertEquals(6.0, result.getReward(), 0.00001);
|
double[][] expectedLastObservations = new double[][] {
|
||||||
assertEquals(0.0, result.getScore(), 0.00001);
|
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||||
assertEquals(3.0, result.getLastObs().toArray()[0], 0.00001);
|
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
||||||
assertEquals(1, asyncGlobalMock.enqueueCallCount);
|
};
|
||||||
|
double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 };
|
||||||
|
for(int i = 0; i < 2; ++i) {
|
||||||
|
AsyncThread.SubEpochReturn result = sut.trainSubEpochResults.get(i);
|
||||||
|
assertEquals(4, result.getSteps());
|
||||||
|
assertEquals(expectedSubEpochReturnRewards[i], result.getReward(), 0.00001);
|
||||||
|
assertEquals(0.0, result.getScore(), 0.00001);
|
||||||
|
|
||||||
|
double[] expectedLastObservation = expectedLastObservations[i];
|
||||||
|
assertEquals(expectedLastObservation.length, result.getLastObs().getData().shape()[1]);
|
||||||
|
for(int j = 0; j < expectedLastObservation.length; ++j) {
|
||||||
|
assertEquals(expectedLastObservation[j], 255.0 * result.getLastObs().getData().getDouble(j), 0.00001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertEquals(2, asyncGlobalMock.enqueueCallCount);
|
||||||
|
|
||||||
// HistoryProcessor
|
// HistoryProcessor
|
||||||
assertEquals(10, hpMock.addCalls.size());
|
double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
|
||||||
double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 };
|
assertEquals(expectedAddValues.length, hpMock.addCalls.size());
|
||||||
|
for(int i = 0; i < expectedAddValues.length; ++i) {
|
||||||
|
assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001);
|
||||||
|
}
|
||||||
|
|
||||||
|
double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, };
|
||||||
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
|
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
|
||||||
for(int i = 0; i < expectedRecordValues.length; ++i) {
|
for(int i = 0; i < expectedRecordValues.length; ++i) {
|
||||||
assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001);
|
assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001);
|
||||||
|
@ -59,49 +74,89 @@ public class AsyncThreadDiscreteTest {
|
||||||
|
|
||||||
// Policy
|
// Policy
|
||||||
double[][] expectedPolicyInputs = new double[][] {
|
double[][] expectedPolicyInputs = new double[][] {
|
||||||
new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||||
new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 },
|
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||||
new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 },
|
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||||
};
|
};
|
||||||
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
|
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
|
||||||
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
|
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
|
||||||
double[] expectedRow = expectedPolicyInputs[i];
|
double[] expectedRow = expectedPolicyInputs[i];
|
||||||
INDArray input = policyMock.actionInputs.get(i);
|
INDArray input = policyMock.actionInputs.get(i);
|
||||||
assertEquals(expectedRow.length, input.shape()[0]);
|
assertEquals(expectedRow.length, input.shape()[1]);
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NeuralNetwork
|
// NeuralNetwork
|
||||||
assertEquals(1, nnMock.copyCallCount);
|
assertEquals(2, nnMock.copyCallCount);
|
||||||
double[][] expectedNNInputs = new double[][] {
|
double[][] expectedNNInputs = new double[][] {
|
||||||
new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||||
new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 },
|
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
||||||
new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 },
|
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||||
new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 },
|
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||||
|
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
||||||
};
|
};
|
||||||
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
|
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
|
||||||
for(int i = 0; i < expectedNNInputs.length; ++i) {
|
for(int i = 0; i < expectedNNInputs.length; ++i) {
|
||||||
double[] expectedRow = expectedNNInputs[i];
|
double[] expectedRow = expectedNNInputs[i];
|
||||||
INDArray input = nnMock.outputAllInputs.get(i);
|
INDArray input = nnMock.outputAllInputs.get(i);
|
||||||
assertEquals(expectedRow.length, input.shape()[0]);
|
assertEquals(expectedRow.length, input.shape()[1]);
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int arrayIdx = 0;
|
||||||
|
double[][][] expectedMinitransObs = new double[][][] {
|
||||||
|
new double[][] {
|
||||||
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
|
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||||
|
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: The last minitrans contains the next observation
|
||||||
|
},
|
||||||
|
new double[][] {
|
||||||
|
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||||
|
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||||
|
new double[] { 8.0, 9.0, 11.0, 13.0, 15 }, // FIXME: The last minitrans contains the next observation
|
||||||
|
}
|
||||||
|
};
|
||||||
|
double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
|
||||||
|
double[] expectedRewards = new double[] { 0.0, 0.0, 3.0, 0.0, 0.0, 6.0 };
|
||||||
|
|
||||||
|
assertEquals(2, sut.rewards.size());
|
||||||
|
for(int rewardIdx = 0; rewardIdx < 2; ++rewardIdx) {
|
||||||
|
Stack<MiniTrans<Integer>> miniTransStack = sut.rewards.get(rewardIdx);
|
||||||
|
|
||||||
|
for (int i = 0; i < expectedMinitransObs[rewardIdx].length; ++i) {
|
||||||
|
MiniTrans minitrans = miniTransStack.get(i);
|
||||||
|
|
||||||
|
// Observation
|
||||||
|
double[] expectedRow = expectedMinitransObs[rewardIdx][i];
|
||||||
|
INDArray realRewards = minitrans.getObs();
|
||||||
|
assertEquals(expectedRow.length, realRewards.shape()[1]);
|
||||||
|
for (int j = 0; j < expectedRow.length; ++j) {
|
||||||
|
assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * realRewards.getDouble(j), 0.00001);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(expectedOutputs[arrayIdx], minitrans.getOutput()[0].getDouble(0), 0.00001);
|
||||||
|
assertEquals(expectedRewards[arrayIdx], minitrans.getReward(), 0.00001);
|
||||||
|
++arrayIdx;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
|
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
|
||||||
|
|
||||||
private final IAsyncGlobal<MockNeuralNet> asyncGlobal;
|
private final MockAsyncGlobal asyncGlobal;
|
||||||
private final MockPolicy policy;
|
private final MockPolicy policy;
|
||||||
private final MockAsyncConfiguration config;
|
private final MockAsyncConfiguration config;
|
||||||
|
|
||||||
public TestAsyncThreadDiscrete(IAsyncGlobal<MockNeuralNet> asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
|
public final List<SubEpochReturn> trainSubEpochResults = new ArrayList<SubEpochReturn>();
|
||||||
|
public final List<Stack<MiniTrans<Integer>>> rewards = new ArrayList<Stack<MiniTrans<Integer>>>();
|
||||||
|
|
||||||
|
public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
|
||||||
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
|
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
|
||||||
MockAsyncConfiguration config, IHistoryProcessor hp) {
|
MockAsyncConfiguration config, IHistoryProcessor hp) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
|
@ -113,6 +168,7 @@ public class AsyncThreadDiscreteTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack<MiniTrans<Integer>> rewards) {
|
public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack<MiniTrans<Integer>> rewards) {
|
||||||
|
this.rewards.add(rewards);
|
||||||
return new Gradient[0];
|
return new Gradient[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,5 +186,13 @@ public class AsyncThreadDiscreteTest {
|
||||||
protected IPolicy<MockEncodable, Integer> getPolicy(MockNeuralNet net) {
|
protected IPolicy<MockEncodable, Integer> getPolicy(MockNeuralNet net) {
|
||||||
return policy;
|
return policy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
|
||||||
|
asyncGlobal.increaseCurrentLoop();
|
||||||
|
SubEpochReturn result = super.trainSubEpoch(sObs, nstep);
|
||||||
|
trainSubEpochResults.add(result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,11 +6,14 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -23,107 +26,100 @@ public class AsyncThreadTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_newEpochStarted_expect_neuralNetworkReset() {
|
public void when_newEpochStarted_expect_neuralNetworkReset() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestContext context = new TestContext();
|
int numberOfEpochs = 5;
|
||||||
context.listener.setRemainingOnNewEpochCallCount(5);
|
TestContext context = new TestContext(numberOfEpochs);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.run();
|
context.sut.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(6, context.neuralNet.resetCallCount);
|
assertEquals(numberOfEpochs, context.neuralNet.resetCallCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_onNewEpochReturnsStop_expect_threadStopped() {
|
public void when_onNewEpochReturnsStop_expect_threadStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestContext context = new TestContext();
|
int stopAfterNumCalls = 1;
|
||||||
context.listener.setRemainingOnNewEpochCallCount(1);
|
TestContext context = new TestContext(100000);
|
||||||
|
context.listener.setRemainingOnNewEpochCallCount(stopAfterNumCalls);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.run();
|
context.sut.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(2, context.listener.onNewEpochCallCount);
|
assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: The call that returns stop is counted
|
||||||
assertEquals(1, context.listener.onEpochTrainingResultCallCount);
|
assertEquals(stopAfterNumCalls, context.listener.onEpochTrainingResultCallCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_epochTrainingResultReturnsStop_expect_threadStopped() {
|
public void when_epochTrainingResultReturnsStop_expect_threadStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestContext context = new TestContext();
|
int stopAfterNumCalls = 1;
|
||||||
context.listener.setRemainingOnEpochTrainingResult(1);
|
TestContext context = new TestContext(100000);
|
||||||
|
context.listener.setRemainingOnEpochTrainingResult(stopAfterNumCalls);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.run();
|
context.sut.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(2, context.listener.onNewEpochCallCount);
|
assertEquals(stopAfterNumCalls + 1, context.listener.onEpochTrainingResultCallCount); // +1: The call that returns stop is counted
|
||||||
assertEquals(2, context.listener.onEpochTrainingResultCallCount);
|
assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: onNewEpoch is called on the epoch that onEpochTrainingResult() will stop
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_run_expect_preAndPostEpochCalled() {
|
public void when_run_expect_preAndPostEpochCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestContext context = new TestContext();
|
int numberOfEpochs = 5;
|
||||||
|
TestContext context = new TestContext(numberOfEpochs);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.run();
|
context.sut.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(6, context.sut.preEpochCallCount);
|
assertEquals(numberOfEpochs, context.sut.preEpochCallCount);
|
||||||
assertEquals(6, context.sut.postEpochCallCount);
|
assertEquals(numberOfEpochs, context.sut.postEpochCallCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() {
|
public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestContext context = new TestContext();
|
int numberOfEpochs = 5;
|
||||||
|
TestContext context = new TestContext(numberOfEpochs);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.run();
|
context.sut.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(5, context.listener.statEntries.size());
|
assertEquals(numberOfEpochs, context.listener.statEntries.size());
|
||||||
int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 };
|
int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 };
|
||||||
for(int i = 0; i < 5; ++i) {
|
double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init
|
||||||
|
+ 1.0; // Reward from trainSubEpoch()
|
||||||
|
for(int i = 0; i < numberOfEpochs; ++i) {
|
||||||
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
|
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
|
||||||
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
|
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
|
||||||
assertEquals(i, statEntry.getEpochCounter());
|
assertEquals(i, statEntry.getEpochCounter());
|
||||||
assertEquals(38.0, statEntry.getReward(), 0.0001);
|
assertEquals(expectedReward, statEntry.getReward(), 0.0001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void when_run_expect_NeuralNetIsResetAtInitAndEveryEpoch() {
|
|
||||||
// Arrange
|
|
||||||
TestContext context = new TestContext();
|
|
||||||
|
|
||||||
// Act
|
|
||||||
context.sut.run();
|
|
||||||
|
|
||||||
// Assert
|
|
||||||
assertEquals(6, context.neuralNet.resetCallCount);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_run_expect_trainSubEpochCalled() {
|
public void when_run_expect_trainSubEpochCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestContext context = new TestContext();
|
int numberOfEpochs = 5;
|
||||||
|
TestContext context = new TestContext(numberOfEpochs);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.run();
|
context.sut.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(10, context.sut.trainSubEpochParams.size());
|
assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size());
|
||||||
for(int i = 0; i < 10; ++i) {
|
double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 };
|
||||||
|
for(int i = 0; i < context.sut.getEpochCounter(); ++i) {
|
||||||
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
|
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
|
||||||
if(i % 2 == 0) {
|
assertEquals(2, params.nstep);
|
||||||
assertEquals(2, params.nstep);
|
assertEquals(expectedObservation.length, params.obs.getData().shape()[1]);
|
||||||
assertEquals(8.0, params.obs.toArray()[0], 0.00001);
|
for(int j = 0; j < expectedObservation.length; ++j){
|
||||||
}
|
assertEquals(expectedObservation[j], 255.0 * params.obs.getData().getDouble(j), 0.00001);
|
||||||
else {
|
|
||||||
assertEquals(1, params.nstep);
|
|
||||||
assertNull(params.obs);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -136,30 +132,30 @@ public class AsyncThreadTest {
|
||||||
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0);
|
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0);
|
||||||
public final TrainingListenerList listeners = new TrainingListenerList();
|
public final TrainingListenerList listeners = new TrainingListenerList();
|
||||||
public final MockTrainingListener listener = new MockTrainingListener();
|
public final MockTrainingListener listener = new MockTrainingListener();
|
||||||
private final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf);
|
public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf);
|
||||||
|
|
||||||
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
|
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
|
||||||
|
|
||||||
public TestContext() {
|
public TestContext(int numEpochs) {
|
||||||
asyncGlobal.setMaxLoops(10);
|
asyncGlobal.setMaxLoops(numEpochs);
|
||||||
listeners.add(listener);
|
listeners.add(listener);
|
||||||
sut.setHistoryProcessor(historyProcessor);
|
sut.setHistoryProcessor(historyProcessor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class MockAsyncThread extends AsyncThread {
|
public static class MockAsyncThread extends AsyncThread<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
|
||||||
|
|
||||||
public int preEpochCallCount = 0;
|
public int preEpochCallCount = 0;
|
||||||
public int postEpochCallCount = 0;
|
public int postEpochCallCount = 0;
|
||||||
|
|
||||||
private final IAsyncGlobal asyncGlobal;
|
private final MockAsyncGlobal asyncGlobal;
|
||||||
private final MockNeuralNet neuralNet;
|
private final MockNeuralNet neuralNet;
|
||||||
private final AsyncConfiguration conf;
|
private final AsyncConfiguration conf;
|
||||||
|
|
||||||
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
|
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
|
||||||
|
|
||||||
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
|
public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, 0);
|
super(asyncGlobal, mdp, listeners, threadNumber, 0);
|
||||||
|
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
|
@ -180,7 +176,7 @@ public class AsyncThreadTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected NeuralNet getCurrent() {
|
protected MockNeuralNet getCurrent() {
|
||||||
return neuralNet;
|
return neuralNet;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,20 +191,22 @@ public class AsyncThreadTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Policy getPolicy(NeuralNet net) {
|
protected Policy getPolicy(MockNeuralNet net) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) {
|
protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) {
|
||||||
|
asyncGlobal.increaseCurrentLoop();
|
||||||
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
|
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
|
||||||
return new SubEpochReturn(1, null, 1.0, 1.0);
|
setStepCounter(getStepCounter() + nstep);
|
||||||
|
return new SubEpochReturn(nstep, null, 1.0, 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Getter
|
@Getter
|
||||||
public static class TrainSubEpochParams {
|
public static class TrainSubEpochParams {
|
||||||
Encodable obs;
|
Observation obs;
|
||||||
int nstep;
|
int nstep;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,181 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningThreadDiscrete;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Stack;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class A3CThreadDiscreteTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void refac_calcGradient() {
|
||||||
|
// Arrange
|
||||||
|
double gamma = 0.9;
|
||||||
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
|
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||||
|
A3CDiscrete.A3CConfiguration config = new A3CDiscrete.A3CConfiguration(0, 0, 0, 0, 0, 0, 0, gamma, 0);
|
||||||
|
MockActorCritic actorCriticMock = new MockActorCritic();
|
||||||
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
|
||||||
|
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
|
||||||
|
A3CThreadDiscrete sut = new A3CThreadDiscrete<MockEncodable>(mdpMock, asyncGlobalMock, config, 0, null, 0);
|
||||||
|
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
|
||||||
|
sut.setHistoryProcessor(hpMock);
|
||||||
|
|
||||||
|
double[][] minitransObs = new double[][] {
|
||||||
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
|
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||||
|
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||||
|
};
|
||||||
|
double[] outputs = new double[] { 1.0, 2.0, 3.0 };
|
||||||
|
double[] rewards = new double[] { 0.0, 0.0, 3.0 };
|
||||||
|
|
||||||
|
Stack<MiniTrans<Integer>> minitransList = new Stack<MiniTrans<Integer>>();
|
||||||
|
for(int i = 0; i < 3; ++i) {
|
||||||
|
INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1);
|
||||||
|
INDArray[] output = new INDArray[] {
|
||||||
|
Nd4j.zeros(5)
|
||||||
|
};
|
||||||
|
output[0].putScalar(i, outputs[i]);
|
||||||
|
minitransList.push(new MiniTrans<Integer>(obs, i, output, rewards[i]));
|
||||||
|
}
|
||||||
|
minitransList.push(new MiniTrans<Integer>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.calcGradient(actorCriticMock, minitransList);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, actorCriticMock.gradientParams.size());
|
||||||
|
INDArray input = actorCriticMock.gradientParams.get(0).getFirst();
|
||||||
|
INDArray[] labels = actorCriticMock.gradientParams.get(0).getSecond();
|
||||||
|
|
||||||
|
assertEquals(minitransObs.length, input.shape()[0]);
|
||||||
|
for(int i = 0; i < minitransObs.length; ++i) {
|
||||||
|
double[] expectedRow = minitransObs[i];
|
||||||
|
assertEquals(expectedRow.length, input.shape()[1]);
|
||||||
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
|
assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double latestReward = (gamma * 4.0) + 3.0;
|
||||||
|
double[] expectedLabels0 = new double[] { gamma * gamma * latestReward, gamma * latestReward, latestReward };
|
||||||
|
for(int i = 0; i < expectedLabels0.length; ++i) {
|
||||||
|
assertEquals(expectedLabels0[i], labels[0].getDouble(i), 0.00001);
|
||||||
|
}
|
||||||
|
double[][] expectedLabels1 = new double[][] {
|
||||||
|
new double[] { 4.346, 0.0, 0.0, 0.0, 0.0 },
|
||||||
|
new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 },
|
||||||
|
new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 },
|
||||||
|
};
|
||||||
|
|
||||||
|
assertArrayEquals(new long[] { expectedLabels0.length, 1 }, labels[0].shape());
|
||||||
|
|
||||||
|
for(int i = 0; i < expectedLabels1.length; ++i) {
|
||||||
|
double[] expectedRow = expectedLabels1[i];
|
||||||
|
assertEquals(expectedRow.length, labels[1].shape()[1]);
|
||||||
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
|
assertEquals(expectedRow[j], labels[1].getDouble(i, j), 0.00001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public class MockActorCritic implements IActorCritic {
|
||||||
|
|
||||||
|
public final List<Pair<INDArray, INDArray[]>> gradientParams = new ArrayList<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNetwork[] getNeuralNetworks() {
|
||||||
|
return new NeuralNetwork[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isRecurrent() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(INDArray input, INDArray[] labels) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
|
return new INDArray[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IActorCritic clone() {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(NeuralNet from) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(IActorCritic from) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
||||||
|
gradientParams.add(new Pair<INDArray, INDArray[]>(input, labels));
|
||||||
|
return new Gradient[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void applyGradient(Gradient[] gradient, int batchSize) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(String pathValue, String pathPolicy) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getLatestScore() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(OutputStream os) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(String filename) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||||
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.Stack;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class AsyncNStepQLearningThreadDiscreteTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void refac_calcGradient() {
|
||||||
|
// Arrange
|
||||||
|
double gamma = 0.9;
|
||||||
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
|
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||||
|
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration config = new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(0, 0, 0, 0, 0, 0, 0, 0, gamma, 0, 0, 0);
|
||||||
|
MockDQN dqnMock = new MockDQN();
|
||||||
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
|
||||||
|
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
|
||||||
|
AsyncNStepQLearningThreadDiscrete sut = new AsyncNStepQLearningThreadDiscrete<MockEncodable>(mdpMock, asyncGlobalMock, config, null, 0, 0);
|
||||||
|
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
|
||||||
|
sut.setHistoryProcessor(hpMock);
|
||||||
|
|
||||||
|
double[][] minitransObs = new double[][] {
|
||||||
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
|
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||||
|
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||||
|
};
|
||||||
|
double[] outputs = new double[] { 1.0, 2.0, 3.0 };
|
||||||
|
double[] rewards = new double[] { 0.0, 0.0, 3.0 };
|
||||||
|
|
||||||
|
Stack<MiniTrans<Integer>> minitransList = new Stack<MiniTrans<Integer>>();
|
||||||
|
for(int i = 0; i < 3; ++i) {
|
||||||
|
INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1);
|
||||||
|
INDArray[] output = new INDArray[] {
|
||||||
|
Nd4j.zeros(5)
|
||||||
|
};
|
||||||
|
output[0].putScalar(i, outputs[i]);
|
||||||
|
minitransList.push(new MiniTrans<Integer>(obs, i, output, rewards[i]));
|
||||||
|
}
|
||||||
|
minitransList.push(new MiniTrans<Integer>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.calcGradient(dqnMock, minitransList);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, dqnMock.gradientParams.size());
|
||||||
|
INDArray input = dqnMock.gradientParams.get(0).getFirst();
|
||||||
|
INDArray labels = dqnMock.gradientParams.get(0).getSecond();
|
||||||
|
|
||||||
|
assertEquals(minitransObs.length, input.shape()[0]);
|
||||||
|
for(int i = 0; i < minitransObs.length; ++i) {
|
||||||
|
double[] expectedRow = minitransObs[i];
|
||||||
|
assertEquals(expectedRow.length, input.shape()[1]);
|
||||||
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
|
assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double latestReward = (gamma * 4.0) + 3.0;
|
||||||
|
double[][] expectedLabels = new double[][] {
|
||||||
|
new double[] { gamma * gamma * latestReward, 0.0, 0.0, 0.0, 0.0 },
|
||||||
|
new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 },
|
||||||
|
new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 },
|
||||||
|
};
|
||||||
|
assertEquals(minitransObs.length, labels.shape()[0]);
|
||||||
|
for(int i = 0; i < minitransObs.length; ++i) {
|
||||||
|
double[] expectedRow = expectedLabels[i];
|
||||||
|
assertEquals(expectedRow.length, labels.shape()[1]);
|
||||||
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
|
assertEquals(expectedRow[j], labels.getDouble(i, j), 0.00001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -63,7 +63,7 @@ public class QLearningDiscreteTest {
|
||||||
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
|
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
|
||||||
assertEquals(expectedAdds.length, hp.addCalls.size());
|
assertEquals(expectedAdds.length, hp.addCalls.size());
|
||||||
for(int i = 0; i < expectedAdds.length; ++i) {
|
for(int i = 0; i < expectedAdds.length; ++i) {
|
||||||
assertEquals(expectedAdds[i], 255.0 * hp.addCalls.get(i).getDouble(0), 0.0001);
|
assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001);
|
||||||
}
|
}
|
||||||
assertEquals(0, hp.startMonitorCallCount);
|
assertEquals(0, hp.startMonitorCallCount);
|
||||||
assertEquals(0, hp.stopMonitorCallCount);
|
assertEquals(0, hp.stopMonitorCallCount);
|
||||||
|
@ -92,8 +92,8 @@ public class QLearningDiscreteTest {
|
||||||
for(int i = 0; i < expectedDQNOutput.length; ++i) {
|
for(int i = 0; i < expectedDQNOutput.length; ++i) {
|
||||||
INDArray outputParam = dqn.outputParams.get(i);
|
INDArray outputParam = dqn.outputParams.get(i);
|
||||||
|
|
||||||
assertEquals(5, outputParam.shape()[0]);
|
assertEquals(5, outputParam.shape()[1]);
|
||||||
assertEquals(1, outputParam.shape()[1]);
|
assertEquals(1, outputParam.shape()[2]);
|
||||||
|
|
||||||
double[] expectedRow = expectedDQNOutput[i];
|
double[] expectedRow = expectedDQNOutput[i];
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
|
@ -124,13 +124,15 @@ public class QLearningDiscreteTest {
|
||||||
assertEquals(expectedTrActions[i], tr.getAction());
|
assertEquals(expectedTrActions[i], tr.getAction());
|
||||||
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
|
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
|
||||||
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
||||||
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(j, 0), 0.0001);
|
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// trainEpoch result
|
// trainEpoch result
|
||||||
assertEquals(16, result.getStepCounter());
|
assertEquals(16, result.getStepCounter());
|
||||||
assertEquals(300.0, result.getReward(), 0.00001);
|
assertEquals(300.0, result.getReward(), 0.00001);
|
||||||
|
assertTrue(dqn.hasBeenReset);
|
||||||
|
assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscret
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
|
@ -229,6 +230,11 @@ public class PolicyTest {
|
||||||
return neuralNet;
|
return neuralNet;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer nextAction(Observation obs) {
|
||||||
|
return nextAction(obs.getData());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Integer nextAction(INDArray input) {
|
public Integer nextAction(INDArray input) {
|
||||||
return (int)input.getDouble(0);
|
return (int)input.getDouble(0);
|
||||||
|
|
|
@ -8,9 +8,10 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
public class MockAsyncGlobal implements IAsyncGlobal {
|
public class MockAsyncGlobal<NN extends NeuralNet> implements IAsyncGlobal<NN> {
|
||||||
|
|
||||||
private final NeuralNet current;
|
@Getter
|
||||||
|
private final NN current;
|
||||||
|
|
||||||
public boolean hasBeenStarted = false;
|
public boolean hasBeenStarted = false;
|
||||||
public boolean hasBeenTerminated = false;
|
public boolean hasBeenTerminated = false;
|
||||||
|
@ -27,7 +28,7 @@ public class MockAsyncGlobal implements IAsyncGlobal {
|
||||||
this(null);
|
this(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public MockAsyncGlobal(NeuralNet current) {
|
public MockAsyncGlobal(NN current) {
|
||||||
maxLoops = Integer.MAX_VALUE;
|
maxLoops = Integer.MAX_VALUE;
|
||||||
numLoopsStopRunning = Integer.MAX_VALUE;
|
numLoopsStopRunning = Integer.MAX_VALUE;
|
||||||
this.current = current;
|
this.current = current;
|
||||||
|
@ -45,7 +46,7 @@ public class MockAsyncGlobal implements IAsyncGlobal {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isTrainingComplete() {
|
public boolean isTrainingComplete() {
|
||||||
return ++currentLoop > maxLoops;
|
return currentLoop >= maxLoops;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -59,12 +60,7 @@ public class MockAsyncGlobal implements IAsyncGlobal {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNet getCurrent() {
|
public NN getTarget() {
|
||||||
return current;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NeuralNet getTarget() {
|
|
||||||
return current;
|
return current;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,4 +68,8 @@ public class MockAsyncGlobal implements IAsyncGlobal {
|
||||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
public void enqueue(Gradient[] gradient, Integer nstep) {
|
||||||
++enqueueCallCount;
|
++enqueueCallCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void increaseCurrentLoop() {
|
||||||
|
++currentLoop;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,8 +15,10 @@ import java.util.List;
|
||||||
|
|
||||||
public class MockDQN implements IDQN {
|
public class MockDQN implements IDQN {
|
||||||
|
|
||||||
|
public boolean hasBeenReset = false;
|
||||||
public final List<INDArray> outputParams = new ArrayList<>();
|
public final List<INDArray> outputParams = new ArrayList<>();
|
||||||
public final List<Pair<INDArray, INDArray>> fitParams = new ArrayList<>();
|
public final List<Pair<INDArray, INDArray>> fitParams = new ArrayList<>();
|
||||||
|
public final List<Pair<INDArray, INDArray>> gradientParams = new ArrayList<>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNetwork[] getNeuralNetworks() {
|
public NeuralNetwork[] getNeuralNetworks() {
|
||||||
|
@ -30,7 +32,7 @@ public class MockDQN implements IDQN {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void reset() {
|
public void reset() {
|
||||||
|
hasBeenReset = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -61,7 +63,10 @@ public class MockDQN implements IDQN {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public IDQN clone() {
|
public IDQN clone() {
|
||||||
return null;
|
MockDQN clone = new MockDQN();
|
||||||
|
clone.hasBeenReset = hasBeenReset;
|
||||||
|
|
||||||
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -76,6 +81,7 @@ public class MockDQN implements IDQN {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Gradient[] gradient(INDArray input, INDArray label) {
|
public Gradient[] gradient(INDArray input, INDArray label) {
|
||||||
|
gradientParams.add(new Pair<INDArray, INDArray>(input, label));
|
||||||
return new Gradient[0];
|
return new Gradient[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ public class MockNeuralNet implements NeuralNet {
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] outputAll(INDArray batch) {
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
outputAllInputs.add(batch);
|
outputAllInputs.add(batch);
|
||||||
return new INDArray[] { Nd4j.create(new double[] { 1.0 }) };
|
return new INDArray[] { Nd4j.create(new double[] { outputAllInputs.size() }) };
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -2,6 +2,7 @@ package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -23,6 +24,11 @@ public class MockPolicy implements IPolicy<MockEncodable, Integer> {
|
||||||
@Override
|
@Override
|
||||||
public Integer nextAction(INDArray input) {
|
public Integer nextAction(INDArray input) {
|
||||||
actionInputs.add(input);
|
actionInputs.add(input);
|
||||||
return null;
|
return input.getInt(0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer nextAction(Observation observation) {
|
||||||
|
return nextAction(observation.getData());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import java.util.List;
|
||||||
|
|
||||||
public class MockTrainingListener implements TrainingListener {
|
public class MockTrainingListener implements TrainingListener {
|
||||||
|
|
||||||
|
private final MockAsyncGlobal asyncGlobal;
|
||||||
public int onTrainingStartCallCount = 0;
|
public int onTrainingStartCallCount = 0;
|
||||||
public int onTrainingEndCallCount = 0;
|
public int onTrainingEndCallCount = 0;
|
||||||
public int onNewEpochCallCount = 0;
|
public int onNewEpochCallCount = 0;
|
||||||
|
@ -28,6 +29,14 @@ public class MockTrainingListener implements TrainingListener {
|
||||||
|
|
||||||
public final List<IDataManager.StatEntry> statEntries = new ArrayList<>();
|
public final List<IDataManager.StatEntry> statEntries = new ArrayList<>();
|
||||||
|
|
||||||
|
public MockTrainingListener() {
|
||||||
|
this(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public MockTrainingListener(MockAsyncGlobal asyncGlobal) {
|
||||||
|
this.asyncGlobal = asyncGlobal;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ListenerResponse onTrainingStart() {
|
public ListenerResponse onTrainingStart() {
|
||||||
|
@ -55,6 +64,9 @@ public class MockTrainingListener implements TrainingListener {
|
||||||
public ListenerResponse onTrainingProgress(ILearning learning) {
|
public ListenerResponse onTrainingProgress(ILearning learning) {
|
||||||
++onTrainingProgressCallCount;
|
++onTrainingProgressCallCount;
|
||||||
--remainingonTrainingProgressCallCount;
|
--remainingonTrainingProgressCallCount;
|
||||||
|
if(asyncGlobal != null) {
|
||||||
|
asyncGlobal.increaseCurrentLoop();
|
||||||
|
}
|
||||||
return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue