RL4J: Sanitize Observation (#404)

* working on ALE image pipelines that appear to lose data

* transformation pipeline for ALE has been broken for a while and needed some cleanup to make sure that openCV tooling for scene transforms was actually working.

* allowing history length to be set and passed through to history merge transforms

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* native image loader is not thread-safe so should not be static

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* make sure the transformer for encoding observations that are not pixels converts corectly

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* Test fixes for ALE pixel observation shape

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* Fix compilation errors

Signed-off-by: Samuel Audet <samuel.audet@gmail.com>

* fixing some post-merge issues, and comments from PR

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

Co-authored-by: Samuel Audet <samuel.audet@gmail.com>
master
Chris Bamford 2020-04-23 02:47:26 +01:00 committed by GitHub
parent 75cc6e2ed7
commit 032b97912e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 524 additions and 577 deletions

View File

@ -25,9 +25,13 @@ import org.bytedeco.javacpp.IntPointer;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace; import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/** /**
* @author saudet * @author saudet
@ -70,10 +74,14 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
actions = new int[(int)a.limit()]; actions = new int[(int)a.limit()];
a.get(actions); a.get(actions);
int height = (int)ale.getScreen().height();
int width = (int)(int)ale.getScreen().width();
discreteSpace = new DiscreteSpace(actions.length); discreteSpace = new DiscreteSpace(actions.length);
int[] shape = {(int)ale.getScreen().height(), (int)ale.getScreen().width(), 3}; int[] shape = {3, height, width};
observationSpace = new ArrayObservationSpace<>(shape); observationSpace = new ArrayObservationSpace<>(shape);
screenBuffer = new byte[shape[0] * shape[1] * shape[2]]; screenBuffer = new byte[shape[0] * shape[1] * shape[2]];
} }
public void setupGame() { public void setupGame() {
@ -103,7 +111,7 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
public GameScreen reset() { public GameScreen reset() {
ale.reset_game(); ale.reset_game();
ale.getScreenRGB(screenBuffer); ale.getScreenRGB(screenBuffer);
return new GameScreen(screenBuffer); return new GameScreen(observationSpace.getShape(), screenBuffer);
} }
@ -115,7 +123,8 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
double r = ale.act(actions[action]) * scaleFactor; double r = ale.act(actions[action]) * scaleFactor;
log.info(ale.getEpisodeFrameNumber() + " " + r + " " + action + " "); log.info(ale.getEpisodeFrameNumber() + " " + r + " " + action + " ");
ale.getScreenRGB(screenBuffer); ale.getScreenRGB(screenBuffer);
return new StepReply(new GameScreen(screenBuffer), r, ale.game_over(), null);
return new StepReply(new GameScreen(observationSpace.getShape(), screenBuffer), r, ale.game_over(), null);
} }
public ObservationSpace<GameScreen> getObservationSpace() { public ObservationSpace<GameScreen> getObservationSpace() {
@ -140,17 +149,35 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
} }
public static class GameScreen implements Encodable { public static class GameScreen implements Encodable {
double[] array;
public GameScreen(byte[] screen) { final INDArray data;
array = new double[screen.length]; public GameScreen(int[] shape, byte[] screen) {
for (int i = 0; i < screen.length; i++) {
array[i] = (screen[i] & 0xFF) / 255.0; data = Nd4j.create(screen, new long[] {shape[1], shape[2], 3}, DataType.UINT8).permute(2,0,1);
}
} }
private GameScreen(INDArray toDup) {
data = toDup.dup();
}
@Override
public double[] toArray() { public double[] toArray() {
return array; return data.data().asDouble();
}
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return data;
}
@Override
public GameScreen dup() {
return new GameScreen(data);
} }
} }
} }

View File

@ -19,15 +19,15 @@ package org.deeplearning4j.gym;
import lombok.Value; import lombok.Value;
/** /**
* @param <T> type of observation * @param <OBSERVATION> type of observation
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16.
* *
* StepReply is the container for the data returned after each step(action). * StepReply is the container for the data returned after each step(action).
*/ */
@Value @Value
public class StepReply<T> { public class StepReply<OBSERVATION> {
T observation; OBSERVATION observation;
double reward; double reward;
boolean done; boolean done;
Object info; Object info;

View File

@ -32,7 +32,7 @@ import org.deeplearning4j.rl4j.space.ObservationSpace;
* in a "functionnal manner" if step return a mdp * in a "functionnal manner" if step return a mdp
* *
*/ */
public interface MDP<OBSERVATION, ACTION, ACTION_SPACE extends ActionSpace<ACTION>> { public interface MDP<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>> {
ObservationSpace<OBSERVATION> getObservationSpace(); ObservationSpace<OBSERVATION> getObservationSpace();

View File

@ -16,6 +16,9 @@
package org.deeplearning4j.rl4j.space; package org.deeplearning4j.rl4j.space;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
* *
@ -25,13 +28,37 @@ package org.deeplearning4j.rl4j.space;
*/ */
public class Box implements Encodable { public class Box implements Encodable {
private final double[] array; private final INDArray data;
public Box(double[] arr) { public Box(double... arr) {
this.array = arr; this.data = Nd4j.create(arr);
} }
public Box(int[] shape, double... arr) {
this.data = Nd4j.create(arr).reshape(shape);
}
private Box(INDArray toDup) {
data = toDup.dup();
}
@Override
public double[] toArray() { public double[] toArray() {
return array; return data.data().asDouble();
}
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return data;
}
@Override
public Encodable dup() {
return new Box(data);
} }
} }

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2020 Konduit K. K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,17 +16,19 @@
package org.deeplearning4j.rl4j.space; package org.deeplearning4j.rl4j.space;
/** import org.nd4j.linalg.api.ndarray.INDArray;
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/19/16.
* Encodable is an interface that ensure that the state is convertible to a double array
*/
public interface Encodable { public interface Encodable {
/** @Deprecated
* $
* encodes all the information of an Observation in an array double and can be used as input of a DQN directly
*
* @return the encoded informations
*/
double[] toArray(); double[] toArray();
boolean isSkipped();
/**
* Any image data should be in CHW format.
*/
INDArray getData();
Encodable dup();
} }

View File

@ -24,16 +24,17 @@ import org.nd4j.linalg.factory.Nd4j;
* @author Alexandre Boulanger * @author Alexandre Boulanger
*/ */
public class INDArrayHelper { public class INDArrayHelper {
/** /**
* MultiLayerNetwork and ComputationGraph expect the first dimension to be the number of examples in the INDArray. * MultiLayerNetwork and ComputationGraph expects input data to be in NCHW in the case of pixels and NS in case of other data types.
* In the case of RL4J, it must be 1. This method will return a INDArray with the correct shape.
* *
* @param source A INDArray * We must have either shape 2 (NK) or shape 4 (NCHW)
* @return The source INDArray with the correct shape
*/ */
public static INDArray forceCorrectShape(INDArray source) { public static INDArray forceCorrectShape(INDArray source) {
return source.shape()[0] == 1 && source.shape().length > 1 return source.shape()[0] == 1 && source.shape().length > 1
? source ? source
: Nd4j.expandDims(source, 0); : Nd4j.expandDims(source, 0);
} }
} }

View File

@ -46,7 +46,6 @@ public class HistoryProcessor implements IHistoryProcessor {
@Getter @Getter
final private Configuration conf; final private Configuration conf;
final private OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat();
private CircularFifoQueue<INDArray> history; private CircularFifoQueue<INDArray> history;
private VideoRecorder videoRecorder; private VideoRecorder videoRecorder;
@ -63,8 +62,7 @@ public class HistoryProcessor implements IHistoryProcessor {
public void startMonitor(String filename, int[] shape) { public void startMonitor(String filename, int[] shape) {
if(videoRecorder == null) { if(videoRecorder == null) {
videoRecorder = VideoRecorder.builder(shape[0], shape[1]) videoRecorder = VideoRecorder.builder(shape[1], shape[2])
.frameInputType(VideoRecorder.FrameInputTypes.Float)
.build(); .build();
} }
@ -89,14 +87,13 @@ public class HistoryProcessor implements IHistoryProcessor {
return videoRecorder != null && videoRecorder.isRecording(); return videoRecorder != null && videoRecorder.isRecording();
} }
public void record(INDArray raw) { public void record(INDArray pixelArray) {
if(isMonitoring()) { if(isMonitoring()) {
// before accessing the raw pointer, we need to make sure that array is actual on the host side // before accessing the raw pointer, we need to make sure that array is actual on the host side
Nd4j.getAffinityManager().ensureLocation(raw, AffinityManager.Location.HOST); Nd4j.getAffinityManager().ensureLocation(pixelArray, AffinityManager.Location.HOST);
VideoRecorder.VideoFrame frame = videoRecorder.createFrame(raw.data().pointer());
try { try {
videoRecorder.record(frame); videoRecorder.record(pixelArray);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); e.printStackTrace();
} }

View File

@ -64,7 +64,7 @@ public interface IHistoryProcessor {
@Builder.Default int skipFrame = 4; @Builder.Default int skipFrame = 4;
public int[] getShape() { public int[] getShape() {
return new int[] {getHistoryLength(), getCroppingHeight(), getCroppingWidth()}; return new int[] {getHistoryLength(), getRescaledHeight(), getRescaledWidth()};
} }
} }
} }

View File

@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
* *
* A common interface that any training method should implement * A common interface that any training method should implement
*/ */
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> { public interface ILearning<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> {
IPolicy<A> getPolicy(); IPolicy<A> getPolicy();
@ -38,7 +38,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
ILearningConfiguration getConfiguration(); ILearningConfiguration getConfiguration();
MDP<O, A, AS> getMdp(); MDP<OBSERVATION, A, AS> getMdp();
IHistoryProcessor getHistoryProcessor(); IHistoryProcessor getHistoryProcessor();

View File

@ -21,7 +21,6 @@ import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.Value; import lombok.Value;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
@ -38,8 +37,8 @@ import org.nd4j.linalg.factory.Nd4j;
* *
*/ */
@Slf4j @Slf4j
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> public abstract class Learning<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> { implements ILearning<OBSERVATION, A, AS>, NeuralNetFetchable<NN> {
@Getter @Setter @Getter @Setter
protected int stepCount = 0; protected int stepCount = 0;

View File

@ -29,10 +29,10 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -188,7 +188,7 @@ public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_
} }
private boolean handleTraining(RunContext context) { private boolean handleTraining(RunContext context) {
int maxTrainSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpisodeStepCount); int maxTrainSteps = Math.min(getConfiguration().getNStep(), getConfiguration().getMaxEpochStep() - currentEpisodeStepCount);
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxTrainSteps); SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxTrainSteps);
context.obs = subEpochReturn.getLastObs(); context.obs = subEpochReturn.getLastObs();
@ -219,7 +219,7 @@ public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_
protected abstract IAsyncGlobal<NN> getAsyncGlobal(); protected abstract IAsyncGlobal<NN> getAsyncGlobal();
protected abstract IAsyncLearningConfiguration getConf(); protected abstract IAsyncLearningConfiguration getConfiguration();
protected abstract IPolicy<ACTION> getPolicy(NN net); protected abstract IPolicy<ACTION> getPolicy(NN net);

View File

@ -24,29 +24,22 @@ import lombok.Setter;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Stack;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
* <p> * <p>
* Async Learning specialized for the Discrete Domain * Async Learning specialized for the Discrete Domain
*/ */
public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet> public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN extends NeuralNet>
extends AsyncThread<O, Integer, DiscreteSpace, NN> { extends AsyncThread<OBSERVATION, Integer, DiscreteSpace, NN> {
@Getter @Getter
private NN current; private NN current;
@ -59,7 +52,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
private ExperienceHandler experienceHandler = new StateActionExperienceHandler(); private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
MDP<O, Integer, DiscreteSpace> mdp, MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
TrainingListenerList listeners, TrainingListenerList listeners,
int threadNumber, int threadNumber,
int deviceNum) { int deviceNum) {
@ -112,7 +105,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
} }
StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action); StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action);
accuReward += stepReply.getReward() * getConf().getRewardFactor(); accuReward += stepReply.getReward() * getConfiguration().getRewardFactor();
if (!obs.isSkipped()) { if (!obs.isSkipped()) {
experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone()); experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone());
@ -126,7 +119,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
} }
boolean episodeComplete = getMdp().isDone() || getConf().getMaxEpochStep() == currentEpisodeStepCount; boolean episodeComplete = getMdp().isDone() || getConfiguration().getMaxEpochStep() == currentEpisodeStepCount;
if (episodeComplete && experienceHandler.getTrainingBatchSize() != trainingSteps) { if (episodeComplete && experienceHandler.getTrainingBatchSize() != trainingSteps) {
experienceHandler.setFinalObservation(obs); experienceHandler.setFinalObservation(obs);

View File

@ -28,9 +28,9 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -41,19 +41,19 @@ import org.nd4j.linalg.factory.Nd4j;
* All methods are fully implemented as described in the * All methods are fully implemented as described in the
* https://arxiv.org/abs/1602.01783 paper. * https://arxiv.org/abs/1602.01783 paper.
*/ */
public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, Integer, DiscreteSpace, IActorCritic> { public abstract class A3CDiscrete<OBSERVATION extends Encodable> extends AsyncLearning<OBSERVATION, Integer, DiscreteSpace, IActorCritic> {
@Getter @Getter
final public A3CLearningConfiguration configuration; final public A3CLearningConfiguration configuration;
@Getter @Getter
final protected MDP<O, Integer, DiscreteSpace> mdp; final protected MDP<OBSERVATION, Integer, DiscreteSpace> mdp;
final private IActorCritic iActorCritic; final private IActorCritic iActorCritic;
@Getter @Getter
final private AsyncGlobal asyncGlobal; final private AsyncGlobal asyncGlobal;
@Getter @Getter
final private ACPolicy<O> policy; final private ACPolicy<OBSERVATION> policy;
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) { public A3CDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) {
this.iActorCritic = iActorCritic; this.iActorCritic = iActorCritic;
this.mdp = mdp; this.mdp = mdp;
this.configuration = conf; this.configuration = conf;

View File

@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph;
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv; import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration; import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
@ -42,19 +42,19 @@ import org.deeplearning4j.rl4j.util.IDataManager;
* first layers since they're essentially doing the same dimension * first layers since they're essentially doing the same dimension
* reduction task * reduction task
**/ **/
public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> { public class A3CDiscreteConv<OBSERVATION extends Encodable> extends A3CDiscrete<OBSERVATION> {
final private HistoryProcessor.Configuration hpconf; final private HistoryProcessor.Configuration hpconf;
@Deprecated @Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, actorCritic, hpconf, conf); this(mdp, actorCritic, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
@Deprecated @Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
super(mdp, IActorCritic, conf.toLearningConfiguration()); super(mdp, IActorCritic, conf.toLearningConfiguration());
@ -62,7 +62,7 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
} }
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
super(mdp, IActorCritic, conf); super(mdp, IActorCritic, conf);
this.hpconf = hpconf; this.hpconf = hpconf;
@ -70,35 +70,35 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
} }
@Deprecated @Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
} }
@Deprecated @Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
} }
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
} }
@Deprecated @Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
} }
@Deprecated @Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf); this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
} }
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticNetworkConfiguration netConf, public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticNetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf); this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
} }

