RL4J: Change frame skipping logic (#8596)
* Added isSkipped() to Observation Signed-off-by: unknown <aboulang2002@yahoo.com> * Changed refacInitMdp to use isSkipped() Signed-off-by: unknown <aboulang2002@yahoo.com> * Changed getHistoryProcessor() Signed-off-by: unknown <aboulang2002@yahoo.com> * Fixed getEpochCounter() incorrectly changed to getCurrentEpochStep() calls Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Removed StepCountable Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Fix build Signed-off-by: Samuel Audet <samuel.audet@gmail.com> * Fixed a problem in QLearningDiscrete and another in CartpoleNative Signed-off-by: unknown <aboulang2002@yahoo.com> * Update versions of JavaCPP Presets for NumPy, MKL, Gym, and TensorFlow Signed-off-by: Samuel Audet <samuel.audet@gmail.com> * RL4J: Add ability to set a random seed for GymEnv Signed-off-by: Samuel Audet <samuel.audet@gmail.com> Co-authored-by: Samuel Audet <samuel.audet@gmail.com>master
parent
7a20324105
commit
20e3039f2e
4
pom.xml
4
pom.xml
|
@ -303,8 +303,8 @@
|
||||||
<leptonica.version>1.79.0</leptonica.version>
|
<leptonica.version>1.79.0</leptonica.version>
|
||||||
<hdf5.version>1.10.6</hdf5.version>
|
<hdf5.version>1.10.6</hdf5.version>
|
||||||
<ale.version>0.6.1</ale.version>
|
<ale.version>0.6.1</ale.version>
|
||||||
<gym.version>0.15.4</gym.version>
|
<gym.version>0.15.5</gym.version>
|
||||||
<tensorflow.version>1.15.0</tensorflow.version>
|
<tensorflow.version>1.15.2</tensorflow.version>
|
||||||
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>
|
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>
|
||||||
|
|
||||||
<commons-compress.version>1.18</commons-compress.version>
|
<commons-compress.version>1.18</commons-compress.version>
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning;
|
||||||
|
|
||||||
|
public interface EpochStepCounter {
|
||||||
|
int getCurrentEpochStep();
|
||||||
|
}
|
|
@ -21,9 +21,14 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
/**
|
/**
|
||||||
* The common API between Learning and AsyncThread.
|
* The common API between Learning and AsyncThread.
|
||||||
*
|
*
|
||||||
|
* Express the ability to count the number of step of the current training.
|
||||||
|
* Factorisation of a feature between threads in async and learning process
|
||||||
|
* for the web monitoring
|
||||||
|
*
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
*/
|
*/
|
||||||
public interface IEpochTrainer {
|
public interface IEpochTrainer extends EpochStepCounter {
|
||||||
int getStepCounter();
|
int getStepCounter();
|
||||||
int getEpochCounter();
|
int getEpochCounter();
|
||||||
IHistoryProcessor getHistoryProcessor();
|
IHistoryProcessor getHistoryProcessor();
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
*
|
*
|
||||||
* A common interface that any training method should implement
|
* A common interface that any training method should implement
|
||||||
*/
|
*/
|
||||||
public interface ILearning<O, A, AS extends ActionSpace<A>> extends StepCountable {
|
public interface ILearning<O, A, AS extends ActionSpace<A>> {
|
||||||
|
|
||||||
IPolicy<O, A> getPolicy();
|
IPolicy<O, A> getPolicy();
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
@ -53,55 +52,6 @@ public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends Neura
|
||||||
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
|
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <O, A, AS extends ActionSpace<A>> INDArray getInput(MDP<O, A, AS> mdp, O obs) {
|
|
||||||
INDArray arr = Nd4j.create(((Encodable)obs).toArray());
|
|
||||||
int[] shape = mdp.getObservationSpace().getShape();
|
|
||||||
if (shape.length == 1)
|
|
||||||
return arr.reshape(new long[] {1, arr.length()});
|
|
||||||
else
|
|
||||||
return arr.reshape(shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static <O, A, AS extends ActionSpace<A>> InitMdp<O> initMdp(MDP<O, A, AS> mdp,
|
|
||||||
IHistoryProcessor hp) {
|
|
||||||
|
|
||||||
O obs = mdp.reset();
|
|
||||||
|
|
||||||
int step = 0;
|
|
||||||
double reward = 0;
|
|
||||||
|
|
||||||
boolean isHistoryProcessor = hp != null;
|
|
||||||
|
|
||||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
|
||||||
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
|
||||||
|
|
||||||
INDArray input = Learning.getInput(mdp, obs);
|
|
||||||
if (isHistoryProcessor)
|
|
||||||
hp.record(input);
|
|
||||||
|
|
||||||
|
|
||||||
while (step < requiredFrame && !mdp.isDone()) {
|
|
||||||
|
|
||||||
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
|
||||||
if (step % skipFrame == 0 && isHistoryProcessor)
|
|
||||||
hp.add(input);
|
|
||||||
|
|
||||||
StepReply<O> stepReply = mdp.step(action);
|
|
||||||
reward += stepReply.getReward();
|
|
||||||
obs = stepReply.getObservation();
|
|
||||||
|
|
||||||
input = Learning.getInput(mdp, obs);
|
|
||||||
if (isHistoryProcessor)
|
|
||||||
hp.record(input);
|
|
||||||
|
|
||||||
step++;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return new InitMdp(step, obs, reward);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int[] makeShape(int size, int[] shape) {
|
public static int[] makeShape(int size, int[] shape) {
|
||||||
int[] nshape = new int[shape.length + 1];
|
int[] nshape = new int[shape.length + 1];
|
||||||
nshape[0] = size;
|
nshape[0] = size;
|
||||||
|
@ -122,16 +72,16 @@ public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends Neura
|
||||||
|
|
||||||
public abstract NN getNeuralNet();
|
public abstract NN getNeuralNet();
|
||||||
|
|
||||||
public int incrementStep() {
|
public void incrementStep() {
|
||||||
return stepCounter++;
|
stepCounter++;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int incrementEpoch() {
|
public void incrementEpoch() {
|
||||||
return epochCounter++;
|
epochCounter++;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setHistoryProcessor(HistoryProcessor.Configuration conf) {
|
public void setHistoryProcessor(HistoryProcessor.Configuration conf) {
|
||||||
historyProcessor = new HistoryProcessor(conf);
|
setHistoryProcessor(new HistoryProcessor(conf));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
|
||||||
*
|
|
||||||
* Express the ability to count the number of step of the current training.
|
|
||||||
* Factorisation of a feature between threads in async and learning process
|
|
||||||
* for the web monitoring
|
|
||||||
*/
|
|
||||||
public interface StepCountable {
|
|
||||||
|
|
||||||
int getStepCounter();
|
|
||||||
|
|
||||||
}
|
|
|
@ -30,8 +30,6 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -48,7 +46,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
extends Thread implements StepCountable, IEpochTrainer {
|
extends Thread implements IEpochTrainer {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private int threadNumber;
|
private int threadNumber;
|
||||||
|
@ -61,6 +59,9 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private IHistoryProcessor historyProcessor;
|
private IHistoryProcessor historyProcessor;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private int currentEpochStep = 0;
|
||||||
|
|
||||||
private boolean isEpochStarted = false;
|
private boolean isEpochStarted = false;
|
||||||
private final LegacyMDPWrapper<O, A, AS> mdp;
|
private final LegacyMDPWrapper<O, A, AS> mdp;
|
||||||
|
|
||||||
|
@ -138,7 +139,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
|
|
||||||
handleTraining(context);
|
handleTraining(context);
|
||||||
|
|
||||||
if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
if (currentEpochStep >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||||
boolean canContinue = finishEpoch(context);
|
boolean canContinue = finishEpoch(context);
|
||||||
if (!canContinue) {
|
if (!canContinue) {
|
||||||
break;
|
break;
|
||||||
|
@ -154,11 +155,10 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
}
|
}
|
||||||
|
|
||||||
private void handleTraining(RunContext context) {
|
private void handleTraining(RunContext context) {
|
||||||
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps);
|
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep);
|
||||||
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
||||||
|
|
||||||
context.obs = subEpochReturn.getLastObs();
|
context.obs = subEpochReturn.getLastObs();
|
||||||
context.epochElapsedSteps += subEpochReturn.getSteps();
|
|
||||||
context.rewards += subEpochReturn.getReward();
|
context.rewards += subEpochReturn.getReward();
|
||||||
context.score = subEpochReturn.getScore();
|
context.score = subEpochReturn.getScore();
|
||||||
}
|
}
|
||||||
|
@ -169,7 +169,6 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
|
|
||||||
context.obs = initMdp.getLastObs();
|
context.obs = initMdp.getLastObs();
|
||||||
context.rewards = initMdp.getReward();
|
context.rewards = initMdp.getReward();
|
||||||
context.epochElapsedSteps = initMdp.getSteps();
|
|
||||||
|
|
||||||
isEpochStarted = true;
|
isEpochStarted = true;
|
||||||
preEpoch();
|
preEpoch();
|
||||||
|
@ -180,9 +179,9 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
private boolean finishEpoch(RunContext context) {
|
private boolean finishEpoch(RunContext context) {
|
||||||
isEpochStarted = false;
|
isEpochStarted = false;
|
||||||
postEpoch();
|
postEpoch();
|
||||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score);
|
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, currentEpochStep, context.score);
|
||||||
|
|
||||||
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
|
log.info("ThreadNum-" + threadNumber + " Epoch: " + getCurrentEpochStep() + ", reward: " + context.rewards);
|
||||||
|
|
||||||
return listeners.notifyEpochTrainingResult(this, statEntry);
|
return listeners.notifyEpochTrainingResult(this, statEntry);
|
||||||
}
|
}
|
||||||
|
@ -205,37 +204,30 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
|
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
|
||||||
|
|
||||||
private Learning.InitMdp<Observation> refacInitMdp() {
|
private Learning.InitMdp<Observation> refacInitMdp() {
|
||||||
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
currentEpochStep = 0;
|
||||||
IHistoryProcessor hp = getHistoryProcessor();
|
|
||||||
|
|
||||||
Observation observation = mdp.reset();
|
|
||||||
|
|
||||||
int step = 0;
|
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
|
|
||||||
boolean isHistoryProcessor = hp != null;
|
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
||||||
|
Observation observation = mdp.reset();
|
||||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
|
||||||
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
|
||||||
|
|
||||||
while (step < requiredFrame && !mdp.isDone()) {
|
|
||||||
|
|
||||||
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||||
|
while (observation.isSkipped() && !mdp.isDone()) {
|
||||||
StepReply<Observation> stepReply = mdp.step(action);
|
StepReply<Observation> stepReply = mdp.step(action);
|
||||||
|
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
observation = stepReply.getObservation();
|
observation = stepReply.getObservation();
|
||||||
|
|
||||||
step++;
|
incrementStep();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Learning.InitMdp(step, observation, reward);
|
return new Learning.InitMdp(0, observation, reward);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void incrementStep() {
|
public void incrementStep() {
|
||||||
++stepCounter;
|
++stepCounter;
|
||||||
|
++currentEpochStep;
|
||||||
}
|
}
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
|
@ -260,7 +252,6 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
private static class RunContext {
|
private static class RunContext {
|
||||||
private Observation obs;
|
private Observation obs;
|
||||||
private double rewards;
|
private double rewards;
|
||||||
private int epochElapsedSteps;
|
|
||||||
private double score;
|
private double score;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,19 +20,14 @@ import lombok.Getter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
|
||||||
|
|
||||||
import java.util.Stack;
|
import java.util.Stack;
|
||||||
|
|
||||||
|
@ -74,17 +69,18 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
IPolicy<O, Integer> policy = getPolicy(current);
|
IPolicy<O, Integer> policy = getPolicy(current);
|
||||||
|
|
||||||
Integer action;
|
Integer action;
|
||||||
Integer lastAction = null;
|
Integer lastAction = getMdp().getActionSpace().noOp();
|
||||||
IHistoryProcessor hp = getHistoryProcessor();
|
IHistoryProcessor hp = getHistoryProcessor();
|
||||||
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
|
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
|
||||||
|
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
double accuReward = 0;
|
double accuReward = 0;
|
||||||
int i = 0;
|
int stepAtStart = getCurrentEpochStep();
|
||||||
while (!getMdp().isDone() && i < nstep * skipFrame) {
|
int lastStep = nstep * skipFrame + stepAtStart;
|
||||||
|
while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) {
|
||||||
|
|
||||||
//if step of training, just repeat lastAction
|
//if step of training, just repeat lastAction
|
||||||
if (i % skipFrame != 0 && lastAction != null) {
|
if (obs.isSkipped()) {
|
||||||
action = lastAction;
|
action = lastAction;
|
||||||
} else {
|
} else {
|
||||||
action = policy.nextAction(obs);
|
action = policy.nextAction(obs);
|
||||||
|
@ -94,7 +90,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
accuReward += stepReply.getReward() * getConf().getRewardFactor();
|
accuReward += stepReply.getReward() * getConf().getRewardFactor();
|
||||||
|
|
||||||
//if it's not a skipped frame, you can do a step of training
|
//if it's not a skipped frame, you can do a step of training
|
||||||
if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) {
|
if (!obs.isSkipped() || stepReply.isDone()) {
|
||||||
|
|
||||||
INDArray[] output = current.outputAll(obs.getData());
|
INDArray[] output = current.outputAll(obs.getData());
|
||||||
rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
|
rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
|
||||||
|
@ -106,7 +102,6 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
|
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
|
|
||||||
i++;
|
|
||||||
incrementStep();
|
incrementStep();
|
||||||
lastAction = action;
|
lastAction = action;
|
||||||
}
|
}
|
||||||
|
@ -114,7 +109,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
//a bit of a trick usable because of how the stack is treated to init R
|
//a bit of a trick usable because of how the stack is treated to init R
|
||||||
// FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored.
|
// FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored.
|
||||||
|
|
||||||
if (getMdp().isDone() && i < nstep * skipFrame)
|
if (getMdp().isDone() && getCurrentEpochStep() < lastStep)
|
||||||
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
|
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
|
||||||
else {
|
else {
|
||||||
INDArray[] output = null;
|
INDArray[] output = null;
|
||||||
|
@ -127,9 +122,9 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
rewards.add(new MiniTrans(obs.getData(), null, output, maxQ));
|
rewards.add(new MiniTrans(obs.getData(), null, output, maxQ));
|
||||||
}
|
}
|
||||||
|
|
||||||
getAsyncGlobal().enqueue(calcGradient(current, rewards), i);
|
getAsyncGlobal().enqueue(calcGradient(current, rewards), getCurrentEpochStep());
|
||||||
|
|
||||||
return new SubEpochReturn(i, obs, reward, current.getLatestScore());
|
return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore());
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> rewards);
|
public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> rewards);
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync;
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
@ -25,7 +24,6 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.*;
|
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync;
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -34,7 +35,7 @@ import java.util.List;
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@Value
|
@Data
|
||||||
public class Transition<A> {
|
public class Transition<A> {
|
||||||
|
|
||||||
Observation observation;
|
Observation observation;
|
||||||
|
@ -43,12 +44,15 @@ public class Transition<A> {
|
||||||
boolean isTerminal;
|
boolean isTerminal;
|
||||||
INDArray nextObservation;
|
INDArray nextObservation;
|
||||||
|
|
||||||
public Transition(Observation observation, A action, double reward, boolean isTerminal, Observation nextObservation) {
|
public Transition(Observation observation, A action, double reward, boolean isTerminal) {
|
||||||
this.observation = observation;
|
this.observation = observation;
|
||||||
this.action = action;
|
this.action = action;
|
||||||
this.reward = reward;
|
this.reward = reward;
|
||||||
this.isTerminal = isTerminal;
|
this.isTerminal = isTerminal;
|
||||||
|
this.nextObservation = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setNextObservation(Observation nextObservation) {
|
||||||
// To conserve memory, only the most recent frame of the next observation is kept (if history is used).
|
// To conserve memory, only the most recent frame of the next observation is kept (if history is used).
|
||||||
// The full nextObservation will be re-build from observation when needed.
|
// The full nextObservation will be re-build from observation when needed.
|
||||||
long[] nextObservationShape = nextObservation.getData().shape().clone();
|
long[] nextObservationShape = nextObservation.getData().shape().clone();
|
||||||
|
|
|
@ -21,8 +21,7 @@ import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
|
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
|
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
|
||||||
|
@ -34,9 +33,8 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -49,7 +47,8 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||||
extends SyncLearning<O, A, AS, IDQN> implements TargetQNetworkSource {
|
extends SyncLearning<O, A, AS, IDQN>
|
||||||
|
implements TargetQNetworkSource, EpochStepCounter {
|
||||||
|
|
||||||
// FIXME Changed for refac
|
// FIXME Changed for refac
|
||||||
// @Getter
|
// @Getter
|
||||||
|
@ -104,18 +103,22 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
|
|
||||||
protected abstract QLStepReturn<Observation> trainStep(Observation obs);
|
protected abstract QLStepReturn<Observation> trainStep(Observation obs);
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private int currentEpochStep = 0;
|
||||||
|
|
||||||
protected StatEntry trainEpoch() {
|
protected StatEntry trainEpoch() {
|
||||||
|
resetNetworks();
|
||||||
|
|
||||||
InitMdp<Observation> initMdp = refacInitMdp();
|
InitMdp<Observation> initMdp = refacInitMdp();
|
||||||
Observation obs = initMdp.getLastObs();
|
Observation obs = initMdp.getLastObs();
|
||||||
|
|
||||||
double reward = initMdp.getReward();
|
double reward = initMdp.getReward();
|
||||||
int step = initMdp.getSteps();
|
|
||||||
|
|
||||||
Double startQ = Double.NaN;
|
Double startQ = Double.NaN;
|
||||||
double meanQ = 0;
|
double meanQ = 0;
|
||||||
int numQ = 0;
|
int numQ = 0;
|
||||||
List<Double> scores = new ArrayList<>();
|
List<Double> scores = new ArrayList<>();
|
||||||
while (step < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
|
while (currentEpochStep < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
|
||||||
|
|
||||||
if (getStepCounter() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
|
if (getStepCounter() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
|
||||||
updateTargetNetwork();
|
updateTargetNetwork();
|
||||||
|
@ -136,49 +139,53 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
reward += stepR.getStepReply().getReward();
|
reward += stepR.getStepReply().getReward();
|
||||||
obs = stepR.getStepReply().getObservation();
|
obs = stepR.getStepReply().getObservation();
|
||||||
incrementStep();
|
incrementStep();
|
||||||
step++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
finishEpoch(obs);
|
||||||
|
|
||||||
meanQ /= (numQ + 0.001); //avoid div zero
|
meanQ /= (numQ + 0.001); //avoid div zero
|
||||||
|
|
||||||
|
|
||||||
StatEntry statEntry = new QLStatEntry(getStepCounter(), getEpochCounter(), reward, step, scores,
|
StatEntry statEntry = new QLStatEntry(getStepCounter(), getEpochCounter(), reward, currentEpochStep, scores,
|
||||||
getEgPolicy().getEpsilon(), startQ, meanQ);
|
getEgPolicy().getEpsilon(), startQ, meanQ);
|
||||||
|
|
||||||
return statEntry;
|
return statEntry;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void finishEpoch(Observation observation) {
|
||||||
|
// Do Nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void incrementStep() {
|
||||||
|
super.incrementStep();
|
||||||
|
++currentEpochStep;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void resetNetworks() {
|
||||||
|
getQNetwork().reset();
|
||||||
|
getTargetQNetwork().reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
private InitMdp<Observation> refacInitMdp() {
|
private InitMdp<Observation> refacInitMdp() {
|
||||||
getQNetwork().reset();
|
currentEpochStep = 0;
|
||||||
getTargetQNetwork().reset();
|
|
||||||
|
|
||||||
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
|
||||||
IHistoryProcessor hp = getHistoryProcessor();
|
|
||||||
|
|
||||||
Observation observation = mdp.reset();
|
|
||||||
|
|
||||||
int step = 0;
|
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
|
|
||||||
boolean isHistoryProcessor = hp != null;
|
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
||||||
|
Observation observation = mdp.reset();
|
||||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
|
||||||
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
|
||||||
|
|
||||||
while (step < requiredFrame && !mdp.isDone()) {
|
|
||||||
|
|
||||||
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||||
|
while (observation.isSkipped() && !mdp.isDone()) {
|
||||||
StepReply<Observation> stepReply = mdp.step(action);
|
StepReply<Observation> stepReply = mdp.step(action);
|
||||||
|
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
observation = stepReply.getObservation();
|
observation = stepReply.getObservation();
|
||||||
|
|
||||||
step++;
|
incrementStep();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new InitMdp(step, observation, reward);
|
return new InitMdp(0, observation, reward);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,10 +20,13 @@ import lombok.AccessLevel;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
@ -36,7 +39,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
|
||||||
|
@ -68,6 +70,8 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
private int lastAction;
|
private int lastAction;
|
||||||
private double accuReward = 0;
|
private double accuReward = 0;
|
||||||
|
|
||||||
|
private Transition pendingTransition;
|
||||||
|
|
||||||
ITDTargetAlgorithm tdTargetAlgorithm;
|
ITDTargetAlgorithm tdTargetAlgorithm;
|
||||||
|
|
||||||
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
|
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
|
||||||
|
@ -83,7 +87,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
int epsilonNbStep, Random random) {
|
int epsilonNbStep, Random random) {
|
||||||
super(conf);
|
super(conf);
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, this);
|
this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, null, this);
|
||||||
qNetwork = dqn;
|
qNetwork = dqn;
|
||||||
targetQNetwork = dqn.clone();
|
targetQNetwork = dqn.clone();
|
||||||
policy = new DQNPolicy(getQNetwork());
|
policy = new DQNPolicy(getQNetwork());
|
||||||
|
@ -108,8 +112,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
}
|
}
|
||||||
|
|
||||||
public void preEpoch() {
|
public void preEpoch() {
|
||||||
lastAction = 0;
|
lastAction = mdp.getActionSpace().noOp();
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
|
pendingTransition = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
||||||
|
super.setHistoryProcessor(historyProcessor);
|
||||||
|
mdp.setHistoryProcessor(historyProcessor);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -120,9 +131,8 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
protected QLStepReturn<Observation> trainStep(Observation obs) {
|
protected QLStepReturn<Observation> trainStep(Observation obs) {
|
||||||
|
|
||||||
Integer action;
|
Integer action;
|
||||||
|
|
||||||
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
||||||
|
|
||||||
|
|
||||||
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
||||||
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
||||||
int updateStart = getConfiguration().getUpdateStart()
|
int updateStart = getConfiguration().getUpdateStart()
|
||||||
|
@ -131,7 +141,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
Double maxQ = Double.NaN; //ignore if Nan for stats
|
Double maxQ = Double.NaN; //ignore if Nan for stats
|
||||||
|
|
||||||
//if step of training, just repeat lastAction
|
//if step of training, just repeat lastAction
|
||||||
if (getStepCounter() % skipFrame != 0) {
|
if (obs.isSkipped()) {
|
||||||
action = lastAction;
|
action = lastAction;
|
||||||
} else {
|
} else {
|
||||||
INDArray qs = getQNetwork().output(obs);
|
INDArray qs = getQNetwork().output(obs);
|
||||||
|
@ -145,22 +155,25 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
StepReply<Observation> stepReply = mdp.step(action);
|
StepReply<Observation> stepReply = mdp.step(action);
|
||||||
|
|
||||||
Observation nextObservation = stepReply.getObservation();
|
|
||||||
|
|
||||||
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
||||||
|
|
||||||
//if it's not a skipped frame, you can do a step of training
|
//if it's not a skipped frame, you can do a step of training
|
||||||
if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
|
if (!obs.isSkipped() || stepReply.isDone()) {
|
||||||
|
|
||||||
Transition<Integer> trans = new Transition(obs, action, accuReward, stepReply.isDone(), nextObservation);
|
// Add experience
|
||||||
getExpReplay().store(trans);
|
if(pendingTransition != null) {
|
||||||
|
pendingTransition.setNextObservation(obs);
|
||||||
|
getExpReplay().store(pendingTransition);
|
||||||
|
}
|
||||||
|
pendingTransition = new Transition(obs, action, accuReward, stepReply.isDone());
|
||||||
|
accuReward = 0;
|
||||||
|
|
||||||
|
// Update NN
|
||||||
|
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
|
||||||
if (getStepCounter() > updateStart) {
|
if (getStepCounter() > updateStart) {
|
||||||
DataSet targets = setTarget(getExpReplay().getBatch());
|
DataSet targets = setTarget(getExpReplay().getBatch());
|
||||||
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
||||||
}
|
}
|
||||||
|
|
||||||
accuReward = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
||||||
|
@ -172,4 +185,12 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
return tdTargetAlgorithm.computeTDTargets(transitions);
|
return tdTargetAlgorithm.computeTDTargets(transitions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void finishEpoch(Observation observation) {
|
||||||
|
if(pendingTransition != null) {
|
||||||
|
pendingTransition.setNextObservation(observation);
|
||||||
|
getExpReplay().store(pendingTransition);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,6 +92,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
||||||
theta = 0.1 * rnd.nextDouble() - 0.05;
|
theta = 0.1 * rnd.nextDouble() - 0.05;
|
||||||
thetaDot = 0.1 * rnd.nextDouble() - 0.05;
|
thetaDot = 0.1 * rnd.nextDouble() - 0.05;
|
||||||
stepsBeyondDone = null;
|
stepsBeyondDone = null;
|
||||||
|
done = false;
|
||||||
|
|
||||||
return new State(new double[] { x, xDot, theta, thetaDot });
|
return new State(new double[] { x, xDot, theta, thetaDot });
|
||||||
}
|
}
|
||||||
|
@ -126,7 +127,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean done = x < -xThreshold || x > xThreshold
|
done |= x < -xThreshold || x > xThreshold
|
||||||
|| theta < -thetaThresholdRadians || theta > thetaThresholdRadians;
|
|| theta < -thetaThresholdRadians || theta > thetaThresholdRadians;
|
||||||
|
|
||||||
double reward;
|
double reward;
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.observation;
|
package org.deeplearning4j.rl4j.observation;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -28,6 +30,9 @@ public class Observation {
|
||||||
|
|
||||||
private final DataSet data;
|
private final DataSet data;
|
||||||
|
|
||||||
|
@Getter @Setter
|
||||||
|
private boolean skipped;
|
||||||
|
|
||||||
public Observation(INDArray[] data) {
|
public Observation(INDArray[] data) {
|
||||||
this(data, false);
|
this(data, false);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,8 @@ package org.deeplearning4j.rl4j.policy;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.rl4j.learning.StepCountable;
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
@ -46,7 +47,7 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
||||||
final private int epsilonNbStep;
|
final private int epsilonNbStep;
|
||||||
final private Random rnd;
|
final private Random rnd;
|
||||||
final private float minEpsilon;
|
final private float minEpsilon;
|
||||||
final private StepCountable learning;
|
final private IEpochTrainer learning;
|
||||||
|
|
||||||
public NeuralNet getNeuralNet() {
|
public NeuralNet getNeuralNet() {
|
||||||
return policy.getNeuralNet();
|
return policy.getNeuralNet();
|
||||||
|
|
|
@ -19,20 +19,15 @@ package org.deeplearning4j.rl4j.policy;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
|
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
||||||
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.StepCountable;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
||||||
|
@ -57,24 +52,22 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
||||||
RefacStepCountable stepCountable = new RefacStepCountable();
|
resetNetworks();
|
||||||
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, stepCountable);
|
|
||||||
|
|
||||||
boolean isHistoryProcessor = hp != null;
|
RefacEpochStepCounter epochStepCounter = new RefacEpochStepCounter();
|
||||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, epochStepCounter);
|
||||||
|
|
||||||
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp);
|
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp, epochStepCounter);
|
||||||
Observation obs = initMdp.getLastObs();
|
Observation obs = initMdp.getLastObs();
|
||||||
|
|
||||||
double reward = initMdp.getReward();
|
double reward = initMdp.getReward();
|
||||||
|
|
||||||
A lastAction = mdpWrapper.getActionSpace().noOp();
|
A lastAction = mdpWrapper.getActionSpace().noOp();
|
||||||
A action;
|
A action;
|
||||||
stepCountable.setStepCounter(initMdp.getSteps());
|
|
||||||
|
|
||||||
while (!mdpWrapper.isDone()) {
|
while (!mdpWrapper.isDone()) {
|
||||||
|
|
||||||
if (stepCountable.getStepCounter() % skipFrame != 0) {
|
if (obs.isSkipped()) {
|
||||||
action = lastAction;
|
action = lastAction;
|
||||||
} else {
|
} else {
|
||||||
action = nextAction(obs);
|
action = nextAction(obs);
|
||||||
|
@ -86,52 +79,46 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
|
|
||||||
obs = stepReply.getObservation();
|
obs = stepReply.getObservation();
|
||||||
stepCountable.increment();
|
epochStepCounter.incrementEpochStep();
|
||||||
}
|
}
|
||||||
|
|
||||||
return reward;
|
return reward;
|
||||||
}
|
}
|
||||||
|
|
||||||
private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
|
protected void resetNetworks() {
|
||||||
getNeuralNet().reset();
|
getNeuralNet().reset();
|
||||||
Observation observation = mdpWrapper.reset();
|
}
|
||||||
|
|
||||||
|
private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
|
||||||
|
epochStepCounter.setCurrentEpochStep(0);
|
||||||
|
|
||||||
int step = 0;
|
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
|
|
||||||
boolean isHistoryProcessor = hp != null;
|
Observation observation = mdpWrapper.reset();
|
||||||
|
|
||||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
|
||||||
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
|
||||||
|
|
||||||
while (step < requiredFrame && !mdpWrapper.isDone()) {
|
|
||||||
|
|
||||||
A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP
|
A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||||
|
while (observation.isSkipped() && !mdpWrapper.isDone()) {
|
||||||
|
|
||||||
StepReply<Observation> stepReply = mdpWrapper.step(action);
|
StepReply<Observation> stepReply = mdpWrapper.step(action);
|
||||||
|
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
observation = stepReply.getObservation();
|
observation = stepReply.getObservation();
|
||||||
|
|
||||||
step++;
|
epochStepCounter.incrementEpochStep();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Learning.InitMdp(step, observation, reward);
|
return new Learning.InitMdp(0, observation, reward);
|
||||||
}
|
}
|
||||||
|
|
||||||
private class RefacStepCountable implements StepCountable {
|
public class RefacEpochStepCounter implements EpochStepCounter {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
private int stepCounter = 0;
|
private int currentEpochStep = 0;
|
||||||
|
|
||||||
public void increment() {
|
public void incrementEpochStep() {
|
||||||
++stepCounter;
|
++currentEpochStep;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getStepCounter() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
package org.deeplearning4j.rl4j.util;
|
package org.deeplearning4j.rl4j.util;
|
||||||
|
|
||||||
|
import lombok.AccessLevel;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
|
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
|
||||||
import org.deeplearning4j.rl4j.learning.StepCountable;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
@ -19,50 +20,20 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
||||||
private final MDP<O, A, AS> wrappedMDP;
|
private final MDP<O, A, AS> wrappedMDP;
|
||||||
@Getter
|
@Getter
|
||||||
private final WrapperObservationSpace observationSpace;
|
private final WrapperObservationSpace observationSpace;
|
||||||
private final ILearning learning;
|
|
||||||
|
@Getter(AccessLevel.PRIVATE) @Setter(AccessLevel.PUBLIC)
|
||||||
private IHistoryProcessor historyProcessor;
|
private IHistoryProcessor historyProcessor;
|
||||||
private final StepCountable stepCountable;
|
|
||||||
private int skipFrame;
|
|
||||||
|
|
||||||
private int step = 0;
|
private final EpochStepCounter epochStepCounter;
|
||||||
|
|
||||||
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning) {
|
private int skipFrame = 1;
|
||||||
this(wrappedMDP, learning, null, null);
|
private int requiredFrame = 0;
|
||||||
}
|
|
||||||
|
|
||||||
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
|
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) {
|
||||||
this(wrappedMDP, null, historyProcessor, stepCountable);
|
|
||||||
}
|
|
||||||
|
|
||||||
private LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
|
|
||||||
this.wrappedMDP = wrappedMDP;
|
this.wrappedMDP = wrappedMDP;
|
||||||
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
|
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
|
||||||
this.learning = learning;
|
|
||||||
this.historyProcessor = historyProcessor;
|
this.historyProcessor = historyProcessor;
|
||||||
this.stepCountable = stepCountable;
|
this.epochStepCounter = epochStepCounter;
|
||||||
}
|
|
||||||
|
|
||||||
private IHistoryProcessor getHistoryProcessor() {
|
|
||||||
if(historyProcessor != null) {
|
|
||||||
return historyProcessor;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (learning != null) {
|
|
||||||
return learning.getHistoryProcessor();
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
|
||||||
this.historyProcessor = historyProcessor;
|
|
||||||
}
|
|
||||||
|
|
||||||
private int getStep() {
|
|
||||||
if(stepCountable != null) {
|
|
||||||
return stepCountable.getStepCounter();
|
|
||||||
}
|
|
||||||
|
|
||||||
return learning.getStepCounter();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -83,9 +54,12 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
||||||
|
|
||||||
if(historyProcessor != null) {
|
if(historyProcessor != null) {
|
||||||
skipFrame = historyProcessor.getConf().getSkipFrame();
|
skipFrame = historyProcessor.getConf().getSkipFrame();
|
||||||
|
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
|
||||||
|
|
||||||
historyProcessor.add(rawObservation);
|
historyProcessor.add(rawObservation);
|
||||||
}
|
}
|
||||||
step = 0;
|
|
||||||
|
observation.setSkipped(skipFrame != 0);
|
||||||
|
|
||||||
return observation;
|
return observation;
|
||||||
}
|
}
|
||||||
|
@ -97,21 +71,18 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
||||||
StepReply<O> rawStepReply = wrappedMDP.step(a);
|
StepReply<O> rawStepReply = wrappedMDP.step(a);
|
||||||
INDArray rawObservation = getInput(rawStepReply.getObservation());
|
INDArray rawObservation = getInput(rawStepReply.getObservation());
|
||||||
|
|
||||||
++step;
|
int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
|
||||||
|
|
||||||
int requiredFrame = 0;
|
|
||||||
if(historyProcessor != null) {
|
if(historyProcessor != null) {
|
||||||
historyProcessor.record(rawObservation);
|
historyProcessor.record(rawObservation);
|
||||||
|
|
||||||
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
|
if (stepOfObservation % skipFrame == 0) {
|
||||||
if ((getStep() % skipFrame == 0 && step >= requiredFrame)
|
|
||||||
|| (step % skipFrame == 0 && step < requiredFrame )){
|
|
||||||
historyProcessor.add(rawObservation);
|
historyProcessor.add(rawObservation);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Observation observation;
|
Observation observation;
|
||||||
if(historyProcessor != null && step >= requiredFrame) {
|
if(historyProcessor != null && stepOfObservation >= requiredFrame) {
|
||||||
observation = new Observation(historyProcessor.getHistory(), true);
|
observation = new Observation(historyProcessor.getHistory(), true);
|
||||||
observation.getData().muli(1.0 / historyProcessor.getScale());
|
observation.getData().muli(1.0 / historyProcessor.getScale());
|
||||||
}
|
}
|
||||||
|
@ -119,6 +90,10 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
||||||
observation = new Observation(new INDArray[] { rawObservation }, false);
|
observation = new Observation(new INDArray[] { rawObservation }, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(stepOfObservation % skipFrame != 0 || stepOfObservation < requiredFrame) {
|
||||||
|
observation.setSkipped(true);
|
||||||
|
}
|
||||||
|
|
||||||
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
|
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,7 +109,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MDP<Observation, A, AS> newInstance() {
|
public MDP<Observation, A, AS> newInstance() {
|
||||||
return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), learning);
|
return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), historyProcessor, epochStepCounter);
|
||||||
}
|
}
|
||||||
|
|
||||||
private INDArray getInput(O obs) {
|
private INDArray getInput(O obs) {
|
||||||
|
|
|
@ -32,7 +32,7 @@ public class AsyncThreadDiscreteTest {
|
||||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||||
TrainingListenerList listeners = new TrainingListenerList();
|
TrainingListenerList listeners = new TrainingListenerList();
|
||||||
MockPolicy policyMock = new MockPolicy();
|
MockPolicy policyMock = new MockPolicy();
|
||||||
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 16, 0, 0, 2, 5,0, 0, 0, 0);
|
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0);
|
||||||
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
|
@ -41,8 +41,8 @@ public class AsyncThreadDiscreteTest {
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(2, sut.trainSubEpochResults.size());
|
assertEquals(2, sut.trainSubEpochResults.size());
|
||||||
double[][] expectedLastObservations = new double[][] {
|
double[][] expectedLastObservations = new double[][] {
|
||||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
||||||
};
|
};
|
||||||
double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 };
|
double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 };
|
||||||
for(int i = 0; i < 2; ++i) {
|
for(int i = 0; i < 2; ++i) {
|
||||||
|
@ -60,7 +60,7 @@ public class AsyncThreadDiscreteTest {
|
||||||
assertEquals(2, asyncGlobalMock.enqueueCallCount);
|
assertEquals(2, asyncGlobalMock.enqueueCallCount);
|
||||||
|
|
||||||
// HistoryProcessor
|
// HistoryProcessor
|
||||||
double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
|
double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
|
||||||
assertEquals(expectedAddValues.length, hpMock.addCalls.size());
|
assertEquals(expectedAddValues.length, hpMock.addCalls.size());
|
||||||
for(int i = 0; i < expectedAddValues.length; ++i) {
|
for(int i = 0; i < expectedAddValues.length; ++i) {
|
||||||
assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001);
|
assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001);
|
||||||
|
@ -75,9 +75,9 @@ public class AsyncThreadDiscreteTest {
|
||||||
// Policy
|
// Policy
|
||||||
double[][] expectedPolicyInputs = new double[][] {
|
double[][] expectedPolicyInputs = new double[][] {
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||||
};
|
};
|
||||||
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
|
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
|
||||||
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
|
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
|
||||||
|
@ -93,11 +93,11 @@ public class AsyncThreadDiscreteTest {
|
||||||
assertEquals(2, nnMock.copyCallCount);
|
assertEquals(2, nnMock.copyCallCount);
|
||||||
double[][] expectedNNInputs = new double[][] {
|
double[][] expectedNNInputs = new double[][] {
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
||||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
||||||
};
|
};
|
||||||
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
|
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
|
||||||
for(int i = 0; i < expectedNNInputs.length; ++i) {
|
for(int i = 0; i < expectedNNInputs.length; ++i) {
|
||||||
|
@ -113,13 +113,13 @@ public class AsyncThreadDiscreteTest {
|
||||||
double[][][] expectedMinitransObs = new double[][][] {
|
double[][][] expectedMinitransObs = new double[][][] {
|
||||||
new double[][] {
|
new double[][] {
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: The last minitrans contains the next observation
|
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: The last minitrans contains the next observation
|
||||||
},
|
},
|
||||||
new double[][] {
|
new double[][] {
|
||||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15 }, // FIXME: The last minitrans contains the next observation
|
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: The last minitrans contains the next observation
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
|
double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
|
||||||
|
|
|
@ -5,15 +5,12 @@ import lombok.Getter;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -91,7 +88,7 @@ public class AsyncThreadTest {
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(numberOfEpochs, context.listener.statEntries.size());
|
assertEquals(numberOfEpochs, context.listener.statEntries.size());
|
||||||
int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 };
|
int[] expectedStepCounter = new int[] { 10, 20, 30, 40, 50 };
|
||||||
double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init
|
double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init
|
||||||
+ 1.0; // Reward from trainSubEpoch()
|
+ 1.0; // Reward from trainSubEpoch()
|
||||||
for(int i = 0; i < numberOfEpochs; ++i) {
|
for(int i = 0; i < numberOfEpochs; ++i) {
|
||||||
|
@ -114,7 +111,7 @@ public class AsyncThreadTest {
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size());
|
assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size());
|
||||||
double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 };
|
double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 };
|
||||||
for(int i = 0; i < context.sut.getEpochCounter(); ++i) {
|
for(int i = 0; i < context.sut.trainSubEpochParams.size(); ++i) {
|
||||||
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
|
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
|
||||||
assertEquals(2, params.nstep);
|
assertEquals(2, params.nstep);
|
||||||
assertEquals(expectedObservation.length, params.obs.getData().shape()[1]);
|
assertEquals(expectedObservation.length, params.obs.getData().shape()[1]);
|
||||||
|
@ -199,7 +196,9 @@ public class AsyncThreadTest {
|
||||||
protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) {
|
protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) {
|
||||||
asyncGlobal.increaseCurrentLoop();
|
asyncGlobal.increaseCurrentLoop();
|
||||||
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
|
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
|
||||||
setStepCounter(getStepCounter() + nstep);
|
for(int i = 0; i < nstep; ++i) {
|
||||||
|
incrementStep();
|
||||||
|
}
|
||||||
return new SubEpochReturn(nstep, null, 1.0, 1.0);
|
return new SubEpochReturn(nstep, null, 1.0, 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,8 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
123, 234, false, new Observation(Nd4j.create(1)));
|
123, 234, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition);
|
sut.store(transition);
|
||||||
List<Transition<Integer>> results = sut.getBatch(1);
|
List<Transition<Integer>> results = sut.getBatch(1);
|
||||||
|
|
||||||
|
@ -36,12 +36,12 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
1, 2, false, new Observation(Nd4j.create(1)));
|
1, 2, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
3, 4, false, new Observation(Nd4j.create(1)));
|
3, 4, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
5, 6, false, new Observation(Nd4j.create(1)));
|
5, 6, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition1);
|
sut.store(transition1);
|
||||||
sut.store(transition2);
|
sut.store(transition2);
|
||||||
sut.store(transition3);
|
sut.store(transition3);
|
||||||
|
@ -78,12 +78,12 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
1, 2, false, new Observation(Nd4j.create(1)));
|
1, 2, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
3, 4, false, new Observation(Nd4j.create(1)));
|
3, 4, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
5, 6, false, new Observation(Nd4j.create(1)));
|
5, 6, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition1);
|
sut.store(transition1);
|
||||||
sut.store(transition2);
|
sut.store(transition2);
|
||||||
sut.store(transition3);
|
sut.store(transition3);
|
||||||
|
@ -100,12 +100,12 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
1, 2, false, new Observation(Nd4j.create(1)));
|
1, 2, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
3, 4, false, new Observation(Nd4j.create(1)));
|
3, 4, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
5, 6, false, new Observation(Nd4j.create(1)));
|
5, 6, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition1);
|
sut.store(transition1);
|
||||||
sut.store(transition2);
|
sut.store(transition2);
|
||||||
sut.store(transition3);
|
sut.store(transition3);
|
||||||
|
@ -131,16 +131,16 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
1, 2, false, new Observation(Nd4j.create(1)));
|
1, 2, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
3, 4, false, new Observation(Nd4j.create(1)));
|
3, 4, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
5, 6, false, new Observation(Nd4j.create(1)));
|
5, 6, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
7, 8, false, new Observation(Nd4j.create(1)));
|
7, 8, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
9, 10, false, new Observation(Nd4j.create(1)));
|
9, 10, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition1);
|
sut.store(transition1);
|
||||||
sut.store(transition2);
|
sut.store(transition2);
|
||||||
sut.store(transition3);
|
sut.store(transition3);
|
||||||
|
@ -168,16 +168,16 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
1, 2, false, new Observation(Nd4j.create(1)));
|
1, 2, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
3, 4, false, new Observation(Nd4j.create(1)));
|
3, 4, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
5, 6, false, new Observation(Nd4j.create(1)));
|
5, 6, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
7, 8, false, new Observation(Nd4j.create(1)));
|
7, 8, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
9, 10, false, new Observation(Nd4j.create(1)));
|
9, 10, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition1);
|
sut.store(transition1);
|
||||||
sut.store(transition2);
|
sut.store(transition2);
|
||||||
sut.store(transition3);
|
sut.store(transition3);
|
||||||
|
@ -196,6 +196,12 @@ public class ExpReplayTest {
|
||||||
|
|
||||||
assertEquals(5, (int)results.get(2).getAction());
|
assertEquals(5, (int)results.get(2).getAction());
|
||||||
assertEquals(6, (int)results.get(2).getReward());
|
assertEquals(6, (int)results.get(2).getReward());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, Observation nextObservation) {
|
||||||
|
Transition<Integer> result = new Transition<Integer>(observation, action, reward, false);
|
||||||
|
result.setNextObservation(nextObservation);
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync;
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
@ -88,12 +89,15 @@ public class SyncLearningTest {
|
||||||
|
|
||||||
private final LConfiguration conf;
|
private final LConfiguration conf;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private int currentEpochStep = 0;
|
||||||
|
|
||||||
public MockSyncLearning(LConfiguration conf) {
|
public MockSyncLearning(LConfiguration conf) {
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void preEpoch() { }
|
protected void preEpoch() { currentEpochStep = 0; }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void postEpoch() { }
|
protected void postEpoch() { }
|
||||||
|
@ -101,7 +105,7 @@ public class SyncLearningTest {
|
||||||
@Override
|
@Override
|
||||||
protected IDataManager.StatEntry trainEpoch() {
|
protected IDataManager.StatEntry trainEpoch() {
|
||||||
setStepCounter(getStepCounter() + 1);
|
setStepCounter(getStepCounter() + 1);
|
||||||
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
|
return new MockStatEntry(getCurrentEpochStep(), getStepCounter(), 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -21,7 +21,7 @@ public class TransitionTest {
|
||||||
Observation nextObservation = buildObservation(nextObs);
|
Observation nextObservation = buildObservation(nextObs);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
|
Transition transition = buildTransition(observation, 123, 234.0, nextObservation);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
double[][] expectedObservation = new double[][] { obs };
|
double[][] expectedObservation = new double[][] { obs };
|
||||||
|
@ -52,7 +52,7 @@ public class TransitionTest {
|
||||||
Observation nextObservation = buildObservation(nextObs);
|
Observation nextObservation = buildObservation(nextObs);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
|
Transition transition = buildTransition(observation, 123, 234.0, nextObservation);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertExpected(obs, transition.getObservation().getData());
|
assertExpected(obs, transition.getObservation().getData());
|
||||||
|
@ -71,12 +71,12 @@ public class TransitionTest {
|
||||||
double[] obs1 = new double[] { 0.0, 1.0, 2.0 };
|
double[] obs1 = new double[] { 0.0, 1.0, 2.0 };
|
||||||
Observation observation1 = buildObservation(obs1);
|
Observation observation1 = buildObservation(obs1);
|
||||||
Observation nextObservation1 = buildObservation(new double[] { 100.0, 101.0, 102.0 });
|
Observation nextObservation1 = buildObservation(new double[] { 100.0, 101.0, 102.0 });
|
||||||
transitions.add(new Transition(observation1,0, 0.0, false, nextObservation1));
|
transitions.add(buildTransition(observation1,0, 0.0, nextObservation1));
|
||||||
|
|
||||||
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
|
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
|
||||||
Observation observation2 = buildObservation(obs2);
|
Observation observation2 = buildObservation(obs2);
|
||||||
Observation nextObservation2 = buildObservation(new double[] { 110.0, 111.0, 112.0 });
|
Observation nextObservation2 = buildObservation(new double[] { 110.0, 111.0, 112.0 });
|
||||||
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
INDArray result = Transition.buildStackedObservations(transitions);
|
INDArray result = Transition.buildStackedObservations(transitions);
|
||||||
|
@ -101,7 +101,7 @@ public class TransitionTest {
|
||||||
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
||||||
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
|
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
|
||||||
|
|
||||||
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
|
transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
|
||||||
|
|
||||||
double[][] obs2 = new double[][] {
|
double[][] obs2 = new double[][] {
|
||||||
{ 10.0, 11.0, 12.0 },
|
{ 10.0, 11.0, 12.0 },
|
||||||
|
@ -112,7 +112,7 @@ public class TransitionTest {
|
||||||
|
|
||||||
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
||||||
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
|
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
|
||||||
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
INDArray result = Transition.buildStackedObservations(transitions);
|
INDArray result = Transition.buildStackedObservations(transitions);
|
||||||
|
@ -131,13 +131,13 @@ public class TransitionTest {
|
||||||
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
||||||
Observation observation1 = buildObservation(obs1);
|
Observation observation1 = buildObservation(obs1);
|
||||||
Observation nextObservation1 = buildObservation(nextObs1);
|
Observation nextObservation1 = buildObservation(nextObs1);
|
||||||
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
|
transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
|
||||||
|
|
||||||
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
|
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
|
||||||
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
||||||
Observation observation2 = buildObservation(obs2);
|
Observation observation2 = buildObservation(obs2);
|
||||||
Observation nextObservation2 = buildObservation(nextObs2);
|
Observation nextObservation2 = buildObservation(nextObs2);
|
||||||
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
INDArray result = Transition.buildStackedNextObservations(transitions);
|
INDArray result = Transition.buildStackedNextObservations(transitions);
|
||||||
|
@ -162,7 +162,7 @@ public class TransitionTest {
|
||||||
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
||||||
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
|
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
|
||||||
|
|
||||||
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
|
transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
|
||||||
|
|
||||||
double[][] obs2 = new double[][] {
|
double[][] obs2 = new double[][] {
|
||||||
{ 10.0, 11.0, 12.0 },
|
{ 10.0, 11.0, 12.0 },
|
||||||
|
@ -174,7 +174,7 @@ public class TransitionTest {
|
||||||
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
||||||
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
|
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
|
||||||
|
|
||||||
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
INDArray result = Transition.buildStackedNextObservations(transitions);
|
INDArray result = Transition.buildStackedNextObservations(transitions);
|
||||||
|
@ -207,7 +207,13 @@ public class TransitionTest {
|
||||||
Nd4j.create(obs[1]).reshape(1, 3),
|
Nd4j.create(obs[1]).reshape(1, 3),
|
||||||
};
|
};
|
||||||
return new Observation(nextHistory);
|
return new Observation(nextHistory);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) {
|
||||||
|
Transition result = new Transition(observation, action, reward, false);
|
||||||
|
result.setNextObservation(nextObservation);
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void assertExpected(double[] expected, INDArray actual) {
|
private void assertExpected(double[] expected, INDArray actual) {
|
||||||
|
|
|
@ -40,8 +40,10 @@ public class QLearningDiscreteTest {
|
||||||
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
||||||
MockMDP mdp = new MockMDP(observationSpace, random);
|
MockMDP mdp = new MockMDP(observationSpace, random);
|
||||||
|
|
||||||
|
int initStepCount = 8;
|
||||||
|
|
||||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000,
|
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000,
|
||||||
0, 1.0, 0, 0, 0, 0, true);
|
initStepCount, 1.0, 0, 0, 0, 0, true);
|
||||||
MockDataManager dataManager = new MockDataManager(false);
|
MockDataManager dataManager = new MockDataManager(false);
|
||||||
MockExpReplay expReplay = new MockExpReplay();
|
MockExpReplay expReplay = new MockExpReplay();
|
||||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
|
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
|
||||||
|
@ -60,7 +62,7 @@ public class QLearningDiscreteTest {
|
||||||
for(int i = 0; i < expectedRecords.length; ++i) {
|
for(int i = 0; i < expectedRecords.length; ++i) {
|
||||||
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||||
}
|
}
|
||||||
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
|
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0 };
|
||||||
assertEquals(expectedAdds.length, hp.addCalls.size());
|
assertEquals(expectedAdds.length, hp.addCalls.size());
|
||||||
for(int i = 0; i < expectedAdds.length; ++i) {
|
for(int i = 0; i < expectedAdds.length; ++i) {
|
||||||
assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001);
|
assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001);
|
||||||
|
@ -75,19 +77,19 @@ public class QLearningDiscreteTest {
|
||||||
assertEquals(14, dqn.outputParams.size());
|
assertEquals(14, dqn.outputParams.size());
|
||||||
double[][] expectedDQNOutput = new double[][] {
|
double[][] expectedDQNOutput = new double[][] {
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
||||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
||||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
|
||||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
|
||||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
||||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
||||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
||||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
||||||
};
|
};
|
||||||
for(int i = 0; i < expectedDQNOutput.length; ++i) {
|
for(int i = 0; i < expectedDQNOutput.length; ++i) {
|
||||||
INDArray outputParam = dqn.outputParams.get(i);
|
INDArray outputParam = dqn.outputParams.get(i);
|
||||||
|
@ -105,19 +107,20 @@ public class QLearningDiscreteTest {
|
||||||
assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
|
assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
|
||||||
|
|
||||||
// ExpReplay calls
|
// ExpReplay calls
|
||||||
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
|
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
|
||||||
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
||||||
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
|
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
|
||||||
double[][] expectedTrObservations = new double[][] {
|
double[][] expectedTrObservations = new double[][] {
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
||||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
|
||||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
||||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
||||||
};
|
};
|
||||||
|
assertEquals(expectedTrObservations.length, expReplay.transitions.size());
|
||||||
for(int i = 0; i < expectedTrRewards.length; ++i) {
|
for(int i = 0; i < expectedTrRewards.length; ++i) {
|
||||||
Transition tr = expReplay.transitions.get(i);
|
Transition tr = expReplay.transitions.get(i);
|
||||||
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
|
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
|
||||||
|
@ -129,7 +132,7 @@ public class QLearningDiscreteTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
// trainEpoch result
|
// trainEpoch result
|
||||||
assertEquals(16, result.getStepCounter());
|
assertEquals(initStepCount + 16, result.getStepCounter());
|
||||||
assertEquals(300.0, result.getReward(), 0.00001);
|
assertEquals(300.0, result.getReward(), 0.00001);
|
||||||
assertTrue(dqn.hasBeenReset);
|
assertTrue(dqn.hasBeenReset);
|
||||||
assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset);
|
assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset);
|
||||||
|
|
|
@ -26,7 +26,7 @@ public class DoubleDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
|
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -52,7 +52,7 @@ public class DoubleDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -78,11 +78,11 @@ public class DoubleDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||||
add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}),
|
add(builtTransition(buildObservation(new double[]{3.3, 4.4}),
|
||||||
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
|
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
|
||||||
add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}),
|
add(builtTransition(buildObservation(new double[]{5.5, 6.6}),
|
||||||
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
|
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -108,4 +108,11 @@ public class DoubleDQNTest {
|
||||||
private Observation buildObservation(double[] data) {
|
private Observation buildObservation(double[] data) {
|
||||||
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
|
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Transition<Integer> builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
|
||||||
|
Transition<Integer> result = new Transition<Integer>(observation, action, reward, isTerminal);
|
||||||
|
result.setNextObservation(nextObservation);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ public class StandardDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
|
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -51,7 +51,7 @@ public class StandardDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -77,11 +77,11 @@ public class StandardDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||||
add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}),
|
add(buildTransition(buildObservation(new double[]{3.3, 4.4}),
|
||||||
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
|
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
|
||||||
add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}),
|
add(buildTransition(buildObservation(new double[]{5.5, 6.6}),
|
||||||
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
|
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -108,4 +108,10 @@ public class StandardDQNTest {
|
||||||
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
|
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
|
||||||
|
Transition<Integer> result = new Transition<Integer>(observation, action, reward, isTerminal);
|
||||||
|
result.setNextObservation(nextObservation);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -198,7 +198,7 @@ public class PolicyTest {
|
||||||
assertEquals(465.0, totalReward, 0.0001);
|
assertEquals(465.0, totalReward, 0.0001);
|
||||||
|
|
||||||
// HistoryProcessor
|
// HistoryProcessor
|
||||||
assertEquals(27, hp.addCalls.size());
|
assertEquals(16, hp.addCalls.size());
|
||||||
assertEquals(31, hp.recordCalls.size());
|
assertEquals(31, hp.recordCalls.size());
|
||||||
for(int i=0; i <= 30; ++i) {
|
for(int i=0; i <= 30; ++i) {
|
||||||
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
|
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||||
|
|
|
@ -142,6 +142,11 @@ public class DataManagerTrainingListenerTest {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getCurrentEpochStep() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
private IHistoryProcessor historyProcessor;
|
private IHistoryProcessor historyProcessor;
|
||||||
|
|
|
@ -93,6 +93,9 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
||||||
private boolean done = false;
|
private boolean done = false;
|
||||||
|
|
||||||
public GymEnv(String envId, boolean render, boolean monitor) {
|
public GymEnv(String envId, boolean render, boolean monitor) {
|
||||||
|
this(envId, render, monitor, (Integer)null);
|
||||||
|
}
|
||||||
|
public GymEnv(String envId, boolean render, boolean monitor, Integer seed) {
|
||||||
this.envId = envId;
|
this.envId = envId;
|
||||||
this.render = render;
|
this.render = render;
|
||||||
this.monitor = monitor;
|
this.monitor = monitor;
|
||||||
|
@ -107,6 +110,10 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
||||||
Py_DecRef(PyRun_StringFlags("env = gym.wrappers.Monitor(env, '" + GYM_MONITOR_DIR + "')", Py_single_input, globals, locals, null));
|
Py_DecRef(PyRun_StringFlags("env = gym.wrappers.Monitor(env, '" + GYM_MONITOR_DIR + "')", Py_single_input, globals, locals, null));
|
||||||
checkPythonError();
|
checkPythonError();
|
||||||
}
|
}
|
||||||
|
if (seed != null) {
|
||||||
|
Py_DecRef(PyRun_StringFlags("env.seed(" + seed + ")", Py_single_input, globals, locals, null));
|
||||||
|
checkPythonError();
|
||||||
|
}
|
||||||
PyObject shapeTuple = PyRun_StringFlags("env.observation_space.shape", Py_eval_input, globals, locals, null);
|
PyObject shapeTuple = PyRun_StringFlags("env.observation_space.shape", Py_eval_input, globals, locals, null);
|
||||||
int[] shape = new int[(int)PyTuple_Size(shapeTuple)];
|
int[] shape = new int[(int)PyTuple_Size(shapeTuple)];
|
||||||
for (int i = 0; i < shape.length; i++) {
|
for (int i = 0; i < shape.length; i++) {
|
||||||
|
@ -125,7 +132,10 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
||||||
}
|
}
|
||||||
|
|
||||||
public GymEnv(String envId, boolean render, boolean monitor, int[] actions) {
|
public GymEnv(String envId, boolean render, boolean monitor, int[] actions) {
|
||||||
this(envId, render, monitor);
|
this(envId, render, monitor, null, actions);
|
||||||
|
}
|
||||||
|
public GymEnv(String envId, boolean render, boolean monitor, Integer seed, int[] actions) {
|
||||||
|
this(envId, render, monitor, seed);
|
||||||
actionTransformer = new ActionTransformer((HighLowDiscrete) getActionSpace(), actions);
|
actionTransformer = new ActionTransformer((HighLowDiscrete) getActionSpace(), actions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue