RL4J: Sanitize Observation (#404)

* working on ALE image pipelines that appear to lose data

* transformation pipeline for ALE has been broken for a while and needed some cleanup to make sure that openCV tooling for scene transforms was actually working.

* allowing history length to be set and passed through to history merge transforms

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* native image loader is not thread-safe so should not be static

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* make sure the transformer for encoding observations that are not pixels converts corectly

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* Test fixes for ALE pixel observation shape

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* Fix compilation errors

Signed-off-by: Samuel Audet <samuel.audet@gmail.com>

* fixing some post-merge issues, and comments from PR

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

Co-authored-by: Samuel Audet <samuel.audet@gmail.com>
master
Chris Bamford 2020-04-23 02:47:26 +01:00 committed by GitHub
parent 75cc6e2ed7
commit 032b97912e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 524 additions and 577 deletions

View File

@ -25,9 +25,13 @@ import org.bytedeco.javacpp.IntPointer;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author saudet
@ -70,10 +74,14 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
actions = new int[(int)a.limit()];
a.get(actions);
int height = (int)ale.getScreen().height();
int width = (int)(int)ale.getScreen().width();
discreteSpace = new DiscreteSpace(actions.length);
int[] shape = {(int)ale.getScreen().height(), (int)ale.getScreen().width(), 3};
int[] shape = {3, height, width};
observationSpace = new ArrayObservationSpace<>(shape);
screenBuffer = new byte[shape[0] * shape[1] * shape[2]];
}
public void setupGame() {
@ -103,7 +111,7 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
public GameScreen reset() {
ale.reset_game();
ale.getScreenRGB(screenBuffer);
return new GameScreen(screenBuffer);
return new GameScreen(observationSpace.getShape(), screenBuffer);
}
@ -115,7 +123,8 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
double r = ale.act(actions[action]) * scaleFactor;
log.info(ale.getEpisodeFrameNumber() + " " + r + " " + action + " ");
ale.getScreenRGB(screenBuffer);
return new StepReply(new GameScreen(screenBuffer), r, ale.game_over(), null);
return new StepReply(new GameScreen(observationSpace.getShape(), screenBuffer), r, ale.game_over(), null);
}
public ObservationSpace<GameScreen> getObservationSpace() {
@ -140,17 +149,35 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
}
public static class GameScreen implements Encodable {
double[] array;
public GameScreen(byte[] screen) {
array = new double[screen.length];
for (int i = 0; i < screen.length; i++) {
array[i] = (screen[i] & 0xFF) / 255.0;
}
final INDArray data;
public GameScreen(int[] shape, byte[] screen) {
data = Nd4j.create(screen, new long[] {shape[1], shape[2], 3}, DataType.UINT8).permute(2,0,1);
}
private GameScreen(INDArray toDup) {
data = toDup.dup();
}
@Override
public double[] toArray() {
return array;
return data.data().asDouble();
}
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return data;
}
@Override
public GameScreen dup() {
return new GameScreen(data);
}
}
}

View File