View File

@ -21,8 +21,8 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.*; import org.deeplearning4j.rl4j.network.ac.*;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
@ -34,74 +34,74 @@ import org.deeplearning4j.rl4j.util.IDataManager;
* We use specifically the Separate version because * We use specifically the Separate version because
* the model is too small to have enough benefit by sharing layers * the model is too small to have enough benefit by sharing layers
*/ */
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> { public class A3CDiscreteDense<OBSERVATION extends Encodable> extends A3CDiscrete<OBSERVATION> {
@Deprecated @Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf, public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
IDataManager dataManager) { IDataManager dataManager) {
this(mdp, IActorCritic, conf); this(mdp, IActorCritic, conf);
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
@Deprecated @Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) { public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
super(mdp, actorCritic, conf.toLearningConfiguration()); super(mdp, actorCritic, conf.toLearningConfiguration());
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) { public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) {
super(mdp, actorCritic, conf); super(mdp, actorCritic, conf);
} }
@Deprecated @Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf, IDataManager dataManager) { A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
@Deprecated @Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf) { A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CLearningConfiguration conf) { A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }
@Deprecated @Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
IDataManager dataManager) { IDataManager dataManager) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager); this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
} }
@Deprecated @Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) { ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf); this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) { ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf); this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
} }
@Deprecated @Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf, IDataManager dataManager) { A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
@Deprecated @Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf) { A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CLearningConfiguration conf) { A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }

View File

@ -23,23 +23,23 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
* * <p>
* Local thread as described in the https://arxiv.org/abs/1602.01783 paper. * Local thread as described in the https://arxiv.org/abs/1602.01783 paper.
*/ */
public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IActorCritic> { public class A3CThreadDiscrete<OBSERVATION extends Encodable> extends AsyncThreadDiscrete<OBSERVATION, IActorCritic> {
@Getter @Getter
final protected A3CLearningConfiguration conf; final protected A3CLearningConfiguration configuration;
@Getter @Getter
final protected IAsyncGlobal<IActorCritic> asyncGlobal; final protected IAsyncGlobal<IActorCritic> asyncGlobal;
@Getter @Getter
@ -47,17 +47,17 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
final private Random rnd; final private Random rnd;
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal, public A3CThreadDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal,
A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners, A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
int threadNumber) { int threadNumber) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = a3cc; this.configuration = a3cc;
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber; this.threadNumber = threadNumber;
Long seed = conf.getSeed(); Long seed = configuration.getSeed();
rnd = Nd4j.getRandom(); rnd = Nd4j.getRandom();
if(seed != null) { if (seed != null) {
rnd.setSeed(seed + threadNumber); rnd.setSeed(seed + threadNumber);
} }
@ -69,9 +69,12 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
return new ACPolicy(net, rnd); return new ACPolicy(net, rnd);
} }
/**
* calc the gradients based on the n-step rewards
*/
@Override @Override
protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() { protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), conf.getGamma()); return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), configuration.getGamma());
} }
} }

