RL4J: Add Observation and LegacyMDPWrapper (#8368)

* Added Observable & LegacyMDPWrapper

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Moved observation processing to LegacyMDPWrapper

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Observation using DataSets, changes in Transition and BaseTDTargetAlgorithm

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Added javadoc to Transition new methods

Signed-off-by: unknown <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2019-11-26 09:05:11 -05:00 committed by Samuel Audet
parent 8d87b078c2
commit 47c58cf69d
25 changed files with 742 additions and 218 deletions

View File

@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
*
* A common interface that any training method should implement
*/
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> extends StepCountable {
public interface ILearning<O, A, AS extends ActionSpace<A>> extends StepCountable {
IPolicy<O, A> getPolicy();

View File

@ -39,7 +39,7 @@ import org.nd4j.linalg.factory.Nd4j;
*
*/
@Slf4j
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
@Getter @Setter
@ -53,8 +53,8 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
}
public static <O extends Encodable, A, AS extends ActionSpace<A>> INDArray getInput(MDP<O, A, AS> mdp, O obs) {
INDArray arr = Nd4j.create(obs.toArray());
public static <O, A, AS extends ActionSpace<A>> INDArray getInput(MDP<O, A, AS> mdp, O obs) {
INDArray arr = Nd4j.create(((Encodable)obs).toArray());
int[] shape = mdp.getObservationSpace().getShape();
if (shape.length == 1)
return arr.reshape(new long[] {1, arr.length()});
@ -62,7 +62,7 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
return arr.reshape(shape);
}
public static <O extends Encodable, A, AS extends ActionSpace<A>> InitMdp<O> initMdp(MDP<O, A, AS> mdp,
public static <O, A, AS extends ActionSpace<A>> InitMdp<O> initMdp(MDP<O, A, AS> mdp,
IHistoryProcessor hp) {
O obs = mdp.reset();
@ -138,15 +138,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
this.historyProcessor = historyProcessor;
}
public INDArray getInput(O obs) {
return getInput(getMdp(), obs);
}
public InitMdp<O> initMdp() {
getNeuralNet().reset();
return initMdp(getMdp(), getHistoryProcessor());
}
@AllArgsConstructor
@Value
public static class InitMdp<O> {

View File

@ -36,7 +36,7 @@ import org.deeplearning4j.rl4j.util.IDataManager;
* @author Alexandre Boulanger
*/
@Slf4j
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Learning<O, A, AS, NN> implements IEpochTrainer {
private final TrainingListenerList listeners = new TrainingListenerList();

View File

@ -16,27 +16,56 @@
package org.deeplearning4j.rl4j.learning.sync;
import lombok.AllArgsConstructor;
import lombok.Value;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.List;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16.
*
* A transition is a SARS tuple
* State, Action, Reward, (isTerminal), State
*
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16.
* @author Alexandre Boulanger
*
*/
@Value
@AllArgsConstructor
public class Transition<A> {
INDArray[] observation;
Observation observation;
A action;
double reward;
boolean isTerminal;
INDArray nextObservation;
public Transition(Observation observation, A action, double reward, boolean isTerminal, Observation nextObservation) {
this.observation = observation;
this.action = action;
this.reward = reward;
this.isTerminal = isTerminal;
// To conserve memory, only the most recent frame of the next observation is kept (if history is used).
// The full nextObservation will be re-build from observation when needed.
long[] nextObservationShape = nextObservation.getData().shape().clone();
nextObservationShape[0] = 1;
this.nextObservation = nextObservation.getData()
.get(new INDArrayIndex[] {NDArrayIndex.point(0)})
.reshape(nextObservationShape);
}
private Transition(Observation observation, A action, double reward, boolean isTerminal, INDArray nextObservation) {
this.observation = observation;
this.action = action;
this.reward = reward;
this.isTerminal = isTerminal;
this.nextObservation = nextObservation;
}
/**
* concat an array history into a single INDArry of as many channel
* as element in the history array
@ -53,36 +82,80 @@ public class Transition<A> {
* @return this transition duplicated
*/
public Transition<A> dup() {
INDArray[] dupObservation = dup(observation);
Observation dupObservation = observation.dup();
INDArray nextObs = nextObservation.dup();
return new Transition<>(dupObservation, action, reward, isTerminal, nextObs);
return new Transition<A>(dupObservation, action, reward, isTerminal, nextObs);
}
/**
* Duplicate an history
* @param history the history to duplicate
* @return a duplicate of the history
* Stack along the 0-dimension all the observations of the batch in a INDArray.
*
* @param transitions A list of the transitions of the batch
* @param <A> The type of the Action
* @return A INDArray of all of the batch's observations stacked along the 0-dimension.
*/
public static INDArray[] dup(INDArray[] history) {
INDArray[] dupHistory = new INDArray[history.length];
for (int i = 0; i < history.length; i++) {
dupHistory[i] = history[i].dup();
public static <A> INDArray buildStackedObservations(List<Transition<A>> transitions) {
int size = transitions.size();
long[] shape = getShape(transitions);
INDArray[] array = new INDArray[size];
for (int i = 0; i < size; i++) {
array[i] = transitions.get(i).getObservation().getData();
}
return dupHistory;
return Nd4j.concat(0, array).reshape(shape);
}
/**
* append a pixel frame to an history (throwing the last frame)
* @param history the history on which to append
* @param append the pixel frame to append
* @return the appended history
* Stack along the 0-dimension all the next observations of the batch in a INDArray.
*
* @param transitions A list of the transitions of the batch
* @param <A> The type of the Action
* @return A INDArray of all of the batch's next observations stacked along the 0-dimension.
*/
public static INDArray[] append(INDArray[] history, INDArray append) {
INDArray[] appended = new INDArray[history.length];
appended[0] = append;
System.arraycopy(history, 0, appended, 1, history.length - 1);
return appended;
public static <A> INDArray buildStackedNextObservations(List<Transition<A>> transitions) {
int size = transitions.size();
long[] shape = getShape(transitions);
INDArray[] array = new INDArray[size];
for (int i = 0; i < size; i++) {
Transition<A> trans = transitions.get(i);
INDArray obs = trans.getObservation().getData();
long historyLength = obs.shape()[0];
if(historyLength != 1) {
// To conserve memory, only the most recent frame of the next observation is kept (if history is used).
// We need to rebuild the frame-stack in addition to builing the batch-stack.
INDArray historyPart = obs.get(new INDArrayIndex[]{NDArrayIndex.interval(0, historyLength - 1)});
array[i] = Nd4j.concat(0, trans.getNextObservation(), historyPart);
}
else {
array[i] = trans.getNextObservation();
}
}
return Nd4j.concat(0, array).reshape(shape);
}
private static <A> long[] getShape(List<Transition<A>> transitions) {
INDArray observations = transitions.get(0).getObservation().getData();
long[] observationShape = observations.shape();
long[] stackedShape;
if(observationShape[0] == 1) {
// FIXME: Currently RL4J doesn't support 1D observations. So if we have a shape with 1 in the first dimension, we can use that dimension and don't need to add another one.
stackedShape = new long[observationShape.length];
System.arraycopy(observationShape, 0, stackedShape, 0, observationShape.length);
}
else {
stackedShape = new long[observationShape.length + 1];
System.arraycopy(observationShape, 1, stackedShape, 2, observationShape.length - 1);
stackedShape[1] = observationShape[1];
}
stackedShape[0] = transitions.size();
return stackedShape;
}
}

View File

@ -21,15 +21,20 @@ import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random;
@ -53,6 +58,8 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
@Setter(AccessLevel.PROTECTED)
protected IExpReplay<A> expReplay;
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
public QLearning(QLConfiguration conf) {
this(conf, getSeededRandom(conf.getSeed()));
}
@ -95,11 +102,11 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
protected abstract void postEpoch();
protected abstract QLStepReturn<O> trainStep(O obs);
protected abstract QLStepReturn<Observation> trainStep(Observation obs);
protected StatEntry trainEpoch() {
InitMdp<O> initMdp = initMdp();
O obs = initMdp.getLastObs();
InitMdp<Observation> initMdp = refacInitMdp();
Observation obs = initMdp.getLastObs();
double reward = initMdp.getReward();
int step = initMdp.getSteps();
@ -114,7 +121,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
updateTargetNetwork();
}
QLStepReturn<O> stepR = trainStep(obs);
QLStepReturn<Observation> stepR = trainStep(obs);
if (!stepR.getMaxQ().isNaN()) {
if (startQ.isNaN())
@ -142,6 +149,36 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
}
private InitMdp<Observation> refacInitMdp() {
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
IHistoryProcessor hp = getHistoryProcessor();
Observation observation = mdp.reset();
int step = 0;
double reward = 0;
boolean isHistoryProcessor = hp != null;
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
while (step < requiredFrame && !mdp.isDone()) {
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
StepReply<Observation> stepReply = mdp.step(action);
reward += stepReply.getReward();
observation = stepReply.getObservation();
step++;
}
return new InitMdp(step, observation, reward);
}
@AllArgsConstructor
@Builder
@Value

View File

@ -26,10 +26,12 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.api.DataSet;
@ -51,8 +53,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Getter
final private QLConfiguration configuration;
@Getter
final private MDP<O, Integer, DiscreteSpace> mdp;
private final LegacyMDPWrapper<O, Integer, DiscreteSpace> mdp;
@Getter
private DQNPolicy<O> policy;
@Getter
@ -65,11 +66,14 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
private IDQN targetQNetwork;
private int lastAction;
private INDArray[] history = null;
private double accuReward = 0;
ITDTargetAlgorithm tdTargetAlgorithm;
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
return mdp;
}
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
int epsilonNbStep) {
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
@ -79,7 +83,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
int epsilonNbStep, Random random) {
super(conf);
this.configuration = conf;
this.mdp = mdp;
this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, this);
qNetwork = dqn;
targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork());
@ -92,6 +96,10 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
}
public MDP<O, Integer, DiscreteSpace> getMdp() {
return mdp.getWrappedMDP();
}
public void postEpoch() {
if (getHistoryProcessor() != null)
@ -100,7 +108,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
}
public void preEpoch() {
history = null;
lastAction = 0;
accuReward = 0;
}
@ -110,10 +117,9 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
* @param obs last obs
* @return relevant info for next step
*/
protected QLStepReturn<O> trainStep(O obs) {
protected QLStepReturn<Observation> trainStep(Observation obs) {
Integer action;
INDArray input = getInput(obs);
boolean isHistoryProcessor = getHistoryProcessor() != null;
@ -128,50 +134,25 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
if (getStepCounter() % skipFrame != 0) {
action = lastAction;
} else {
if (history == null) {
if (isHistoryProcessor) {
getHistoryProcessor().add(input);
history = getHistoryProcessor().getHistory();
} else
history = new INDArray[] {input};
}
//concat the history into a single INDArray input
INDArray hstack = Transition.concat(Transition.dup(history));
if (isHistoryProcessor) {
hstack.muli(1.0 / getHistoryProcessor().getScale());
}
//if input is not 2d, you have to append that the batch is 1 length high
if (hstack.shape().length > 2)
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape())));
INDArray qs = getQNetwork().output(hstack);
INDArray qs = getQNetwork().output(obs);
int maxAction = Learning.getMaxAction(qs);
maxQ = qs.getDouble(maxAction);
action = getEgPolicy().nextAction(hstack);
action = getEgPolicy().nextAction(obs);
}
lastAction = action;
StepReply<O> stepReply = getMdp().step(action);
StepReply<Observation> stepReply = mdp.step(action);
INDArray ninput = getInput(stepReply.getObservation());
if (isHistoryProcessor)
getHistoryProcessor().record(ninput);
Observation nextObservation = stepReply.getObservation();
accuReward += stepReply.getReward() * configuration.getRewardFactor();
//if it's not a skipped frame, you can do a step of training
if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
if (isHistoryProcessor)
getHistoryProcessor().add(ninput);
INDArray[] nhistory = isHistoryProcessor ? getHistoryProcessor().getHistory() : new INDArray[] {ninput};
Transition<Integer> trans = new Transition(history, action, accuReward, stepReply.isDone(), nhistory[0]);
Transition<Integer> trans = new Transition(obs, action, accuReward, stepReply.isDone(), nextObservation);
getExpReplay().store(trans);
if (getStepCounter() > updateStart) {
@ -179,27 +160,16 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
}
history = nhistory;
accuReward = 0;
}
return new QLStepReturn<O>(maxQ, getQNetwork().getLatestScore(), stepReply);
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
}
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
if (transitions.size() == 0)
throw new IllegalArgumentException("too few transitions");
// TODO: Remove once we use DataSets in observations
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
: getHistoryProcessor().getConf().getShape();
((BaseTDTargetAlgorithm) tdTargetAlgorithm).setNShape(makeShape(transitions.size(), shape));
// TODO: Remove once we use DataSets in observations
if(getHistoryProcessor() != null) {
((BaseTDTargetAlgorithm) tdTargetAlgorithm).setScale(getHistoryProcessor().getScale());
}
return tdTargetAlgorithm.computeTDTargets(transitions);
}
}