@ -19,15 +19,15 @@ package org.deeplearning4j.gym;
import lombok.Value;
/**
* @param <T> type of observation
* @param <OBSERVATION> type of observation
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16.
*
* StepReply is the container for the data returned after each step(action).
*/
@Value
public class StepReply<T> {
public class StepReply<OBSERVATION> {
T observation;
OBSERVATION observation;
double reward;
boolean done;
Object info;

View File

@ -32,7 +32,7 @@ import org.deeplearning4j.rl4j.space.ObservationSpace;
* in a "functionnal manner" if step return a mdp
*
*/
public interface MDP<OBSERVATION, ACTION, ACTION_SPACE extends ActionSpace<ACTION>> {
public interface MDP<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>> {
ObservationSpace<OBSERVATION> getObservationSpace();

View File

@ -16,6 +16,9 @@
package org.deeplearning4j.rl4j.space;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
*
@ -25,13 +28,37 @@ package org.deeplearning4j.rl4j.space;
*/
public class Box implements Encodable {
private final double[] array;
private final INDArray data;
public Box(double[] arr) {
this.array = arr;
public Box(double... arr) {
this.data = Nd4j.create(arr);
}
public Box(int[] shape, double... arr) {
this.data = Nd4j.create(arr).reshape(shape);
}
private Box(INDArray toDup) {
data = toDup.dup();
}
@Override
public double[] toArray() {
return array;
return data.data().asDouble();
}
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return data;
}
@Override
public Encodable dup() {
return new Box(data);
}
}

View File

@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K. K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,17 +16,19 @@
package org.deeplearning4j.rl4j.space;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/19/16.
* Encodable is an interface that ensure that the state is convertible to a double array
*/
import org.nd4j.linalg.api.ndarray.INDArray;
public interface Encodable {
/**
* $
* encodes all the information of an Observation in an array double and can be used as input of a DQN directly
*
* @return the encoded informations
*/
@Deprecated
double[] toArray();
boolean isSkipped();
/**
* Any image data should be in CHW format.
*/
INDArray getData();
Encodable dup();
}

View File

@ -24,16 +24,17 @@ import org.nd4j.linalg.factory.Nd4j;
* @author Alexandre Boulanger
*/
public class INDArrayHelper {
/**
* MultiLayerNetwork and ComputationGraph expect the first dimension to be the number of examples in the INDArray.
* In the case of RL4J, it must be 1. This method will return a INDArray with the correct shape.
* MultiLayerNetwork and ComputationGraph expects input data to be in NCHW in the case of pixels and NS in case of other data types.
*
* @param source A INDArray
* @return The source INDArray with the correct shape
* We must have either shape 2 (NK) or shape 4 (NCHW)
*/
public static INDArray forceCorrectShape(INDArray source) {
return source.shape()[0] == 1 && source.shape().length > 1
? source
: Nd4j.expandDims(source, 0);
}
}

View File

@ -46,7 +46,6 @@ public class HistoryProcessor implements IHistoryProcessor {
@Getter
final private Configuration conf;
final private OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat();
private CircularFifoQueue<INDArray> history;
private VideoRecorder videoRecorder;
@ -63,8 +62,7 @@ public class HistoryProcessor implements IHistoryProcessor {
public void startMonitor(String filename, int[] shape) {
if(videoRecorder == null) {
videoRecorder = VideoRecorder.builder(shape[0], shape[1])
.frameInputType(VideoRecorder.FrameInputTypes.Float)
videoRecorder = VideoRecorder.builder(shape[1], shape[2])
.build();
}
@ -89,14 +87,13 @@ public class HistoryProcessor implements IHistoryProcessor {
return videoRecorder != null && videoRecorder.isRecording();
}
public void record(INDArray raw) {
public void record(INDArray pixelArray) {
if(isMonitoring()) {
// before accessing the raw pointer, we need to make sure that array is actual on the host side
Nd4j.getAffinityManager().ensureLocation(raw, AffinityManager.Location.HOST);
Nd4j.getAffinityManager().ensureLocation(pixelArray, AffinityManager.Location.HOST);
VideoRecorder.VideoFrame frame = videoRecorder.createFrame(raw.data().pointer());
try {
videoRecorder.record(frame);
videoRecorder.record(pixelArray);
} catch (Exception e) {
e.printStackTrace();
}

View File

@ -64,7 +64,7 @@ public interface IHistoryProcessor {
@Builder.Default int skipFrame = 4;
public int[] getShape() {
return new int[] {getHistoryLength(), getCroppingHeight(), getCroppingWidth()};
return new int[] {getHistoryLength(), getRescaledHeight(), getRescaledWidth()};
}
}
}

View File

@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
*
* A common interface that any training method should implement
*/
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
public interface ILearning<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> {
IPolicy<A> getPolicy();
@ -38,7 +38,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
ILearningConfiguration getConfiguration();
MDP<O, A, AS> getMdp();
MDP<OBSERVATION, A, AS> getMdp();
IHistoryProcessor getHistoryProcessor();

View File

@ -21,7 +21,6 @@ import lombok.Getter;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
@ -38,8 +37,8 @@ import org.nd4j.linalg.factory.Nd4j;
*
*/
@Slf4j
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
public abstract class Learning<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
implements ILearning<OBSERVATION, A, AS>, NeuralNetFetchable<NN> {
@Getter @Setter
protected int stepCount = 0;

View File

@ -29,10 +29,10 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.factory.Nd4j;
@ -188,7 +188,7 @@ public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_
}
private boolean handleTraining(RunContext context) {
int maxTrainSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpisodeStepCount);
int maxTrainSteps = Math.min(getConfiguration().getNStep(), getConfiguration().getMaxEpochStep() - currentEpisodeStepCount);
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxTrainSteps);
context.obs = subEpochReturn.getLastObs();
@ -219,7 +219,7 @@ public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
protected abstract IAsyncLearningConfiguration getConf();
protected abstract IAsyncLearningConfiguration getConfiguration();
protected abstract IPolicy<ACTION> getPolicy(NN net);

View File

@ -24,29 +24,22 @@ import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Stack;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
* <p>
* Async Learning specialized for the Discrete Domain
*/
public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet>
extends AsyncThread<O, Integer, DiscreteSpace, NN> {
public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN extends NeuralNet>
extends AsyncThread<OBSERVATION, Integer, DiscreteSpace, NN> {
@Getter
private NN current;
@ -59,7 +52,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
MDP<O, Integer, DiscreteSpace> mdp,
MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
TrainingListenerList listeners,
int threadNumber,
int deviceNum) {
@ -112,7 +105,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
}
StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action);
accuReward += stepReply.getReward() * getConf().getRewardFactor();
accuReward += stepReply.getReward() * getConfiguration().getRewardFactor();
if (!obs.isSkipped()) {
experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone());
@ -126,7 +119,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
}
boolean episodeComplete = getMdp().isDone() || getConf().getMaxEpochStep() == currentEpisodeStepCount;
boolean episodeComplete = getMdp().isDone() || getConfiguration().getMaxEpochStep() == currentEpisodeStepCount;
if (episodeComplete && experienceHandler.getTrainingBatchSize() != trainingSteps) {
experienceHandler.setFinalObservation(obs);

View File

@ -28,9 +28,9 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
@ -41,19 +41,19 @@ import org.nd4j.linalg.factory.Nd4j;
* All methods are fully implemented as described in the
* https://arxiv.org/abs/1602.01783 paper.
*/
public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, Integer, DiscreteSpace, IActorCritic> {
public abstract class A3CDiscrete<OBSERVATION extends Encodable> extends AsyncLearning<OBSERVATION, Integer, DiscreteSpace, IActorCritic> {
@Getter
final public A3CLearningConfiguration configuration;
@Getter
final protected MDP<O, Integer, DiscreteSpace> mdp;
final protected MDP<OBSERVATION, Integer, DiscreteSpace> mdp;
final private IActorCritic iActorCritic;
@Getter
final private AsyncGlobal asyncGlobal;
@Getter
final private ACPolicy<O> policy;
final private ACPolicy<OBSERVATION> policy;
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) {
public A3CDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) {
this.iActorCritic = iActorCritic;
this.mdp = mdp;
this.configuration = conf;

View File

@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph;
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
@ -42,19 +42,19 @@ import org.deeplearning4j.rl4j.util.IDataManager;
* first layers since they're essentially doing the same dimension
* reduction task
**/
public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
public class A3CDiscreteConv<OBSERVATION extends Encodable> extends A3CDiscrete<OBSERVATION> {
final private HistoryProcessor.Configuration hpconf;
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, actorCritic, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
super(mdp, IActorCritic, conf.toLearningConfiguration());
@ -62,7 +62,7 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
setHistoryProcessor(hpconf);
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
super(mdp, IActorCritic, conf);
this.hpconf = hpconf;
@ -70,35 +70,35 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
}
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
}
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticNetworkConfiguration netConf,
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticNetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
}

View File

@ -21,8 +21,8 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.*;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
@ -34,74 +34,74 @@ import org.deeplearning4j.rl4j.util.IDataManager;
* We use specifically the Separate version because
* the model is too small to have enough benefit by sharing layers
*/
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
public class A3CDiscreteDense<OBSERVATION extends Encodable> extends A3CDiscrete<OBSERVATION> {
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
IDataManager dataManager) {
this(mdp, IActorCritic, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
super(mdp, actorCritic, conf.toLearningConfiguration());
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) {
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) {
super(mdp, actorCritic, conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
IDataManager dataManager) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}

View File

@ -23,23 +23,23 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
*
* <p>
* Local thread as described in the https://arxiv.org/abs/1602.01783 paper.
*/
public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IActorCritic> {
public class A3CThreadDiscrete<OBSERVATION extends Encodable> extends AsyncThreadDiscrete<OBSERVATION, IActorCritic> {
@Getter
final protected A3CLearningConfiguration conf;
final protected A3CLearningConfiguration configuration;
@Getter
final protected IAsyncGlobal<IActorCritic> asyncGlobal;
@Getter
@ -47,17 +47,17 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
final private Random rnd;
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal,
public A3CThreadDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal,
A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
int threadNumber) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = a3cc;
this.configuration = a3cc;
this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber;
Long seed = conf.getSeed();
Long seed = configuration.getSeed();
rnd = Nd4j.getRandom();
if(seed != null) {
if (seed != null) {
rnd.setSeed(seed + threadNumber);
}
@ -69,9 +69,12 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
return new ACPolicy(net, rnd);
}
/**
* calc the gradients based on the n-step rewards
*/
@Override
protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), conf.getGamma());
return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), configuration.getGamma());
}
}

View File

@ -28,26 +28,26 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*/
public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
extends AsyncLearning<O, Integer, DiscreteSpace, IDQN> {
public abstract class AsyncNStepQLearningDiscrete<OBSERVATION extends Encodable>
extends AsyncLearning<OBSERVATION, Integer, DiscreteSpace, IDQN> {
@Getter
final public AsyncQLearningConfiguration configuration;
@Getter
final private MDP<O, Integer, DiscreteSpace> mdp;
final private MDP<OBSERVATION, Integer, DiscreteSpace> mdp;
@Getter
final private AsyncGlobal<IDQN> asyncGlobal;
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
public AsyncNStepQLearningDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
this.mdp = mdp;
this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
@ -63,7 +63,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
}
public IPolicy<Integer> getPolicy() {
return new DQNPolicy<O>(getNeuralNet());
return new DQNPolicy<OBSERVATION>(getNeuralNet());
}
@Data

View File

@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
@ -35,17 +35,17 @@ import org.deeplearning4j.rl4j.util.IDataManager;
* Specialized constructors for the Conv (pixels input) case
* Specialized conf + provide additional type safety
*/
public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
public class AsyncNStepQLearningDiscreteConv<OBSERVATION extends Encodable> extends AsyncNStepQLearningDiscrete<OBSERVATION> {
final private HistoryProcessor.Configuration hpconf;
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
super(mdp, dqn, conf);
this.hpconf = hpconf;
@ -53,21 +53,21 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
}
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
}
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
}

View File

@ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
*/
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
public class AsyncNStepQLearningDiscreteDense<OBSERVATION extends Encodable> extends AsyncNStepQLearningDiscrete<OBSERVATION> {
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf.toLearningConfiguration());
addListener(new DataManagerTrainingListener(dataManager));
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration());
}
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncQLearningConfiguration conf) {
super(mdp, dqn, conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncQLearningConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
}
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf);
}

