RL4J: Change frame skipping logic (#8596)

* Added isSkipped() to Observation

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Changed refacInitMdp to use isSkipped()

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Changed getHistoryProcessor()

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Fixed getEpochCounter() incorrectly changed to getCurrentEpochStep() calls

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Removed StepCountable

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Fix build

Signed-off-by: Samuel Audet <samuel.audet@gmail.com>

* Fixed a problem in QLearningDiscrete and another in CartpoleNative

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Update versions of JavaCPP Presets for NumPy, MKL, Gym, and TensorFlow

Signed-off-by: Samuel Audet <samuel.audet@gmail.com>

* RL4J: Add ability to set a random seed for GymEnv

Signed-off-by: Samuel Audet <samuel.audet@gmail.com>

Co-authored-by: Samuel Audet <samuel.audet@gmail.com>
master
Alexandre Boulanger 2020-02-03 22:23:39 -05:00 committed by GitHub
parent 7a20324105
commit 20e3039f2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 333 additions and 372 deletions

View File

@ -303,8 +303,8 @@
<leptonica.version>1.79.0</leptonica.version>
<hdf5.version>1.10.6</hdf5.version>
<ale.version>0.6.1</ale.version>
<gym.version>0.15.4</gym.version>
<tensorflow.version>1.15.0</tensorflow.version>
<gym.version>0.15.5</gym.version>
<tensorflow.version>1.15.2</tensorflow.version>
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>
<commons-compress.version>1.18</commons-compress.version>

View File

@ -0,0 +1,5 @@
package org.deeplearning4j.rl4j.learning;
public interface EpochStepCounter {
int getCurrentEpochStep();
}

View File

@ -21,9 +21,14 @@ import org.deeplearning4j.rl4j.mdp.MDP;
/**
* The common API between Learning and AsyncThread.
*
* Express the ability to count the number of step of the current training.
* Factorisation of a feature between threads in async and learning process
* for the web monitoring
*
* @author Alexandre Boulanger
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*/
public interface IEpochTrainer {
public interface IEpochTrainer extends EpochStepCounter {
int getStepCounter();
int getEpochCounter();
IHistoryProcessor getHistoryProcessor();

View File

@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
*
* A common interface that any training method should implement
*/
public interface ILearning<O, A, AS extends ActionSpace<A>> extends StepCountable {
public interface ILearning<O, A, AS extends ActionSpace<A>> {
IPolicy<O, A> getPolicy();

View File

@ -21,7 +21,6 @@ import lombok.Getter;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
@ -53,55 +52,6 @@ public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends Neura
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
}
public static <O, A, AS extends ActionSpace<A>> INDArray getInput(MDP<O, A, AS> mdp, O obs) {
INDArray arr = Nd4j.create(((Encodable)obs).toArray());
int[] shape = mdp.getObservationSpace().getShape();
if (shape.length == 1)
return arr.reshape(new long[] {1, arr.length()});
else
return arr.reshape(shape);
}
public static <O, A, AS extends ActionSpace<A>> InitMdp<O> initMdp(MDP<O, A, AS> mdp,
IHistoryProcessor hp) {
O obs = mdp.reset();
int step = 0;
double reward = 0;
boolean isHistoryProcessor = hp != null;
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
INDArray input = Learning.getInput(mdp, obs);
if (isHistoryProcessor)
hp.record(input);
while (step < requiredFrame && !mdp.isDone()) {
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
if (step % skipFrame == 0 && isHistoryProcessor)
hp.add(input);
StepReply<O> stepReply = mdp.step(action);
reward += stepReply.getReward();
obs = stepReply.getObservation();
input = Learning.getInput(mdp, obs);
if (isHistoryProcessor)
hp.record(input);
step++;
}
return new InitMdp(step, obs, reward);
}
public static int[] makeShape(int size, int[] shape) {
int[] nshape = new int[shape.length + 1];
nshape[0] = size;
@ -122,16 +72,16 @@ public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends Neura
public abstract NN getNeuralNet();
public int incrementStep() {
return stepCounter++;
public void incrementStep() {
stepCounter++;
}
public int incrementEpoch() {
return epochCounter++;
public void incrementEpoch() {
epochCounter++;
}
public void setHistoryProcessor(HistoryProcessor.Configuration conf) {
historyProcessor = new HistoryProcessor(conf);
setHistoryProcessor(new HistoryProcessor(conf));
}
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {

View File

@ -1,30 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*
* Express the ability to count the number of step of the current training.
* Factorisation of a feature between threads in async and learning process
* for the web monitoring
*/
public interface StepCountable {
int getStepCounter();
}

View File

@ -30,8 +30,6 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.factory.Nd4j;
@ -48,7 +46,7 @@ import org.nd4j.linalg.factory.Nd4j;
*/
@Slf4j
public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Thread implements StepCountable, IEpochTrainer {
extends Thread implements IEpochTrainer {
@Getter
private int threadNumber;
@ -61,6 +59,9 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
@Getter @Setter
private IHistoryProcessor historyProcessor;
@Getter
private int currentEpochStep = 0;
private boolean isEpochStarted = false;
private final LegacyMDPWrapper<O, A, AS> mdp;
@ -138,7 +139,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
handleTraining(context);
if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
if (currentEpochStep >= getConf().getMaxEpochStep() || getMdp().isDone()) {
boolean canContinue = finishEpoch(context);
if (!canContinue) {
break;
@ -154,11 +155,10 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
}
private void handleTraining(RunContext context) {
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps);
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep);
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
context.obs = subEpochReturn.getLastObs();
context.epochElapsedSteps += subEpochReturn.getSteps();
context.rewards += subEpochReturn.getReward();
context.score = subEpochReturn.getScore();
}
@ -169,7 +169,6 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
context.obs = initMdp.getLastObs();
context.rewards = initMdp.getReward();
context.epochElapsedSteps = initMdp.getSteps();
isEpochStarted = true;
preEpoch();
@ -180,9 +179,9 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
private boolean finishEpoch(RunContext context) {
isEpochStarted = false;
postEpoch();
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score);
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, currentEpochStep, context.score);
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
log.info("ThreadNum-" + threadNumber + " Epoch: " + getCurrentEpochStep() + ", reward: " + context.rewards);
return listeners.notifyEpochTrainingResult(this, statEntry);
}
@ -205,37 +204,30 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
private Learning.InitMdp<Observation> refacInitMdp() {
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
IHistoryProcessor hp = getHistoryProcessor();
currentEpochStep = 0;
Observation observation = mdp.reset();
int step = 0;
double reward = 0;
boolean isHistoryProcessor = hp != null;
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
while (step < requiredFrame && !mdp.isDone()) {
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
Observation observation = mdp.reset();
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
while (observation.isSkipped() && !mdp.isDone()) {
StepReply<Observation> stepReply = mdp.step(action);
reward += stepReply.getReward();
observation = stepReply.getObservation();
step++;
incrementStep();
}
return new Learning.InitMdp(step, observation, reward);
return new Learning.InitMdp(0, observation, reward);
}
public void incrementStep() {
++stepCounter;
++currentEpochStep;
}
@AllArgsConstructor
@ -260,7 +252,6 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
private static class RunContext {
private Observation obs;
private double rewards;
private int epochElapsedSteps;
private double score;
}

View File

@ -20,19 +20,14 @@ import lombok.Getter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.Stack;
@ -74,17 +69,18 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
IPolicy<O, Integer> policy = getPolicy(current);
Integer action;
Integer lastAction = null;
Integer lastAction = getMdp().getActionSpace().noOp();
IHistoryProcessor hp = getHistoryProcessor();
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
double reward = 0;
double accuReward = 0;
int i = 0;
while (!getMdp().isDone() && i < nstep * skipFrame) {
int stepAtStart = getCurrentEpochStep();
int lastStep = nstep * skipFrame + stepAtStart;
while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) {
//if step of training, just repeat lastAction
if (i % skipFrame != 0 && lastAction != null) {
if (obs.isSkipped()) {
action = lastAction;
} else {
action = policy.nextAction(obs);
@ -94,7 +90,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
accuReward += stepReply.getReward() * getConf().getRewardFactor();
//if it's not a skipped frame, you can do a step of training
if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) {
if (!obs.isSkipped() || stepReply.isDone()) {
INDArray[] output = current.outputAll(obs.getData());
rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
@ -106,7 +102,6 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
reward += stepReply.getReward();
i++;
incrementStep();
lastAction = action;
}
@ -114,7 +109,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
//a bit of a trick usable because of how the stack is treated to init R
// FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored.
if (getMdp().isDone() && i < nstep * skipFrame)
if (getMdp().isDone() && getCurrentEpochStep() < lastStep)
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
else {
INDArray[] output = null;
@ -127,9 +122,9 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
rewards.add(new MiniTrans(obs.getData(), null, output, maxQ));
}
getAsyncGlobal().enqueue(calcGradient(current, rewards), i);
getAsyncGlobal().enqueue(calcGradient(current, rewards), getCurrentEpochStep());
return new SubEpochReturn(i, obs, reward, current.getLatestScore());
return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore());
}
public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> rewards);

View File

@ -17,7 +17,6 @@
package org.deeplearning4j.rl4j.learning.sync;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
@ -25,7 +24,6 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
/**

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync;
import lombok.Data;
import lombok.Value;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -34,7 +35,7 @@ import java.util.List;
* @author Alexandre Boulanger
*
*/
@Value
@Data
public class Transition<A> {
Observation observation;
@ -43,12 +44,15 @@ public class Transition<A> {
boolean isTerminal;
INDArray nextObservation;
public Transition(Observation observation, A action, double reward, boolean isTerminal, Observation nextObservation) {
public Transition(Observation observation, A action, double reward, boolean isTerminal) {
this.observation = observation;
this.action = action;
this.reward = reward;
this.isTerminal = isTerminal;
this.nextObservation = null;
}
public void setNextObservation(Observation nextObservation) {
// To conserve memory, only the most recent frame of the next observation is kept (if history is used).
// The full nextObservation will be re-build from observation when needed.
long[] nextObservationShape = nextObservation.getData().shape().clone();

View File

@ -21,8 +21,7 @@ import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
@ -34,9 +33,8 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
@ -49,7 +47,8 @@ import java.util.List;
*/
@Slf4j
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
extends SyncLearning<O, A, AS, IDQN> implements TargetQNetworkSource {
extends SyncLearning<O, A, AS, IDQN>
implements TargetQNetworkSource, EpochStepCounter {
// FIXME Changed for refac
// @Getter
@ -104,18 +103,22 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
protected abstract QLStepReturn<Observation> trainStep(Observation obs);
@Getter
private int currentEpochStep = 0;
protected StatEntry trainEpoch() {
resetNetworks();
InitMdp<Observation> initMdp = refacInitMdp();
Observation obs = initMdp.getLastObs();
double reward = initMdp.getReward();
int step = initMdp.getSteps();
Double startQ = Double.NaN;
double meanQ = 0;
int numQ = 0;
List<Double> scores = new ArrayList<>();
while (step < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
while (currentEpochStep < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
if (getStepCounter() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
updateTargetNetwork();
@ -136,49 +139,53 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
reward += stepR.getStepReply().getReward();
obs = stepR.getStepReply().getObservation();
incrementStep();
step++;
}
finishEpoch(obs);
meanQ /= (numQ + 0.001); //avoid div zero
StatEntry statEntry = new QLStatEntry(getStepCounter(), getEpochCounter(), reward, step, scores,
StatEntry statEntry = new QLStatEntry(getStepCounter(), getEpochCounter(), reward, currentEpochStep, scores,
getEgPolicy().getEpsilon(), startQ, meanQ);
return statEntry;
}
protected void finishEpoch(Observation observation) {
// Do Nothing
}
@Override
public void incrementStep() {
super.incrementStep();
++currentEpochStep;
}
protected void resetNetworks() {
getQNetwork().reset();
getTargetQNetwork().reset();
}
private InitMdp<Observation> refacInitMdp() {
getQNetwork().reset();
getTargetQNetwork().reset();
currentEpochStep = 0;
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
IHistoryProcessor hp = getHistoryProcessor();
Observation observation = mdp.reset();
int step = 0;
double reward = 0;
boolean isHistoryProcessor = hp != null;
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
while (step < requiredFrame && !mdp.isDone()) {
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
Observation observation = mdp.reset();
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
while (observation.isSkipped() && !mdp.isDone()) {
StepReply<Observation> stepReply = mdp.step(action);
reward += stepReply.getReward();
observation = stepReply.getObservation();
step++;
incrementStep();
}
return new InitMdp(step, observation, reward);
return new InitMdp(0, observation, reward);
}

View File

@ -20,10 +20,13 @@ import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
@ -36,7 +39,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.ArrayList;
@ -68,6 +70,8 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
private int lastAction;
private double accuReward = 0;
private Transition pendingTransition;
ITDTargetAlgorithm tdTargetAlgorithm;
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
@ -83,7 +87,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
int epsilonNbStep, Random random) {
super(conf);
this.configuration = conf;
this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, this);
this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, null, this);
qNetwork = dqn;
targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork());
@ -108,8 +112,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
}
public void preEpoch() {
lastAction = 0;
lastAction = mdp.getActionSpace().noOp();
accuReward = 0;
pendingTransition = null;
}
@Override
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
super.setHistoryProcessor(historyProcessor);
mdp.setHistoryProcessor(historyProcessor);
}
/**
@ -120,9 +131,8 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
protected QLStepReturn<Observation> trainStep(Observation obs) {
Integer action;
boolean isHistoryProcessor = getHistoryProcessor() != null;
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
int updateStart = getConfiguration().getUpdateStart()
@ -131,7 +141,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
Double maxQ = Double.NaN; //ignore if Nan for stats
//if step of training, just repeat lastAction
if (getStepCounter() % skipFrame != 0) {
if (obs.isSkipped()) {
action = lastAction;
} else {
INDArray qs = getQNetwork().output(obs);
@ -145,22 +155,25 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
StepReply<Observation> stepReply = mdp.step(action);
Observation nextObservation = stepReply.getObservation();
accuReward += stepReply.getReward() * configuration.getRewardFactor();
//if it's not a skipped frame, you can do a step of training
if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
if (!obs.isSkipped() || stepReply.isDone()) {
Transition<Integer> trans = new Transition(obs, action, accuReward, stepReply.isDone(), nextObservation);
getExpReplay().store(trans);
// Add experience
if(pendingTransition != null) {
pendingTransition.setNextObservation(obs);
getExpReplay().store(pendingTransition);
}
pendingTransition = new Transition(obs, action, accuReward, stepReply.isDone());
accuReward = 0;
// Update NN
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
if (getStepCounter() > updateStart) {
DataSet targets = setTarget(getExpReplay().getBatch());
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
}
accuReward = 0;
}
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
@ -172,4 +185,12 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
return tdTargetAlgorithm.computeTDTargets(transitions);
}
@Override
protected void finishEpoch(Observation observation) {
if(pendingTransition != null) {
pendingTransition.setNextObservation(observation);
getExpReplay().store(pendingTransition);
}
}
}

View File

@ -92,6 +92,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
theta = 0.1 * rnd.nextDouble() - 0.05;
thetaDot = 0.1 * rnd.nextDouble() - 0.05;
stepsBeyondDone = null;
done = false;
return new State(new double[] { x, xDot, theta, thetaDot });
}
@ -126,7 +127,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
break;
}
boolean done = x < -xThreshold || x > xThreshold
done |= x < -xThreshold || x > xThreshold
|| theta < -thetaThresholdRadians || theta > thetaThresholdRadians;
double reward;

View File

@ -16,6 +16,8 @@
package org.deeplearning4j.rl4j.observation;
import lombok.Getter;
import lombok.Setter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
@ -28,6 +30,9 @@ public class Observation {
private final DataSet data;
@Getter @Setter
private boolean skipped;
public Observation(INDArray[] data) {
this(data, false);
}

View File

@ -18,7 +18,8 @@ package org.deeplearning4j.rl4j.policy;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
@ -46,7 +47,7 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
final private int epsilonNbStep;
final private Random rnd;
final private float minEpsilon;
final private StepCountable learning;
final private IEpochTrainer learning;
public NeuralNet getNeuralNet() {
return policy.getNeuralNet();

View File

@ -19,20 +19,15 @@ package org.deeplearning4j.rl4j.policy;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
@ -57,24 +52,22 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
@Override
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
RefacStepCountable stepCountable = new RefacStepCountable();
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, stepCountable);
resetNetworks();
boolean isHistoryProcessor = hp != null;
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
RefacEpochStepCounter epochStepCounter = new RefacEpochStepCounter();
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, epochStepCounter);
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp);
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp, epochStepCounter);
Observation obs = initMdp.getLastObs();
double reward = initMdp.getReward();
A lastAction = mdpWrapper.getActionSpace().noOp();
A action;
stepCountable.setStepCounter(initMdp.getSteps());
while (!mdpWrapper.isDone()) {
if (stepCountable.getStepCounter() % skipFrame != 0) {
if (obs.isSkipped()) {
action = lastAction;
} else {
action = nextAction(obs);
@ -86,52 +79,46 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
reward += stepReply.getReward();
obs = stepReply.getObservation();
stepCountable.increment();
epochStepCounter.incrementEpochStep();
}
return reward;
}
private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
protected void resetNetworks() {
getNeuralNet().reset();
Observation observation = mdpWrapper.reset();
}
private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
epochStepCounter.setCurrentEpochStep(0);
int step = 0;
double reward = 0;
boolean isHistoryProcessor = hp != null;
Observation observation = mdpWrapper.reset();
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
while (step < requiredFrame && !mdpWrapper.isDone()) {
A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP
A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP
while (observation.isSkipped() && !mdpWrapper.isDone()) {
StepReply<Observation> stepReply = mdpWrapper.step(action);
reward += stepReply.getReward();
observation = stepReply.getObservation();
step++;
epochStepCounter.incrementEpochStep();
}
return new Learning.InitMdp(step, observation, reward);
return new Learning.InitMdp(0, observation, reward);
}
private class RefacStepCountable implements StepCountable {
public class RefacEpochStepCounter implements EpochStepCounter {
@Getter
@Setter
private int stepCounter = 0;
private int currentEpochStep = 0;
public void increment() {
++stepCounter;
public void incrementEpochStep() {
++currentEpochStep;
}
@Override
public int getStepCounter() {
return 0;
}
}
}

View File

@ -1,10 +1,11 @@
package org.deeplearning4j.rl4j.util;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
@ -19,50 +20,20 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
private final MDP<O, A, AS> wrappedMDP;
@Getter
private final WrapperObservationSpace observationSpace;
private final ILearning learning;
@Getter(AccessLevel.PRIVATE) @Setter(AccessLevel.PUBLIC)
private IHistoryProcessor historyProcessor;
private final StepCountable stepCountable;
private int skipFrame;
private int step = 0;
private final EpochStepCounter epochStepCounter;
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning) {
this(wrappedMDP, learning, null, null);
}
private int skipFrame = 1;
private int requiredFrame = 0;
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
this(wrappedMDP, null, historyProcessor, stepCountable);
}
private LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) {
this.wrappedMDP = wrappedMDP;
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
this.learning = learning;
this.historyProcessor = historyProcessor;
this.stepCountable = stepCountable;
}
private IHistoryProcessor getHistoryProcessor() {
if(historyProcessor != null) {
return historyProcessor;
}
if (learning != null) {
return learning.getHistoryProcessor();
}
return null;
}
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
this.historyProcessor = historyProcessor;
}
private int getStep() {
if(stepCountable != null) {
return stepCountable.getStepCounter();
}
return learning.getStepCounter();
this.epochStepCounter = epochStepCounter;
}
@Override
@ -83,9 +54,12 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
if(historyProcessor != null) {
skipFrame = historyProcessor.getConf().getSkipFrame();
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
historyProcessor.add(rawObservation);
}
step = 0;
observation.setSkipped(skipFrame != 0);
return observation;
}
@ -97,21 +71,18 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
StepReply<O> rawStepReply = wrappedMDP.step(a);
INDArray rawObservation = getInput(rawStepReply.getObservation());
++step;
int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
int requiredFrame = 0;
if(historyProcessor != null) {
historyProcessor.record(rawObservation);
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
if ((getStep() % skipFrame == 0 && step >= requiredFrame)
|| (step % skipFrame == 0 && step < requiredFrame )){
if (stepOfObservation % skipFrame == 0) {
historyProcessor.add(rawObservation);
}
}
Observation observation;
if(historyProcessor != null && step >= requiredFrame) {
if(historyProcessor != null && stepOfObservation >= requiredFrame) {
observation = new Observation(historyProcessor.getHistory(), true);
observation.getData().muli(1.0 / historyProcessor.getScale());
}
@ -119,6 +90,10 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
observation = new Observation(new INDArray[] { rawObservation }, false);
}
if(stepOfObservation % skipFrame != 0 || stepOfObservation < requiredFrame) {
observation.setSkipped(true);
}
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
}
@ -134,7 +109,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
@Override
public MDP<Observation, A, AS> newInstance() {
return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), learning);
return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), historyProcessor, epochStepCounter);
}
private INDArray getInput(O obs) {

View File

@ -32,7 +32,7 @@ public class AsyncThreadDiscreteTest {
MockMDP mdpMock = new MockMDP(observationSpace);
TrainingListenerList listeners = new TrainingListenerList();
MockPolicy policyMock = new MockPolicy();
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 16, 0, 0, 2, 5,0, 0, 0, 0);
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0);
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
// Act
@ -41,8 +41,8 @@ public class AsyncThreadDiscreteTest {
// Assert
assertEquals(2, sut.trainSubEpochResults.size());
double[][] expectedLastObservations = new double[][] {
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
};
double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 };
for(int i = 0; i < 2; ++i) {
@ -60,7 +60,7 @@ public class AsyncThreadDiscreteTest {
assertEquals(2, asyncGlobalMock.enqueueCallCount);
// HistoryProcessor
double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
assertEquals(expectedAddValues.length, hpMock.addCalls.size());
for(int i = 0; i < expectedAddValues.length; ++i) {
assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001);
@ -75,9 +75,9 @@ public class AsyncThreadDiscreteTest {
// Policy
double[][] expectedPolicyInputs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
};
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
@ -93,11 +93,11 @@ public class AsyncThreadDiscreteTest {
assertEquals(2, nnMock.copyCallCount);
double[][] expectedNNInputs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: This one comes from the computation of output of the last minitrans
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, // FIXME: This one comes from the computation of output of the last minitrans
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: This one comes from the computation of output of the last minitrans
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: This one comes from the computation of output of the last minitrans
};
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
for(int i = 0; i < expectedNNInputs.length; ++i) {
@ -113,13 +113,13 @@ public class AsyncThreadDiscreteTest {
double[][][] expectedMinitransObs = new double[][][] {
new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: The last minitrans contains the next observation
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: The last minitrans contains the next observation
},
new double[][] {
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15 }, // FIXME: The last minitrans contains the next observation
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: The last minitrans contains the next observation
}
};
double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };

View File

@ -5,15 +5,12 @@ import lombok.Getter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.List;
@ -91,7 +88,7 @@ public class AsyncThreadTest {
// Assert
assertEquals(numberOfEpochs, context.listener.statEntries.size());
int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 };
int[] expectedStepCounter = new int[] { 10, 20, 30, 40, 50 };
double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init
+ 1.0; // Reward from trainSubEpoch()
for(int i = 0; i < numberOfEpochs; ++i) {
@ -114,7 +111,7 @@ public class AsyncThreadTest {
// Assert
assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size());
double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 };
for(int i = 0; i < context.sut.getEpochCounter(); ++i) {
for(int i = 0; i < context.sut.trainSubEpochParams.size(); ++i) {
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
assertEquals(2, params.nstep);
assertEquals(expectedObservation.length, params.obs.getData().shape()[1]);
@ -199,7 +196,9 @@ public class AsyncThreadTest {
protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) {
asyncGlobal.increaseCurrentLoop();
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
setStepCounter(getStepCounter() + nstep);
for(int i = 0; i < nstep; ++i) {
incrementStep();
}
return new SubEpochReturn(nstep, null, 1.0, 1.0);
}

View File

@ -18,8 +18,8 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
// Act
Transition<Integer> transition = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
123, 234, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
123, 234, new Observation(Nd4j.create(1)));
sut.store(transition);
List<Transition<Integer>> results = sut.getBatch(1);
@ -36,12 +36,12 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@ -78,12 +78,12 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@ -100,12 +100,12 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@ -131,16 +131,16 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
7, 8, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
9, 10, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, new Observation(Nd4j.create(1)));
Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
7, 8, new Observation(Nd4j.create(1)));
Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
9, 10, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@ -168,16 +168,16 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
7, 8, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
9, 10, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, new Observation(Nd4j.create(1)));
Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
7, 8, new Observation(Nd4j.create(1)));
Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
9, 10, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@ -196,6 +196,12 @@ public class ExpReplayTest {
assertEquals(5, (int)results.get(2).getAction());
assertEquals(6, (int)results.get(2).getReward());
}
private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, Observation nextObservation) {
Transition<Integer> result = new Transition<Integer>(observation, action, reward, false);
result.setNextObservation(nextObservation);
return result;
}
}

View File

@ -1,5 +1,6 @@
package org.deeplearning4j.rl4j.learning.sync;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
@ -88,12 +89,15 @@ public class SyncLearningTest {
private final LConfiguration conf;
@Getter
private int currentEpochStep = 0;
public MockSyncLearning(LConfiguration conf) {
this.conf = conf;
}
@Override
protected void preEpoch() { }
protected void preEpoch() { currentEpochStep = 0; }
@Override
protected void postEpoch() { }
@ -101,7 +105,7 @@ public class SyncLearningTest {
@Override
protected IDataManager.StatEntry trainEpoch() {
setStepCounter(getStepCounter() + 1);
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
return new MockStatEntry(getCurrentEpochStep(), getStepCounter(), 1.0);
}
@Override

View File

@ -21,7 +21,7 @@ public class TransitionTest {
Observation nextObservation = buildObservation(nextObs);
// Act
Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
Transition transition = buildTransition(observation, 123, 234.0, nextObservation);
// Assert
double[][] expectedObservation = new double[][] { obs };
@ -52,7 +52,7 @@ public class TransitionTest {
Observation nextObservation = buildObservation(nextObs);
// Act
Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
Transition transition = buildTransition(observation, 123, 234.0, nextObservation);
// Assert
assertExpected(obs, transition.getObservation().getData());
@ -71,12 +71,12 @@ public class TransitionTest {
double[] obs1 = new double[] { 0.0, 1.0, 2.0 };
Observation observation1 = buildObservation(obs1);
Observation nextObservation1 = buildObservation(new double[] { 100.0, 101.0, 102.0 });
transitions.add(new Transition(observation1,0, 0.0, false, nextObservation1));
transitions.add(buildTransition(observation1,0, 0.0, nextObservation1));
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
Observation observation2 = buildObservation(obs2);
Observation nextObservation2 = buildObservation(new double[] { 110.0, 111.0, 112.0 });
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
// Act
INDArray result = Transition.buildStackedObservations(transitions);
@ -101,7 +101,7 @@ public class TransitionTest {
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
double[][] obs2 = new double[][] {
{ 10.0, 11.0, 12.0 },
@ -112,7 +112,7 @@ public class TransitionTest {
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
// Act
INDArray result = Transition.buildStackedObservations(transitions);
@ -131,13 +131,13 @@ public class TransitionTest {
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
Observation observation1 = buildObservation(obs1);
Observation nextObservation1 = buildObservation(nextObs1);
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
Observation observation2 = buildObservation(obs2);
Observation nextObservation2 = buildObservation(nextObs2);
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
// Act
INDArray result = Transition.buildStackedNextObservations(transitions);
@ -162,7 +162,7 @@ public class TransitionTest {
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
double[][] obs2 = new double[][] {
{ 10.0, 11.0, 12.0 },
@ -174,7 +174,7 @@ public class TransitionTest {
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
// Act
INDArray result = Transition.buildStackedNextObservations(transitions);
@ -207,7 +207,13 @@ public class TransitionTest {
Nd4j.create(obs[1]).reshape(1, 3),
};
return new Observation(nextHistory);
}
private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) {
Transition result = new Transition(observation, action, reward, false);
result.setNextObservation(nextObservation);
return result;
}
private void assertExpected(double[] expected, INDArray actual) {

View File

@ -40,8 +40,10 @@ public class QLearningDiscreteTest {
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
MockMDP mdp = new MockMDP(observationSpace, random);
int initStepCount = 8;
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000,
0, 1.0, 0, 0, 0, 0, true);
initStepCount, 1.0, 0, 0, 0, 0, true);
MockDataManager dataManager = new MockDataManager(false);
MockExpReplay expReplay = new MockExpReplay();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
@ -60,7 +62,7 @@ public class QLearningDiscreteTest {
for(int i = 0; i < expectedRecords.length; ++i) {
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
}
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0 };
assertEquals(expectedAdds.length, hp.addCalls.size());
for(int i = 0; i < expectedAdds.length; ++i) {
assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001);
@ -75,19 +77,19 @@ public class QLearningDiscreteTest {
assertEquals(14, dqn.outputParams.size());
double[][] expectedDQNOutput = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
};
for(int i = 0; i < expectedDQNOutput.length; ++i) {
INDArray outputParam = dqn.outputParams.get(i);
@ -105,19 +107,20 @@ public class QLearningDiscreteTest {
assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
// ExpReplay calls
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
double[][] expectedTrObservations = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
};
assertEquals(expectedTrObservations.length, expReplay.transitions.size());
for(int i = 0; i < expectedTrRewards.length; ++i) {
Transition tr = expReplay.transitions.get(i);
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
@ -129,7 +132,7 @@ public class QLearningDiscreteTest {
}
// trainEpoch result
assertEquals(16, result.getStepCounter());
assertEquals(initStepCount + 16, result.getStepCounter());
assertEquals(300.0, result.getReward(), 0.00001);
assertTrue(dqn.hasBeenReset);
assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset);

View File

@ -26,7 +26,7 @@ public class DoubleDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
}
};
@ -52,7 +52,7 @@ public class DoubleDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
}
};
@ -78,11 +78,11 @@ public class DoubleDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}),
add(builtTransition(buildObservation(new double[]{3.3, 4.4}),
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}),
add(builtTransition(buildObservation(new double[]{5.5, 6.6}),
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
}
};
@ -108,4 +108,11 @@ public class DoubleDQNTest {
private Observation buildObservation(double[] data) {
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
}
private Transition<Integer> builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
Transition<Integer> result = new Transition<Integer>(observation, action, reward, isTerminal);
result.setNextObservation(nextObservation);
return result;
}
}

View File

@ -25,7 +25,7 @@ public class StandardDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
}
};
@ -51,7 +51,7 @@ public class StandardDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
}
};
@ -77,11 +77,11 @@ public class StandardDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}),
add(buildTransition(buildObservation(new double[]{3.3, 4.4}),
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}),
add(buildTransition(buildObservation(new double[]{5.5, 6.6}),
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
}
};
@ -108,4 +108,10 @@ public class StandardDQNTest {
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
}
private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
Transition<Integer> result = new Transition<Integer>(observation, action, reward, isTerminal);
result.setNextObservation(nextObservation);
return result;
}
}

View File

@ -198,7 +198,7 @@ public class PolicyTest {
assertEquals(465.0, totalReward, 0.0001);
// HistoryProcessor
assertEquals(27, hp.addCalls.size());
assertEquals(16, hp.addCalls.size());
assertEquals(31, hp.recordCalls.size());
for(int i=0; i <= 30; ++i) {
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);

View File

@ -142,6 +142,11 @@ public class DataManagerTrainingListenerTest {
return 0;
}
@Override
public int getCurrentEpochStep() {
return 0;
}
@Getter
@Setter
private IHistoryProcessor historyProcessor;

View File

@ -93,6 +93,9 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
private boolean done = false;
public GymEnv(String envId, boolean render, boolean monitor) {
this(envId, render, monitor, (Integer)null);
}
public GymEnv(String envId, boolean render, boolean monitor, Integer seed) {
this.envId = envId;
this.render = render;
this.monitor = monitor;
@ -107,6 +110,10 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
Py_DecRef(PyRun_StringFlags("env = gym.wrappers.Monitor(env, '" + GYM_MONITOR_DIR + "')", Py_single_input, globals, locals, null));
checkPythonError();
}
if (seed != null) {
Py_DecRef(PyRun_StringFlags("env.seed(" + seed + ")", Py_single_input, globals, locals, null));
checkPythonError();
}
PyObject shapeTuple = PyRun_StringFlags("env.observation_space.shape", Py_eval_input, globals, locals, null);
int[] shape = new int[(int)PyTuple_Size(shapeTuple)];
for (int i = 0; i < shape.length; i++) {
@ -125,7 +132,10 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
}
public GymEnv(String envId, boolean render, boolean monitor, int[] actions) {
this(envId, render, monitor);
this(envId, render, monitor, null, actions);
}
public GymEnv(String envId, boolean render, boolean monitor, Integer seed, int[] actions) {
this(envId, render, monitor, seed);
actionTransformer = new ActionTransformer((HighLowDiscrete) getActionSpace(), actions);
}