RL4J: Add Observation and LegacyMDPWrapper (#8368)
* Added Observable & LegacyMDPWrapper Signed-off-by: unknown <aboulang2002@yahoo.com> * Moved observation processing to LegacyMDPWrapper Signed-off-by: unknown <aboulang2002@yahoo.com> * Observation using DataSets, changes in Transition and BaseTDTargetAlgorithm Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Added javadoc to Transition new methods Signed-off-by: unknown <aboulang2002@yahoo.com>master
parent
8d87b078c2
commit
47c58cf69d
|
@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
*
|
*
|
||||||
* A common interface that any training method should implement
|
* A common interface that any training method should implement
|
||||||
*/
|
*/
|
||||||
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> extends StepCountable {
|
public interface ILearning<O, A, AS extends ActionSpace<A>> extends StepCountable {
|
||||||
|
|
||||||
IPolicy<O, A> getPolicy();
|
IPolicy<O, A> getPolicy();
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
|
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
|
||||||
|
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
|
@ -53,8 +53,8 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
||||||
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
|
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <O extends Encodable, A, AS extends ActionSpace<A>> INDArray getInput(MDP<O, A, AS> mdp, O obs) {
|
public static <O, A, AS extends ActionSpace<A>> INDArray getInput(MDP<O, A, AS> mdp, O obs) {
|
||||||
INDArray arr = Nd4j.create(obs.toArray());
|
INDArray arr = Nd4j.create(((Encodable)obs).toArray());
|
||||||
int[] shape = mdp.getObservationSpace().getShape();
|
int[] shape = mdp.getObservationSpace().getShape();
|
||||||
if (shape.length == 1)
|
if (shape.length == 1)
|
||||||
return arr.reshape(new long[] {1, arr.length()});
|
return arr.reshape(new long[] {1, arr.length()});
|
||||||
|
@ -62,7 +62,7 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
||||||
return arr.reshape(shape);
|
return arr.reshape(shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <O extends Encodable, A, AS extends ActionSpace<A>> InitMdp<O> initMdp(MDP<O, A, AS> mdp,
|
public static <O, A, AS extends ActionSpace<A>> InitMdp<O> initMdp(MDP<O, A, AS> mdp,
|
||||||
IHistoryProcessor hp) {
|
IHistoryProcessor hp) {
|
||||||
|
|
||||||
O obs = mdp.reset();
|
O obs = mdp.reset();
|
||||||
|
@ -138,15 +138,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
||||||
this.historyProcessor = historyProcessor;
|
this.historyProcessor = historyProcessor;
|
||||||
}
|
}
|
||||||
|
|
||||||
public INDArray getInput(O obs) {
|
|
||||||
return getInput(getMdp(), obs);
|
|
||||||
}
|
|
||||||
|
|
||||||
public InitMdp<O> initMdp() {
|
|
||||||
getNeuralNet().reset();
|
|
||||||
return initMdp(getMdp(), getHistoryProcessor());
|
|
||||||
}
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Value
|
@Value
|
||||||
public static class InitMdp<O> {
|
public static class InitMdp<O> {
|
||||||
|
|
|
@ -36,7 +36,7 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
extends Learning<O, A, AS, NN> implements IEpochTrainer {
|
extends Learning<O, A, AS, NN> implements IEpochTrainer {
|
||||||
|
|
||||||
private final TrainingListenerList listeners = new TrainingListenerList();
|
private final TrainingListenerList listeners = new TrainingListenerList();
|
||||||
|
|
|
@ -16,27 +16,56 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync;
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16.
|
|
||||||
*
|
*
|
||||||
* A transition is a SARS tuple
|
* A transition is a SARS tuple
|
||||||
* State, Action, Reward, (isTerminal), State
|
* State, Action, Reward, (isTerminal), State
|
||||||
|
*
|
||||||
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16.
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
@Value
|
@Value
|
||||||
@AllArgsConstructor
|
|
||||||
public class Transition<A> {
|
public class Transition<A> {
|
||||||
|
|
||||||
INDArray[] observation;
|
Observation observation;
|
||||||
A action;
|
A action;
|
||||||
double reward;
|
double reward;
|
||||||
boolean isTerminal;
|
boolean isTerminal;
|
||||||
INDArray nextObservation;
|
INDArray nextObservation;
|
||||||
|
|
||||||
|
public Transition(Observation observation, A action, double reward, boolean isTerminal, Observation nextObservation) {
|
||||||
|
this.observation = observation;
|
||||||
|
this.action = action;
|
||||||
|
this.reward = reward;
|
||||||
|
this.isTerminal = isTerminal;
|
||||||
|
|
||||||
|
// To conserve memory, only the most recent frame of the next observation is kept (if history is used).
|
||||||
|
// The full nextObservation will be re-build from observation when needed.
|
||||||
|
long[] nextObservationShape = nextObservation.getData().shape().clone();
|
||||||
|
nextObservationShape[0] = 1;
|
||||||
|
this.nextObservation = nextObservation.getData()
|
||||||
|
.get(new INDArrayIndex[] {NDArrayIndex.point(0)})
|
||||||
|
.reshape(nextObservationShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Transition(Observation observation, A action, double reward, boolean isTerminal, INDArray nextObservation) {
|
||||||
|
this.observation = observation;
|
||||||
|
this.action = action;
|
||||||
|
this.reward = reward;
|
||||||
|
this.isTerminal = isTerminal;
|
||||||
|
this.nextObservation = nextObservation;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* concat an array history into a single INDArry of as many channel
|
* concat an array history into a single INDArry of as many channel
|
||||||
* as element in the history array
|
* as element in the history array
|
||||||
|
@ -53,36 +82,80 @@ public class Transition<A> {
|
||||||
* @return this transition duplicated
|
* @return this transition duplicated
|
||||||
*/
|
*/
|
||||||
public Transition<A> dup() {
|
public Transition<A> dup() {
|
||||||
INDArray[] dupObservation = dup(observation);
|
Observation dupObservation = observation.dup();
|
||||||
INDArray nextObs = nextObservation.dup();
|
INDArray nextObs = nextObservation.dup();
|
||||||
|
|
||||||
return new Transition<>(dupObservation, action, reward, isTerminal, nextObs);
|
return new Transition<A>(dupObservation, action, reward, isTerminal, nextObs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Duplicate an history
|
* Stack along the 0-dimension all the observations of the batch in a INDArray.
|
||||||
* @param history the history to duplicate
|
*
|
||||||
* @return a duplicate of the history
|
* @param transitions A list of the transitions of the batch
|
||||||
|
* @param <A> The type of the Action
|
||||||
|
* @return A INDArray of all of the batch's observations stacked along the 0-dimension.
|
||||||
*/
|
*/
|
||||||
public static INDArray[] dup(INDArray[] history) {
|
public static <A> INDArray buildStackedObservations(List<Transition<A>> transitions) {
|
||||||
INDArray[] dupHistory = new INDArray[history.length];
|
int size = transitions.size();
|
||||||
for (int i = 0; i < history.length; i++) {
|
long[] shape = getShape(transitions);
|
||||||
dupHistory[i] = history[i].dup();
|
|
||||||
|
INDArray[] array = new INDArray[size];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
array[i] = transitions.get(i).getObservation().getData();
|
||||||
}
|
}
|
||||||
return dupHistory;
|
|
||||||
|
return Nd4j.concat(0, array).reshape(shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* append a pixel frame to an history (throwing the last frame)
|
* Stack along the 0-dimension all the next observations of the batch in a INDArray.
|
||||||
* @param history the history on which to append
|
*
|
||||||
* @param append the pixel frame to append
|
* @param transitions A list of the transitions of the batch
|
||||||
* @return the appended history
|
* @param <A> The type of the Action
|
||||||
|
* @return A INDArray of all of the batch's next observations stacked along the 0-dimension.
|
||||||
*/
|
*/
|
||||||
public static INDArray[] append(INDArray[] history, INDArray append) {
|
public static <A> INDArray buildStackedNextObservations(List<Transition<A>> transitions) {
|
||||||
INDArray[] appended = new INDArray[history.length];
|
int size = transitions.size();
|
||||||
appended[0] = append;
|
long[] shape = getShape(transitions);
|
||||||
System.arraycopy(history, 0, appended, 1, history.length - 1);
|
|
||||||
return appended;
|
INDArray[] array = new INDArray[size];
|
||||||
|
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
Transition<A> trans = transitions.get(i);
|
||||||
|
INDArray obs = trans.getObservation().getData();
|
||||||
|
long historyLength = obs.shape()[0];
|
||||||
|
|
||||||
|
if(historyLength != 1) {
|
||||||
|
// To conserve memory, only the most recent frame of the next observation is kept (if history is used).
|
||||||
|
// We need to rebuild the frame-stack in addition to builing the batch-stack.
|
||||||
|
INDArray historyPart = obs.get(new INDArrayIndex[]{NDArrayIndex.interval(0, historyLength - 1)});
|
||||||
|
array[i] = Nd4j.concat(0, trans.getNextObservation(), historyPart);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
array[i] = trans.getNextObservation();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Nd4j.concat(0, array).reshape(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static <A> long[] getShape(List<Transition<A>> transitions) {
|
||||||
|
INDArray observations = transitions.get(0).getObservation().getData();
|
||||||
|
long[] observationShape = observations.shape();
|
||||||
|
long[] stackedShape;
|
||||||
|
if(observationShape[0] == 1) {
|
||||||
|
// FIXME: Currently RL4J doesn't support 1D observations. So if we have a shape with 1 in the first dimension, we can use that dimension and don't need to add another one.
|
||||||
|
stackedShape = new long[observationShape.length];
|
||||||
|
System.arraycopy(observationShape, 0, stackedShape, 0, observationShape.length);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
stackedShape = new long[observationShape.length + 1];
|
||||||
|
System.arraycopy(observationShape, 1, stackedShape, 2, observationShape.length - 1);
|
||||||
|
stackedShape[1] = observationShape[1];
|
||||||
|
}
|
||||||
|
stackedShape[0] = transitions.size();
|
||||||
|
|
||||||
|
return stackedShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,15 +21,20 @@ import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
|
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
|
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
||||||
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
|
@ -53,6 +58,8 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
@Setter(AccessLevel.PROTECTED)
|
@Setter(AccessLevel.PROTECTED)
|
||||||
protected IExpReplay<A> expReplay;
|
protected IExpReplay<A> expReplay;
|
||||||
|
|
||||||
|
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
|
||||||
|
|
||||||
public QLearning(QLConfiguration conf) {
|
public QLearning(QLConfiguration conf) {
|
||||||
this(conf, getSeededRandom(conf.getSeed()));
|
this(conf, getSeededRandom(conf.getSeed()));
|
||||||
}
|
}
|
||||||
|
@ -95,11 +102,11 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
|
|
||||||
protected abstract void postEpoch();
|
protected abstract void postEpoch();
|
||||||
|
|
||||||
protected abstract QLStepReturn<O> trainStep(O obs);
|
protected abstract QLStepReturn<Observation> trainStep(Observation obs);
|
||||||
|
|
||||||
protected StatEntry trainEpoch() {
|
protected StatEntry trainEpoch() {
|
||||||
InitMdp<O> initMdp = initMdp();
|
InitMdp<Observation> initMdp = refacInitMdp();
|
||||||
O obs = initMdp.getLastObs();
|
Observation obs = initMdp.getLastObs();
|
||||||
|
|
||||||
double reward = initMdp.getReward();
|
double reward = initMdp.getReward();
|
||||||
int step = initMdp.getSteps();
|
int step = initMdp.getSteps();
|
||||||
|
@ -114,7 +121,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
updateTargetNetwork();
|
updateTargetNetwork();
|
||||||
}
|
}
|
||||||
|
|
||||||
QLStepReturn<O> stepR = trainStep(obs);
|
QLStepReturn<Observation> stepR = trainStep(obs);
|
||||||
|
|
||||||
if (!stepR.getMaxQ().isNaN()) {
|
if (!stepR.getMaxQ().isNaN()) {
|
||||||
if (startQ.isNaN())
|
if (startQ.isNaN())
|
||||||
|
@ -142,6 +149,36 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private InitMdp<Observation> refacInitMdp() {
|
||||||
|
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
||||||
|
IHistoryProcessor hp = getHistoryProcessor();
|
||||||
|
|
||||||
|
Observation observation = mdp.reset();
|
||||||
|
|
||||||
|
int step = 0;
|
||||||
|
double reward = 0;
|
||||||
|
|
||||||
|
boolean isHistoryProcessor = hp != null;
|
||||||
|
|
||||||
|
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||||
|
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
||||||
|
|
||||||
|
while (step < requiredFrame && !mdp.isDone()) {
|
||||||
|
|
||||||
|
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||||
|
|
||||||
|
StepReply<Observation> stepReply = mdp.step(action);
|
||||||
|
reward += stepReply.getReward();
|
||||||
|
observation = stepReply.getObservation();
|
||||||
|
|
||||||
|
step++;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return new InitMdp(step, observation, reward);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Builder
|
@Builder
|
||||||
@Value
|
@Value
|
||||||
|
|
|
@ -26,10 +26,12 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
@ -51,8 +53,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final private QLConfiguration configuration;
|
final private QLConfiguration configuration;
|
||||||
@Getter
|
private final LegacyMDPWrapper<O, Integer, DiscreteSpace> mdp;
|
||||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
|
||||||
@Getter
|
@Getter
|
||||||
private DQNPolicy<O> policy;
|
private DQNPolicy<O> policy;
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -65,11 +66,14 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
private IDQN targetQNetwork;
|
private IDQN targetQNetwork;
|
||||||
|
|
||||||
private int lastAction;
|
private int lastAction;
|
||||||
private INDArray[] history = null;
|
|
||||||
private double accuReward = 0;
|
private double accuReward = 0;
|
||||||
|
|
||||||
ITDTargetAlgorithm tdTargetAlgorithm;
|
ITDTargetAlgorithm tdTargetAlgorithm;
|
||||||
|
|
||||||
|
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
|
||||||
|
return mdp;
|
||||||
|
}
|
||||||
|
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||||
int epsilonNbStep) {
|
int epsilonNbStep) {
|
||||||
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
|
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
|
||||||
|
@ -79,7 +83,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
int epsilonNbStep, Random random) {
|
int epsilonNbStep, Random random) {
|
||||||
super(conf);
|
super(conf);
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.mdp = mdp;
|
this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, this);
|
||||||
qNetwork = dqn;
|
qNetwork = dqn;
|
||||||
targetQNetwork = dqn.clone();
|
targetQNetwork = dqn.clone();
|
||||||
policy = new DQNPolicy(getQNetwork());
|
policy = new DQNPolicy(getQNetwork());
|
||||||
|
@ -92,6 +96,10 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MDP<O, Integer, DiscreteSpace> getMdp() {
|
||||||
|
return mdp.getWrappedMDP();
|
||||||
|
}
|
||||||
|
|
||||||
public void postEpoch() {
|
public void postEpoch() {
|
||||||
|
|
||||||
if (getHistoryProcessor() != null)
|
if (getHistoryProcessor() != null)
|
||||||
|
@ -100,7 +108,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
}
|
}
|
||||||
|
|
||||||
public void preEpoch() {
|
public void preEpoch() {
|
||||||
history = null;
|
|
||||||
lastAction = 0;
|
lastAction = 0;
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
}
|
}
|
||||||
|
@ -110,10 +117,9 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
* @param obs last obs
|
* @param obs last obs
|
||||||
* @return relevant info for next step
|
* @return relevant info for next step
|
||||||
*/
|
*/
|
||||||
protected QLStepReturn<O> trainStep(O obs) {
|
protected QLStepReturn<Observation> trainStep(Observation obs) {
|
||||||
|
|
||||||
Integer action;
|
Integer action;
|
||||||
INDArray input = getInput(obs);
|
|
||||||
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,50 +134,25 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
if (getStepCounter() % skipFrame != 0) {
|
if (getStepCounter() % skipFrame != 0) {
|
||||||
action = lastAction;
|
action = lastAction;
|
||||||
} else {
|
} else {
|
||||||
if (history == null) {
|
INDArray qs = getQNetwork().output(obs);
|
||||||
if (isHistoryProcessor) {
|
|
||||||
getHistoryProcessor().add(input);
|
|
||||||
history = getHistoryProcessor().getHistory();
|
|
||||||
} else
|
|
||||||
history = new INDArray[] {input};
|
|
||||||
}
|
|
||||||
//concat the history into a single INDArray input
|
|
||||||
INDArray hstack = Transition.concat(Transition.dup(history));
|
|
||||||
if (isHistoryProcessor) {
|
|
||||||
hstack.muli(1.0 / getHistoryProcessor().getScale());
|
|
||||||
}
|
|
||||||
|
|
||||||
//if input is not 2d, you have to append that the batch is 1 length high
|
|
||||||
if (hstack.shape().length > 2)
|
|
||||||
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape())));
|
|
||||||
|
|
||||||
INDArray qs = getQNetwork().output(hstack);
|
|
||||||
int maxAction = Learning.getMaxAction(qs);
|
int maxAction = Learning.getMaxAction(qs);
|
||||||
|
|
||||||
maxQ = qs.getDouble(maxAction);
|
maxQ = qs.getDouble(maxAction);
|
||||||
action = getEgPolicy().nextAction(hstack);
|
|
||||||
|
action = getEgPolicy().nextAction(obs);
|
||||||
}
|
}
|
||||||
|
|
||||||
lastAction = action;
|
lastAction = action;
|
||||||
|
|
||||||
StepReply<O> stepReply = getMdp().step(action);
|
StepReply<Observation> stepReply = mdp.step(action);
|
||||||
|
|
||||||
INDArray ninput = getInput(stepReply.getObservation());
|
Observation nextObservation = stepReply.getObservation();
|
||||||
|
|
||||||
if (isHistoryProcessor)
|
|
||||||
getHistoryProcessor().record(ninput);
|
|
||||||
|
|
||||||
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
||||||
|
|
||||||
//if it's not a skipped frame, you can do a step of training
|
//if it's not a skipped frame, you can do a step of training
|
||||||
if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
|
if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
|
||||||
|
|
||||||
if (isHistoryProcessor)
|
Transition<Integer> trans = new Transition(obs, action, accuReward, stepReply.isDone(), nextObservation);
|
||||||
getHistoryProcessor().add(ninput);
|
|
||||||
|
|
||||||
INDArray[] nhistory = isHistoryProcessor ? getHistoryProcessor().getHistory() : new INDArray[] {ninput};
|
|
||||||
|
|
||||||
Transition<Integer> trans = new Transition(history, action, accuReward, stepReply.isDone(), nhistory[0]);
|
|
||||||
getExpReplay().store(trans);
|
getExpReplay().store(trans);
|
||||||
|
|
||||||
if (getStepCounter() > updateStart) {
|
if (getStepCounter() > updateStart) {
|
||||||
|
@ -179,27 +160,16 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
||||||
}
|
}
|
||||||
|
|
||||||
history = nhistory;
|
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return new QLStepReturn<O>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
|
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
|
||||||
if (transitions.size() == 0)
|
if (transitions.size() == 0)
|
||||||
throw new IllegalArgumentException("too few transitions");
|
throw new IllegalArgumentException("too few transitions");
|
||||||
|
|
||||||
// TODO: Remove once we use DataSets in observations
|
|
||||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
|
|
||||||
: getHistoryProcessor().getConf().getShape();
|
|
||||||
((BaseTDTargetAlgorithm) tdTargetAlgorithm).setNShape(makeShape(transitions.size(), shape));
|
|
||||||
|
|
||||||
// TODO: Remove once we use DataSets in observations
|
|
||||||
if(getHistoryProcessor() != null) {
|
|
||||||
((BaseTDTargetAlgorithm) tdTargetAlgorithm).setScale(getHistoryProcessor().getScale());
|
|
||||||
}
|
|
||||||
|
|
||||||
return tdTargetAlgorithm.computeTDTargets(transitions);
|
return tdTargetAlgorithm.computeTDTargets(transitions);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,14 +16,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -40,11 +36,6 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
||||||
private final double errorClamp;
|
private final double errorClamp;
|
||||||
private final boolean isClamped;
|
private final boolean isClamped;
|
||||||
|
|
||||||
@Setter
|
|
||||||
private int[] nShape; // TODO: Remove once we use DataSets in observations
|
|
||||||
@Setter
|
|
||||||
private double scale = 1.0; // TODO: Remove once we use DataSets in observations
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param qNetworkSource The source of the Q-Network
|
* @param qNetworkSource The source of the Q-Network
|
||||||
|
@ -93,37 +84,8 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
||||||
|
|
||||||
int size = transitions.size();
|
int size = transitions.size();
|
||||||
|
|
||||||
INDArray observations = Nd4j.create(nShape);
|
INDArray observations = Transition.buildStackedObservations(transitions);
|
||||||
INDArray nextObservations = Nd4j.create(nShape);
|
INDArray nextObservations = Transition.buildStackedNextObservations(transitions);
|
||||||
|
|
||||||
// TODO: Remove once we use DataSets in observations
|
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
Transition<Integer> trans = transitions.get(i);
|
|
||||||
|
|
||||||
INDArray[] obsArray = trans.getObservation();
|
|
||||||
if (observations.rank() == 2) {
|
|
||||||
observations.putRow(i, obsArray[0]);
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < obsArray.length; j++) {
|
|
||||||
observations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, obsArray[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation());
|
|
||||||
if (nextObservations.rank() == 2) {
|
|
||||||
nextObservations.putRow(i, nextObsArray[0]);
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < nextObsArray.length; j++) {
|
|
||||||
nextObservations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, nextObsArray[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Remove once we use DataSets in observations
|
|
||||||
if(scale != 1.0) {
|
|
||||||
observations.muli(1.0 / scale);
|
|
||||||
nextObservations.muli(1.0 / scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
initComputation(observations, nextObservations);
|
initComputation(observations, nextObservations);
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
@ -70,6 +71,10 @@ public class DQN<NN extends DQN> implements IDQN<NN> {
|
||||||
return mln.output(batch);
|
return mln.output(batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public INDArray output(Observation observation) {
|
||||||
|
return this.output(observation.getData());
|
||||||
|
}
|
||||||
|
|
||||||
public INDArray[] outputAll(INDArray batch) {
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
return new INDArray[] {output(batch)};
|
return new INDArray[] {output(batch)};
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.network.dqn;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -37,6 +38,7 @@ public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
|
||||||
void fit(INDArray input, INDArray[] labels);
|
void fit(INDArray input, INDArray[] labels);
|
||||||
|
|
||||||
INDArray output(INDArray batch);
|
INDArray output(INDArray batch);
|
||||||
|
INDArray output(Observation observation);
|
||||||
|
|
||||||
INDArray[] outputAll(INDArray batch);
|
INDArray[] outputAll(INDArray batch);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.observation;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Presently only a dummy container. Will contain observation channels when done.
|
||||||
|
*/
|
||||||
|
public class Observation {
|
||||||
|
// TODO: Presently only a dummy container. Will contain observation channels when done.
|
||||||
|
|
||||||
|
private final DataSet data;
|
||||||
|
|
||||||
|
public Observation(INDArray[] data) {
|
||||||
|
this(new org.nd4j.linalg.dataset.DataSet(Nd4j.concat(0, data), null));
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME: Remove -- only used in unit tests
|
||||||
|
public Observation(INDArray data) {
|
||||||
|
this.data = new org.nd4j.linalg.dataset.DataSet(data, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Observation(DataSet data) {
|
||||||
|
this.data = data;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Observation dup() {
|
||||||
|
return new Observation(new org.nd4j.linalg.dataset.DataSet(data.getFeatures().dup(), null));
|
||||||
|
}
|
||||||
|
|
||||||
|
public INDArray getData() {
|
||||||
|
return data.getFeatures();
|
||||||
|
}
|
||||||
|
}
|
|
@ -21,8 +21,8 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.rl4j.learning.StepCountable;
|
import org.deeplearning4j.rl4j.learning.StepCountable;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ import org.nd4j.linalg.api.rng.Random;
|
||||||
*/
|
*/
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
||||||
|
|
||||||
final private Policy<O, A> policy;
|
final private Policy<O, A> policy;
|
||||||
final private MDP<O, A, AS> mdp;
|
final private MDP<O, A, AS> mdp;
|
||||||
|
@ -61,8 +61,10 @@ public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extend
|
||||||
return policy.nextAction(input);
|
return policy.nextAction(input);
|
||||||
else
|
else
|
||||||
return mdp.getActionSpace().randomAction();
|
return mdp.getActionSpace().randomAction();
|
||||||
|
}
|
||||||
|
|
||||||
|
public A nextAction(Observation observation) {
|
||||||
|
return this.nextAction(observation.getData());
|
||||||
}
|
}
|
||||||
|
|
||||||
public float getEpsilon() {
|
public float getEpsilon() {
|
||||||
|
|
|
@ -6,7 +6,7 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public interface IPolicy<O extends Encodable, A> {
|
public interface IPolicy<O, A> {
|
||||||
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
||||||
A nextAction(INDArray input);
|
A nextAction(INDArray input);
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ import org.nd4j.linalg.util.ArrayUtil;
|
||||||
*
|
*
|
||||||
* A Policy responsability is to choose the next action given a state
|
* A Policy responsability is to choose the next action given a state
|
||||||
*/
|
*/
|
||||||
public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
public abstract class Policy<O, A> implements IPolicy<O, A> {
|
||||||
|
|
||||||
public abstract NeuralNet getNeuralNet();
|
public abstract NeuralNet getNeuralNet();
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,139 @@
|
||||||
|
package org.deeplearning4j.rl4j.util;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.deeplearning4j.gym.StepReply;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final MDP<O, A, AS> wrappedMDP;
|
||||||
|
@Getter
|
||||||
|
private final WrapperObservationSpace observationSpace;
|
||||||
|
private final ILearning learning;
|
||||||
|
private int skipFrame;
|
||||||
|
|
||||||
|
private int step = 0;
|
||||||
|
|
||||||
|
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning) {
|
||||||
|
this.wrappedMDP = wrappedMDP;
|
||||||
|
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
|
||||||
|
this.learning = learning;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AS getActionSpace() {
|
||||||
|
return wrappedMDP.getActionSpace();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Observation reset() {
|
||||||
|
INDArray rawObservation = getInput(wrappedMDP.reset());
|
||||||
|
|
||||||
|
IHistoryProcessor historyProcessor = learning.getHistoryProcessor();
|
||||||
|
if(historyProcessor != null) {
|
||||||
|
historyProcessor.record(rawObservation.dup());
|
||||||
|
rawObservation.muli(1.0 / historyProcessor.getScale());
|
||||||
|
}
|
||||||
|
|
||||||
|
Observation observation = new Observation(new INDArray[] { rawObservation });
|
||||||
|
|
||||||
|
if(historyProcessor != null) {
|
||||||
|
skipFrame = historyProcessor.getConf().getSkipFrame();
|
||||||
|
historyProcessor.add(rawObservation);
|
||||||
|
}
|
||||||
|
step = 0;
|
||||||
|
|
||||||
|
return observation;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
wrappedMDP.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public StepReply<Observation> step(A a) {
|
||||||
|
IHistoryProcessor historyProcessor = learning.getHistoryProcessor();
|
||||||
|
|
||||||
|
StepReply<O> rawStepReply = wrappedMDP.step(a);
|
||||||
|
INDArray rawObservation = getInput(rawStepReply.getObservation());
|
||||||
|
|
||||||
|
++step;
|
||||||
|
|
||||||
|
int requiredFrame = 0;
|
||||||
|
if(historyProcessor != null) {
|
||||||
|
historyProcessor.record(rawObservation.dup());
|
||||||
|
rawObservation.muli(1.0 / historyProcessor.getScale());
|
||||||
|
|
||||||
|
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
|
||||||
|
if ((learning.getStepCounter() % skipFrame == 0 && step >= requiredFrame)
|
||||||
|
|| (step % skipFrame == 0 && step < requiredFrame )){
|
||||||
|
historyProcessor.add(rawObservation);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Observation observation;
|
||||||
|
if(historyProcessor != null && step >= requiredFrame) {
|
||||||
|
observation = new Observation(historyProcessor.getHistory());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
observation = new Observation(new INDArray[] { rawObservation });
|
||||||
|
}
|
||||||
|
|
||||||
|
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isDone() {
|
||||||
|
return wrappedMDP.isDone();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MDP<Observation, A, AS> newInstance() {
|
||||||
|
return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), learning);
|
||||||
|
}
|
||||||
|
|
||||||
|
private INDArray getInput(O obs) {
|
||||||
|
INDArray arr = Nd4j.create(obs.toArray());
|
||||||
|
int[] shape = observationSpace.getShape();
|
||||||
|
if (shape.length == 1)
|
||||||
|
return arr.reshape(new long[] {1, arr.length()});
|
||||||
|
else
|
||||||
|
return arr.reshape(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class WrapperObservationSpace implements ObservationSpace<Observation> {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final int[] shape;
|
||||||
|
|
||||||
|
public WrapperObservationSpace(int[] shape) {
|
||||||
|
|
||||||
|
this.shape = shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getLow() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getHigh() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -50,7 +50,7 @@ public class AsyncThreadDiscreteTest {
|
||||||
assertEquals(1, asyncGlobalMock.enqueueCallCount);
|
assertEquals(1, asyncGlobalMock.enqueueCallCount);
|
||||||
|
|
||||||
// HistoryProcessor
|
// HistoryProcessor
|
||||||
assertEquals(10, hpMock.addCallCount);
|
assertEquals(10, hpMock.addCalls.size());
|
||||||
double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 };
|
double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 };
|
||||||
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
|
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
|
||||||
for(int i = 0; i < expectedRecordValues.length; ++i) {
|
for(int i = 0; i < expectedRecordValues.length; ++i) {
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync;
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.support.MockRandom;
|
import org.deeplearning4j.rl4j.support.MockRandom;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -17,7 +18,8 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 123, 234, false, Nd4j.create(1));
|
Transition<Integer> transition = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
123, 234, false, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition);
|
sut.store(transition);
|
||||||
List<Transition<Integer>> results = sut.getBatch(1);
|
List<Transition<Integer>> results = sut.getBatch(1);
|
||||||
|
|
||||||
|
@ -34,9 +36,12 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
1, 2, false, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
3, 4, false, new Observation(Nd4j.create(1)));
|
||||||
|
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
5, 6, false, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition1);
|
sut.store(transition1);
|
||||||
sut.store(transition2);
|
sut.store(transition2);
|
||||||
sut.store(transition3);
|
sut.store(transition3);
|
||||||
|
@ -73,9 +78,12 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
1, 2, false, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
3, 4, false, new Observation(Nd4j.create(1)));
|
||||||
|
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
5, 6, false, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition1);
|
sut.store(transition1);
|
||||||
sut.store(transition2);
|
sut.store(transition2);
|
||||||
sut.store(transition3);
|
sut.store(transition3);
|
||||||
|
@ -92,9 +100,12 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
1, 2, false, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
3, 4, false, new Observation(Nd4j.create(1)));
|
||||||
|
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
5, 6, false, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition1);
|
sut.store(transition1);
|
||||||
sut.store(transition2);
|
sut.store(transition2);
|
||||||
sut.store(transition3);
|
sut.store(transition3);
|
||||||
|
@ -120,11 +131,16 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
1, 2, false, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1));
|
3, 4, false, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1));
|
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
5, 6, false, new Observation(Nd4j.create(1)));
|
||||||
|
Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
7, 8, false, new Observation(Nd4j.create(1)));
|
||||||
|
Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
9, 10, false, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition1);
|
sut.store(transition1);
|
||||||
sut.store(transition2);
|
sut.store(transition2);
|
||||||
sut.store(transition3);
|
sut.store(transition3);
|
||||||
|
@ -152,11 +168,16 @@ public class ExpReplayTest {
|
||||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
1, 2, false, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1));
|
3, 4, false, new Observation(Nd4j.create(1)));
|
||||||
Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1));
|
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
5, 6, false, new Observation(Nd4j.create(1)));
|
||||||
|
Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
7, 8, false, new Observation(Nd4j.create(1)));
|
||||||
|
Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||||
|
9, 10, false, new Observation(Nd4j.create(1)));
|
||||||
sut.store(transition1);
|
sut.store(transition1);
|
||||||
sut.store(transition2);
|
sut.store(transition2);
|
||||||
sut.store(transition3);
|
sut.store(transition3);
|
||||||
|
|
|
@ -0,0 +1,255 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class TransitionTest {
|
||||||
|
@Test
|
||||||
|
public void when_callingCtorWithoutHistory_expect_2DObservationAndNextObservation() {
|
||||||
|
// Arrange
|
||||||
|
double[] obs = new double[] { 1.0, 2.0, 3.0 };
|
||||||
|
Observation observation = buildObservation(obs);
|
||||||
|
|
||||||
|
double[] nextObs = new double[] { 10.0, 20.0, 30.0 };
|
||||||
|
Observation nextObservation = buildObservation(nextObs);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
double[][] expectedObservation = new double[][] { obs };
|
||||||
|
assertExpected(expectedObservation, transition.getObservation().getData());
|
||||||
|
|
||||||
|
double[][] expectedNextObservation = new double[][] { nextObs };
|
||||||
|
assertExpected(expectedNextObservation, transition.getNextObservation());
|
||||||
|
|
||||||
|
assertEquals(123, transition.getAction());
|
||||||
|
assertEquals(234.0, transition.getReward(), 0.0001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingCtorWithHistory_expect_ObservationWithHistoryAndNextObservationWithout() {
|
||||||
|
// Arrange
|
||||||
|
double[][] obs = new double[][] {
|
||||||
|
{ 0.0, 1.0, 2.0 },
|
||||||
|
{ 3.0, 4.0, 5.0 },
|
||||||
|
{ 6.0, 7.0, 8.0 },
|
||||||
|
};
|
||||||
|
Observation observation = buildObservation(obs);
|
||||||
|
|
||||||
|
double[][] nextObs = new double[][] {
|
||||||
|
{ 10.0, 11.0, 12.0 },
|
||||||
|
{ 0.0, 1.0, 2.0 },
|
||||||
|
{ 3.0, 4.0, 5.0 },
|
||||||
|
};
|
||||||
|
Observation nextObservation = buildObservation(nextObs);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertExpected(obs, transition.getObservation().getData());
|
||||||
|
|
||||||
|
assertExpected(nextObs[0], transition.getNextObservation());
|
||||||
|
|
||||||
|
assertEquals(123, transition.getAction());
|
||||||
|
assertEquals(234.0, transition.getReward(), 0.0001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_CallingBuildStackedObservationsAndShapeRankIs2_expect_2DResultWithObservationsStackedOnDimension0() {
|
||||||
|
// Arrange
|
||||||
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>();
|
||||||
|
|
||||||
|
double[] obs1 = new double[] { 0.0, 1.0, 2.0 };
|
||||||
|
Observation observation1 = buildObservation(obs1);
|
||||||
|
Observation nextObservation1 = buildObservation(new double[] { 100.0, 101.0, 102.0 });
|
||||||
|
transitions.add(new Transition(observation1,0, 0.0, false, nextObservation1));
|
||||||
|
|
||||||
|
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
|
||||||
|
Observation observation2 = buildObservation(obs2);
|
||||||
|
Observation nextObservation2 = buildObservation(new double[] { 110.0, 111.0, 112.0 });
|
||||||
|
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
INDArray result = Transition.buildStackedObservations(transitions);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
double[][] expected = new double[][] { obs1, obs2 };
|
||||||
|
assertExpected(expected, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_CallingBuildStackedObservationsAndShapeRankIsGreaterThan2_expect_ResultWithOneMoreDimensionAndObservationsStackedOnDimension0() {
|
||||||
|
// Arrange
|
||||||
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>();
|
||||||
|
|
||||||
|
double[][] obs1 = new double[][] {
|
||||||
|
{ 0.0, 1.0, 2.0 },
|
||||||
|
{ 3.0, 4.0, 5.0 },
|
||||||
|
{ 6.0, 7.0, 8.0 },
|
||||||
|
};
|
||||||
|
Observation observation1 = buildObservation(obs1);
|
||||||
|
|
||||||
|
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
||||||
|
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
|
||||||
|
|
||||||
|
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
|
||||||
|
|
||||||
|
double[][] obs2 = new double[][] {
|
||||||
|
{ 10.0, 11.0, 12.0 },
|
||||||
|
{ 13.0, 14.0, 15.0 },
|
||||||
|
{ 16.0, 17.0, 18.0 },
|
||||||
|
};
|
||||||
|
Observation observation2 = buildObservation(obs2);
|
||||||
|
|
||||||
|
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
||||||
|
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
|
||||||
|
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
INDArray result = Transition.buildStackedObservations(transitions);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
double[][][] expected = new double[][][] { obs1, obs2 };
|
||||||
|
assertExpected(expected, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_CallingBuildStackedNextObservationsAndShapeRankIs2_expect_2DResultWithObservationsStackedOnDimension0() {
|
||||||
|
// Arrange
|
||||||
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>();
|
||||||
|
|
||||||
|
double[] obs1 = new double[] { 0.0, 1.0, 2.0 };
|
||||||
|
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
||||||
|
Observation observation1 = buildObservation(obs1);
|
||||||
|
Observation nextObservation1 = buildObservation(nextObs1);
|
||||||
|
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
|
||||||
|
|
||||||
|
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
|
||||||
|
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
||||||
|
Observation observation2 = buildObservation(obs2);
|
||||||
|
Observation nextObservation2 = buildObservation(nextObs2);
|
||||||
|
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
INDArray result = Transition.buildStackedNextObservations(transitions);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
double[][] expected = new double[][] { nextObs1, nextObs2 };
|
||||||
|
assertExpected(expected, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_CallingBuildStackedNextObservationsAndShapeRankIsGreaterThan2_expect_ResultWithOneMoreDimensionAndObservationsStackedOnDimension0() {
|
||||||
|
// Arrange
|
||||||
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>();
|
||||||
|
|
||||||
|
double[][] obs1 = new double[][] {
|
||||||
|
{ 0.0, 1.0, 2.0 },
|
||||||
|
{ 3.0, 4.0, 5.0 },
|
||||||
|
{ 6.0, 7.0, 8.0 },
|
||||||
|
};
|
||||||
|
Observation observation1 = buildObservation(obs1);
|
||||||
|
|
||||||
|
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
||||||
|
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
|
||||||
|
|
||||||
|
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
|
||||||
|
|
||||||
|
double[][] obs2 = new double[][] {
|
||||||
|
{ 10.0, 11.0, 12.0 },
|
||||||
|
{ 13.0, 14.0, 15.0 },
|
||||||
|
{ 16.0, 17.0, 18.0 },
|
||||||
|
};
|
||||||
|
Observation observation2 = buildObservation(obs2);
|
||||||
|
|
||||||
|
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
||||||
|
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
|
||||||
|
|
||||||
|
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
INDArray result = Transition.buildStackedNextObservations(transitions);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
double[][][] expected = new double[][][] {
|
||||||
|
new double[][] { nextObs1, obs1[0], obs1[1] },
|
||||||
|
new double[][] { nextObs2, obs2[0], obs2[1] }
|
||||||
|
};
|
||||||
|
assertExpected(expected, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Observation buildObservation(double[][] obs) {
|
||||||
|
INDArray[] history = new INDArray[] {
|
||||||
|
Nd4j.create(obs[0]).reshape(1, 3),
|
||||||
|
Nd4j.create(obs[1]).reshape(1, 3),
|
||||||
|
Nd4j.create(obs[2]).reshape(1, 3),
|
||||||
|
};
|
||||||
|
return new Observation(history);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Observation buildObservation(double[] obs) {
|
||||||
|
return new Observation(new INDArray[] { Nd4j.create(obs).reshape(1, 3) });
|
||||||
|
}
|
||||||
|
|
||||||
|
private Observation buildNextObservation(double[][] obs, double[] nextObs) {
|
||||||
|
INDArray[] nextHistory = new INDArray[] {
|
||||||
|
Nd4j.create(nextObs).reshape(1, 3),
|
||||||
|
Nd4j.create(obs[0]).reshape(1, 3),
|
||||||
|
Nd4j.create(obs[1]).reshape(1, 3),
|
||||||
|
};
|
||||||
|
return new Observation(nextHistory);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertExpected(double[] expected, INDArray actual) {
|
||||||
|
long[] shape = actual.shape();
|
||||||
|
assertEquals(2, shape.length);
|
||||||
|
assertEquals(1, shape[0]);
|
||||||
|
assertEquals(expected.length, shape[1]);
|
||||||
|
for(int i = 0; i < expected.length; ++i) {
|
||||||
|
assertEquals(expected[i], actual.getDouble(0, i), 0.0001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertExpected(double[][] expected, INDArray actual) {
|
||||||
|
long[] shape = actual.shape();
|
||||||
|
assertEquals(2, shape.length);
|
||||||
|
assertEquals(expected.length, shape[0]);
|
||||||
|
assertEquals(expected[0].length, shape[1]);
|
||||||
|
|
||||||
|
for(int i = 0; i < expected.length; ++i) {
|
||||||
|
double[] expectedLine = expected[i];
|
||||||
|
for(int j = 0; j < expectedLine.length; ++j) {
|
||||||
|
assertEquals(expectedLine[j], actual.getDouble(i, j), 0.0001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertExpected(double[][][] expected, INDArray actual) {
|
||||||
|
long[] shape = actual.shape();
|
||||||
|
assertEquals(3, shape.length);
|
||||||
|
assertEquals(expected.length, shape[0]);
|
||||||
|
assertEquals(expected[0].length, shape[1]);
|
||||||
|
assertEquals(expected[0][0].length, shape[2]);
|
||||||
|
|
||||||
|
for(int i = 0; i < expected.length; ++i) {
|
||||||
|
double[][] expected2D = expected[i];
|
||||||
|
for(int j = 0; j < expected2D.length; ++j) {
|
||||||
|
double[] expectedLine = expected2D[j];
|
||||||
|
for (int k = 0; k < expectedLine.length; ++k) {
|
||||||
|
assertEquals(expectedLine[k], actual.getDouble(i, j, k), 0.0001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -40,7 +40,7 @@ public class QLearningDiscreteTest {
|
||||||
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
||||||
MockMDP mdp = new MockMDP(observationSpace, random);
|
MockMDP mdp = new MockMDP(observationSpace, random);
|
||||||
|
|
||||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000,
|
||||||
0, 1.0, 0, 0, 0, 0, true);
|
0, 1.0, 0, 0, 0, 0, true);
|
||||||
MockDataManager dataManager = new MockDataManager(false);
|
MockDataManager dataManager = new MockDataManager(false);
|
||||||
MockExpReplay expReplay = new MockExpReplay();
|
MockExpReplay expReplay = new MockExpReplay();
|
||||||
|
@ -48,15 +48,10 @@ public class QLearningDiscreteTest {
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||||
sut.setHistoryProcessor(hp);
|
sut.setHistoryProcessor(hp);
|
||||||
MockEncodable obs = new MockEncodable(-100);
|
|
||||||
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.initMdp();
|
IDataManager.StatEntry result = sut.trainEpoch();
|
||||||
for(int step = 0; step < 16; ++step) {
|
|
||||||
results.add(sut.trainStep(obs));
|
|
||||||
sut.incrementStep();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
// HistoryProcessor calls
|
// HistoryProcessor calls
|
||||||
|
@ -65,7 +60,11 @@ public class QLearningDiscreteTest {
|
||||||
for(int i = 0; i < expectedRecords.length; ++i) {
|
for(int i = 0; i < expectedRecords.length; ++i) {
|
||||||
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||||
}
|
}
|
||||||
assertEquals(13, hp.addCallCount);
|
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
|
||||||
|
assertEquals(expectedAdds.length, hp.addCalls.size());
|
||||||
|
for(int i = 0; i < expectedAdds.length; ++i) {
|
||||||
|
assertEquals(expectedAdds[i], 255.0 * hp.addCalls.get(i).getDouble(0), 0.0001);
|
||||||
|
}
|
||||||
assertEquals(0, hp.startMonitorCallCount);
|
assertEquals(0, hp.startMonitorCallCount);
|
||||||
assertEquals(0, hp.stopMonitorCallCount);
|
assertEquals(0, hp.stopMonitorCallCount);
|
||||||
|
|
||||||
|
@ -75,14 +74,14 @@ public class QLearningDiscreteTest {
|
||||||
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
|
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
|
||||||
assertEquals(14, dqn.outputParams.size());
|
assertEquals(14, dqn.outputParams.size());
|
||||||
double[][] expectedDQNOutput = new double[][] {
|
double[][] expectedDQNOutput = new double[][] {
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||||
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||||
new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 },
|
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||||
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||||
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||||
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
|
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
||||||
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
|
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
||||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||||
|
@ -108,13 +107,13 @@ public class QLearningDiscreteTest {
|
||||||
// ExpReplay calls
|
// ExpReplay calls
|
||||||
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
|
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
|
||||||
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
||||||
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, -100.0, 9.0, 11.0, 13.0, 15.0 };
|
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
|
||||||
double[][] expectedTrObservations = new double[][] {
|
double[][] expectedTrObservations = new double[][] {
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, -100.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
new double[] { 2.0, 4.0, 6.0, -100.0, 9.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||||
new double[] { 4.0, 6.0, -100.0, 9.0, 11.0 },
|
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||||
new double[] { 6.0, -100.0, 9.0, 11.0, 13.0 },
|
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||||
new double[] { -100.0, 9.0, 11.0, 13.0, 15.0 },
|
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
||||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||||
|
@ -123,26 +122,15 @@ public class QLearningDiscreteTest {
|
||||||
Transition tr = expReplay.transitions.get(i);
|
Transition tr = expReplay.transitions.get(i);
|
||||||
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
|
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
|
||||||
assertEquals(expectedTrActions[i], tr.getAction());
|
assertEquals(expectedTrActions[i], tr.getAction());
|
||||||
assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001);
|
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
|
||||||
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
||||||
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001);
|
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(j, 0), 0.0001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// trainStep results
|
// trainEpoch result
|
||||||
assertEquals(16, results.size());
|
assertEquals(16, result.getStepCounter());
|
||||||
double[] expectedMaxQ = new double[] { 6.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 };
|
assertEquals(300.0, result.getReward(), 0.00001);
|
||||||
double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
|
|
||||||
for(int i=0; i < 16; ++i) {
|
|
||||||
QLearning.QLStepReturn<MockEncodable> result = results.get(i);
|
|
||||||
if(i % 2 == 0) {
|
|
||||||
assertEquals(expectedMaxQ[i/2], 255.0 * result.getMaxQ(), 0.001);
|
|
||||||
assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
assertTrue(result.getMaxQ().isNaN());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
||||||
|
@ -163,5 +151,9 @@ public class QLearningDiscreteTest {
|
||||||
this.expReplay = exp;
|
this.expReplay = exp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IDataManager.StatEntry trainEpoch() {
|
||||||
|
return super.trainEpoch();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorit
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
@ -25,12 +26,12 @@ public class DoubleDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0})));
|
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||||
|
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||||
sut.setNShape(new int[] { 1, 2 });
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -51,12 +52,12 @@ public class DoubleDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
|
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||||
|
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||||
sut.setNShape(new int[] { 1, 2 });
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -77,14 +78,16 @@ public class DoubleDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
|
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||||
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0})));
|
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||||
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0})));
|
add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}),
|
||||||
|
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
|
||||||
|
add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}),
|
||||||
|
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||||
sut.setNShape(new int[] { 3, 2 });
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -102,4 +105,7 @@ public class DoubleDQNTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Observation buildObservation(double[] data) {
|
||||||
|
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorit
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
@ -24,12 +25,12 @@ public class StandardDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0})));
|
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||||
|
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||||
sut.setNShape(new int[] { 1, 2 });
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -50,12 +51,12 @@ public class StandardDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
|
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||||
|
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||||
sut.setNShape(new int[] { 1, 2 });
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -76,14 +77,16 @@ public class StandardDQNTest {
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
|
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||||
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0})));
|
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||||
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0})));
|
add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}),
|
||||||
|
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
|
||||||
|
add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}),
|
||||||
|
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||||
sut.setNShape(new int[] { 3, 2 });
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -101,4 +104,8 @@ public class StandardDQNTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Observation buildObservation(double[] data) {
|
||||||
|
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.support;
|
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||||
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
@ -57,6 +56,11 @@ public class MockDQN implements IDQN {
|
||||||
return batch;
|
return batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(Observation observation) {
|
||||||
|
return this.output(observation.getData());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] outputAll(INDArray batch) {
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
return new INDArray[0];
|
return new INDArray[0];
|
||||||
|
|
|
@ -197,7 +197,7 @@ public class PolicyTest {
|
||||||
assertEquals(465.0, totalReward, 0.0001);
|
assertEquals(465.0, totalReward, 0.0001);
|
||||||
|
|
||||||
// HistoryProcessor
|
// HistoryProcessor
|
||||||
assertEquals(27, hp.addCallCount);
|
assertEquals(27, hp.addCalls.size());
|
||||||
assertEquals(31, hp.recordCalls.size());
|
assertEquals(31, hp.recordCalls.size());
|
||||||
for(int i=0; i <= 30; ++i) {
|
for(int i=0; i <= 30; ++i) {
|
||||||
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
|
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||||
|
|
|
@ -4,6 +4,7 @@ import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
@ -48,6 +49,11 @@ public class MockDQN implements IDQN {
|
||||||
return batch;
|
return batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(Observation observation) {
|
||||||
|
return this.output(observation.getData());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] outputAll(INDArray batch) {
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
return new INDArray[0];
|
return new INDArray[0];
|
||||||
|
|
|
@ -2,6 +2,7 @@ package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -9,7 +10,6 @@ import java.util.ArrayList;
|
||||||
|
|
||||||
public class MockHistoryProcessor implements IHistoryProcessor {
|
public class MockHistoryProcessor implements IHistoryProcessor {
|
||||||
|
|
||||||
public int addCallCount = 0;
|
|
||||||
public int startMonitorCallCount = 0;
|
public int startMonitorCallCount = 0;
|
||||||
public int stopMonitorCallCount = 0;
|
public int stopMonitorCallCount = 0;
|
||||||
|
|
||||||
|
@ -17,12 +17,14 @@ public class MockHistoryProcessor implements IHistoryProcessor {
|
||||||
private final CircularFifoQueue<INDArray> history;
|
private final CircularFifoQueue<INDArray> history;
|
||||||
|
|
||||||
public final ArrayList<INDArray> recordCalls;
|
public final ArrayList<INDArray> recordCalls;
|
||||||
|
public final ArrayList<INDArray> addCalls;
|
||||||
|
|
||||||
public MockHistoryProcessor(Configuration config) {
|
public MockHistoryProcessor(Configuration config) {
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
history = new CircularFifoQueue<>(config.getHistoryLength());
|
history = new CircularFifoQueue<>(config.getHistoryLength());
|
||||||
recordCalls = new ArrayList<INDArray>();
|
recordCalls = new ArrayList<INDArray>();
|
||||||
|
addCalls = new ArrayList<INDArray>();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -46,7 +48,7 @@ public class MockHistoryProcessor implements IHistoryProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void add(INDArray image) {
|
public void add(INDArray image) {
|
||||||
++addCallCount;
|
addCalls.add(image);
|
||||||
history.add(image);
|
history.add(image);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
Loading…
Reference in New Issue