View File

@ -25,21 +25,21 @@ import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguratio
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*/
public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IDQN> {
public class AsyncNStepQLearningThreadDiscrete<OBSERVATION extends Encodable> extends AsyncThreadDiscrete<OBSERVATION, IDQN> {
@Getter
final protected AsyncQLearningConfiguration conf;
final protected AsyncQLearningConfiguration configuration;
@Getter
final protected IAsyncGlobal<IDQN> asyncGlobal;
@Getter
@ -47,17 +47,17 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
final private Random rnd;
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
AsyncQLearningConfiguration conf,
public AsyncNStepQLearningThreadDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
AsyncQLearningConfiguration configuration,
TrainingListenerList listeners, int threadNumber, int deviceNum) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = conf;
this.configuration = configuration;
this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber;
rnd = Nd4j.getRandom();
Long seed = conf.getSeed();
if (seed != null) {
Long seed = configuration.getSeed();
if(seed != null) {
rnd.setSeed(seed + threadNumber);
}
@ -65,13 +65,13 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
}
public Policy<Integer> getPolicy(IDQN nn) {
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
rnd, conf.getMinEpsilon(), this);
return new EpsGreedy(new DQNPolicy(nn), getMdp(), configuration.getUpdateStart(), configuration.getEpsilonNbStep(),
rnd, configuration.getMinEpsilon(), this);
}
@Override
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), conf.getGamma());
return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), configuration.getGamma());
}
}

View File

@ -32,10 +32,10 @@ import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;

View File

@ -33,11 +33,11 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorith
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;

View File

@ -24,8 +24,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
@ -34,59 +34,59 @@ import org.deeplearning4j.rl4j.util.IDataManager;
* Specialized constructors for the Conv (pixels input) case
* Specialized conf + provide additional type safety
*/
public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscrete<O> {
public class QLearningDiscreteConv<OBSERVATION extends Encodable> extends QLearningDiscrete<OBSERVATION> {
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf, IDataManager dataManager) {
this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf);
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLearningConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf);
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
}

View File

@ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
*/
public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscrete<O> {
public class QLearningDiscreteDense<OBSERVATION extends Encodable> extends QLearningDiscrete<OBSERVATION> {
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
IDataManager dataManager) {
this(mdp, dqn, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep());
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf) {
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep());
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearning.QLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearningConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNDenseNetworkConfiguration netConf,
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNDenseNetworkConfiguration netConf,
QLearningConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf);
}

View File

@ -4,8 +4,8 @@ import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import java.util.Random;
@ -36,7 +36,7 @@ import java.util.Random;
*/
public class CartpoleNative implements MDP<CartpoleNative.State, Integer, DiscreteSpace> {
public class CartpoleNative implements MDP<Box, Integer, DiscreteSpace> {
public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
private static final int NUM_ACTIONS = 2;
@ -74,7 +74,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
@Getter
private DiscreteSpace actionSpace = new DiscreteSpace(NUM_ACTIONS);
@Getter
private ObservationSpace<CartpoleNative.State> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES });
private ObservationSpace<Box> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES });
public CartpoleNative() {
rnd = new Random();
@ -85,7 +85,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
}
@Override
public State reset() {
public Box reset() {
x = 0.1 * rnd.nextDouble() - 0.05;
xDot = 0.1 * rnd.nextDouble() - 0.05;
@ -94,7 +94,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
stepsBeyondDone = null;
done = false;
return new State(new double[] { x, xDot, theta, thetaDot });
return new Box(x, xDot, theta, thetaDot);
}
@Override
@ -103,7 +103,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
}
@Override
public StepReply<State> step(Integer action) {
public StepReply<Box> step(Integer action) {
double force = action == ACTION_RIGHT ? forceMag : -forceMag;
double cosTheta = Math.cos(theta);
double sinTheta = Math.sin(theta);
@ -143,26 +143,12 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
reward = 0;
}
return new StepReply<>(new State(new double[] { x, xDot, theta, thetaDot }), reward, done, null);
return new StepReply<>(new Box(x, xDot, theta, thetaDot), reward, done, null);
}
@Override
public MDP<State, Integer, DiscreteSpace> newInstance() {
public MDP<Box, Integer, DiscreteSpace> newInstance() {
return new CartpoleNative();
}
public static class State implements Encodable {
private final double[] state;
State(double[] state) {
this.state = state;
}
@Override
public double[] toArray() {
return state;
}
}
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy;
import lombok.Value;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
@ -31,4 +32,19 @@ public class HardToyState implements Encodable {
public double[] toArray() {
return values;
}
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return null;
}
@Override
public Encodable dup() {
return null;
}
}

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -40,7 +41,6 @@ import org.nd4j.linalg.factory.Nd4j;
public class SimpleToy implements MDP<SimpleToyState, Integer, DiscreteSpace> {
final private int maxStep;
//TODO 10 steps toy (always +1 reward2 actions), toylong (1000 steps), toyhard (7 actions, +1 only if actiion = (step/100+step)%7, and toyStoch (like last but reward has 0.10 odd to be somewhere else).
@Getter
private DiscreteSpace actionSpace = new DiscreteSpace(2);
@Getter

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy;
import lombok.Value;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
@ -28,11 +29,24 @@ public class SimpleToyState implements Encodable {
int i;
int step;
@Override
public double[] toArray() {
double[] ar = new double[1];
ar[0] = (20 - i);
return ar;
}
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return null;
}
@Override
public Encodable dup() {
return null;
}
}