View File

@ -16,14 +16,10 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import lombok.Setter;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.List;
@ -40,11 +36,6 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
private final double errorClamp;
private final boolean isClamped;
@Setter
private int[] nShape; // TODO: Remove once we use DataSets in observations
@Setter
private double scale = 1.0; // TODO: Remove once we use DataSets in observations
/**
*
* @param qNetworkSource The source of the Q-Network
@ -93,37 +84,8 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
int size = transitions.size();
INDArray observations = Nd4j.create(nShape);
INDArray nextObservations = Nd4j.create(nShape);
// TODO: Remove once we use DataSets in observations
for (int i = 0; i < size; i++) {
Transition<Integer> trans = transitions.get(i);
INDArray[] obsArray = trans.getObservation();
if (observations.rank() == 2) {
observations.putRow(i, obsArray[0]);
} else {
for (int j = 0; j < obsArray.length; j++) {
observations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, obsArray[j]);
}
}
INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation());
if (nextObservations.rank() == 2) {
nextObservations.putRow(i, nextObsArray[0]);
} else {
for (int j = 0; j < nextObsArray.length; j++) {
nextObservations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, nextObsArray[j]);
}
}
}
// TODO: Remove once we use DataSets in observations
if(scale != 1.0) {
observations.muli(1.0 / scale);
nextObservations.muli(1.0 / scale);
}
INDArray observations = Transition.buildStackedObservations(transitions);
INDArray nextObservations = Transition.buildStackedNextObservations(transitions);
initComputation(observations, nextObservations);

View File

@ -22,6 +22,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -70,6 +71,10 @@ public class DQN<NN extends DQN> implements IDQN<NN> {
return mln.output(batch);
}
public INDArray output(Observation observation) {
return this.output(observation.getData());
}
public INDArray[] outputAll(INDArray batch) {
return new INDArray[] {output(batch)};
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.network.dqn;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
@ -37,6 +38,7 @@ public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
void fit(INDArray input, INDArray[] labels);
INDArray output(INDArray batch);
INDArray output(Observation observation);
INDArray[] outputAll(INDArray batch);

View File

@ -0,0 +1,51 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
/**
* Presently only a dummy container. Will contain observation channels when done.
*/
public class Observation {
// TODO: Presently only a dummy container. Will contain observation channels when done.
private final DataSet data;
public Observation(INDArray[] data) {
this(new org.nd4j.linalg.dataset.DataSet(Nd4j.concat(0, data), null));
}
// FIXME: Remove -- only used in unit tests
public Observation(INDArray data) {
this.data = new org.nd4j.linalg.dataset.DataSet(data, null);
}
private Observation(DataSet data) {
this.data = data;
}
public Observation dup() {
return new Observation(new org.nd4j.linalg.dataset.DataSet(data.getFeatures().dup(), null));
}
public INDArray getData() {
return data.getFeatures();
}
}

View File

@ -21,8 +21,8 @@ import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
@ -38,7 +38,7 @@ import org.nd4j.linalg.api.rng.Random;
*/
@AllArgsConstructor
@Slf4j
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<O, A> {
public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
final private Policy<O, A> policy;
final private MDP<O, A, AS> mdp;
@ -61,8 +61,10 @@ public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extend
return policy.nextAction(input);
else
return mdp.getActionSpace().randomAction();
}
public A nextAction(Observation observation) {
return this.nextAction(observation.getData());
}
public float getEpsilon() {

View File

@ -6,7 +6,7 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
public interface IPolicy<O extends Encodable, A> {
public interface IPolicy<O, A> {
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
A nextAction(INDArray input);
}

View File

@ -35,7 +35,7 @@ import org.nd4j.linalg.util.ArrayUtil;
*
* A Policy responsability is to choose the next action given a state
*/
public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
public abstract class Policy<O, A> implements IPolicy<O, A> {
public abstract NeuralNet getNeuralNet();

View File

@ -0,0 +1,139 @@
package org.deeplearning4j.rl4j.util;
import lombok.Getter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
@Getter
private final MDP<O, A, AS> wrappedMDP;
@Getter
private final WrapperObservationSpace observationSpace;
private final ILearning learning;
private int skipFrame;
private int step = 0;
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning) {
this.wrappedMDP = wrappedMDP;
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
this.learning = learning;
}
@Override
public AS getActionSpace() {
return wrappedMDP.getActionSpace();
}
@Override
public Observation reset() {
INDArray rawObservation = getInput(wrappedMDP.reset());
IHistoryProcessor historyProcessor = learning.getHistoryProcessor();
if(historyProcessor != null) {
historyProcessor.record(rawObservation.dup());
rawObservation.muli(1.0 / historyProcessor.getScale());
}
Observation observation = new Observation(new INDArray[] { rawObservation });
if(historyProcessor != null) {
skipFrame = historyProcessor.getConf().getSkipFrame();
historyProcessor.add(rawObservation);
}
step = 0;
return observation;
}
@Override
public void close() {
wrappedMDP.close();
}
@Override
public StepReply<Observation> step(A a) {
IHistoryProcessor historyProcessor = learning.getHistoryProcessor();
StepReply<O> rawStepReply = wrappedMDP.step(a);
INDArray rawObservation = getInput(rawStepReply.getObservation());
++step;
int requiredFrame = 0;
if(historyProcessor != null) {
historyProcessor.record(rawObservation.dup());
rawObservation.muli(1.0 / historyProcessor.getScale());
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
if ((learning.getStepCounter() % skipFrame == 0 && step >= requiredFrame)
|| (step % skipFrame == 0 && step < requiredFrame )){
historyProcessor.add(rawObservation);
}
}
Observation observation;
if(historyProcessor != null && step >= requiredFrame) {
observation = new Observation(historyProcessor.getHistory());
}
else {
observation = new Observation(new INDArray[] { rawObservation });
}
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
}
@Override
public boolean isDone() {
return wrappedMDP.isDone();
}
@Override
public MDP<Observation, A, AS> newInstance() {
return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), learning);
}
private INDArray getInput(O obs) {
INDArray arr = Nd4j.create(obs.toArray());
int[] shape = observationSpace.getShape();
if (shape.length == 1)
return arr.reshape(new long[] {1, arr.length()});
else
return arr.reshape(shape);
}
public static class WrapperObservationSpace implements ObservationSpace<Observation> {
@Getter
private final int[] shape;
public WrapperObservationSpace(int[] shape) {
this.shape = shape;
}
@Override
public String getName() {
return null;
}
@Override
public INDArray getLow() {
return null;
}
@Override
public INDArray getHigh() {
return null;
}
}
}

View File

@ -50,7 +50,7 @@ public class AsyncThreadDiscreteTest {
assertEquals(1, asyncGlobalMock.enqueueCallCount);
// HistoryProcessor
assertEquals(10, hpMock.addCallCount);
assertEquals(10, hpMock.addCalls.size());
double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 };
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
for(int i = 0; i < expectedRecordValues.length; ++i) {

View File

@ -1,5 +1,6 @@
package org.deeplearning4j.rl4j.learning.sync;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.support.MockRandom;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -17,7 +18,8 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
// Act
Transition<Integer> transition = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 123, 234, false, Nd4j.create(1));
Transition<Integer> transition = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
123, 234, false, new Observation(Nd4j.create(1)));
sut.store(transition);
List<Transition<Integer>> results = sut.getBatch(1);
@ -34,9 +36,12 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, false, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@ -73,9 +78,12 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, false, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@ -92,9 +100,12 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, false, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@ -120,11 +131,16 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1));
Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1));
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
7, 8, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
9, 10, false, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
@ -152,11 +168,16 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1));
Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1));
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
1, 2, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
3, 4, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
5, 6, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
7, 8, false, new Observation(Nd4j.create(1)));
Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
9, 10, false, new Observation(Nd4j.create(1)));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);

View File

@ -0,0 +1,255 @@
package org.deeplearning4j.rl4j.learning.sync;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
public class TransitionTest {
@Test
public void when_callingCtorWithoutHistory_expect_2DObservationAndNextObservation() {
// Arrange
double[] obs = new double[] { 1.0, 2.0, 3.0 };
Observation observation = buildObservation(obs);
double[] nextObs = new double[] { 10.0, 20.0, 30.0 };
Observation nextObservation = buildObservation(nextObs);
// Act
Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
// Assert
double[][] expectedObservation = new double[][] { obs };
assertExpected(expectedObservation, transition.getObservation().getData());
double[][] expectedNextObservation = new double[][] { nextObs };
assertExpected(expectedNextObservation, transition.getNextObservation());
assertEquals(123, transition.getAction());
assertEquals(234.0, transition.getReward(), 0.0001);
}
@Test
public void when_callingCtorWithHistory_expect_ObservationWithHistoryAndNextObservationWithout() {
// Arrange
double[][] obs = new double[][] {
{ 0.0, 1.0, 2.0 },
{ 3.0, 4.0, 5.0 },
{ 6.0, 7.0, 8.0 },
};
Observation observation = buildObservation(obs);
double[][] nextObs = new double[][] {
{ 10.0, 11.0, 12.0 },
{ 0.0, 1.0, 2.0 },
{ 3.0, 4.0, 5.0 },
};
Observation nextObservation = buildObservation(nextObs);
// Act
Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
// Assert
assertExpected(obs, transition.getObservation().getData());
assertExpected(nextObs[0], transition.getNextObservation());
assertEquals(123, transition.getAction());
assertEquals(234.0, transition.getReward(), 0.0001);
}
@Test
public void when_CallingBuildStackedObservationsAndShapeRankIs2_expect_2DResultWithObservationsStackedOnDimension0() {
// Arrange
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>();
double[] obs1 = new double[] { 0.0, 1.0, 2.0 };
Observation observation1 = buildObservation(obs1);
Observation nextObservation1 = buildObservation(new double[] { 100.0, 101.0, 102.0 });
transitions.add(new Transition(observation1,0, 0.0, false, nextObservation1));
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
Observation observation2 = buildObservation(obs2);
Observation nextObservation2 = buildObservation(new double[] { 110.0, 111.0, 112.0 });
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
// Act
INDArray result = Transition.buildStackedObservations(transitions);
// Assert
double[][] expected = new double[][] { obs1, obs2 };
assertExpected(expected, result);
}
@Test
public void when_CallingBuildStackedObservationsAndShapeRankIsGreaterThan2_expect_ResultWithOneMoreDimensionAndObservationsStackedOnDimension0() {
// Arrange
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>();
double[][] obs1 = new double[][] {
{ 0.0, 1.0, 2.0 },
{ 3.0, 4.0, 5.0 },
{ 6.0, 7.0, 8.0 },
};
Observation observation1 = buildObservation(obs1);
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
double[][] obs2 = new double[][] {
{ 10.0, 11.0, 12.0 },
{ 13.0, 14.0, 15.0 },
{ 16.0, 17.0, 18.0 },
};
Observation observation2 = buildObservation(obs2);
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
// Act
INDArray result = Transition.buildStackedObservations(transitions);
// Assert
double[][][] expected = new double[][][] { obs1, obs2 };
assertExpected(expected, result);
}
@Test
public void when_CallingBuildStackedNextObservationsAndShapeRankIs2_expect_2DResultWithObservationsStackedOnDimension0() {
// Arrange
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>();
double[] obs1 = new double[] { 0.0, 1.0, 2.0 };
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
Observation observation1 = buildObservation(obs1);
Observation nextObservation1 = buildObservation(nextObs1);
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
Observation observation2 = buildObservation(obs2);
Observation nextObservation2 = buildObservation(nextObs2);
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
// Act
INDArray result = Transition.buildStackedNextObservations(transitions);
// Assert
double[][] expected = new double[][] { nextObs1, nextObs2 };
assertExpected(expected, result);
}
@Test
public void when_CallingBuildStackedNextObservationsAndShapeRankIsGreaterThan2_expect_ResultWithOneMoreDimensionAndObservationsStackedOnDimension0() {
// Arrange
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>();
double[][] obs1 = new double[][] {
{ 0.0, 1.0, 2.0 },
{ 3.0, 4.0, 5.0 },
{ 6.0, 7.0, 8.0 },
};
Observation observation1 = buildObservation(obs1);
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
double[][] obs2 = new double[][] {
{ 10.0, 11.0, 12.0 },
{ 13.0, 14.0, 15.0 },
{ 16.0, 17.0, 18.0 },
};
Observation observation2 = buildObservation(obs2);
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
// Act
INDArray result = Transition.buildStackedNextObservations(transitions);
// Assert
double[][][] expected = new double[][][] {
new double[][] { nextObs1, obs1[0], obs1[1] },
new double[][] { nextObs2, obs2[0], obs2[1] }
};
assertExpected(expected, result);
}
private Observation buildObservation(double[][] obs) {
INDArray[] history = new INDArray[] {
Nd4j.create(obs[0]).reshape(1, 3),
Nd4j.create(obs[1]).reshape(1, 3),
Nd4j.create(obs[2]).reshape(1, 3),
};
return new Observation(history);
}
private Observation buildObservation(double[] obs) {
return new Observation(new INDArray[] { Nd4j.create(obs).reshape(1, 3) });
}
private Observation buildNextObservation(double[][] obs, double[] nextObs) {
INDArray[] nextHistory = new INDArray[] {
Nd4j.create(nextObs).reshape(1, 3),
Nd4j.create(obs[0]).reshape(1, 3),
Nd4j.create(obs[1]).reshape(1, 3),
};
return new Observation(nextHistory);
}
private void assertExpected(double[] expected, INDArray actual) {
long[] shape = actual.shape();
assertEquals(2, shape.length);
assertEquals(1, shape[0]);
assertEquals(expected.length, shape[1]);
for(int i = 0; i < expected.length; ++i) {
assertEquals(expected[i], actual.getDouble(0, i), 0.0001);
}
}
private void assertExpected(double[][] expected, INDArray actual) {
long[] shape = actual.shape();
assertEquals(2, shape.length);
assertEquals(expected.length, shape[0]);
assertEquals(expected[0].length, shape[1]);
for(int i = 0; i < expected.length; ++i) {
double[] expectedLine = expected[i];
for(int j = 0; j < expectedLine.length; ++j) {
assertEquals(expectedLine[j], actual.getDouble(i, j), 0.0001);
}
}
}
private void assertExpected(double[][][] expected, INDArray actual) {
long[] shape = actual.shape();
assertEquals(3, shape.length);
assertEquals(expected.length, shape[0]);
assertEquals(expected[0].length, shape[1]);
assertEquals(expected[0][0].length, shape[2]);
for(int i = 0; i < expected.length; ++i) {
double[][] expected2D = expected[i];
for(int j = 0; j < expected2D.length; ++j) {
double[] expectedLine = expected2D[j];
for (int k = 0; k < expectedLine.length; ++k) {
assertEquals(expectedLine[k], actual.getDouble(i, j, k), 0.0001);
}
}
}
}
}

View File

@ -40,7 +40,7 @@ public class QLearningDiscreteTest {
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
MockMDP mdp = new MockMDP(observationSpace, random);
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000,
0, 1.0, 0, 0, 0, 0, true);
MockDataManager dataManager = new MockDataManager(false);
MockExpReplay expReplay = new MockExpReplay();
@ -48,15 +48,10 @@ public class QLearningDiscreteTest {
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hp);
MockEncodable obs = new MockEncodable(-100);
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
// Act
sut.initMdp();
for(int step = 0; step < 16; ++step) {
results.add(sut.trainStep(obs));
sut.incrementStep();
}
IDataManager.StatEntry result = sut.trainEpoch();
// Assert
// HistoryProcessor calls
@ -65,7 +60,11 @@ public class QLearningDiscreteTest {
for(int i = 0; i < expectedRecords.length; ++i) {
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
}
assertEquals(13, hp.addCallCount);
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
assertEquals(expectedAdds.length, hp.addCalls.size());
for(int i = 0; i < expectedAdds.length; ++i) {
assertEquals(expectedAdds[i], 255.0 * hp.addCalls.get(i).getDouble(0), 0.0001);
}
assertEquals(0, hp.startMonitorCallCount);
assertEquals(0, hp.stopMonitorCallCount);
@ -75,14 +74,14 @@ public class QLearningDiscreteTest {
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
assertEquals(14, dqn.outputParams.size());
double[][] expectedDQNOutput = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 },
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 },
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
@ -108,13 +107,13 @@ public class QLearningDiscreteTest {
// ExpReplay calls
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, -100.0, 9.0, 11.0, 13.0, 15.0 };
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
double[][] expectedTrObservations = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 },
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 },
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
@ -123,26 +122,15 @@ public class QLearningDiscreteTest {
Transition tr = expReplay.transitions.get(i);
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
assertEquals(expectedTrActions[i], tr.getAction());
assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001);
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001);
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(j, 0), 0.0001);
}
}
// trainStep results
assertEquals(16, results.size());
double[] expectedMaxQ = new double[] { 6.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 };
double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
for(int i=0; i < 16; ++i) {
QLearning.QLStepReturn<MockEncodable> result = results.get(i);
if(i % 2 == 0) {
assertEquals(expectedMaxQ[i/2], 255.0 * result.getMaxQ(), 0.001);
assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001);
}
else {
assertTrue(result.getMaxQ().isNaN());
}
}
// trainEpoch result
assertEquals(16, result.getStepCounter());
assertEquals(300.0, result.getReward(), 0.00001);
}
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
@ -163,5 +151,9 @@ public class QLearningDiscreteTest {
this.expReplay = exp;
}
@Override
public IDataManager.StatEntry trainEpoch() {
return super.trainEpoch();
}
}
}

View File

@ -3,6 +3,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorit
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
@ -25,12 +26,12 @@ public class DoubleDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0})));
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
}
};
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 1, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
@ -51,12 +52,12 @@ public class DoubleDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
}
};
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 1, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
@ -77,14 +78,16 @@ public class DoubleDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0})));
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0})));
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}),
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}),
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
}
};
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 3, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
@ -102,4 +105,7 @@ public class DoubleDQNTest {
}
private Observation buildObservation(double[] data) {
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
}
}

View File

@ -3,6 +3,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorit
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
@ -24,12 +25,12 @@ public class StandardDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0})));
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
}
};
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 1, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
@ -50,12 +51,12 @@ public class StandardDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
}
};
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 1, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
@ -76,14 +77,16 @@ public class StandardDQNTest {
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0})));
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0})));
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}),
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}),
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
}
};
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 3, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
@ -101,4 +104,8 @@ public class StandardDQNTest {
}
private Observation buildObservation(double[] data) {
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
}
}

View File

@ -1,12 +1,11 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import lombok.Setter;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException;
import java.io.OutputStream;
@ -57,6 +56,11 @@ public class MockDQN implements IDQN {
return batch;
}
@Override
public INDArray output(Observation observation) {
return this.output(observation.getData());
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];

View File