View File

@ -28,26 +28,26 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*/ */
public abstract class AsyncNStepQLearningDiscrete<O extends Encodable> public abstract class AsyncNStepQLearningDiscrete<OBSERVATION extends Encodable>
extends AsyncLearning<O, Integer, DiscreteSpace, IDQN> { extends AsyncLearning<OBSERVATION, Integer, DiscreteSpace, IDQN> {
@Getter @Getter
final public AsyncQLearningConfiguration configuration; final public AsyncQLearningConfiguration configuration;
@Getter @Getter
final private MDP<O, Integer, DiscreteSpace> mdp; final private MDP<OBSERVATION, Integer, DiscreteSpace> mdp;
@Getter @Getter
final private AsyncGlobal<IDQN> asyncGlobal; final private AsyncGlobal<IDQN> asyncGlobal;
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) { public AsyncNStepQLearningDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
this.mdp = mdp; this.mdp = mdp;
this.configuration = conf; this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf); this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
@ -63,7 +63,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
} }
public IPolicy<Integer> getPolicy() { public IPolicy<Integer> getPolicy() {
return new DQNPolicy<O>(getNeuralNet()); return new DQNPolicy<OBSERVATION>(getNeuralNet());
} }
@Data @Data

View File

@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
@ -35,17 +35,17 @@ import org.deeplearning4j.rl4j.util.IDataManager;
* Specialized constructors for the Conv (pixels input) case * Specialized constructors for the Conv (pixels input) case
* Specialized conf + provide additional type safety * Specialized conf + provide additional type safety
*/ */
public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> { public class AsyncNStepQLearningDiscreteConv<OBSERVATION extends Encodable> extends AsyncNStepQLearningDiscrete<OBSERVATION> {
final private HistoryProcessor.Configuration hpconf; final private HistoryProcessor.Configuration hpconf;
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, dqn, hpconf, conf); this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
super(mdp, dqn, conf); super(mdp, dqn, conf);
this.hpconf = hpconf; this.hpconf = hpconf;
@ -53,21 +53,21 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
} }
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
} }
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
} }
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf, public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
} }
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf, public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
} }

View File

@ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio
import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
*/ */
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> { public class AsyncNStepQLearningDiscreteDense<OBSERVATION extends Encodable> extends AsyncNStepQLearningDiscrete<OBSERVATION> {
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf, IDataManager dataManager) { AsyncNStepQLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf.toLearningConfiguration()); super(mdp, dqn, conf.toLearningConfiguration());
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf) { AsyncNStepQLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration()); super(mdp, dqn, conf.toLearningConfiguration());
} }
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncQLearningConfiguration conf) { AsyncQLearningConfiguration conf) {
super(mdp, dqn, conf); super(mdp, dqn, conf);
} }
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf, IDataManager dataManager) { AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf) { AsyncNStepQLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncQLearningConfiguration conf) { AsyncQLearningConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
} }
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) { DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
} }
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) { DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf); this(mdp, new DQNFactoryStdDense(netConf), conf);
} }

View File

@ -25,21 +25,21 @@ import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguratio
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*/ */
public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IDQN> { public class AsyncNStepQLearningThreadDiscrete<OBSERVATION extends Encodable> extends AsyncThreadDiscrete<OBSERVATION, IDQN> {
@Getter @Getter
final protected AsyncQLearningConfiguration conf; final protected AsyncQLearningConfiguration configuration;
@Getter @Getter
final protected IAsyncGlobal<IDQN> asyncGlobal; final protected IAsyncGlobal<IDQN> asyncGlobal;
@Getter @Getter
@ -47,17 +47,17 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
final private Random rnd; final private Random rnd;
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal, public AsyncNStepQLearningThreadDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
AsyncQLearningConfiguration conf, AsyncQLearningConfiguration configuration,
TrainingListenerList listeners, int threadNumber, int deviceNum) { TrainingListenerList listeners, int threadNumber, int deviceNum) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = conf; this.configuration = configuration;
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber; this.threadNumber = threadNumber;
rnd = Nd4j.getRandom(); rnd = Nd4j.getRandom();
Long seed = conf.getSeed(); Long seed = configuration.getSeed();
if (seed != null) { if(seed != null) {
rnd.setSeed(seed + threadNumber); rnd.setSeed(seed + threadNumber);
} }
@ -65,13 +65,13 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
} }
public Policy<Integer> getPolicy(IDQN nn) { public Policy<Integer> getPolicy(IDQN nn) {
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(), return new EpsGreedy(new DQNPolicy(nn), getMdp(), configuration.getUpdateStart(), configuration.getEpsilonNbStep(),
rnd, conf.getMinEpsilon(), this); rnd, configuration.getMinEpsilon(), this);
} }
@Override @Override
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() { protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), conf.getGamma()); return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), configuration.getGamma());
} }
} }

View File

@ -32,10 +32,10 @@ import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning; import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry; import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;

View File

@ -33,11 +33,11 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorith
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;

View File

@ -24,8 +24,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
@ -34,59 +34,59 @@ import org.deeplearning4j.rl4j.util.IDataManager;
* Specialized constructors for the Conv (pixels input) case * Specialized constructors for the Conv (pixels input) case
* Specialized conf + provide additional type safety * Specialized conf + provide additional type safety
*/ */
public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscrete<O> { public class QLearningDiscreteConv<OBSERVATION extends Encodable> extends QLearningDiscrete<OBSERVATION> {
@Deprecated @Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf, IDataManager dataManager) { QLConfiguration conf, IDataManager dataManager) {
this(mdp, dqn, hpconf, conf); this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
@Deprecated @Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf) { QLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame()); super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
} }
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLearningConfiguration conf) { QLearningConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
} }
@Deprecated @Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
} }
@Deprecated @Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) { HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
} }
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
} }
@Deprecated @Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
} }
@Deprecated @Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) { HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf); this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
} }
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf, public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
} }

View File

@ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio
import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
*/ */
public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscrete<O> { public class QLearningDiscreteDense<OBSERVATION extends Encodable> extends QLearningDiscrete<OBSERVATION> {
@Deprecated @Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf, public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
IDataManager dataManager) { IDataManager dataManager) {
this(mdp, dqn, conf); this(mdp, dqn, conf);
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
@Deprecated @Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) { public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep()); super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep());
} }
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf) { public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep()); super(mdp, dqn, conf, conf.getEpsilonNbStep());
} }
@Deprecated @Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearning.QLConfiguration conf, IDataManager dataManager) { QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
@Deprecated @Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearning.QLConfiguration conf) { QLearning.QLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearningConfiguration conf) { QLearningConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }
@Deprecated @Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf, public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf, IDataManager dataManager) { QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
} }
@Deprecated @Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf, public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf) { QLearning.QLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
} }
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNDenseNetworkConfiguration netConf, public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNDenseNetworkConfiguration netConf,
QLearningConfiguration conf) { QLearningConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf); this(mdp, new DQNFactoryStdDense(netConf), conf);
} }

View File

@ -4,8 +4,8 @@ import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace; import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
import java.util.Random; import java.util.Random;
@ -36,7 +36,7 @@ import java.util.Random;
*/ */
public class CartpoleNative implements MDP<CartpoleNative.State, Integer, DiscreteSpace> { public class CartpoleNative implements MDP<Box, Integer, DiscreteSpace> {
public enum KinematicsIntegrators { Euler, SemiImplicitEuler }; public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
private static final int NUM_ACTIONS = 2; private static final int NUM_ACTIONS = 2;
@ -74,7 +74,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
@Getter @Getter
private DiscreteSpace actionSpace = new DiscreteSpace(NUM_ACTIONS); private DiscreteSpace actionSpace = new DiscreteSpace(NUM_ACTIONS);
@Getter @Getter
private ObservationSpace<CartpoleNative.State> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES }); private ObservationSpace<Box> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES });
public CartpoleNative() { public CartpoleNative() {
rnd = new Random(); rnd = new Random();
@ -85,7 +85,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
} }
@Override @Override
public State reset() { public Box reset() {
x = 0.1 * rnd.nextDouble() - 0.05; x = 0.1 * rnd.nextDouble() - 0.05;
xDot = 0.1 * rnd.nextDouble() - 0.05; xDot = 0.1 * rnd.nextDouble() - 0.05;
@ -94,7 +94,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
stepsBeyondDone = null; stepsBeyondDone = null;
done = false; done = false;
return new State(new double[] { x, xDot, theta, thetaDot }); return new Box(x, xDot, theta, thetaDot);
} }
@Override @Override
@ -103,7 +103,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
} }
@Override @Override
public StepReply<State> step(Integer action) { public StepReply<Box> step(Integer action) {
double force = action == ACTION_RIGHT ? forceMag : -forceMag; double force = action == ACTION_RIGHT ? forceMag : -forceMag;
double cosTheta = Math.cos(theta); double cosTheta = Math.cos(theta);
double sinTheta = Math.sin(theta); double sinTheta = Math.sin(theta);
@ -143,26 +143,12 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
reward = 0; reward = 0;
} }
return new StepReply<>(new State(new double[] { x, xDot, theta, thetaDot }), reward, done, null); return new StepReply<>(new Box(x, xDot, theta, thetaDot), reward, done, null);
} }
@Override @Override
public MDP<State, Integer, DiscreteSpace> newInstance() { public MDP<Box, Integer, DiscreteSpace> newInstance() {
return new CartpoleNative(); return new CartpoleNative();
} }
public static class State implements Encodable {
private final double[] state;
State(double[] state) {
this.state = state;
}
@Override
public double[] toArray() {
return state;
}
}
} }

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy;
import lombok.Value; import lombok.Value;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
@ -31,4 +32,19 @@ public class HardToyState implements Encodable {
public double[] toArray() { public double[] toArray() {
return values; return values;
} }
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return null;
}
@Override
public Encodable dup() {
return null;
}
} }

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace; import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -40,7 +41,6 @@ import org.nd4j.linalg.factory.Nd4j;
public class SimpleToy implements MDP<SimpleToyState, Integer, DiscreteSpace> { public class SimpleToy implements MDP<SimpleToyState, Integer, DiscreteSpace> {
final private int maxStep; final private int maxStep;
//TODO 10 steps toy (always +1 reward2 actions), toylong (1000 steps), toyhard (7 actions, +1 only if actiion = (step/100+step)%7, and toyStoch (like last but reward has 0.10 odd to be somewhere else).
@Getter @Getter
private DiscreteSpace actionSpace = new DiscreteSpace(2); private DiscreteSpace actionSpace = new DiscreteSpace(2);
@Getter @Getter

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy;
import lombok.Value; import lombok.Value;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
@ -28,11 +29,24 @@ public class SimpleToyState implements Encodable {
int i; int i;
int step; int step;
@Override
public double[] toArray() { public double[] toArray() {
double[] ar = new double[1]; double[] ar = new double[1];
ar[0] = (20 - i); ar[0] = (20 - i);
return ar; return ar;
} }
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return null;
}
@Override
public Encodable dup() {
return null;
}
} }

View File

@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
* *
* @author Alexandre Boulanger * @author Alexandre Boulanger
*/ */
public class Observation { public class Observation implements Encodable {
/** /**
* A singleton representing a skipped observation * A singleton representing a skipped observation
@ -38,6 +38,11 @@ public class Observation {
@Getter @Getter
private final INDArray data; private final INDArray data;
@Override
public double[] toArray() {
return data.data().asDouble();
}
public boolean isSkipped() { public boolean isSkipped() {
return data == null; return data == null;
} }

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2020 Konduit K.K. * Copyright (c) 2020 Konduit K. K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -13,29 +13,16 @@
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.rl4j.observation.transform.legacy;
import org.bytedeco.javacv.OpenCVFrameConverter; package org.deeplearning4j.rl4j.observation.transform;
import org.bytedeco.opencv.opencv_core.Mat;
import org.datavec.api.transform.Operation; import org.datavec.api.transform.Operation;
import org.datavec.image.data.ImageWritable;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
public class EncodableToINDArrayTransform implements Operation<Encodable, INDArray> { public class EncodableToINDArrayTransform implements Operation<Encodable, INDArray> {
private final int[] shape;
public EncodableToINDArrayTransform(int[] shape) {
this.shape = shape;
}
@Override @Override
public INDArray transform(Encodable encodable) { public INDArray transform(Encodable encodable) {
return Nd4j.create(encodable.toArray()).reshape(shape); return encodable.getData();
} }
} }

View File

@ -15,34 +15,32 @@
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.rl4j.observation.transform.legacy; package org.deeplearning4j.rl4j.observation.transform.legacy;
import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.bytedeco.opencv.opencv_core.Mat; import org.bytedeco.opencv.opencv_core.Mat;
import org.datavec.api.transform.Operation; import org.datavec.api.transform.Operation;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.bytedeco.opencv.global.opencv_core.CV_32FC; import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
import static org.bytedeco.opencv.global.opencv_core.CV_32FC3;
import static org.bytedeco.opencv.global.opencv_core.CV_32S;
import static org.bytedeco.opencv.global.opencv_core.CV_32SC;
import static org.bytedeco.opencv.global.opencv_core.CV_32SC3;
import static org.bytedeco.opencv.global.opencv_core.CV_64FC;
import static org.bytedeco.opencv.global.opencv_core.CV_8UC3;
public class EncodableToImageWritableTransform implements Operation<Encodable, ImageWritable> { public class EncodableToImageWritableTransform implements Operation<Encodable, ImageWritable> {
private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); final static NativeImageLoader nativeImageLoader = new NativeImageLoader();
private final int height;
private final int width;
private final int colorChannels;
public EncodableToImageWritableTransform(int height, int width, int colorChannels) {
this.height = height;
this.width = width;
this.colorChannels = colorChannels;
}
@Override @Override
public ImageWritable transform(Encodable encodable) { public ImageWritable transform(Encodable encodable) {
INDArray indArray = Nd4j.create(encodable.toArray()).reshape(height, width, colorChannels); return new ImageWritable(nativeImageLoader.asFrame(encodable.getData(), Frame.DEPTH_UBYTE));
Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer());
return new ImageWritable(converter.convert(mat));
} }
} }

View File

@ -18,34 +18,31 @@ package org.deeplearning4j.rl4j.observation.transform.legacy;
import org.datavec.api.transform.Operation; import org.datavec.api.transform.Operation;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException; import java.io.IOException;
public class ImageWritableToINDArrayTransform implements Operation<ImageWritable, INDArray> { public class ImageWritableToINDArrayTransform implements Operation<ImageWritable, INDArray> {
private final int height; private final NativeImageLoader loader = new NativeImageLoader();
private final int width;
private final NativeImageLoader loader;
public ImageWritableToINDArrayTransform(int height, int width) {
this.height = height;
this.width = width;
this.loader = new NativeImageLoader(height, width);
}
@Override @Override
public INDArray transform(ImageWritable imageWritable) { public INDArray transform(ImageWritable imageWritable) {
int height = imageWritable.getHeight();
int width = imageWritable.getWidth();
int channels = imageWritable.getFrame().imageChannels;
INDArray out = null; INDArray out = null;
try { try {
out = loader.asMatrix(imageWritable); out = loader.asMatrix(imageWritable);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); e.printStackTrace();
} }
out = out.reshape(1, height, width);
// Convert back to uint8 and reshape to the number of channels in the image
out = out.reshape(channels, height, width);
INDArray compressed = out.castTo(DataType.UINT8); INDArray compressed = out.castTo(DataType.UINT8);
return compressed; return compressed;
} }

View File

@ -46,19 +46,20 @@ public class HistoryMergeTransform implements Operation<INDArray, INDArray>, Res
private final HistoryMergeElementStore historyMergeElementStore; private final HistoryMergeElementStore historyMergeElementStore;
private final HistoryMergeAssembler historyMergeAssembler; private final HistoryMergeAssembler historyMergeAssembler;
private final boolean shouldStoreCopy; private final boolean shouldStoreCopy;
private final boolean isFirstDimenstionBatch; private final boolean isFirstDimensionBatch;
private HistoryMergeTransform(Builder builder) { private HistoryMergeTransform(Builder builder) {
this.historyMergeElementStore = builder.historyMergeElementStore; this.historyMergeElementStore = builder.historyMergeElementStore;
this.historyMergeAssembler = builder.historyMergeAssembler; this.historyMergeAssembler = builder.historyMergeAssembler;
this.shouldStoreCopy = builder.shouldStoreCopy; this.shouldStoreCopy = builder.shouldStoreCopy;
this.isFirstDimenstionBatch = builder.isFirstDimenstionBatch; this.isFirstDimensionBatch = builder.isFirstDimenstionBatch;
} }
@Override @Override
public INDArray transform(INDArray input) { public INDArray transform(INDArray input) {
INDArray element; INDArray element;
if(isFirstDimenstionBatch) { if(isFirstDimensionBatch) {
element = input.slice(0, 0); element = input.slice(0, 0);
} }
else { else {
@ -132,9 +133,9 @@ public class HistoryMergeTransform implements Operation<INDArray, INDArray>, Res
return this; return this;
} }
public HistoryMergeTransform build() { public HistoryMergeTransform build(int frameStackLength) {
if(historyMergeElementStore == null) { if(historyMergeElementStore == null) {
historyMergeElementStore = new CircularFifoStore(); historyMergeElementStore = new CircularFifoStore(frameStackLength);
} }
if(historyMergeAssembler == null) { if(historyMergeAssembler == null) {

View File

@ -28,14 +28,9 @@ import org.nd4j.linalg.factory.Nd4j;
* @author Alexandre Boulanger * @author Alexandre Boulanger
*/ */
public class CircularFifoStore implements HistoryMergeElementStore { public class CircularFifoStore implements HistoryMergeElementStore {
private static final int DEFAULT_STORE_SIZE = 4;
private final CircularFifoQueue<INDArray> queue; private final CircularFifoQueue<INDArray> queue;
public CircularFifoStore() {
this(DEFAULT_STORE_SIZE);
}
public CircularFifoStore(int size) { public CircularFifoStore(int size) {
Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size); Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size);
queue = new CircularFifoQueue<>(size); queue = new CircularFifoQueue<>(size);

View File

@ -20,8 +20,8 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph; import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate; import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -35,7 +35,7 @@ import java.io.IOException;
* the softmax output of the actor critic, but objects constructed * the softmax output of the actor critic, but objects constructed
* with a {@link Random} argument of null return the max only. * with a {@link Random} argument of null return the max only.
*/ */
public class ACPolicy<O extends Encodable> extends Policy<Integer> { public class ACPolicy<OBSERVATION extends Encodable> extends Policy<Integer> {
final private IActorCritic actorCritic; final private IActorCritic actorCritic;
Random rnd; Random rnd;
@ -48,18 +48,18 @@ public class ACPolicy<O extends Encodable> extends Policy<Integer> {
this.rnd = rnd; this.rnd = rnd;
} }
public static <O extends Encodable> ACPolicy<O> load(String path) throws IOException { public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String path) throws IOException {
return new ACPolicy<O>(ActorCriticCompGraph.load(path)); return new ACPolicy<>(ActorCriticCompGraph.load(path));
} }
public static <O extends Encodable> ACPolicy<O> load(String path, Random rnd) throws IOException { public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String path, Random rnd) throws IOException {
return new ACPolicy<O>(ActorCriticCompGraph.load(path), rnd); return new ACPolicy<>(ActorCriticCompGraph.load(path), rnd);
} }
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy) throws IOException { public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String pathValue, String pathPolicy) throws IOException {
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy)); return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy));
} }
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rnd) throws IOException { public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String pathValue, String pathPolicy, Random rnd) throws IOException {
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd); return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd);
} }
public IActorCritic getNeuralNet() { public IActorCritic getNeuralNet() {

View File

@ -17,8 +17,8 @@
package org.deeplearning4j.rl4j.policy; package org.deeplearning4j.rl4j.policy;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
@ -30,7 +30,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp;
* Boltzmann exploration is a stochastic policy wrt to the * Boltzmann exploration is a stochastic policy wrt to the
* exponential Q-values as evaluated by the dqn model. * exponential Q-values as evaluated by the dqn model.
*/ */
public class BoltzmannQ<O extends Encodable> extends Policy<Integer> { public class BoltzmannQ<OBSERVATION extends Encodable> extends Policy<Integer> {
final private IDQN dqn; final private IDQN dqn;
final private Random rnd; final private Random rnd;

View File

@ -20,8 +20,8 @@ import lombok.AllArgsConstructor;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.dqn.DQN; import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.IOException; import java.io.IOException;
@ -35,12 +35,12 @@ import java.io.IOException;
// FIXME: Should we rename this "GreedyPolicy"? // FIXME: Should we rename this "GreedyPolicy"?
@AllArgsConstructor @AllArgsConstructor
public class DQNPolicy<O> extends Policy<Integer> { public class DQNPolicy<OBSERVATION> extends Policy<Integer> {
final private IDQN dqn; final private IDQN dqn;
public static <O extends Encodable> DQNPolicy<O> load(String path) throws IOException { public static <OBSERVATION extends Encodable> DQNPolicy<OBSERVATION> load(String path) throws IOException {
return new DQNPolicy<O>(DQN.load(path)); return new DQNPolicy<>(DQN.load(path));
} }
public IDQN getNeuralNet() { public IDQN getNeuralNet() {

View File

@ -20,12 +20,11 @@ package org.deeplearning4j.rl4j.policy;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.IEpochTrainer; import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
@ -41,10 +40,10 @@ import org.nd4j.linalg.api.rng.Random;
*/ */
@AllArgsConstructor @AllArgsConstructor
@Slf4j @Slf4j
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> { public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
final private Policy<A> policy; final private Policy<A> policy;
final private MDP<O, A, AS> mdp; final private MDP<OBSERVATION, A, AS> mdp;
final private int updateStart; final private int updateStart;
final private int epsilonNbStep; final private int epsilonNbStep;
final private Random rnd; final private Random rnd;

View File

@ -22,9 +22,9 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
/** /**

View File

@ -7,22 +7,22 @@ import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.CropImageTransform; import org.datavec.image.transform.CropImageTransform;
import org.datavec.image.transform.MultiImageTransform; import org.datavec.image.transform.MultiImageTransform;
import org.datavec.image.transform.ResizeImageTransform; import org.datavec.image.transform.ResizeImageTransform;
import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess; import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform; import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform; import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -46,6 +46,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
private int skipFrame = 1; private int skipFrame = 1;
private int steps = 0; private int steps = 0;
public LegacyMDPWrapper(MDP<OBSERVATION, A, AS> wrappedMDP, IHistoryProcessor historyProcessor) { public LegacyMDPWrapper(MDP<OBSERVATION, A, AS> wrappedMDP, IHistoryProcessor historyProcessor) {
this.wrappedMDP = wrappedMDP; this.wrappedMDP = wrappedMDP;
this.shape = wrappedMDP.getObservationSpace().getShape(); this.shape = wrappedMDP.getObservationSpace().getShape();
@ -66,28 +67,33 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
if(historyProcessor != null && shape.length == 3) { if(historyProcessor != null && shape.length == 3) {
int skipFrame = historyProcessor.getConf().getSkipFrame(); int skipFrame = historyProcessor.getConf().getSkipFrame();
int frameStackLength = historyProcessor.getConf().getHistoryLength();
int finalHeight = historyProcessor.getConf().getCroppingHeight(); int height = shape[1];
int finalWidth = historyProcessor.getConf().getCroppingWidth(); int width = shape[2];
int cropBottom = height - historyProcessor.getConf().getCroppingHeight();
int cropRight = width - historyProcessor.getConf().getCroppingWidth();
transformProcess = TransformProcess.builder() transformProcess = TransformProcess.builder()
.filter(new UniformSkippingFilter(skipFrame)) .filter(new UniformSkippingFilter(skipFrame))
.transform("data", new EncodableToImageWritableTransform(shape[0], shape[1], shape[2])) .transform("data", new EncodableToImageWritableTransform())
.transform("data", new MultiImageTransform( .transform("data", new MultiImageTransform(
new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), cropBottom, cropRight),
new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()), new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()),
new ColorConversionTransform(COLOR_BGR2GRAY), new ColorConversionTransform(COLOR_BGR2GRAY)
new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), finalHeight, finalWidth) //new ShowImageTransform("crop + resize + greyscale")
)) ))
.transform("data", new ImageWritableToINDArrayTransform(finalHeight, finalWidth)) .transform("data", new ImageWritableToINDArrayTransform())
.transform("data", new SimpleNormalizationTransform(0.0, 255.0)) .transform("data", new SimpleNormalizationTransform(0.0, 255.0))
.transform("data", HistoryMergeTransform.builder() .transform("data", HistoryMergeTransform.builder()
.isFirstDimenstionBatch(true) .isFirstDimenstionBatch(true)
.build()) .build(frameStackLength))
.build("data"); .build("data");
} }
else { else {
transformProcess = TransformProcess.builder() transformProcess = TransformProcess.builder()
.transform("data", new EncodableToINDArrayTransform(shape)) .transform("data", new EncodableToINDArrayTransform())
.build("data"); .build("data");
} }
} }
@ -127,6 +133,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation()); Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation());
Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone()); Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone());
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
} }
@ -161,12 +168,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
} }
private INDArray getInput(OBSERVATION obs) { private INDArray getInput(OBSERVATION obs) {
INDArray arr = Nd4j.create(obs.toArray()); return obs.getData();
int[] shape = observationSpace.getShape();
if (shape.length == 1)
return arr.reshape(new long[] {1, arr.length()});
else
return arr.reshape(shape);
} }
public static class WrapperObservationSpace implements ObservationSpace<Observation> { public static class WrapperObservationSpace implements ObservationSpace<Observation> {

View File

@ -16,26 +16,21 @@
package org.deeplearning4j.rl4j.util; package org.deeplearning4j.rl4j.util;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacv.FFmpegFrameRecorder; import org.bytedeco.javacv.FFmpegFrameRecorder;
import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.loader.NativeImageLoader;
import org.bytedeco.opencv.global.opencv_core; import org.nd4j.linalg.api.buffer.DataType;
import org.bytedeco.opencv.global.opencv_imgproc; import org.nd4j.linalg.api.ndarray.INDArray;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.Rect;
import org.bytedeco.opencv.opencv_core.Size;
import org.opencv.imgproc.Imgproc;
import static org.bytedeco.ffmpeg.global.avcodec.*; import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_H264;
import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_MPEG4;
import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB0;
import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB24;
import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB8;
/** /**
* VideoRecorder is used to create a video from a sequence of individual frames. If using 3 channels * VideoRecorder is used to create a video from a sequence of INDArray frames. INDArrays are assumed to be in CHW format where C=3 and pixels are RGB encoded<br>
* images, it expects B-G-R order. A RGB order can be used by calling isRGBOrder(true).<br>
* Example:<br> * Example:<br>
* <pre> * <pre>
* {@code * {@code
@ -45,11 +40,8 @@ import static org.bytedeco.opencv.global.opencv_core.*;
* .build(); * .build();
* recorder.startRecording("myVideo.mp4"); * recorder.startRecording("myVideo.mp4");
* while(...) { * while(...) {
* byte[] data = new byte[160*100*3]; * INDArray chwData = Nd4j.create()
* // Todo: Fill data * recorder.record(chwData);
* VideoRecorder.VideoFrame frame = recorder.createFrame(data);
* // Todo: Apply cropping or resizing to frame
* recorder.record(frame);
* } * }
* recorder.stopRecording(); * recorder.stopRecording();
* } * }
@ -60,16 +52,13 @@ import static org.bytedeco.opencv.global.opencv_core.*;
@Slf4j @Slf4j
public class VideoRecorder implements AutoCloseable { public class VideoRecorder implements AutoCloseable {
public enum FrameInputTypes { BGR, RGB, Float } private final NativeImageLoader nativeImageLoader = new NativeImageLoader();
private final int height; private final int height;
private final int width; private final int width;
private final int imageType;
private final OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat();
private final int codec; private final int codec;
private final double framerate; private final double framerate;
private final int videoQuality; private final int videoQuality;
private final FrameInputTypes frameInputType;
private FFmpegFrameRecorder fmpegFrameRecorder = null; private FFmpegFrameRecorder fmpegFrameRecorder = null;
@ -83,11 +72,9 @@ public class VideoRecorder implements AutoCloseable {
private VideoRecorder(Builder builder) { private VideoRecorder(Builder builder) {
this.height = builder.height; this.height = builder.height;
this.width = builder.width; this.width = builder.width;
imageType = CV_8UC(builder.numChannels);
codec = builder.codec; codec = builder.codec;
framerate = builder.frameRate; framerate = builder.frameRate;
videoQuality = builder.videoQuality; videoQuality = builder.videoQuality;
frameInputType = builder.frameInputType;
} }
/** /**
@ -119,59 +106,11 @@ public class VideoRecorder implements AutoCloseable {
/** /**
* Add a frame to the video * Add a frame to the video
* @param frame the VideoFrame to add to the video * @param imageArray the INDArray that contains the data to be recorded, the data must be in CHW format
* @throws Exception * @throws Exception
*/ */
public void record(VideoFrame frame) throws Exception { public void record(INDArray imageArray) throws Exception {
Size size = frame.getMat().size(); fmpegFrameRecorder.record(nativeImageLoader.asFrame(imageArray, Frame.DEPTH_UBYTE));
if(size.height() != height || size.width() != width) {
throw new IllegalArgumentException(String.format("Wrong frame size. Got (%dh x %dw) expected (%dh x %dw)", size.height(), size.width(), height, width));
}
Frame cvFrame = openCVFrameConverter.convert(frame.getMat());
fmpegFrameRecorder.record(cvFrame);
}
/**
* Create a VideoFrame from a byte array.
* @param data A byte array. Expect the index to be of the form [(Y*Width + X) * NumChannels + channel]
* @return An instance of VideoFrame
*/
public VideoFrame createFrame(byte[] data) {
return createFrame(new BytePointer(data));
}
/**
* Create a VideoFrame from a byte array with different height and width than the video
* the frame will need to be cropped or resized before being added to the video)
*
* @param data A byte array Expect the index to be of the form [(Y*customWidth + X) * NumChannels + channel]
* @param customHeight The actual height of the data
* @param customWidth The actual width of the data
* @return A VideoFrame instance
*/
public VideoFrame createFrame(byte[] data, int customHeight, int customWidth) {
return createFrame(new BytePointer(data), customHeight, customWidth);
}
/**
* Create a VideoFrame from a Pointer (to use for example with a INDarray).
* @param data A Pointer (for example myINDArray.data().pointer())
* @return An instance of VideoFrame
*/
public VideoFrame createFrame(Pointer data) {
return new VideoFrame(height, width, imageType, frameInputType, data);
}
/**
* Create a VideoFrame from a Pointer with different height and width than the video
* the frame will need to be cropped or resized before being added to the video)
* @param data
* @param customHeight The actual height of the data
* @param customWidth The actual width of the data
* @return A VideoFrame instance
*/
public VideoFrame createFrame(Pointer data, int customHeight, int customWidth) {
return new VideoFrame(customHeight, customWidth, imageType, frameInputType, data);
} }
/** /**
@ -192,69 +131,12 @@ public class VideoRecorder implements AutoCloseable {
return new Builder(height, width); return new Builder(height, width);
} }
/**
* An individual frame for the video
*/
public static class VideoFrame {
private final int height;
private final int width;
private final int imageType;
@Getter
private Mat mat;
private VideoFrame(int height, int width, int imageType, FrameInputTypes frameInputType, Pointer data) {
this.height = height;
this.width = width;
this.imageType = imageType;
switch(frameInputType) {
case RGB:
Mat src = new Mat(height, width, imageType, data);
mat = new Mat(height, width, imageType);
opencv_imgproc.cvtColor(src, mat, Imgproc.COLOR_RGB2BGR);
break;
case BGR:
mat = new Mat(height, width, imageType, data);
break;
case Float:
Mat tmpMat = new Mat(height, width, CV_32FC(3), data);
mat = new Mat(height, width, imageType);
tmpMat.convertTo(mat, CV_8UC(3), 255.0, 0.0);
}
}
/**
* Crop the video to a specified size
* @param newHeight The new height of the frame
* @param newWidth The new width of the frame
* @param heightOffset The starting height offset in the uncropped frame
* @param widthOffset The starting weight offset in the uncropped frame
*/
public void crop(int newHeight, int newWidth, int heightOffset, int widthOffset) {
mat = mat.apply(new Rect(widthOffset, heightOffset, newWidth, newHeight));
}
/**
* Resize the frame to a specified size
* @param newHeight The new height of the frame
* @param newWidth The new width of the frame
*/
public void resize(int newHeight, int newWidth) {
mat = new Mat(newHeight, newWidth, imageType);
}
}
/** /**
* A builder class for the VideoRecorder * A builder class for the VideoRecorder
*/ */
public static class Builder { public static class Builder {
private final int height; private final int height;
private final int width; private final int width;
private int numChannels = 3;
private FrameInputTypes frameInputType = FrameInputTypes.BGR;
private int codec = AV_CODEC_ID_H264; private int codec = AV_CODEC_ID_H264;
private double frameRate = 30.0; private double frameRate = 30.0;
private int videoQuality = 30; private int videoQuality = 30;
@ -268,24 +150,6 @@ public class VideoRecorder implements AutoCloseable {
this.width = width; this.width = width;
} }
/**
* Specify the number of channels. Default is 3
* @param numChannels
*/
public Builder numChannels(int numChannels) {
this.numChannels = numChannels;
return this;
}
/**
* Tell the VideoRecorder what data it will receive (default is BGR)
* @param frameInputType (See {@link FrameInputTypes}}
*/
public Builder frameInputType(FrameInputTypes frameInputType) {
this.frameInputType = frameInputType;
return this;
}
/** /**
* The codec to use for the video. Default is AV_CODEC_ID_H264 * The codec to use for the video. Default is AV_CODEC_ID_H264
* @param codec Code (see {@link org.bytedeco.ffmpeg.global.avcodec codec codes}) * @param codec Code (see {@link org.bytedeco.ffmpeg.global.avcodec codec codes})

View File

@ -115,7 +115,7 @@ public class AsyncThreadDiscreteTest {
asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm); asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm);
when(asyncThreadDiscrete.getConf()).thenReturn(mockAsyncConfiguration); when(asyncThreadDiscrete.getConfiguration()).thenReturn(mockAsyncConfiguration);
when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0); when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0);
when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal); when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy); when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy);

View File

@ -39,7 +39,6 @@ import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
@ -130,7 +129,7 @@ public class AsyncThreadTest {
when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode); when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode);
when(mockAsyncConfiguration.getNStep()).thenReturn(nstep); when(mockAsyncConfiguration.getNStep()).thenReturn(nstep);
when(thread.getConf()).thenReturn(mockAsyncConfiguration); when(thread.getConfiguration()).thenReturn(mockAsyncConfiguration);
// if we hit the max step count // if we hit the max step count
when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps); when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps);

View File

@ -18,24 +18,16 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -43,17 +35,17 @@ import org.mockito.Mock;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import static org.junit.Assert.assertEquals;
import java.util.List; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
@ -82,6 +74,7 @@ public class QLearningDiscreteTest {
@Mock @Mock
QLearningConfiguration mockQlearningConfiguration; QLearningConfiguration mockQlearningConfiguration;
// HWC
int[] observationShape = new int[]{3, 10, 10}; int[] observationShape = new int[]{3, 10, 10};
int totalObservationSize = 1; int totalObservationSize = 1;
@ -123,6 +116,7 @@ public class QLearningDiscreteTest {
when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]); when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]);
when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]); when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]);
when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames); when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames);
when(mockHistoryConfiguration.getHistoryLength()).thenReturn(1);
when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration); when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration);
qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor); qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor);
@ -148,7 +142,7 @@ public class QLearningDiscreteTest {
Observation observation = new Observation(Nd4j.zeros(observationShape)); Observation observation = new Observation(Nd4j.zeros(observationShape));
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f})); when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null)); when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Observation(Nd4j.zeros(observationShape)), 0, false, null));
// Act // Act
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation); QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
@ -170,25 +164,26 @@ public class QLearningDiscreteTest {
// Arrange // Arrange
mockTestContext(100,0,2,1.0, 10); mockTestContext(100,0,2,1.0, 10);
mockHistoryProcessor(2); Observation skippedObservation = Observation.SkippedObservation;
Observation nextObservation = new Observation(Nd4j.zeros(observationShape));
// An example observation and 2 Q values output (2 actions) when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(nextObservation, 0, false, null));
Observation observation = new Observation(Nd4j.zeros(observationShape));
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
// Act // Act
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation); QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(skippedObservation);
// Assert // Assert
assertEquals(1.0, stepReturn.getMaxQ(), 1e-5); assertEquals(Double.NaN, stepReturn.getMaxQ(), 1e-5);
StepReply<Observation> stepReply = stepReturn.getStepReply(); StepReply<Observation> stepReply = stepReturn.getStepReply();
assertEquals(0, stepReply.getReward(), 1e-5); assertEquals(0, stepReply.getReward(), 1e-5);
assertFalse(stepReply.isDone()); assertFalse(stepReply.isDone());
assertTrue(stepReply.getObservation().isSkipped()); assertFalse(stepReply.getObservation().isSkipped());
assertEquals(0, qLearningDiscrete.getExperienceHandler().getTrainingBatchSize());
verify(mockDQN, never()).output(any(INDArray.class));
} }
//TODO: there are much more test cases here that can be improved upon //TODO: there are much more test cases here that can be improved upon

View File

@ -17,7 +17,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder() HistoryMergeTransform sut = HistoryMergeTransform.builder()
.isFirstDimenstionBatch(false) .isFirstDimenstionBatch(false)
.elementStore(store) .elementStore(store)
.build(); .build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
// Act // Act
@ -35,7 +35,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder() HistoryMergeTransform sut = HistoryMergeTransform.builder()
.isFirstDimenstionBatch(true) .isFirstDimenstionBatch(true)
.elementStore(store) .elementStore(store)
.build(); .build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3); INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3);
// Act // Act
@ -53,7 +53,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder() HistoryMergeTransform sut = HistoryMergeTransform.builder()
.isFirstDimenstionBatch(true) .isFirstDimenstionBatch(true)
.elementStore(store) .elementStore(store)
.build(); .build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
// Act // Act
@ -70,7 +70,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder() HistoryMergeTransform sut = HistoryMergeTransform.builder()
.shouldStoreCopy(false) .shouldStoreCopy(false)
.elementStore(store) .elementStore(store)
.build(); .build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
// Act // Act
@ -87,7 +87,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder() HistoryMergeTransform sut = HistoryMergeTransform.builder()
.shouldStoreCopy(true) .shouldStoreCopy(true)
.elementStore(store) .elementStore(store)
.build(); .build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
// Act // Act
@ -107,7 +107,7 @@ public class HistoryMergeTransformTest {
HistoryMergeTransform sut = HistoryMergeTransform.builder() HistoryMergeTransform sut = HistoryMergeTransform.builder()
.elementStore(store) .elementStore(store)
.assembler(assemble) .assembler(assemble)
.build(); .build(4);
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
// Act // Act

View File

@ -252,8 +252,8 @@ public class PolicyTest {
} }
@Override @Override
protected <O extends Encodable, AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, Integer, AS> mdpWrapper, IHistoryProcessor hp) { protected <MockObservation extends Encodable, AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockObservation, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength)); mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(skipFrame, historyLength));
return super.refacInitMdp(mdpWrapper, hp); return super.refacInitMdp(mdpWrapper, hp);
} }
} }

View File

@ -1,18 +0,0 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.space.Encodable;
public class MockEncodable implements Encodable {
private final int value;
public MockEncodable(int value) {
this.value = value;
}
@Override
public double[] toArray() {
return new double[] { value };
}
}

View File

@ -2,9 +2,10 @@ package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess; import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform; import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore; import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore;
@ -15,7 +16,7 @@ import org.nd4j.linalg.api.rng.Random;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> { public class MockMDP implements MDP<MockObservation, Integer, DiscreteSpace> {
private final DiscreteSpace actionSpace; private final DiscreteSpace actionSpace;
private final int stepsUntilDone; private final int stepsUntilDone;
@ -55,11 +56,11 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
} }
@Override @Override
public MockEncodable reset() { public MockObservation reset() {
++resetCount; ++resetCount;
currentObsValue = 0; currentObsValue = 0;
step = 0; step = 0;
return new MockEncodable(currentObsValue++); return new MockObservation(currentObsValue++);
} }
@Override @Override
@ -68,10 +69,10 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
} }
@Override @Override
public StepReply<MockEncodable> step(Integer action) { public StepReply<MockObservation> step(Integer action) {
actions.add(action); actions.add(action);
++step; ++step;
return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null); return new StepReply<>(new MockObservation(currentObsValue), (double) currentObsValue++, isDone(), null);
} }
@Override @Override
@ -84,14 +85,14 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
return null; return null;
} }
public static TransformProcess buildTransformProcess(int[] shape, int skipFrame, int historyLength) { public static TransformProcess buildTransformProcess(int skipFrame, int historyLength) {
return TransformProcess.builder() return TransformProcess.builder()
.filter(new UniformSkippingFilter(skipFrame)) .filter(new UniformSkippingFilter(skipFrame))
.transform("data", new EncodableToINDArrayTransform(shape)) .transform("data", new EncodableToINDArrayTransform())
.transform("data", new SimpleNormalizationTransform(0.0, 255.0)) .transform("data", new SimpleNormalizationTransform(0.0, 255.0))
.transform("data", HistoryMergeTransform.builder() .transform("data", HistoryMergeTransform.builder()
.elementStore(new CircularFifoStore(historyLength)) .elementStore(new CircularFifoStore(historyLength))
.build()) .build(4))
.build("data"); .build("data");
} }

View File

@ -0,0 +1,51 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K. K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class MockObservation implements Encodable {
final INDArray data;
public MockObservation(int value) {
this.data = Nd4j.ones(1).mul(value);
}
@Override
public double[] toArray() {
return data.data().asDouble();
}
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return data;
}
@Override
public Encodable dup() {
return null;
}
}

View File

@ -17,7 +17,7 @@ public class MockPolicy implements IPolicy<Integer> {
public List<INDArray> actionInputs = new ArrayList<INDArray>(); public List<INDArray> actionInputs = new ArrayList<INDArray>();
@Override @Override
public <O extends Encodable, AS extends ActionSpace<Integer>> double play(MDP<O, Integer, AS> mdp, IHistoryProcessor hp) { public <MockObservation extends Encodable, AS extends ActionSpace<Integer>> double play(MDP<MockObservation, Integer, AS> mdp, IHistoryProcessor hp) {
++playCallCount; ++playCallCount;
return 0; return 0;
} }

View File

@ -28,6 +28,9 @@ import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import vizdoom.*; import vizdoom.*;
import java.util.ArrayList; import java.util.ArrayList;
@ -155,7 +158,7 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre
+ Pointer.formatBytes(Pointer.totalPhysicalBytes())); + Pointer.formatBytes(Pointer.totalPhysicalBytes()));
game.newEpisode(); game.newEpisode();
return new GameScreen(game.getState().screenBuffer); return new GameScreen(observationSpace.getShape(), game.getState().screenBuffer);
} }
@ -168,7 +171,7 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre
double r = game.makeAction(actions.get(action)) * scaleFactor; double r = game.makeAction(actions.get(action)) * scaleFactor;
log.info(game.getEpisodeTime() + " " + r + " " + action + " "); log.info(game.getEpisodeTime() + " " + r + " " + action + " ");
return new StepReply(new GameScreen(game.isEpisodeFinished() return new StepReply(new GameScreen(observationSpace.getShape(), game.isEpisodeFinished()
? new byte[game.getScreenSize()] ? new byte[game.getScreenSize()]
: game.getState().screenBuffer), r, game.isEpisodeFinished(), null); : game.getState().screenBuffer), r, game.isEpisodeFinished(), null);
@ -201,18 +204,34 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre
public static class GameScreen implements Encodable { public static class GameScreen implements Encodable {
final INDArray data;
public GameScreen(int[] shape, byte[] screen) {
double[] array; data = Nd4j.create(screen, new long[] {shape[1], shape[2], 3}, DataType.UINT8).permute(2,0,1);
public GameScreen(byte[] screen) {
array = new double[screen.length];
for (int i = 0; i < screen.length; i++) {
array[i] = (screen[i] & 0xFF) / 255.0;
}
} }
private GameScreen(INDArray toDup) {
data = toDup.dup();
}
@Override
public double[] toArray() { public double[] toArray() {
return array; return data.data().asDouble();
}
@Override
public boolean isSkipped() {
return false;
}
@Override
public INDArray getData() {
return data;
}
@Override
public GameScreen dup() {
return new GameScreen(data);
} }
} }

View File

@ -19,8 +19,6 @@ package org.deeplearning4j.rl4j.mdp.gym;
import java.io.IOException; import java.io.IOException;
import lombok.Getter; import lombok.Getter;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.DoublePointer; import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
@ -31,8 +29,8 @@ import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.HighLowDiscrete; import org.deeplearning4j.rl4j.space.HighLowDiscrete;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.bytedeco.cpython.*; import org.bytedeco.cpython.*;
@ -47,7 +45,7 @@ import static org.bytedeco.numpy.global.numpy.*;
* @author saudet * @author saudet
*/ */
@Slf4j @Slf4j
public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { public class GymEnv<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> implements MDP<OBSERVATION, A, AS> {
public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn"; public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn";
@ -82,7 +80,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
private PyObject locals; private PyObject locals;
final protected DiscreteSpace actionSpace; final protected DiscreteSpace actionSpace;
final protected ObservationSpace<O> observationSpace; final protected ObservationSpace<OBSERVATION> observationSpace;
@Getter @Getter
final private String envId; final private String envId;
@Getter @Getter
@ -119,7 +117,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
for (int i = 0; i < shape.length; i++) { for (int i = 0; i < shape.length; i++) {
shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i)); shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i));
} }
observationSpace = (ObservationSpace<O>) new ArrayObservationSpace<Box>(shape); observationSpace = (ObservationSpace<OBSERVATION>) new ArrayObservationSpace<Box>(shape);
Py_DecRef(shapeTuple); Py_DecRef(shapeTuple);
PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null); PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null);
@ -140,7 +138,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
} }
@Override @Override
public ObservationSpace<O> getObservationSpace() { public ObservationSpace<OBSERVATION> getObservationSpace() {
return observationSpace; return observationSpace;
} }
@ -153,7 +151,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
} }
@Override @Override
public StepReply<O> step(A action) { public StepReply<OBSERVATION> step(A action) {
int gstate = PyGILState_Ensure(); int gstate = PyGILState_Ensure();
try { try {
if (render) { if (render) {
@ -186,7 +184,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
} }
@Override @Override
public O reset() { public OBSERVATION reset() {
int gstate = PyGILState_Ensure(); int gstate = PyGILState_Ensure();
try { try {
Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null)); Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null));
@ -201,7 +199,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
double[] data = new double[(int)stateData.capacity()]; double[] data = new double[(int)stateData.capacity()];
stateData.get(data); stateData.get(data);
return (O) new Box(data); return (OBSERVATION) new Box(data);
} finally { } finally {
PyGILState_Release(gstate); PyGILState_Release(gstate);
} }
@ -220,7 +218,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
} }
@Override @Override
public GymEnv<O, A, AS> newInstance() { public GymEnv<OBSERVATION, A, AS> newInstance() {
return new GymEnv<O, A, AS>(envId, render, monitor); return new GymEnv<OBSERVATION, A, AS>(envId, render, monitor);
} }
} }

View File

@ -40,8 +40,8 @@ public class GymEnvTest {
assertEquals(false, mdp.isDone()); assertEquals(false, mdp.isDone());
Box o = (Box)mdp.reset(); Box o = (Box)mdp.reset();
StepReply r = mdp.step(0); StepReply r = mdp.step(0);
assertEquals(4, o.toArray().length); assertEquals(4, o.getData().shape()[0]);
assertEquals(4, ((Box)r.getObservation()).toArray().length); assertEquals(4, ((Box)r.getObservation()).getData().shape()[0]);
assertNotEquals(null, mdp.newInstance()); assertNotEquals(null, mdp.newInstance());
mdp.close(); mdp.close();
} }

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2020 Konduit K. K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,33 +16,13 @@
package org.deeplearning4j.malmo; package org.deeplearning4j.malmo;
import java.util.Arrays; import org.deeplearning4j.rl4j.space.Box;
import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.rl4j.space.Encodable; @Deprecated
public class MalmoBox extends Box {
/** public MalmoBox(double... arr) {
* Encodable state as a simple value array similar to Gym Box model, but without a JSON constructor super(arr);
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17.
*/
public class MalmoBox implements Encodable {
double[] value;
/**
* Construct state from an array of doubles
* @param value state values
*/
//TODO: If this constructor was added to "Box", we wouldn't need this class at all.
public MalmoBox(double... value) {
this.value = value;
}
@Override
public double[] toArray() {
return value;
}
@Override
public String toString() {
return Arrays.toString(value);
} }
} }

View File

@ -19,6 +19,8 @@ package org.deeplearning4j.malmo;
import java.util.Arrays; import java.util.Arrays;
import com.microsoft.msr.malmo.WorldState; import com.microsoft.msr.malmo.WorldState;
import org.deeplearning4j.rl4j.space.Box;
import org.nd4j.linalg.api.ndarray.INDArray;
/** /**
* A Malmo consistency policy that ensures the both there is a reward and next observation has a different position that the previous one. * A Malmo consistency policy that ensures the both there is a reward and next observation has a different position that the previous one.
@ -30,14 +32,14 @@ public class MalmoDescretePositionPolicy implements MalmoObservationPolicy {
@Override @Override
public boolean isObservationConsistant(WorldState world_state, WorldState original_world_state) { public boolean isObservationConsistant(WorldState world_state, WorldState original_world_state) {
MalmoBox last_observation = observationSpace.getObservation(world_state); Box last_observation = observationSpace.getObservation(world_state);
MalmoBox old_observation = observationSpace.getObservation(original_world_state); Box old_observation = observationSpace.getObservation(original_world_state);
double[] newvalues = old_observation == null ? null : old_observation.toArray(); INDArray newvalues = old_observation == null ? null : old_observation.getData();
double[] oldvalues = last_observation == null ? null : last_observation.toArray(); INDArray oldvalues = last_observation == null ? null : last_observation.getData();
return !(world_state.getObservations().isEmpty() || world_state.getRewards().isEmpty() return !(world_state.getObservations().isEmpty() || world_state.getRewards().isEmpty()
|| Arrays.equals(oldvalues, newvalues)); || oldvalues.eq(newvalues).all());
} }
} }

View File

@ -21,6 +21,7 @@ import java.nio.file.Paths;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import com.microsoft.msr.malmo.AgentHost; import com.microsoft.msr.malmo.AgentHost;
@ -34,6 +35,7 @@ import com.microsoft.msr.malmo.WorldState;
import lombok.Setter; import lombok.Setter;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.rl4j.space.Encodable;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -233,7 +235,7 @@ public class MalmoEnv implements MDP<MalmoBox, Integer, DiscreteSpace> {
logger.info("Mission ended"); logger.info("Mission ended");
} }
return new StepReply<MalmoBox>(last_observation, getRewards(last_world_state), isDone(), null); return new StepReply<>(last_observation, getRewards(last_world_state), isDone(), null);
} }
private double getRewards(WorldState world_state) { private double getRewards(WorldState world_state) {

View File

@ -16,6 +16,8 @@
package org.deeplearning4j.malmo; package org.deeplearning4j.malmo;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
import com.microsoft.msr.malmo.WorldState; import com.microsoft.msr.malmo.WorldState;

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.malmo;
import com.microsoft.msr.malmo.TimestampedStringVector; import com.microsoft.msr.malmo.TimestampedStringVector;
import com.microsoft.msr.malmo.WorldState; import com.microsoft.msr.malmo.WorldState;
import org.deeplearning4j.rl4j.space.Box;
import org.json.JSONArray; import org.json.JSONArray;
import org.json.JSONObject; import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.malmo;
import java.util.HashMap; import java.util.HashMap;
import org.deeplearning4j.rl4j.space.Box;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.malmo; package org.deeplearning4j.malmo;
import org.deeplearning4j.rl4j.space.Box;
import org.json.JSONObject; import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;