View File

@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
*
* @author Alexandre Boulanger
*/
public class Observation {
public class Observation implements Encodable {
/**
* A singleton representing a skipped observation
@ -38,6 +38,11 @@ public class Observation {
@Getter
private final INDArray data;
@Override
public double[] toArray() {
return data.data().asDouble();
}
public boolean isSkipped() {
return data == null;
}

View File

@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
* Copyright (c) 2020 Konduit K. K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -13,29 +13,16 @@
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.observation.transform.legacy;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.bytedeco.opencv.opencv_core.Mat;
package org.deeplearning4j.rl4j.observation.transform;
import org.datavec.api.transform.Operation;
import org.datavec.image.data.ImageWritable;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
public class EncodableToINDArrayTransform implements Operation<Encodable, INDArray> {
private final int[] shape;
public EncodableToINDArrayTransform(int[] shape) {
this.shape = shape;
}
@Override
public INDArray transform(Encodable encodable) {
return Nd4j.create(encodable.toArray()).reshape(shape);
return encodable.getData();
}
}

View File

@ -15,34 +15,32 @@
******************************************************************************/
package org.deeplearning4j.rl4j.observation.transform.legacy;
import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.bytedeco.opencv.opencv_core.Mat;
import org.datavec.api.transform.Operation;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
import static org.bytedeco.opencv.global.opencv_core.CV_32FC3;
import static org.bytedeco.opencv.global.opencv_core.CV_32S;
import static org.bytedeco.opencv.global.opencv_core.CV_32SC;
import static org.bytedeco.opencv.global.opencv_core.CV_32SC3;
import static org.bytedeco.opencv.global.opencv_core.CV_64FC;
import static org.bytedeco.opencv.global.opencv_core.CV_8UC3;
public class EncodableToImageWritableTransform implements Operation<Encodable, ImageWritable> {
private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
private final int height;
private final int width;
private final int colorChannels;
public EncodableToImageWritableTransform(int height, int width, int colorChannels) {
this.height = height;
this.width = width;
this.colorChannels = colorChannels;
}
final static NativeImageLoader nativeImageLoader = new NativeImageLoader();
@Override
public ImageWritable transform(Encodable encodable) {
INDArray indArray = Nd4j.create(encodable.toArray()).reshape(height, width, colorChannels);
Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer());
return new ImageWritable(converter.convert(mat));
return new ImageWritable(nativeImageLoader.asFrame(encodable.getData(), Frame.DEPTH_UBYTE));
}
}

View File

@ -18,34 +18,31 @@ package org.deeplearning4j.rl4j.observation.transform.legacy;
import org.datavec.api.transform.Operation;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException;
public class ImageWritableToINDArrayTransform implements Operation<ImageWritable, INDArray> {
private final int height;
private final int width;
private final NativeImageLoader loader;
public ImageWritableToINDArrayTransform(int height, int width) {
this.height = height;
this.width = width;
this.loader = new NativeImageLoader(height, width);
}
private final NativeImageLoader loader = new NativeImageLoader();
@Override
public INDArray transform(ImageWritable imageWritable) {
int height = imageWritable.getHeight();
int width = imageWritable.getWidth();
int channels = imageWritable.getFrame().imageChannels;
INDArray out = null;
try {
out = loader.asMatrix(imageWritable);
} catch (IOException e) {
e.printStackTrace();
}
out = out.reshape(1, height, width);
// Convert back to uint8 and reshape to the number of channels in the image
out = out.reshape(channels, height, width);
INDArray compressed = out.castTo(DataType.UINT8);
return compressed;
}

View File

@ -46,19 +46,20 @@ public class HistoryMergeTransform implements Operation<INDArray, INDArray>, Res
private final HistoryMergeElementStore historyMergeElementStore;
private final HistoryMergeAssembler historyMergeAssembler;
private final boolean shouldStoreCopy;
private final boolean isFirstDimenstionBatch;
private final boolean isFirstDimensionBatch;
private HistoryMergeTransform(Builder builder) {
this.historyMergeElementStore = builder.historyMergeElementStore;
this.historyMergeAssembler = builder.historyMergeAssembler;
this.shouldStoreCopy = builder.shouldStoreCopy;
this.isFirstDimenstionBatch = builder.isFirstDimenstionBatch;
this.isFirstDimensionBatch = builder.isFirstDimenstionBatch;
}
@Override
public INDArray transform(INDArray input) {
INDArray element;
if(isFirstDimenstionBatch) {
if(isFirstDimensionBatch) {
element = input.slice(0, 0);
}
else {
@ -132,9 +133,9 @@ public class HistoryMergeTransform implements Operation<INDArray, INDArray>, Res
return this;
}
public HistoryMergeTransform build() {
public HistoryMergeTransform build(int frameStackLength) {
if(historyMergeElementStore == null) {
historyMergeElementStore = new CircularFifoStore();
historyMergeElementStore = new CircularFifoStore(frameStackLength);
}
if(historyMergeAssembler == null) {

View File

@ -28,14 +28,9 @@ import org.nd4j.linalg.factory.Nd4j;
* @author Alexandre Boulanger
*/
public class CircularFifoStore implements HistoryMergeElementStore {
private static final int DEFAULT_STORE_SIZE = 4;
private final CircularFifoQueue<INDArray> queue;
public CircularFifoStore() {
this(DEFAULT_STORE_SIZE);
}
public CircularFifoStore(int size) {
Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size);
queue = new CircularFifoQueue<>(size);

View File

@ -20,8 +20,8 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
@ -35,7 +35,7 @@ import java.io.IOException;
* the softmax output of the actor critic, but objects constructed
* with a {@link Random} argument of null return the max only.
*/
public class ACPolicy<O extends Encodable> extends Policy<Integer> {
public class ACPolicy<OBSERVATION extends Encodable> extends Policy<Integer> {
final private IActorCritic actorCritic;
Random rnd;
@ -48,18 +48,18 @@ public class ACPolicy<O extends Encodable> extends Policy<Integer> {
this.rnd = rnd;
}
public static <O extends Encodable> ACPolicy<O> load(String path) throws IOException {
return new ACPolicy<O>(ActorCriticCompGraph.load(path));
public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String path) throws IOException {
return new ACPolicy<>(ActorCriticCompGraph.load(path));
}
public static <O extends Encodable> ACPolicy<O> load(String path, Random rnd) throws IOException {
return new ACPolicy<O>(ActorCriticCompGraph.load(path), rnd);
public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String path, Random rnd) throws IOException {
return new ACPolicy<>(ActorCriticCompGraph.load(path), rnd);
}
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy) throws IOException {
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy));
public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String pathValue, String pathPolicy) throws IOException {
return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy));
}
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rnd) throws IOException {
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd);
public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String pathValue, String pathPolicy, Random rnd) throws IOException {
return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd);
}
public IActorCritic getNeuralNet() {

View File

@ -17,8 +17,8 @@
package org.deeplearning4j.rl4j.policy;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
@ -30,7 +30,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp;
* Boltzmann exploration is a stochastic policy wrt to the
* exponential Q-values as evaluated by the dqn model.
*/
public class BoltzmannQ<O extends Encodable> extends Policy<Integer> {
public class BoltzmannQ<OBSERVATION extends Encodable> extends Policy<Integer> {
final private IDQN dqn;
final private Random rnd;

View File

@ -20,8 +20,8 @@ import lombok.AllArgsConstructor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.IOException;
@ -35,12 +35,12 @@ import java.io.IOException;
// FIXME: Should we rename this "GreedyPolicy"?
@AllArgsConstructor
public class DQNPolicy<O> extends Policy<Integer> {
public class DQNPolicy<OBSERVATION> extends Policy<Integer> {
final private IDQN dqn;
public static <O extends Encodable> DQNPolicy<O> load(String path) throws IOException {
return new DQNPolicy<O>(DQN.load(path));
public static <OBSERVATION extends Encodable> DQNPolicy<OBSERVATION> load(String path) throws IOException {
return new DQNPolicy<>(DQN.load(path));
}
public IDQN getNeuralNet() {

View File

@ -20,12 +20,11 @@ package org.deeplearning4j.rl4j.policy;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
@ -41,10 +40,10 @@ import org.nd4j.linalg.api.rng.Random;
*/
@AllArgsConstructor
@Slf4j
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
final private Policy<A> policy;
final private MDP<O, A, AS> mdp;
final private MDP<OBSERVATION, A, AS> mdp;
final private int updateStart;
final private int epsilonNbStep;
final private Random rnd;

View File

@ -22,9 +22,9 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
/**

View File

@ -7,22 +7,22 @@ import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.CropImageTransform;
import org.datavec.image.transform.MultiImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap;
import java.util.Map;
@ -46,6 +46,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
private int skipFrame = 1;
private int steps = 0;
public LegacyMDPWrapper(MDP<OBSERVATION, A, AS> wrappedMDP, IHistoryProcessor historyProcessor) {
this.wrappedMDP = wrappedMDP;
this.shape = wrappedMDP.getObservationSpace().getShape();
@ -66,28 +67,33 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
if(historyProcessor != null && shape.length == 3) {
int skipFrame = historyProcessor.getConf().getSkipFrame();
int frameStackLength = historyProcessor.getConf().getHistoryLength();
int finalHeight = historyProcessor.getConf().getCroppingHeight();
int finalWidth = historyProcessor.getConf().getCroppingWidth();
int height = shape[1];
int width = shape[2];
int cropBottom = height - historyProcessor.getConf().getCroppingHeight();
int cropRight = width - historyProcessor.getConf().getCroppingWidth();
transformProcess = TransformProcess.builder()
.filter(new UniformSkippingFilter(skipFrame))
.transform("data", new EncodableToImageWritableTransform(shape[0], shape[1], shape[2]))
.transform("data", new EncodableToImageWritableTransform())
.transform("data", new MultiImageTransform(
new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), cropBottom, cropRight),
new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()),
new ColorConversionTransform(COLOR_BGR2GRAY),
new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), finalHeight, finalWidth)
new ColorConversionTransform(COLOR_BGR2GRAY)
//new ShowImageTransform("crop + resize + greyscale")
))
.transform("data", new ImageWritableToINDArrayTransform(finalHeight, finalWidth))
.transform("data", new ImageWritableToINDArrayTransform())
.transform("data", new SimpleNormalizationTransform(0.0, 255.0))
.transform("data", HistoryMergeTransform.builder()
.isFirstDimenstionBatch(true)
.build())
.build(frameStackLength))
.build("data");
}
else {
transformProcess = TransformProcess.builder()
.transform("data", new EncodableToINDArrayTransform(shape))
.transform("data", new EncodableToINDArrayTransform())
.build("data");
}
}
@ -127,6 +133,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation());
Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone());
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
}
@ -161,12 +168,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
}
private INDArray getInput(OBSERVATION obs) {
INDArray arr = Nd4j.create(obs.toArray());
int[] shape = observationSpace.getShape();
if (shape.length == 1)
return arr.reshape(new long[] {1, arr.length()});
else
return arr.reshape(shape);
return obs.getData();
}
public static class WrapperObservationSpace implements ObservationSpace<Observation> {

View File

@ -16,26 +16,21 @@
package org.deeplearning4j.rl4j.util;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacv.FFmpegFrameRecorder;
import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.global.opencv_imgproc;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.Rect;
import org.bytedeco.opencv.opencv_core.Size;
import org.opencv.imgproc.Imgproc;
import org.datavec.image.loader.NativeImageLoader;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import static org.bytedeco.ffmpeg.global.avcodec.*;
import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_H264;
import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_MPEG4;
import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB0;
import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB24;
import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB8;
/**
* VideoRecorder is used to create a video from a sequence of individual frames. If using 3 channels
* images, it expects B-G-R order. A RGB order can be used by calling isRGBOrder(true).<br>
* VideoRecorder is used to create a video from a sequence of INDArray frames. INDArrays are assumed to be in CHW format where C=3 and pixels are RGB encoded<br>
* Example:<br>
* <pre>
* {@code
@ -45,11 +40,8 @@ import static org.bytedeco.opencv.global.opencv_core.*;
* .build();
* recorder.startRecording("myVideo.mp4");
* while(...) {
* byte[] data = new byte[160*100*3];
* // Todo: Fill data
* VideoRecorder.VideoFrame frame = recorder.createFrame(data);
* // Todo: Apply cropping or resizing to frame
* recorder.record(frame);
* INDArray chwData = Nd4j.create()
* recorder.record(chwData);
* }
* recorder.stopRecording();
* }
@ -60,16 +52,13 @@ import static org.bytedeco.opencv.global.opencv_core.*;
@Slf4j
public class VideoRecorder implements AutoCloseable {
public enum FrameInputTypes { BGR, RGB, Float }
private final NativeImageLoader nativeImageLoader = new NativeImageLoader();
private final int height;
private final int width;
private final int imageType;
private final OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat();
private final int codec;
private final double framerate;
private final int videoQuality;
private final FrameInputTypes frameInputType;
private FFmpegFrameRecorder fmpegFrameRecorder = null;
@ -83,11 +72,9 @@ public class VideoRecorder implements AutoCloseable {
private VideoRecorder(Builder builder) {
this.height = builder.height;
this.width = builder.width;
imageType = CV_8UC(builder.numChannels);
codec = builder.codec;
framerate = builder.frameRate;
videoQuality = builder.videoQuality;
frameInputType = builder.frameInputType;
}
/**
@ -119,59 +106,11 @@ public class VideoRecorder implements AutoCloseable {
/**
* Add a frame to the video
* @param frame the VideoFrame to add to the video
* @param imageArray the INDArray that contains the data to be recorded, the data must be in CHW format
* @throws Exception
*/
public void record(VideoFrame frame) throws Exception {
Size size = frame.getMat().size();
if(size.height() != height || size.width() != width) {
throw new IllegalArgumentException(String.format("Wrong frame size. Got (%dh x %dw) expected (%dh x %dw)", size.height(), size.width(), height, width));
}
Frame cvFrame = openCVFrameConverter.convert(frame.getMat());
fmpegFrameRecorder.record(cvFrame);
}
/**
* Create a VideoFrame from a byte array.
* @param data A byte array. Expect the index to be of the form [(Y*Width + X) * NumChannels + channel]
* @return An instance of VideoFrame
*/
public VideoFrame createFrame(byte[] data) {
return createFrame(new BytePointer(data));
}
/**
* Create a VideoFrame from a byte array with different height and width than the video
* the frame will need to be cropped or resized before being added to the video)
*
* @param data A byte array Expect the index to be of the form [(Y*customWidth + X) * NumChannels + channel]
* @param customHeight The actual height of the data
* @param customWidth The actual width of the data
* @return A VideoFrame instance
*/
public VideoFrame createFrame(byte[] data, int customHeight, int customWidth) {
return createFrame(new BytePointer(data), customHeight, customWidth);
}
/**
* Create a VideoFrame from a Pointer (to use for example with a INDarray).
* @param data A Pointer (for example myINDArray.data().pointer())
* @return An instance of VideoFrame
*/
public VideoFrame createFrame(Pointer data) {
return new VideoFrame(height, width, imageType, frameInputType, data);
}
/**
* Create a VideoFrame from a Pointer with different height and width than the video
* the frame will need to be cropped or resized before being added to the video)
* @param data
* @param customHeight The actual height of the data
* @param customWidth The actual width of the data
* @return A VideoFrame instance
*/
public VideoFrame createFrame(Pointer data, int customHeight, int customWidth) {
return new VideoFrame(customHeight, customWidth, imageType, frameInputType, data);
public void record(INDArray imageArray) throws Exception {
fmpegFrameRecorder.record(nativeImageLoader.asFrame(imageArray, Frame.DEPTH_UBYTE));
}
/**
@ -192,69 +131,12 @@ public class VideoRecorder implements AutoCloseable {
return new Builder(height, width);
}
/**
* An individual frame for the video
*/
public static class VideoFrame {
private final int height;
private final int width;
private final int imageType;
@Getter
private Mat mat;
private VideoFrame(int height, int width, int imageType, FrameInputTypes frameInputType, Pointer data) {
this.height = height;
this.width = width;
this.imageType = imageType;
switch(frameInputType) {
case RGB:
Mat src = new Mat(height, width, imageType, data);
mat = new Mat(height, width, imageType);
opencv_imgproc.cvtColor(src, mat, Imgproc.COLOR_RGB2BGR);
break;
case BGR:
mat = new Mat(height, width, imageType, data);
break;
case Float:
Mat tmpMat = new Mat(height, width, CV_32FC(3), data);
mat = new Mat(height, width, imageType);
tmpMat.convertTo(mat, CV_8UC(3), 255.0, 0.0);
}
}
/**
* Crop the video to a specified size
* @param newHeight The new height of the frame
* @param newWidth The new width of the frame
* @param heightOffset The starting height offset in the uncropped frame
* @param widthOffset The starting weight offset in the uncropped frame
*/
public void crop(int newHeight, int newWidth, int heightOffset, int widthOffset) {
mat = mat.apply(new Rect(widthOffset, heightOffset, newWidth, newHeight));
}
/**
* Resize the frame to a specified size
* @param newHeight The new height of the frame
* @param newWidth The new width of the frame
*/
public void resize(int newHeight, int newWidth) {
mat = new Mat(newHeight, newWidth, imageType);
}
}
/**
* A builder class for the VideoRecorder
*/
public static class Builder {
private final int height;
private final int width;
private int numChannels = 3;
private FrameInputTypes frameInputType = FrameInputTypes.BGR;
private int codec = AV_CODEC_ID_H264;
private double frameRate = 30.0;
private int videoQuality = 30;
@ -268,24 +150,6 @@ public class VideoRecorder implements AutoCloseable {
this.width = width;
}
/**
* Specify the number of channels. Default is 3
* @param numChannels
*/
public Builder numChannels(int numChannels) {
this.numChannels = numChannels;
return this;
}
/**
* Tell the VideoRecorder what data it will receive (default is BGR)
* @param frameInputType (See {@link FrameInputTypes}}
*/
public Builder frameInputType(FrameInputTypes frameInputType) {
this.frameInputType = frameInputType;
return this;
}
/**
* The codec to use for the video. Default is AV_CODEC_ID_H264
* @param codec Code (see {@link org.bytedeco.ffmpeg.global.avcodec codec codes})

View File

@ -115,7 +115,7 @@ public class AsyncThreadDiscreteTest {
asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm);
when(asyncThreadDiscrete.getConf()).thenReturn(mockAsyncConfiguration);
when(asyncThreadDiscrete.getConfiguration()).thenReturn(mockAsyncConfiguration);
when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0);
when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy);

View File

@ -39,7 +39,6 @@ import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
@ -130,7 +129,7 @@ public class AsyncThreadTest {
when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode);
when(mockAsyncConfiguration.getNStep()).thenReturn(nstep);
when(thread.getConf()).thenReturn(mockAsyncConfiguration);
when(thread.getConfiguration()).thenReturn(mockAsyncConfiguration);
// if we hit the max step count
when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps);

View File

@ -18,24 +18,16 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -43,17 +35,17 @@ import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
@ -82,6 +74,7 @@ public class QLearningDiscreteTest {
@Mock
QLearningConfiguration mockQlearningConfiguration;
// HWC
int[] observationShape = new int[]{3, 10, 10};
int totalObservationSize = 1;
@ -123,6 +116,7 @@ public class QLearningDiscreteTest {
when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]);
when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]);
when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames);
when(mockHistoryConfiguration.getHistoryLength()).thenReturn(1);
when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration);
qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor);
@ -148,7 +142,7 @@ public class QLearningDiscreteTest {
Observation observation = new Observation(Nd4j.zeros(observationShape));
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Observation(Nd4j.zeros(observationShape)), 0, false, null));
// Act
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
@ -170,25 +164,26 @@ public class QLearningDiscreteTest {
// Arrange
mockTestContext(100,0,2,1.0, 10);
mockHistoryProcessor(2);
Observation skippedObservation = Observation.SkippedObservation;
Observation nextObservation = new Observation(Nd4j.zeros(observationShape));
// An example observation and 2 Q values output (2 actions)
Observation observation = new Observation(Nd4j.zeros(observationShape));
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(nextObservation, 0, false, null));
// Act
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(skippedObservation);
// Assert
assertEquals(1.0, stepReturn.getMaxQ(), 1e-5);
assertEquals(Double.NaN, stepReturn.getMaxQ(), 1e-5);
StepReply<Observation> stepReply = stepReturn.getStepReply();
assertEquals(0, stepReply.getReward(), 1e-5);
assertFalse(stepReply.isDone());
assertTrue(stepReply.getObservation().isSkipped());
assertFalse(stepReply.getObservation().isSkipped());
assertEquals(0, qLearningDiscrete.getExperienceHandler().getTrainingBatchSize());
verify(mockDQN, never()).output(any(INDArray.class));
}
//TODO: there are much more test cases here that can be improved upon

View File

@ -17,7 +17,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder()
.isFirstDimenstionBatch(false)
.elementStore(store)
.build();
.build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
// Act
@ -35,7 +35,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder()
.isFirstDimenstionBatch(true)
.elementStore(store)
.build();
.build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3);
// Act
@ -53,7 +53,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder()
.isFirstDimenstionBatch(true)
.elementStore(store)
.build();
.build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
// Act
@ -70,7 +70,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder()
.shouldStoreCopy(false)
.elementStore(store)
.build();
.build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
// Act
@ -87,7 +87,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder()
.shouldStoreCopy(true)
.elementStore(store)
.build();
.build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
// Act
@ -107,7 +107,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder()
.elementStore(store)
.assembler(assemble)
.build();
.build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
// Act

View File

@ -252,8 +252,8 @@ public class PolicyTest {
}
@Override
protected <O extends Encodable, AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
protected <MockObservation extends Encodable, AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockObservation, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(skipFrame, historyLength));
return super.refacInitMdp(mdpWrapper, hp);
}
}

View File

@ -1,18 +0,0 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.space.Encodable;
public class MockEncodable implements Encodable {
private final int value;
public MockEncodable(int value) {
this.value = value;
}
@Override
public double[] toArray() {
return new double[] { value };
}
}

View File

@ -2,9 +2,10 @@ package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore;
@ -15,7 +16,7 @@ import org.nd4j.linalg.api.rng.Random;
import java.util.ArrayList;
import java.util.List;
public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
public class MockMDP implements MDP<MockObservation, Integer, DiscreteSpace> {
private final DiscreteSpace actionSpace;
private final int stepsUntilDone;
@ -55,11 +56,11 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
}
@Override
public MockEncodable reset() {
public MockObservation reset() {
++resetCount;
currentObsValue = 0;
step = 0;
return new MockEncodable(currentObsValue++);
return new MockObservation(currentObsValue++);
}
@Override
@ -68,10 +69,10 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
}
@Override
public StepReply<MockEncodable> step(Integer action) {
public StepReply<MockObservation> step(Integer action) {
actions.add(action);
++step;
return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null);
return new StepReply<>(new MockObservation(currentObsValue), (double) currentObsValue++, isDone(), null);
}
@Override
@ -84,14 +85,14 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
return null;
}
public static TransformProcess buildTransformProcess(int[] shape, int skipFrame, int historyLength) {
public static TransformProcess buildTransformProcess(int skipFrame, int historyLength) {
return TransformProcess.builder()
.filter(new UniformSkippingFilter(skipFrame))
.transform("data", new EncodableToINDArrayTransform(shape))
.transform("data", new EncodableToINDArrayTransform())
.transform("data", new SimpleNormalizationTransform(0.0, 255.0))
.transform("data", HistoryMergeTransform.builder()
.elementStore(new CircularFifoStore(historyLength))
.build())
.build(4))
.build("data");
}

View File

@ -0,0 +1,51 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K. K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class MockObservation implements Encodable {
final INDArray data;
public MockObservation(int value) {
this.data = Nd4j.ones(1).mul(value);
}
@Override
public double[] toArray() {
return data.data().asDouble();
}
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return data;
}
@Override
public Encodable dup() {
return null;
}
}

View File

@ -17,7 +17,7 @@ public class MockPolicy implements IPolicy<Integer> {
public List<INDArray> actionInputs = new ArrayList<INDArray>();
@Override
public <O extends Encodable, AS extends ActionSpace<Integer>> double play(MDP<O, Integer, AS> mdp, IHistoryProcessor hp) {
public <MockObservation extends Encodable, AS extends ActionSpace<Integer>> double play(MDP<MockObservation, Integer, AS> mdp, IHistoryProcessor hp) {
++playCallCount;
return 0;
}

View File

@ -28,6 +28,9 @@ import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import vizdoom.*;
import java.util.ArrayList;
@ -155,7 +158,7 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre
+ Pointer.formatBytes(Pointer.totalPhysicalBytes()));
game.newEpisode();
return new GameScreen(game.getState().screenBuffer);
return new GameScreen(observationSpace.getShape(), game.getState().screenBuffer);
}
@ -168,7 +171,7 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre
double r = game.makeAction(actions.get(action)) * scaleFactor;
log.info(game.getEpisodeTime() + " " + r + " " + action + " ");
return new StepReply(new GameScreen(game.isEpisodeFinished()
return new StepReply(new GameScreen(observationSpace.getShape(), game.isEpisodeFinished()
? new byte[game.getScreenSize()]
: game.getState().screenBuffer), r, game.isEpisodeFinished(), null);
@ -201,18 +204,34 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre
public static class GameScreen implements Encodable {
final INDArray data;
public GameScreen(int[] shape, byte[] screen) {
double[] array;
public GameScreen(byte[] screen) {
array = new double[screen.length];
for (int i = 0; i < screen.length; i++) {
array[i] = (screen[i] & 0xFF) / 255.0;
}
data = Nd4j.create(screen, new long[] {shape[1], shape[2], 3}, DataType.UINT8).permute(2,0,1);
}
private GameScreen(INDArray toDup) {
data = toDup.dup();
}
@Override
public double[] toArray() {
return array;
return data.data().asDouble();
}
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return data;
}
@Override
public GameScreen dup() {
return new GameScreen(data);
}
}

View File

@ -19,8 +19,6 @@ package org.deeplearning4j.rl4j.mdp.gym;
import java.io.IOException;
import lombok.Getter;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.Pointer;
@ -31,8 +29,8 @@ import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.HighLowDiscrete;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.bytedeco.cpython.*;
@ -47,7 +45,7 @@ import static org.bytedeco.numpy.global.numpy.*;
* @author saudet
*/
@Slf4j
public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
public class GymEnv<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> implements MDP<OBSERVATION, A, AS> {
public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn";
@ -82,7 +80,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
private PyObject locals;
final protected DiscreteSpace actionSpace;
final protected ObservationSpace<O> observationSpace;
final protected ObservationSpace<OBSERVATION> observationSpace;
@Getter
final private String envId;
@Getter
@ -119,7 +117,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
for (int i = 0; i < shape.length; i++) {
shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i));
}
observationSpace = (ObservationSpace<O>) new ArrayObservationSpace<Box>(shape);
observationSpace = (ObservationSpace<OBSERVATION>) new ArrayObservationSpace<Box>(shape);
Py_DecRef(shapeTuple);
PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null);
@ -140,7 +138,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
}
@Override
public ObservationSpace<O> getObservationSpace() {
public ObservationSpace<OBSERVATION> getObservationSpace() {
return observationSpace;
}
@ -153,7 +151,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
}
@Override
public StepReply<O> step(A action) {
public StepReply<OBSERVATION> step(A action) {
int gstate = PyGILState_Ensure();
try {
if (render) {
@ -186,7 +184,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
}
@Override
public O reset() {
public OBSERVATION reset() {
int gstate = PyGILState_Ensure();
try {
Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null));
@ -201,7 +199,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
double[] data = new double[(int)stateData.capacity()];
stateData.get(data);
return (O) new Box(data);
return (OBSERVATION) new Box(data);
} finally {
PyGILState_Release(gstate);
}
@ -220,7 +218,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
}
@Override
public GymEnv<O, A, AS> newInstance() {
return new GymEnv<O, A, AS>(envId, render, monitor);
public GymEnv<OBSERVATION, A, AS> newInstance() {
return new GymEnv<OBSERVATION, A, AS>(envId, render, monitor);
}
}

View File

@ -40,8 +40,8 @@ public class GymEnvTest {
assertEquals(false, mdp.isDone());
Box o = (Box)mdp.reset();
StepReply r = mdp.step(0);
assertEquals(4, o.toArray().length);
assertEquals(4, ((Box)r.getObservation()).toArray().length);
assertEquals(4, o.getData().shape()[0]);
assertEquals(4, ((Box)r.getObservation()).getData().shape()[0]);
assertNotEquals(null, mdp.newInstance());
mdp.close();
}

View File

@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K. K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,33 +16,13 @@
package org.deeplearning4j.malmo;
import java.util.Arrays;
import org.deeplearning4j.rl4j.space.Box;
import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.rl4j.space.Encodable;
@Deprecated
public class MalmoBox extends Box {
/**
* Encodable state as a simple value array similar to Gym Box model, but without a JSON constructor
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17.
*/
public class MalmoBox implements Encodable {
double[] value;
/**
* Construct state from an array of doubles
* @param value state values
*/
//TODO: If this constructor was added to "Box", we wouldn't need this class at all.
public MalmoBox(double... value) {
this.value = value;
}
@Override
public double[] toArray() {
return value;
}
@Override
public String toString() {
return Arrays.toString(value);
public MalmoBox(double... arr) {
super(arr);
}
}

View File

@ -19,6 +19,8 @@ package org.deeplearning4j.malmo;
import java.util.Arrays;
import com.microsoft.msr.malmo.WorldState;
import org.deeplearning4j.rl4j.space.Box;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* A Malmo consistency policy that ensures the both there is a reward and next observation has a different position that the previous one.
@ -30,14 +32,14 @@ public class MalmoDescretePositionPolicy implements MalmoObservationPolicy {
@Override
public boolean isObservationConsistant(WorldState world_state, WorldState original_world_state) {
MalmoBox last_observation = observationSpace.getObservation(world_state);
MalmoBox old_observation = observationSpace.getObservation(original_world_state);
Box last_observation = observationSpace.getObservation(world_state);
Box old_observation = observationSpace.getObservation(original_world_state);
double[] newvalues = old_observation == null ? null : old_observation.toArray();
double[] oldvalues = last_observation == null ? null : last_observation.toArray();
INDArray newvalues = old_observation == null ? null : old_observation.getData();
INDArray oldvalues = last_observation == null ? null : last_observation.getData();
return !(world_state.getObservations().isEmpty() || world_state.getRewards().isEmpty()
|| Arrays.equals(oldvalues, newvalues));
|| oldvalues.eq(newvalues).all());
}
}

View File

@ -21,6 +21,7 @@ import java.nio.file.Paths;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import com.microsoft.msr.malmo.AgentHost;
@ -34,6 +35,7 @@ import com.microsoft.msr.malmo.WorldState;
import lombok.Setter;
import lombok.Getter;
import org.deeplearning4j.rl4j.space.Encodable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -233,7 +235,7 @@ public class MalmoEnv implements MDP<MalmoBox, Integer, DiscreteSpace> {
logger.info("Mission ended");
}
return new StepReply<MalmoBox>(last_observation, getRewards(last_world_state), isDone(), null);
return new StepReply<>(last_observation, getRewards(last_world_state), isDone(), null);
}
private double getRewards(WorldState world_state) {

View File

@ -16,6 +16,8 @@
package org.deeplearning4j.malmo;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import com.microsoft.msr.malmo.WorldState;

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.malmo;
import com.microsoft.msr.malmo.TimestampedStringVector;
import com.microsoft.msr.malmo.WorldState;
import org.deeplearning4j.rl4j.space.Box;
import org.json.JSONArray;
import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.malmo;
import java.util.HashMap;
import org.deeplearning4j.rl4j.space.Box;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.malmo;
import org.deeplearning4j.rl4j.space.Box;
import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray;