diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java index fb07baf1e..5501a29e1 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java @@ -57,6 +57,7 @@ public class AsyncGlobal extends Thread implements IAsyncG final private NN current; final private ConcurrentLinkedQueue> queue; final private AsyncConfiguration a3cc; + private final IAsyncLearning learning; @Getter private AtomicInteger T = new AtomicInteger(0); @Getter @@ -64,10 +65,11 @@ public class AsyncGlobal extends Thread implements IAsyncG @Getter private boolean running = true; - public AsyncGlobal(NN initial, AsyncConfiguration a3cc) { + public AsyncGlobal(NN initial, AsyncConfiguration a3cc, IAsyncLearning learning) { this.current = initial; target = (NN) initial.clone(); this.a3cc = a3cc; + this.learning = learning; queue = new ConcurrentLinkedQueue<>(); } @@ -106,11 +108,14 @@ public class AsyncGlobal extends Thread implements IAsyncG } /** - * Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded. + * Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too. */ public void terminate() { - running = false; - queue.clear(); + if(running) { + running = false; + queue.clear(); + learning.terminate(); + } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java index 0835bf692..994ec9cb0 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java @@ -37,7 +37,10 @@ import org.nd4j.linalg.factory.Nd4j; */ @Slf4j public abstract class AsyncLearning, NN extends NeuralNet> - extends Learning { + extends Learning + implements IAsyncLearning { + + private Thread monitorThread = null; @Getter(AccessLevel.PROTECTED) private final TrainingListenerList listeners = new TrainingListenerList(); @@ -126,6 +129,7 @@ public abstract class AsyncLearning, NN extends NeuralNet> +public abstract class AsyncThread, NN extends NeuralNet> extends Thread implements StepCountable, IEpochTrainer { @Getter @@ -54,26 +58,35 @@ public abstract class AsyncThread mdp; @Getter @Setter private IHistoryProcessor historyProcessor; + private boolean isEpochStarted = false; + private final LegacyMDPWrapper mdp; + private final TrainingListenerList listeners; public AsyncThread(IAsyncGlobal asyncGlobal, MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) { - this.mdp = mdp; + this.mdp = new LegacyMDPWrapper(mdp, null, this); this.listeners = listeners; this.threadNumber = threadNumber; this.deviceNum = deviceNum; } + public MDP getMdp() { + return mdp.getWrappedMDP(); + } + protected LegacyMDPWrapper getLegacyMDPWrapper() { + return mdp; + } + public void setHistoryProcessor(IHistoryProcessor.Configuration conf) { - historyProcessor = new HistoryProcessor(conf); + setHistoryProcessor(new HistoryProcessor(conf)); } public void setHistoryProcessor(IHistoryProcessor historyProcessor) { this.historyProcessor = historyProcessor; + mdp.setHistoryProcessor(historyProcessor); } protected void postEpoch() { @@ -109,61 +122,63 @@ public abstract class AsyncThread context = new RunContext<>(); - Nd4j.getAffinityManager().unsafeSetDevice(deviceNum); + try { + RunContext context = new RunContext(); + Nd4j.getAffinityManager().unsafeSetDevice(deviceNum); - log.info("ThreadNum-" + threadNumber + " Started!"); - - boolean canContinue = initWork(context); - if (canContinue) { + log.info("ThreadNum-" + threadNumber + " Started!"); while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) { - handleTraining(context); - if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) { - canContinue = finishEpoch(context) && startNewEpoch(context); + if (!isEpochStarted) { + boolean canContinue = startNewEpoch(context); if (!canContinue) { break; } } + + handleTraining(context); + + if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) { + boolean canContinue = finishEpoch(context); + if (!canContinue) { + break; + } + + ++epochCounter; + } } } - terminateWork(); + finally { + terminateWork(); + } } - private void initNewEpoch(RunContext context) { - getCurrent().reset(); - Learning.InitMdp initMdp = Learning.initMdp(getMdp(), historyProcessor); - - context.obs = initMdp.getLastObs(); - context.rewards = initMdp.getReward(); - context.epochElapsedSteps = initMdp.getSteps(); - } - - private void handleTraining(RunContext context) { + private void handleTraining(RunContext context) { int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps); - SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); + SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); context.obs = subEpochReturn.getLastObs(); - stepCounter += subEpochReturn.getSteps(); context.epochElapsedSteps += subEpochReturn.getSteps(); context.rewards += subEpochReturn.getReward(); context.score = subEpochReturn.getScore(); } - private boolean initWork(RunContext context) { - initNewEpoch(context); - preEpoch(); - return listeners.notifyNewEpoch(this); - } - private boolean startNewEpoch(RunContext context) { - initNewEpoch(context); - epochCounter++; + getCurrent().reset(); + Learning.InitMdp initMdp = refacInitMdp(); + + context.obs = initMdp.getLastObs(); + context.rewards = initMdp.getReward(); + context.epochElapsedSteps = initMdp.getSteps(); + + isEpochStarted = true; preEpoch(); + return listeners.notifyNewEpoch(this); } private boolean finishEpoch(RunContext context) { + isEpochStarted = false; postEpoch(); IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score); @@ -173,8 +188,10 @@ public abstract class AsyncThread getPolicy(NN net); - protected abstract SubEpochReturn trainSubEpoch(O obs, int nstep); + protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep); + + private Learning.InitMdp refacInitMdp() { + LegacyMDPWrapper mdp = getLegacyMDPWrapper(); + IHistoryProcessor hp = getHistoryProcessor(); + + Observation observation = mdp.reset(); + + int step = 0; + double reward = 0; + + boolean isHistoryProcessor = hp != null; + + int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; + int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; + + while (step < requiredFrame && !mdp.isDone()) { + + A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP + + StepReply stepReply = mdp.step(action); + reward += stepReply.getReward(); + observation = stepReply.getObservation(); + + step++; + + } + + return new Learning.InitMdp(step, observation, reward); + + } + + public void incrementStep() { + ++stepCounter; + } @AllArgsConstructor @Value - public static class SubEpochReturn { + public static class SubEpochReturn { int steps; - O lastObs; + Observation lastObs; double reward; double score; } @@ -206,8 +257,8 @@ public abstract class AsyncThread { - private O obs; + private static class RunContext { + private Observation obs; private double rewards; private int epochElapsedSteps; private double score; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index 8b8bc2861..6b0078883 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -25,9 +25,11 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.ArrayUtil; @@ -40,7 +42,7 @@ import java.util.Stack; * Async Learning specialized for the Discrete Domain * */ -public abstract class AsyncThreadDiscrete +public abstract class AsyncThreadDiscrete extends AsyncThread { @Getter @@ -61,14 +63,14 @@ public abstract class AsyncThreadDiscrete trainSubEpoch(O sObs, int nstep) { + public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) { synchronized (getAsyncGlobal()) { current.copy(getAsyncGlobal().getCurrent()); } Stack> rewards = new Stack<>(); - O obs = sObs; + Observation obs = sObs; IPolicy policy = getPolicy(current); Integer action; @@ -81,93 +83,53 @@ public abstract class AsyncThreadDiscrete stepReply = getMdp().step(action); + StepReply stepReply = getLegacyMDPWrapper().step(action); accuReward += stepReply.getReward() * getConf().getRewardFactor(); //if it's not a skipped frame, you can do a step of training if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) { - obs = stepReply.getObservation(); - if (hstack == null) { - hstack = processHistory(input); - } - INDArray[] output = current.outputAll(hstack); - rewards.add(new MiniTrans(hstack, action, output, accuReward)); + INDArray[] output = current.outputAll(obs.getData()); + rewards.add(new MiniTrans(obs.getData(), action, output, accuReward)); accuReward = 0; } + obs = stepReply.getObservation(); + reward += stepReply.getReward(); i++; + incrementStep(); lastAction = action; } //a bit of a trick usable because of how the stack is treated to init R - INDArray input = Learning.getInput(getMdp(), obs); - INDArray hstack = processHistory(input); - - if (hp != null) { - hp.record(input); - } + // FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored. if (getMdp().isDone() && i < nstep * skipFrame) - rewards.add(new MiniTrans(hstack, null, null, 0)); + rewards.add(new MiniTrans(obs.getData(), null, null, 0)); else { INDArray[] output = null; if (getConf().getTargetDqnUpdateFreq() == -1) - output = current.outputAll(hstack); + output = current.outputAll(obs.getData()); else synchronized (getAsyncGlobal()) { - output = getAsyncGlobal().getTarget().outputAll(hstack); + output = getAsyncGlobal().getTarget().outputAll(obs.getData()); } double maxQ = Nd4j.max(output[0]).getDouble(0); - rewards.add(new MiniTrans(hstack, null, output, maxQ)); + rewards.add(new MiniTrans(obs.getData(), null, output, maxQ)); } getAsyncGlobal().enqueue(calcGradient(current, rewards), i); - return new SubEpochReturn(i, obs, reward, current.getLatestScore()); - } - - protected INDArray processHistory(INDArray input) { - IHistoryProcessor hp = getHistoryProcessor(); - INDArray[] history; - if (hp != null) { - hp.add(input); - history = hp.getHistory(); - } else - history = new INDArray[] {input}; - //concat the history into a single INDArray input - INDArray hstack = Transition.concat(history); - if (hp != null) { - hstack.muli(1.0 / hp.getScale()); - } - - if (getCurrent().isRecurrent()) { - //flatten everything for the RNN - hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()), 1)); - } else { - //if input is not 2d, you have to append that the batch is 1 length high - if (hstack.shape().length > 2) - hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()))); - } - - return hstack; + return new SubEpochReturn(i, obs, reward, current.getLatestScore()); } public abstract Gradient[] calcGradient(NN nn, Stack> rewards); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java new file mode 100644 index 000000000..6bae9fddf --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java @@ -0,0 +1,21 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.async; + +public interface IAsyncLearning { + void terminate(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java index 52fa3932b..81308ba5a 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java @@ -53,7 +53,7 @@ public abstract class A3CDiscrete extends AsyncLearning(iActorCritic, conf); + asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this); Integer seed = conf.getSeed(); Random rnd = Nd4j.getRandom(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index d71fa95a0..22b3894b2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; +import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.MiniTrans; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; @@ -46,13 +47,13 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< @Getter final protected A3CDiscrete.A3CConfiguration conf; @Getter - final protected AsyncGlobal asyncGlobal; + final protected IAsyncGlobal asyncGlobal; @Getter final protected int threadNumber; final private Random rnd; - public A3CThreadDiscrete(MDP mdp, AsyncGlobal asyncGlobal, + public A3CThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners, int threadNumber) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index 0c9ff057f..c18de9e10 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -46,7 +46,7 @@ public abstract class AsyncNStepQLearningDiscrete public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { this.mdp = mdp; this.configuration = conf; - this.asyncGlobal = new AsyncGlobal<>(dqn, conf); + this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this); } @Override diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index 098aefeaa..bfd23ef5b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -150,6 +150,9 @@ public abstract class QLearning refacInitMdp() { + getQNetwork().reset(); + getTargetQNetwork().reset(); + LegacyMDPWrapper mdp = getLegacyMDPWrapper(); IHistoryProcessor hp = getHistoryProcessor(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index 4d089d1ee..9da30ccef 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -46,7 +46,7 @@ import java.util.ArrayList; * * DQN or Deep Q-Learning in the Discrete domain * - * https://arxiv.org/abs/1312.5602 + * http://arxiv.org/abs/1312.5602 * */ public abstract class QLearningDiscrete extends QLearning { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java index 7ca63baaf..197a8f744 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java @@ -29,7 +29,15 @@ public class Observation { private final DataSet data; public Observation(INDArray[] data) { - this(new org.nd4j.linalg.dataset.DataSet(Nd4j.concat(0, data), null)); + this(data, false); + } + + public Observation(INDArray[] data, boolean shouldReshape) { + INDArray features = Nd4j.concat(0, data); + if(shouldReshape) { + features = reshape(features); + } + this.data = new org.nd4j.linalg.dataset.DataSet(features, null); } // FIXME: Remove -- only used in unit tests @@ -37,6 +45,15 @@ public class Observation { this.data = new org.nd4j.linalg.dataset.DataSet(data, null); } + private INDArray reshape(INDArray source) { + long[] shape = source.shape(); + long[] nshape = new long[shape.length + 1]; + nshape[0] = 1; + System.arraycopy(shape, 0, nshape, 1, shape.length); + + return source.reshape(nshape); + } + private Observation(DataSet data) { this.data = data; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java index 09e396ac4..61ba70825 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java @@ -20,6 +20,7 @@ import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph; import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate; import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; @@ -65,6 +66,11 @@ public class ACPolicy extends Policy { return actorCritic; } + @Override + public Integer nextAction(Observation obs) { + return nextAction(obs.getData()); + } + public Integer nextAction(INDArray input) { INDArray output = actorCritic.outputAll(input)[1]; if (rnd == null) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java index bff1a782c..cf2b60f41 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java @@ -17,6 +17,7 @@ package org.deeplearning4j.rl4j.policy; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; @@ -43,6 +44,11 @@ public class BoltzmannQ extends Policy { return dqn; } + @Override + public Integer nextAction(Observation obs) { + return nextAction(obs.getData()); + } + public Integer nextAction(INDArray input) { INDArray output = dqn.output(input); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java index 70562b0f1..c7ef91665 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java @@ -20,6 +20,7 @@ import lombok.AllArgsConstructor; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.network.dqn.DQN; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; @@ -44,6 +45,11 @@ public class DQNPolicy extends Policy { return dqn; } + @Override + public Integer nextAction(Observation obs) { + return nextAction(obs.getData()); + } + public Integer nextAction(INDArray input) { INDArray output = dqn.output(input); return Learning.getMaxAction(output); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java index 885fa36a2..0bef2a757 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java @@ -2,6 +2,7 @@ package org.deeplearning4j.rl4j.policy; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; @@ -9,4 +10,5 @@ import org.nd4j.linalg.api.ndarray.INDArray; public interface IPolicy { > double play(MDP mdp, IHistoryProcessor hp); A nextAction(INDArray input); + A nextAction(Observation observation); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index 84ef26f25..97a11b99c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -16,15 +16,21 @@ package org.deeplearning4j.rl4j.policy; +import lombok.Getter; +import lombok.Setter; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.HistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.StepCountable; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.util.ArrayUtil; @@ -39,7 +45,7 @@ public abstract class Policy implements IPolicy { public abstract NeuralNet getNeuralNet(); - public abstract A nextAction(INDArray input); + public abstract A nextAction(Observation obs); public > double play(MDP mdp) { return play(mdp, (IHistoryProcessor)null); @@ -51,66 +57,81 @@ public abstract class Policy implements IPolicy { @Override public > double play(MDP mdp, IHistoryProcessor hp) { + RefacStepCountable stepCountable = new RefacStepCountable(); + LegacyMDPWrapper mdpWrapper = new LegacyMDPWrapper(mdp, hp, stepCountable); + boolean isHistoryProcessor = hp != null; int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; - getNeuralNet().reset(); - Learning.InitMdp initMdp = Learning.initMdp(mdp, hp); - O obs = initMdp.getLastObs(); + Learning.InitMdp initMdp = refacInitMdp(mdpWrapper, hp); + Observation obs = initMdp.getLastObs(); double reward = initMdp.getReward(); - A lastAction = mdp.getActionSpace().noOp(); + A lastAction = mdpWrapper.getActionSpace().noOp(); A action; - int step = initMdp.getSteps(); - INDArray[] history = null; + stepCountable.setStepCounter(initMdp.getSteps()); - INDArray input = Learning.getInput(mdp, obs); + while (!mdpWrapper.isDone()) { - while (!mdp.isDone()) { - - if (step % skipFrame != 0) { + if (stepCountable.getStepCounter() % skipFrame != 0) { action = lastAction; } else { - - if (history == null) { - if (isHistoryProcessor) { - hp.add(input); - history = hp.getHistory(); - } else - history = new INDArray[] {input}; - } - INDArray hstack = Transition.concat(history); - if (isHistoryProcessor) { - hstack.muli(1.0 / hp.getScale()); - } - if (getNeuralNet().isRecurrent()) { - //flatten everything for the RNN - hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()), 1)); - } else { - if (hstack.shape().length > 2) - hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()))); - } - action = nextAction(hstack); + action = nextAction(obs); } + lastAction = action; - StepReply stepReply = mdp.step(action); + StepReply stepReply = mdpWrapper.step(action); reward += stepReply.getReward(); - input = Learning.getInput(mdp, stepReply.getObservation()); - if (isHistoryProcessor) { - hp.record(input); - hp.add(input); - } - - history = isHistoryProcessor ? hp.getHistory() - : new INDArray[] {Learning.getInput(mdp, stepReply.getObservation())}; - step++; + obs = stepReply.getObservation(); + stepCountable.increment(); } - return reward; } + private > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) { + getNeuralNet().reset(); + Observation observation = mdpWrapper.reset(); + + int step = 0; + double reward = 0; + + boolean isHistoryProcessor = hp != null; + + int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; + int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; + + while (step < requiredFrame && !mdpWrapper.isDone()) { + + A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP + + StepReply stepReply = mdpWrapper.step(action); + reward += stepReply.getReward(); + observation = stepReply.getObservation(); + + step++; + + } + + return new Learning.InitMdp(step, observation, reward); + } + + private class RefacStepCountable implements StepCountable { + + @Getter + @Setter + private int stepCounter = 0; + + public void increment() { + ++stepCounter; + } + + @Override + public int getStepCounter() { + return 0; + } + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java index efbf29603..221409040 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java @@ -4,6 +4,7 @@ import lombok.Getter; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.learning.StepCountable; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; @@ -12,21 +13,53 @@ import org.deeplearning4j.rl4j.space.ObservationSpace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -public class LegacyMDPWrapper> implements MDP { +public class LegacyMDPWrapper> implements MDP { @Getter private final MDP wrappedMDP; @Getter private final WrapperObservationSpace observationSpace; private final ILearning learning; + private IHistoryProcessor historyProcessor; + private final StepCountable stepCountable; private int skipFrame; private int step = 0; public LegacyMDPWrapper(MDP wrappedMDP, ILearning learning) { + this(wrappedMDP, learning, null, null); + } + + public LegacyMDPWrapper(MDP wrappedMDP, IHistoryProcessor historyProcessor, StepCountable stepCountable) { + this(wrappedMDP, null, historyProcessor, stepCountable); + } + + private LegacyMDPWrapper(MDP wrappedMDP, ILearning learning, IHistoryProcessor historyProcessor, StepCountable stepCountable) { this.wrappedMDP = wrappedMDP; this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape()); this.learning = learning; + this.historyProcessor = historyProcessor; + this.stepCountable = stepCountable; + } + + private IHistoryProcessor getHistoryProcessor() { + if(historyProcessor != null) { + return historyProcessor; + } + + return learning.getHistoryProcessor(); + } + + public void setHistoryProcessor(IHistoryProcessor historyProcessor) { + this.historyProcessor = historyProcessor; + } + + private int getStep() { + if(stepCountable != null) { + return stepCountable.getStepCounter(); + } + + return learning.getStepCounter(); } @Override @@ -38,13 +71,12 @@ public class LegacyMDPWrapper> public Observation reset() { INDArray rawObservation = getInput(wrappedMDP.reset()); - IHistoryProcessor historyProcessor = learning.getHistoryProcessor(); + IHistoryProcessor historyProcessor = getHistoryProcessor(); if(historyProcessor != null) { - historyProcessor.record(rawObservation.dup()); - rawObservation.muli(1.0 / historyProcessor.getScale()); + historyProcessor.record(rawObservation); } - Observation observation = new Observation(new INDArray[] { rawObservation }); + Observation observation = new Observation(new INDArray[] { rawObservation }, false); if(historyProcessor != null) { skipFrame = historyProcessor.getConf().getSkipFrame(); @@ -55,14 +87,9 @@ public class LegacyMDPWrapper> return observation; } - @Override - public void close() { - wrappedMDP.close(); - } - @Override public StepReply step(A a) { - IHistoryProcessor historyProcessor = learning.getHistoryProcessor(); + IHistoryProcessor historyProcessor = getHistoryProcessor(); StepReply rawStepReply = wrappedMDP.step(a); INDArray rawObservation = getInput(rawStepReply.getObservation()); @@ -71,11 +98,10 @@ public class LegacyMDPWrapper> int requiredFrame = 0; if(historyProcessor != null) { - historyProcessor.record(rawObservation.dup()); - rawObservation.muli(1.0 / historyProcessor.getScale()); + historyProcessor.record(rawObservation); requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1); - if ((learning.getStepCounter() % skipFrame == 0 && step >= requiredFrame) + if ((getStep() % skipFrame == 0 && step >= requiredFrame) || (step % skipFrame == 0 && step < requiredFrame )){ historyProcessor.add(rawObservation); } @@ -83,15 +109,21 @@ public class LegacyMDPWrapper> Observation observation; if(historyProcessor != null && step >= requiredFrame) { - observation = new Observation(historyProcessor.getHistory()); + observation = new Observation(historyProcessor.getHistory(), true); + observation.getData().muli(1.0 / historyProcessor.getScale()); } else { - observation = new Observation(new INDArray[] { rawObservation }); + observation = new Observation(new INDArray[] { rawObservation }, false); } return new StepReply(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); } + @Override + public void close() { + wrappedMDP.close(); + } + @Override public boolean isDone() { return wrappedMDP.isDone(); @@ -103,7 +135,7 @@ public class LegacyMDPWrapper> } private INDArray getInput(O obs) { - INDArray arr = Nd4j.create(obs.toArray()); + INDArray arr = Nd4j.create(((Encodable)obs).toArray()); int[] shape = observationSpace.getShape(); if (shape.length == 1) return arr.reshape(new long[] {1, arr.length()}); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java index 8de9d864a..2302117d2 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java @@ -72,7 +72,7 @@ public class AsyncLearningTest { public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(); public final MockPolicy policy = new MockPolicy(); public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy); - public final MockTrainingListener listener = new MockTrainingListener(); + public final MockTrainingListener listener = new MockTrainingListener(asyncGlobal); public TestContext() { sut.addListener(listener); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index e3658b8dd..b94541159 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -2,16 +2,17 @@ package org.deeplearning4j.rl4j.learning.async; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; -import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.support.*; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; +import java.util.ArrayList; +import java.util.List; import java.util.Stack; import static org.junit.Assert.assertEquals; @@ -21,37 +22,51 @@ public class AsyncThreadDiscreteTest { @Test public void refac_AsyncThreadDiscrete_trainSubEpoch() { // Arrange + int numEpochs = 1; MockNeuralNet nnMock = new MockNeuralNet(); + IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock); + asyncGlobalMock.setMaxLoops(hpConf.getSkipFrame() * numEpochs); MockObservationSpace observationSpace = new MockObservationSpace(); MockMDP mdpMock = new MockMDP(observationSpace); TrainingListenerList listeners = new TrainingListenerList(); MockPolicy policyMock = new MockPolicy(); - MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 0, 5,0, 0, 0, 0); - IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); - MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); + MockAsyncConfiguration config = new MockAsyncConfiguration(5, 16, 0, 0, 2, 5,0, 0, 0, 0); TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); - MockEncodable obs = new MockEncodable(123); - - hpMock.add(Learning.getInput(mdpMock, new MockEncodable(1))); - hpMock.add(Learning.getInput(mdpMock, new MockEncodable(2))); - hpMock.add(Learning.getInput(mdpMock, new MockEncodable(3))); - hpMock.add(Learning.getInput(mdpMock, new MockEncodable(4))); - hpMock.add(Learning.getInput(mdpMock, new MockEncodable(5))); // Act - AsyncThread.SubEpochReturn result = sut.trainSubEpoch(obs, 2); + sut.run(); // Assert - assertEquals(4, result.getSteps()); - assertEquals(6.0, result.getReward(), 0.00001); - assertEquals(0.0, result.getScore(), 0.00001); - assertEquals(3.0, result.getLastObs().toArray()[0], 0.00001); - assertEquals(1, asyncGlobalMock.enqueueCallCount); + assertEquals(2, sut.trainSubEpochResults.size()); + double[][] expectedLastObservations = new double[][] { + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, + new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, + }; + double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 }; + for(int i = 0; i < 2; ++i) { + AsyncThread.SubEpochReturn result = sut.trainSubEpochResults.get(i); + assertEquals(4, result.getSteps()); + assertEquals(expectedSubEpochReturnRewards[i], result.getReward(), 0.00001); + assertEquals(0.0, result.getScore(), 0.00001); + + double[] expectedLastObservation = expectedLastObservations[i]; + assertEquals(expectedLastObservation.length, result.getLastObs().getData().shape()[1]); + for(int j = 0; j < expectedLastObservation.length; ++j) { + assertEquals(expectedLastObservation[j], 255.0 * result.getLastObs().getData().getDouble(j), 0.00001); + } + } + assertEquals(2, asyncGlobalMock.enqueueCallCount); // HistoryProcessor - assertEquals(10, hpMock.addCalls.size()); - double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 }; + double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 }; + assertEquals(expectedAddValues.length, hpMock.addCalls.size()); + for(int i = 0; i < expectedAddValues.length; ++i) { + assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001); + } + + double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, }; assertEquals(expectedRecordValues.length, hpMock.recordCalls.size()); for(int i = 0; i < expectedRecordValues.length; ++i) { assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001); @@ -59,49 +74,89 @@ public class AsyncThreadDiscreteTest { // Policy double[][] expectedPolicyInputs = new double[][] { - new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 }, - new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 }, - new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 }, - new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 }, + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, + new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, }; assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size()); for(int i = 0; i < expectedPolicyInputs.length; ++i) { double[] expectedRow = expectedPolicyInputs[i]; INDArray input = policyMock.actionInputs.get(i); - assertEquals(expectedRow.length, input.shape()[0]); + assertEquals(expectedRow.length, input.shape()[1]); for(int j = 0; j < expectedRow.length; ++j) { assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001); } } // NeuralNetwork - assertEquals(1, nnMock.copyCallCount); + assertEquals(2, nnMock.copyCallCount); double[][] expectedNNInputs = new double[][] { - new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 }, - new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 }, - new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 }, - new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 }, - new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 }, + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: This one comes from the computation of output of the last minitrans + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, + new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, + new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, // FIXME: This one comes from the computation of output of the last minitrans }; assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size()); for(int i = 0; i < expectedNNInputs.length; ++i) { double[] expectedRow = expectedNNInputs[i]; INDArray input = nnMock.outputAllInputs.get(i); - assertEquals(expectedRow.length, input.shape()[0]); + assertEquals(expectedRow.length, input.shape()[1]); for(int j = 0; j < expectedRow.length; ++j) { assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001); } } + int arrayIdx = 0; + double[][][] expectedMinitransObs = new double[][][] { + new double[][] { + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: The last minitrans contains the next observation + }, + new double[][] { + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, + new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, + new double[] { 8.0, 9.0, 11.0, 13.0, 15 }, // FIXME: The last minitrans contains the next observation + } + }; + double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 }; + double[] expectedRewards = new double[] { 0.0, 0.0, 3.0, 0.0, 0.0, 6.0 }; + + assertEquals(2, sut.rewards.size()); + for(int rewardIdx = 0; rewardIdx < 2; ++rewardIdx) { + Stack> miniTransStack = sut.rewards.get(rewardIdx); + + for (int i = 0; i < expectedMinitransObs[rewardIdx].length; ++i) { + MiniTrans minitrans = miniTransStack.get(i); + + // Observation + double[] expectedRow = expectedMinitransObs[rewardIdx][i]; + INDArray realRewards = minitrans.getObs(); + assertEquals(expectedRow.length, realRewards.shape()[1]); + for (int j = 0; j < expectedRow.length; ++j) { + assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * realRewards.getDouble(j), 0.00001); + } + + assertEquals(expectedOutputs[arrayIdx], minitrans.getOutput()[0].getDouble(0), 0.00001); + assertEquals(expectedRewards[arrayIdx], minitrans.getReward(), 0.00001); + ++arrayIdx; + } + } } public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete { - private final IAsyncGlobal asyncGlobal; + private final MockAsyncGlobal asyncGlobal; private final MockPolicy policy; private final MockAsyncConfiguration config; - public TestAsyncThreadDiscrete(IAsyncGlobal asyncGlobal, MDP mdp, + public final List trainSubEpochResults = new ArrayList(); + public final List>> rewards = new ArrayList>>(); + + public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy, MockAsyncConfiguration config, IHistoryProcessor hp) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); @@ -113,6 +168,7 @@ public class AsyncThreadDiscreteTest { @Override public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack> rewards) { + this.rewards.add(rewards); return new Gradient[0]; } @@ -130,5 +186,13 @@ public class AsyncThreadDiscreteTest { protected IPolicy getPolicy(MockNeuralNet net) { return policy; } + + @Override + public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) { + asyncGlobal.increaseCurrentLoop(); + SubEpochReturn result = super.trainSubEpoch(sObs, nstep); + trainSubEpochResults.add(result); + return result; + } } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index 0a590a1e5..377b32175 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -6,11 +6,14 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.Policy; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; import java.util.List; @@ -23,107 +26,100 @@ public class AsyncThreadTest { @Test public void when_newEpochStarted_expect_neuralNetworkReset() { // Arrange - TestContext context = new TestContext(); - context.listener.setRemainingOnNewEpochCallCount(5); + int numberOfEpochs = 5; + TestContext context = new TestContext(numberOfEpochs); // Act context.sut.run(); // Assert - assertEquals(6, context.neuralNet.resetCallCount); + assertEquals(numberOfEpochs, context.neuralNet.resetCallCount); } @Test public void when_onNewEpochReturnsStop_expect_threadStopped() { // Arrange - TestContext context = new TestContext(); - context.listener.setRemainingOnNewEpochCallCount(1); + int stopAfterNumCalls = 1; + TestContext context = new TestContext(100000); + context.listener.setRemainingOnNewEpochCallCount(stopAfterNumCalls); // Act context.sut.run(); // Assert - assertEquals(2, context.listener.onNewEpochCallCount); - assertEquals(1, context.listener.onEpochTrainingResultCallCount); + assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: The call that returns stop is counted + assertEquals(stopAfterNumCalls, context.listener.onEpochTrainingResultCallCount); } @Test public void when_epochTrainingResultReturnsStop_expect_threadStopped() { // Arrange - TestContext context = new TestContext(); - context.listener.setRemainingOnEpochTrainingResult(1); + int stopAfterNumCalls = 1; + TestContext context = new TestContext(100000); + context.listener.setRemainingOnEpochTrainingResult(stopAfterNumCalls); // Act context.sut.run(); // Assert - assertEquals(2, context.listener.onNewEpochCallCount); - assertEquals(2, context.listener.onEpochTrainingResultCallCount); + assertEquals(stopAfterNumCalls + 1, context.listener.onEpochTrainingResultCallCount); // +1: The call that returns stop is counted + assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: onNewEpoch is called on the epoch that onEpochTrainingResult() will stop } @Test public void when_run_expect_preAndPostEpochCalled() { // Arrange - TestContext context = new TestContext(); + int numberOfEpochs = 5; + TestContext context = new TestContext(numberOfEpochs); // Act context.sut.run(); // Assert - assertEquals(6, context.sut.preEpochCallCount); - assertEquals(6, context.sut.postEpochCallCount); + assertEquals(numberOfEpochs, context.sut.preEpochCallCount); + assertEquals(numberOfEpochs, context.sut.postEpochCallCount); } @Test public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() { // Arrange - TestContext context = new TestContext(); + int numberOfEpochs = 5; + TestContext context = new TestContext(numberOfEpochs); // Act context.sut.run(); // Assert - assertEquals(5, context.listener.statEntries.size()); + assertEquals(numberOfEpochs, context.listener.statEntries.size()); int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 }; - for(int i = 0; i < 5; ++i) { + double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init + + 1.0; // Reward from trainSubEpoch() + for(int i = 0; i < numberOfEpochs; ++i) { IDataManager.StatEntry statEntry = context.listener.statEntries.get(i); assertEquals(expectedStepCounter[i], statEntry.getStepCounter()); assertEquals(i, statEntry.getEpochCounter()); - assertEquals(38.0, statEntry.getReward(), 0.0001); + assertEquals(expectedReward, statEntry.getReward(), 0.0001); } } - @Test - public void when_run_expect_NeuralNetIsResetAtInitAndEveryEpoch() { - // Arrange - TestContext context = new TestContext(); - - // Act - context.sut.run(); - - // Assert - assertEquals(6, context.neuralNet.resetCallCount); - } - @Test public void when_run_expect_trainSubEpochCalled() { // Arrange - TestContext context = new TestContext(); + int numberOfEpochs = 5; + TestContext context = new TestContext(numberOfEpochs); // Act context.sut.run(); // Assert - assertEquals(10, context.sut.trainSubEpochParams.size()); - for(int i = 0; i < 10; ++i) { + assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size()); + double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }; + for(int i = 0; i < context.sut.getEpochCounter(); ++i) { MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i); - if(i % 2 == 0) { - assertEquals(2, params.nstep); - assertEquals(8.0, params.obs.toArray()[0], 0.00001); - } - else { - assertEquals(1, params.nstep); - assertNull(params.obs); + assertEquals(2, params.nstep); + assertEquals(expectedObservation.length, params.obs.getData().shape()[1]); + for(int j = 0; j < expectedObservation.length; ++j){ + assertEquals(expectedObservation[j], 255.0 * params.obs.getData().getDouble(j), 0.00001); } } } @@ -136,30 +132,30 @@ public class AsyncThreadTest { public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0); public final TrainingListenerList listeners = new TrainingListenerList(); public final MockTrainingListener listener = new MockTrainingListener(); - private final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf); public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners); - public TestContext() { - asyncGlobal.setMaxLoops(10); + public TestContext(int numEpochs) { + asyncGlobal.setMaxLoops(numEpochs); listeners.add(listener); sut.setHistoryProcessor(historyProcessor); } } - public static class MockAsyncThread extends AsyncThread { + public static class MockAsyncThread extends AsyncThread { public int preEpochCallCount = 0; public int postEpochCallCount = 0; - private final IAsyncGlobal asyncGlobal; + private final MockAsyncGlobal asyncGlobal; private final MockNeuralNet neuralNet; private final AsyncConfiguration conf; private final List trainSubEpochParams = new ArrayList(); - public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) { + public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) { super(asyncGlobal, mdp, listeners, threadNumber, 0); this.asyncGlobal = asyncGlobal; @@ -180,7 +176,7 @@ public class AsyncThreadTest { } @Override - protected NeuralNet getCurrent() { + protected MockNeuralNet getCurrent() { return neuralNet; } @@ -195,20 +191,22 @@ public class AsyncThreadTest { } @Override - protected Policy getPolicy(NeuralNet net) { + protected Policy getPolicy(MockNeuralNet net) { return null; } @Override - protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) { + protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) { + asyncGlobal.increaseCurrentLoop(); trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep)); - return new SubEpochReturn(1, null, 1.0, 1.0); + setStepCounter(getStepCounter() + nstep); + return new SubEpochReturn(nstep, null, 1.0, 1.0); } @AllArgsConstructor @Getter public static class TrainSubEpochParams { - Encodable obs; + Observation obs; int nstep; } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java new file mode 100644 index 000000000..ef7fec7d0 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java @@ -0,0 +1,181 @@ +package org.deeplearning4j.rl4j.learning.async.a3c.discrete; + +import org.deeplearning4j.nn.api.NeuralNetwork; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete; +import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningThreadDiscrete; +import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.support.*; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Stack; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class A3CThreadDiscreteTest { + + @Test + public void refac_calcGradient() { + // Arrange + double gamma = 0.9; + MockObservationSpace observationSpace = new MockObservationSpace(); + MockMDP mdpMock = new MockMDP(observationSpace); + A3CDiscrete.A3CConfiguration config = new A3CDiscrete.A3CConfiguration(0, 0, 0, 0, 0, 0, 0, gamma, 0); + MockActorCritic actorCriticMock = new MockActorCritic(); + IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(actorCriticMock); + A3CThreadDiscrete sut = new A3CThreadDiscrete(mdpMock, asyncGlobalMock, config, 0, null, 0); + MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); + sut.setHistoryProcessor(hpMock); + + double[][] minitransObs = new double[][] { + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + }; + double[] outputs = new double[] { 1.0, 2.0, 3.0 }; + double[] rewards = new double[] { 0.0, 0.0, 3.0 }; + + Stack> minitransList = new Stack>(); + for(int i = 0; i < 3; ++i) { + INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1); + INDArray[] output = new INDArray[] { + Nd4j.zeros(5) + }; + output[0].putScalar(i, outputs[i]); + minitransList.push(new MiniTrans(obs, i, output, rewards[i])); + } + minitransList.push(new MiniTrans(null, 0, null, 4.0)); // The special batch-ending MiniTrans + + // Act + sut.calcGradient(actorCriticMock, minitransList); + + // Assert + assertEquals(1, actorCriticMock.gradientParams.size()); + INDArray input = actorCriticMock.gradientParams.get(0).getFirst(); + INDArray[] labels = actorCriticMock.gradientParams.get(0).getSecond(); + + assertEquals(minitransObs.length, input.shape()[0]); + for(int i = 0; i < minitransObs.length; ++i) { + double[] expectedRow = minitransObs[i]; + assertEquals(expectedRow.length, input.shape()[1]); + for(int j = 0; j < expectedRow.length; ++j) { + assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001); + } + } + + double latestReward = (gamma * 4.0) + 3.0; + double[] expectedLabels0 = new double[] { gamma * gamma * latestReward, gamma * latestReward, latestReward }; + for(int i = 0; i < expectedLabels0.length; ++i) { + assertEquals(expectedLabels0[i], labels[0].getDouble(i), 0.00001); + } + double[][] expectedLabels1 = new double[][] { + new double[] { 4.346, 0.0, 0.0, 0.0, 0.0 }, + new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 }, + new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 }, + }; + + assertArrayEquals(new long[] { expectedLabels0.length, 1 }, labels[0].shape()); + + for(int i = 0; i < expectedLabels1.length; ++i) { + double[] expectedRow = expectedLabels1[i]; + assertEquals(expectedRow.length, labels[1].shape()[1]); + for(int j = 0; j < expectedRow.length; ++j) { + assertEquals(expectedRow[j], labels[1].getDouble(i, j), 0.00001); + } + } + + } + + public class MockActorCritic implements IActorCritic { + + public final List> gradientParams = new ArrayList<>(); + + @Override + public NeuralNetwork[] getNeuralNetworks() { + return new NeuralNetwork[0]; + } + + @Override + public boolean isRecurrent() { + return false; + } + + @Override + public void reset() { + + } + + @Override + public void fit(INDArray input, INDArray[] labels) { + + } + + @Override + public INDArray[] outputAll(INDArray batch) { + return new INDArray[0]; + } + + @Override + public IActorCritic clone() { + return this; + } + + @Override + public void copy(NeuralNet from) { + + } + + @Override + public void copy(IActorCritic from) { + + } + + @Override + public Gradient[] gradient(INDArray input, INDArray[] labels) { + gradientParams.add(new Pair(input, labels)); + return new Gradient[0]; + } + + @Override + public void applyGradient(Gradient[] gradient, int batchSize) { + + } + + @Override + public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException { + + } + + @Override + public void save(String pathValue, String pathPolicy) throws IOException { + + } + + @Override + public double getLatestScore() { + return 0; + } + + @Override + public void save(OutputStream os) throws IOException { + + } + + @Override + public void save(String filename) throws IOException { + + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java new file mode 100644 index 000000000..d105419df --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java @@ -0,0 +1,81 @@ +package org.deeplearning4j.rl4j.learning.async.nstep.discrete; + +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.support.*; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Stack; + +import static org.junit.Assert.assertEquals; + +public class AsyncNStepQLearningThreadDiscreteTest { + + @Test + public void refac_calcGradient() { + // Arrange + double gamma = 0.9; + MockObservationSpace observationSpace = new MockObservationSpace(); + MockMDP mdpMock = new MockMDP(observationSpace); + AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration config = new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(0, 0, 0, 0, 0, 0, 0, 0, gamma, 0, 0, 0); + MockDQN dqnMock = new MockDQN(); + IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock); + AsyncNStepQLearningThreadDiscrete sut = new AsyncNStepQLearningThreadDiscrete(mdpMock, asyncGlobalMock, config, null, 0, 0); + MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); + sut.setHistoryProcessor(hpMock); + + double[][] minitransObs = new double[][] { + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + }; + double[] outputs = new double[] { 1.0, 2.0, 3.0 }; + double[] rewards = new double[] { 0.0, 0.0, 3.0 }; + + Stack> minitransList = new Stack>(); + for(int i = 0; i < 3; ++i) { + INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1); + INDArray[] output = new INDArray[] { + Nd4j.zeros(5) + }; + output[0].putScalar(i, outputs[i]); + minitransList.push(new MiniTrans(obs, i, output, rewards[i])); + } + minitransList.push(new MiniTrans(null, 0, null, 4.0)); // The special batch-ending MiniTrans + + // Act + sut.calcGradient(dqnMock, minitransList); + + // Assert + assertEquals(1, dqnMock.gradientParams.size()); + INDArray input = dqnMock.gradientParams.get(0).getFirst(); + INDArray labels = dqnMock.gradientParams.get(0).getSecond(); + + assertEquals(minitransObs.length, input.shape()[0]); + for(int i = 0; i < minitransObs.length; ++i) { + double[] expectedRow = minitransObs[i]; + assertEquals(expectedRow.length, input.shape()[1]); + for(int j = 0; j < expectedRow.length; ++j) { + assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001); + } + } + + double latestReward = (gamma * 4.0) + 3.0; + double[][] expectedLabels = new double[][] { + new double[] { gamma * gamma * latestReward, 0.0, 0.0, 0.0, 0.0 }, + new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 }, + new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 }, + }; + assertEquals(minitransObs.length, labels.shape()[0]); + for(int i = 0; i < minitransObs.length; ++i) { + double[] expectedRow = expectedLabels[i]; + assertEquals(expectedRow.length, labels.shape()[1]); + for(int j = 0; j < expectedRow.length; ++j) { + assertEquals(expectedRow[j], labels.getDouble(i, j), 0.00001); + } + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 59c28551b..22e1ba49d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -63,7 +63,7 @@ public class QLearningDiscreteTest { double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 }; assertEquals(expectedAdds.length, hp.addCalls.size()); for(int i = 0; i < expectedAdds.length; ++i) { - assertEquals(expectedAdds[i], 255.0 * hp.addCalls.get(i).getDouble(0), 0.0001); + assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001); } assertEquals(0, hp.startMonitorCallCount); assertEquals(0, hp.stopMonitorCallCount); @@ -92,8 +92,8 @@ public class QLearningDiscreteTest { for(int i = 0; i < expectedDQNOutput.length; ++i) { INDArray outputParam = dqn.outputParams.get(i); - assertEquals(5, outputParam.shape()[0]); - assertEquals(1, outputParam.shape()[1]); + assertEquals(5, outputParam.shape()[1]); + assertEquals(1, outputParam.shape()[2]); double[] expectedRow = expectedDQNOutput[i]; for(int j = 0; j < expectedRow.length; ++j) { @@ -124,13 +124,15 @@ public class QLearningDiscreteTest { assertEquals(expectedTrActions[i], tr.getAction()); assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001); for(int j = 0; j < expectedTrObservations[i].length; ++j) { - assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(j, 0), 0.0001); + assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001); } } // trainEpoch result assertEquals(16, result.getStepCounter()); assertEquals(300.0, result.getReward(), 0.00001); + assertTrue(dqn.hasBeenReset); + assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset); } public static class TestQLearningDiscrete extends QLearningDiscrete { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index ffb3680bb..6fa2d06d4 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -28,6 +28,7 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscret import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.support.*; @@ -229,6 +230,11 @@ public class PolicyTest { return neuralNet; } + @Override + public Integer nextAction(Observation obs) { + return nextAction(obs.getData()); + } + @Override public Integer nextAction(INDArray input) { return (int)input.getDouble(0); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java index 34a2078f0..33dc82314 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java @@ -8,9 +8,10 @@ import org.deeplearning4j.rl4j.network.NeuralNet; import java.util.concurrent.atomic.AtomicInteger; -public class MockAsyncGlobal implements IAsyncGlobal { +public class MockAsyncGlobal implements IAsyncGlobal { - private final NeuralNet current; + @Getter + private final NN current; public boolean hasBeenStarted = false; public boolean hasBeenTerminated = false; @@ -27,7 +28,7 @@ public class MockAsyncGlobal implements IAsyncGlobal { this(null); } - public MockAsyncGlobal(NeuralNet current) { + public MockAsyncGlobal(NN current) { maxLoops = Integer.MAX_VALUE; numLoopsStopRunning = Integer.MAX_VALUE; this.current = current; @@ -45,7 +46,7 @@ public class MockAsyncGlobal implements IAsyncGlobal { @Override public boolean isTrainingComplete() { - return ++currentLoop > maxLoops; + return currentLoop >= maxLoops; } @Override @@ -59,12 +60,7 @@ public class MockAsyncGlobal implements IAsyncGlobal { } @Override - public NeuralNet getCurrent() { - return current; - } - - @Override - public NeuralNet getTarget() { + public NN getTarget() { return current; } @@ -72,4 +68,8 @@ public class MockAsyncGlobal implements IAsyncGlobal { public void enqueue(Gradient[] gradient, Integer nstep) { ++enqueueCallCount; } + + public void increaseCurrentLoop() { + ++currentLoop; + } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java index 680f9a653..28d7f3914 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java @@ -15,8 +15,10 @@ import java.util.List; public class MockDQN implements IDQN { + public boolean hasBeenReset = false; public final List outputParams = new ArrayList<>(); public final List> fitParams = new ArrayList<>(); + public final List> gradientParams = new ArrayList<>(); @Override public NeuralNetwork[] getNeuralNetworks() { @@ -30,7 +32,7 @@ public class MockDQN implements IDQN { @Override public void reset() { - + hasBeenReset = true; } @Override @@ -61,7 +63,10 @@ public class MockDQN implements IDQN { @Override public IDQN clone() { - return null; + MockDQN clone = new MockDQN(); + clone.hasBeenReset = hasBeenReset; + + return clone; } @Override @@ -76,6 +81,7 @@ public class MockDQN implements IDQN { @Override public Gradient[] gradient(INDArray input, INDArray label) { + gradientParams.add(new Pair(input, label)); return new Gradient[0]; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java index 6d542934b..a5d7c5f3e 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java @@ -35,7 +35,7 @@ public class MockNeuralNet implements NeuralNet { @Override public INDArray[] outputAll(INDArray batch) { outputAllInputs.add(batch); - return new INDArray[] { Nd4j.create(new double[] { 1.0 }) }; + return new INDArray[] { Nd4j.create(new double[] { outputAllInputs.size() }) }; } @Override diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java index 82adc65b7..4c4f100e9 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java @@ -2,6 +2,7 @@ package org.deeplearning4j.rl4j.support; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -23,6 +24,11 @@ public class MockPolicy implements IPolicy { @Override public Integer nextAction(INDArray input) { actionInputs.add(input); - return null; + return input.getInt(0, 0, 0); + } + + @Override + public Integer nextAction(Observation observation) { + return nextAction(observation.getData()); } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java index 97bf5cc28..d4e696248 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java @@ -11,6 +11,7 @@ import java.util.List; public class MockTrainingListener implements TrainingListener { + private final MockAsyncGlobal asyncGlobal; public int onTrainingStartCallCount = 0; public int onTrainingEndCallCount = 0; public int onNewEpochCallCount = 0; @@ -28,6 +29,14 @@ public class MockTrainingListener implements TrainingListener { public final List statEntries = new ArrayList<>(); + public MockTrainingListener() { + this(null); + } + + public MockTrainingListener(MockAsyncGlobal asyncGlobal) { + this.asyncGlobal = asyncGlobal; + } + @Override public ListenerResponse onTrainingStart() { @@ -55,6 +64,9 @@ public class MockTrainingListener implements TrainingListener { public ListenerResponse onTrainingProgress(ILearning learning) { ++onTrainingProgressCallCount; --remainingonTrainingProgressCallCount; + if(asyncGlobal != null) { + asyncGlobal.increaseCurrentLoop(); + } return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE; }