@ -197,7 +197,7 @@ public class PolicyTest {
assertEquals(465.0, totalReward, 0.0001);
// HistoryProcessor
assertEquals(27, hp.addCallCount);
assertEquals(27, hp.addCalls.size());
assertEquals(31, hp.recordCalls.size());
for(int i=0; i <= 30; ++i) {
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);

View File

@ -4,6 +4,7 @@ import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
@ -48,6 +49,11 @@ public class MockDQN implements IDQN {
return batch;
}
@Override
public INDArray output(Observation observation) {
return this.output(observation.getData());
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];

View File

@ -2,6 +2,7 @@ package org.deeplearning4j.rl4j.support;
import org.apache.commons.collections4.queue.CircularFifoQueue;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -9,7 +10,6 @@ import java.util.ArrayList;
public class MockHistoryProcessor implements IHistoryProcessor {
public int addCallCount = 0;
public int startMonitorCallCount = 0;
public int stopMonitorCallCount = 0;
@ -17,12 +17,14 @@ public class MockHistoryProcessor implements IHistoryProcessor {
private final CircularFifoQueue<INDArray> history;
public final ArrayList<INDArray> recordCalls;
public final ArrayList<INDArray> addCalls;
public MockHistoryProcessor(Configuration config) {
this.config = config;
history = new CircularFifoQueue<>(config.getHistoryLength());
recordCalls = new ArrayList<INDArray>();
addCalls = new ArrayList<INDArray>();
}
@Override
@ -46,7 +48,7 @@ public class MockHistoryProcessor implements IHistoryProcessor {
@Override
public void add(INDArray image) {
++addCallCount;
addCalls.add(image);
history.add(image);
}

View File

@ -2,7 +2,6 @@ package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.rng.Random;