From 032b97912e19b52c10898f8918ad4266341efeda Mon Sep 17 00:00:00 2001 From: Chris Bamford Date: Thu, 23 Apr 2020 02:47:26 +0100 Subject: [PATCH] RL4J: Sanitize Observation (#404) * working on ALE image pipelines that appear to lose data * transformation pipeline for ALE has been broken for a while and needed some cleanup to make sure that openCV tooling for scene transforms was actually working. * allowing history length to be set and passed through to history merge transforms Signed-off-by: Bam4d * native image loader is not thread-safe so should not be static Signed-off-by: Bam4d * make sure the transformer for encoding observations that are not pixels converts corectly Signed-off-by: Bam4d * Test fixes for ALE pixel observation shape Signed-off-by: Bam4d * Fix compilation errors Signed-off-by: Samuel Audet * fixing some post-merge issues, and comments from PR Signed-off-by: Bam4d Co-authored-by: Samuel Audet --- .../deeplearning4j/rl4j/mdp/ale/ALEMDP.java | 47 +++-- .../org/deeplearning4j/gym/StepReply.java | 6 +- .../java/org/deeplearning4j/rl4j/mdp/MDP.java | 2 +- .../org/deeplearning4j/rl4j/space/Box.java | 35 +++- .../deeplearning4j/rl4j/space/Encodable.java | 24 +-- .../rl4j/helper/INDArrayHelper.java | 9 +- .../rl4j/learning/HistoryProcessor.java | 11 +- .../rl4j/learning/IHistoryProcessor.java | 2 +- .../rl4j/learning/ILearning.java | 4 +- .../rl4j/learning/Learning.java | 5 +- .../rl4j/learning/async/AsyncThread.java | 6 +- .../learning/async/AsyncThreadDiscrete.java | 19 +- .../async/a3c/discrete/A3CDiscrete.java | 10 +- .../async/a3c/discrete/A3CDiscreteConv.java | 22 +-- .../async/a3c/discrete/A3CDiscreteDense.java | 28 +-- .../async/a3c/discrete/A3CThreadDiscrete.java | 21 ++- .../discrete/AsyncNStepQLearningDiscrete.java | 12 +- .../AsyncNStepQLearningDiscreteConv.java | 20 +-- .../AsyncNStepQLearningDiscreteDense.java | 22 +-- .../AsyncNStepQLearningThreadDiscrete.java | 22 +-- .../learning/sync/qlearning/QLearning.java | 2 +- .../qlearning/discrete/QLearningDiscrete.java | 2 +- .../discrete/QLearningDiscreteConv.java | 22 +-- .../discrete/QLearningDiscreteDense.java | 28 +-- .../rl4j/mdp/CartpoleNative.java | 30 +--- .../rl4j/mdp/toy/HardToyState.java | 16 ++ .../rl4j/mdp/toy/SimpleToy.java | 2 +- .../rl4j/mdp/toy/SimpleToyState.java | 16 +- .../rl4j/observation/Observation.java | 7 +- .../EncodableToINDArrayTransform.java | 69 +++----- .../EncodableToImageWritableTransform.java | 24 ++- .../ImageWritableToINDArrayTransform.java | 21 +-- .../operation/HistoryMergeTransform.java | 11 +- .../historymerge/CircularFifoStore.java | 5 - .../deeplearning4j/rl4j/policy/ACPolicy.java | 20 +-- .../rl4j/policy/BoltzmannQ.java | 4 +- .../deeplearning4j/rl4j/policy/DQNPolicy.java | 8 +- .../deeplearning4j/rl4j/policy/EpsGreedy.java | 7 +- .../deeplearning4j/rl4j/policy/Policy.java | 2 +- .../rl4j/util/LegacyMDPWrapper.java | 36 ++-- .../rl4j/util/VideoRecorder.java | 166 ++---------------- .../async/AsyncThreadDiscreteTest.java | 2 +- .../rl4j/learning/async/AsyncThreadTest.java | 3 +- .../discrete/QLearningDiscreteTest.java | 45 +++-- .../operation/HistoryMergeTransformTest.java | 12 +- .../rl4j/policy/PolicyTest.java | 4 +- .../rl4j/support/MockEncodable.java | 18 -- .../deeplearning4j/rl4j/support/MockMDP.java | 19 +- .../rl4j/support/MockObservation.java | 51 ++++++ .../rl4j/support/MockPolicy.java | 2 +- .../rl4j/mdp/vizdoom/VizDoom.java | 39 ++-- .../deeplearning4j/rl4j/mdp/gym/GymEnv.java | 22 ++- .../rl4j/mdp/gym/GymEnvTest.java | 4 +- .../org/deeplearning4j/malmo/MalmoBox.java | 34 +--- .../malmo/MalmoDescretePositionPolicy.java | 12 +- .../org/deeplearning4j/malmo/MalmoEnv.java | 4 +- .../malmo/MalmoObservationSpace.java | 2 + .../malmo/MalmoObservationSpaceGrid.java | 1 + .../malmo/MalmoObservationSpacePixels.java | 1 + .../malmo/MalmoObservationSpacePosition.java | 1 + 60 files changed, 524 insertions(+), 577 deletions(-) rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/{legacy => }/EncodableToINDArrayTransform.java (64%) delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockEncodable.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java diff --git a/rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java b/rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java index 66606de69..7ff58ef30 100644 --- a/rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java +++ b/rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java @@ -25,9 +25,13 @@ import org.bytedeco.javacpp.IntPointer; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.space.ArrayObservationSpace; +import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.ObservationSpace; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; /** * @author saudet @@ -70,10 +74,14 @@ public class ALEMDP implements MDP { actions = new int[(int)a.limit()]; a.get(actions); + int height = (int)ale.getScreen().height(); + int width = (int)(int)ale.getScreen().width(); + discreteSpace = new DiscreteSpace(actions.length); - int[] shape = {(int)ale.getScreen().height(), (int)ale.getScreen().width(), 3}; + int[] shape = {3, height, width}; observationSpace = new ArrayObservationSpace<>(shape); screenBuffer = new byte[shape[0] * shape[1] * shape[2]]; + } public void setupGame() { @@ -103,7 +111,7 @@ public class ALEMDP implements MDP { public GameScreen reset() { ale.reset_game(); ale.getScreenRGB(screenBuffer); - return new GameScreen(screenBuffer); + return new GameScreen(observationSpace.getShape(), screenBuffer); } @@ -115,7 +123,8 @@ public class ALEMDP implements MDP { double r = ale.act(actions[action]) * scaleFactor; log.info(ale.getEpisodeFrameNumber() + " " + r + " " + action + " "); ale.getScreenRGB(screenBuffer); - return new StepReply(new GameScreen(screenBuffer), r, ale.game_over(), null); + + return new StepReply(new GameScreen(observationSpace.getShape(), screenBuffer), r, ale.game_over(), null); } public ObservationSpace getObservationSpace() { @@ -140,17 +149,35 @@ public class ALEMDP implements MDP { } public static class GameScreen implements Encodable { - double[] array; - public GameScreen(byte[] screen) { - array = new double[screen.length]; - for (int i = 0; i < screen.length; i++) { - array[i] = (screen[i] & 0xFF) / 255.0; - } + final INDArray data; + public GameScreen(int[] shape, byte[] screen) { + + data = Nd4j.create(screen, new long[] {shape[1], shape[2], 3}, DataType.UINT8).permute(2,0,1); } + private GameScreen(INDArray toDup) { + data = toDup.dup(); + } + + @Override public double[] toArray() { - return array; + return data.data().asDouble(); + } + + @Override + public boolean isSkipped() { + return false; + } + + @Override + public INDArray getData() { + return data; + } + + @Override + public GameScreen dup() { + return new GameScreen(data); } } } diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java index ab054689a..e37750d72 100644 --- a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java @@ -19,15 +19,15 @@ package org.deeplearning4j.gym; import lombok.Value; /** - * @param type of observation + * @param type of observation * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16. * * StepReply is the container for the data returned after each step(action). */ @Value -public class StepReply { +public class StepReply { - T observation; + OBSERVATION observation; double reward; boolean done; Object info; diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java index 37b097dbf..e911a7acc 100644 --- a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java @@ -32,7 +32,7 @@ import org.deeplearning4j.rl4j.space.ObservationSpace; * in a "functionnal manner" if step return a mdp * */ -public interface MDP> { +public interface MDP> { ObservationSpace getObservationSpace(); diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java index e90601fda..3bc242fea 100644 --- a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java @@ -16,6 +16,9 @@ package org.deeplearning4j.rl4j.space; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. * @@ -25,13 +28,37 @@ package org.deeplearning4j.rl4j.space; */ public class Box implements Encodable { - private final double[] array; + private final INDArray data; - public Box(double[] arr) { - this.array = arr; + public Box(double... arr) { + this.data = Nd4j.create(arr); } + public Box(int[] shape, double... arr) { + this.data = Nd4j.create(arr).reshape(shape); + } + + private Box(INDArray toDup) { + data = toDup.dup(); + } + + @Override public double[] toArray() { - return array; + return data.data().asDouble(); + } + + @Override + public boolean isSkipped() { + return false; + } + + @Override + public INDArray getData() { + return data; + } + + @Override + public Encodable dup() { + return new Box(data); } } diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java index 04b0c22af..bfec24f68 100644 --- a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K. K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,17 +16,19 @@ package org.deeplearning4j.rl4j.space; -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/19/16. - * Encodable is an interface that ensure that the state is convertible to a double array - */ +import org.nd4j.linalg.api.ndarray.INDArray; + public interface Encodable { - /** - * $ - * encodes all the information of an Observation in an array double and can be used as input of a DQN directly - * - * @return the encoded informations - */ + @Deprecated double[] toArray(); + + boolean isSkipped(); + + /** + * Any image data should be in CHW format. + */ + INDArray getData(); + + Encodable dup(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java index 2e608db19..b42a7c503 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java @@ -24,16 +24,17 @@ import org.nd4j.linalg.factory.Nd4j; * @author Alexandre Boulanger */ public class INDArrayHelper { + /** - * MultiLayerNetwork and ComputationGraph expect the first dimension to be the number of examples in the INDArray. - * In the case of RL4J, it must be 1. This method will return a INDArray with the correct shape. + * MultiLayerNetwork and ComputationGraph expects input data to be in NCHW in the case of pixels and NS in case of other data types. * - * @param source A INDArray - * @return The source INDArray with the correct shape + * We must have either shape 2 (NK) or shape 4 (NCHW) */ public static INDArray forceCorrectShape(INDArray source) { + return source.shape()[0] == 1 && source.shape().length > 1 ? source : Nd4j.expandDims(source, 0); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java index 550c6eb70..f3516af50 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java @@ -46,7 +46,6 @@ public class HistoryProcessor implements IHistoryProcessor { @Getter final private Configuration conf; - final private OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat(); private CircularFifoQueue history; private VideoRecorder videoRecorder; @@ -63,8 +62,7 @@ public class HistoryProcessor implements IHistoryProcessor { public void startMonitor(String filename, int[] shape) { if(videoRecorder == null) { - videoRecorder = VideoRecorder.builder(shape[0], shape[1]) - .frameInputType(VideoRecorder.FrameInputTypes.Float) + videoRecorder = VideoRecorder.builder(shape[1], shape[2]) .build(); } @@ -89,14 +87,13 @@ public class HistoryProcessor implements IHistoryProcessor { return videoRecorder != null && videoRecorder.isRecording(); } - public void record(INDArray raw) { + public void record(INDArray pixelArray) { if(isMonitoring()) { // before accessing the raw pointer, we need to make sure that array is actual on the host side - Nd4j.getAffinityManager().ensureLocation(raw, AffinityManager.Location.HOST); + Nd4j.getAffinityManager().ensureLocation(pixelArray, AffinityManager.Location.HOST); - VideoRecorder.VideoFrame frame = videoRecorder.createFrame(raw.data().pointer()); try { - videoRecorder.record(frame); + videoRecorder.record(pixelArray); } catch (Exception e) { e.printStackTrace(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java index a8a09bc0b..6bd74fd28 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java @@ -64,7 +64,7 @@ public interface IHistoryProcessor { @Builder.Default int skipFrame = 4; public int[] getShape() { - return new int[] {getHistoryLength(), getCroppingHeight(), getCroppingWidth()}; + return new int[] {getHistoryLength(), getRescaledHeight(), getRescaledWidth()}; } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java index db964527e..0d1b5bea2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java @@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.space.Encodable; * * A common interface that any training method should implement */ -public interface ILearning> { +public interface ILearning> { IPolicy getPolicy(); @@ -38,7 +38,7 @@ public interface ILearning> { ILearningConfiguration getConfiguration(); - MDP getMdp(); + MDP getMdp(); IHistoryProcessor getHistoryProcessor(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java index ca9451ea2..ba88454a7 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java @@ -21,7 +21,6 @@ import lombok.Getter; import lombok.Setter; import lombok.Value; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; @@ -38,8 +37,8 @@ import org.nd4j.linalg.factory.Nd4j; * */ @Slf4j -public abstract class Learning, NN extends NeuralNet> - implements ILearning, NeuralNetFetchable { +public abstract class Learning, NN extends NeuralNet> + implements ILearning, NeuralNetFetchable { @Getter @Setter protected int stepCount = 0; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java index 54be00cfb..26d8d5e02 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java @@ -29,10 +29,10 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListener; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.factory.Nd4j; @@ -188,7 +188,7 @@ public abstract class AsyncThread getAsyncGlobal(); - protected abstract IAsyncLearningConfiguration getConf(); + protected abstract IAsyncLearningConfiguration getConfiguration(); protected abstract IPolicy getPolicy(NN net); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index f340e2706..c32be6906 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -24,29 +24,22 @@ import lombok.Setter; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; -import org.deeplearning4j.rl4j.experience.ExperienceHandler; -import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; -import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.Stack; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. *

* Async Learning specialized for the Discrete Domain */ -public abstract class AsyncThreadDiscrete - extends AsyncThread { +public abstract class AsyncThreadDiscrete + extends AsyncThread { @Getter private NN current; @@ -59,7 +52,7 @@ public abstract class AsyncThreadDiscrete asyncGlobal, - MDP mdp, + MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) { @@ -112,7 +105,7 @@ public abstract class AsyncThreadDiscrete stepReply = getLegacyMDPWrapper().step(action); - accuReward += stepReply.getReward() * getConf().getRewardFactor(); + accuReward += stepReply.getReward() * getConfiguration().getRewardFactor(); if (!obs.isSkipped()) { experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone()); @@ -126,7 +119,7 @@ public abstract class AsyncThreadDiscrete extends AsyncLearning { +public abstract class A3CDiscrete extends AsyncLearning { @Getter final public A3CLearningConfiguration configuration; @Getter - final protected MDP mdp; + final protected MDP mdp; final private IActorCritic iActorCritic; @Getter final private AsyncGlobal asyncGlobal; @Getter - final private ACPolicy policy; + final private ACPolicy policy; - public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) { + public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) { this.iActorCritic = iActorCritic; this.mdp = mdp; this.configuration = conf; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java index 17c6b8da8..08fec8a94 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java @@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph; import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; @@ -42,19 +42,19 @@ import org.deeplearning4j.rl4j.util.IDataManager; * first layers since they're essentially doing the same dimension * reduction task **/ -public class A3CDiscreteConv extends A3CDiscrete { +public class A3CDiscreteConv extends A3CDiscrete { final private HistoryProcessor.Configuration hpconf; @Deprecated - public A3CDiscreteConv(MDP mdp, IActorCritic actorCritic, + public A3CDiscreteConv(MDP mdp, IActorCritic actorCritic, HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, actorCritic, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } @Deprecated - public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, + public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { super(mdp, IActorCritic, conf.toLearningConfiguration()); @@ -62,7 +62,7 @@ public class A3CDiscreteConv extends A3CDiscrete { setHistoryProcessor(hpconf); } - public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, + public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { super(mdp, IActorCritic, conf); this.hpconf = hpconf; @@ -70,35 +70,35 @@ public class A3CDiscreteConv extends A3CDiscrete { } @Deprecated - public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } @Deprecated - public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } - public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } @Deprecated - public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); } @Deprecated - public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf); } - public A3CDiscreteConv(MDP mdp, ActorCriticNetworkConfiguration netConf, + public A3CDiscreteConv(MDP mdp, ActorCriticNetworkConfiguration netConf, HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java index 74332bf3a..5fd68f571 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java @@ -21,8 +21,8 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.*; import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; @@ -34,74 +34,74 @@ import org.deeplearning4j.rl4j.util.IDataManager; * We use specifically the Separate version because * the model is too small to have enough benefit by sharing layers */ -public class A3CDiscreteDense extends A3CDiscrete { +public class A3CDiscreteDense extends A3CDiscrete { @Deprecated - public A3CDiscreteDense(MDP mdp, IActorCritic IActorCritic, A3CConfiguration conf, + public A3CDiscreteDense(MDP mdp, IActorCritic IActorCritic, A3CConfiguration conf, IDataManager dataManager) { this(mdp, IActorCritic, conf); addListener(new DataManagerTrainingListener(dataManager)); } @Deprecated - public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CConfiguration conf) { + public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CConfiguration conf) { super(mdp, actorCritic, conf.toLearningConfiguration()); } - public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) { + public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) { super(mdp, actorCritic, conf); } @Deprecated - public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } @Deprecated - public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, A3CConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } - public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, A3CLearningConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } @Deprecated - public A3CDiscreteDense(MDP mdp, + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } @Deprecated - public A3CDiscreteDense(MDP mdp, + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) { this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf); } - public A3CDiscreteDense(MDP mdp, + public A3CDiscreteDense(MDP mdp, ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) { this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf); } @Deprecated - public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } @Deprecated - public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, A3CConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } - public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, A3CLearningConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index 36f973957..123680a38 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -23,23 +23,23 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.api.rng.Random; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. - * + *

* Local thread as described in the https://arxiv.org/abs/1602.01783 paper. */ -public class A3CThreadDiscrete extends AsyncThreadDiscrete { +public class A3CThreadDiscrete extends AsyncThreadDiscrete { @Getter - final protected A3CLearningConfiguration conf; + final protected A3CLearningConfiguration configuration; @Getter final protected IAsyncGlobal asyncGlobal; @Getter @@ -47,17 +47,17 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< final private Random rnd; - public A3CThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, + public A3CThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners, int threadNumber) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); - this.conf = a3cc; + this.configuration = a3cc; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; - Long seed = conf.getSeed(); + Long seed = configuration.getSeed(); rnd = Nd4j.getRandom(); - if(seed != null) { + if (seed != null) { rnd.setSeed(seed + threadNumber); } @@ -69,9 +69,12 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< return new ACPolicy(net, rnd); } + /** + * calc the gradients based on the n-step rewards + */ @Override protected UpdateAlgorithm buildUpdateAlgorithm() { int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); - return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), conf.getGamma()); + return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), configuration.getGamma()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index 94edac593..8a302d2d9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -28,26 +28,26 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread; import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. */ -public abstract class AsyncNStepQLearningDiscrete - extends AsyncLearning { +public abstract class AsyncNStepQLearningDiscrete + extends AsyncLearning { @Getter final public AsyncQLearningConfiguration configuration; @Getter - final private MDP mdp; + final private MDP mdp; @Getter final private AsyncGlobal asyncGlobal; - public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncQLearningConfiguration conf) { + public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncQLearningConfiguration conf) { this.mdp = mdp; this.configuration = conf; this.asyncGlobal = new AsyncGlobal<>(dqn, conf); @@ -63,7 +63,7 @@ public abstract class AsyncNStepQLearningDiscrete } public IPolicy getPolicy() { - return new DQNPolicy(getNeuralNet()); + return new DQNPolicy(getNeuralNet()); } @Data diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java index f92b704b6..3f12a60ad 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java @@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration; import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; @@ -35,17 +35,17 @@ import org.deeplearning4j.rl4j.util.IDataManager; * Specialized constructors for the Conv (pixels input) case * Specialized conf + provide additional type safety */ -public class AsyncNStepQLearningDiscreteConv extends AsyncNStepQLearningDiscrete { +public class AsyncNStepQLearningDiscreteConv extends AsyncNStepQLearningDiscrete { final private HistoryProcessor.Configuration hpconf; @Deprecated - public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, + public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, dqn, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } - public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, + public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { super(mdp, dqn, conf); this.hpconf = hpconf; @@ -53,21 +53,21 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN } @Deprecated - public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, - HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { + public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } - public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, + public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } @Deprecated - public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, - HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { + public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); } - public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java index b6216e849..a94eba7a4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java @@ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16. */ -public class AsyncNStepQLearningDiscreteDense extends AsyncNStepQLearningDiscrete { +public class AsyncNStepQLearningDiscreteDense extends AsyncNStepQLearningDiscrete { @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, + public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf, IDataManager dataManager) { super(mdp, dqn, conf.toLearningConfiguration()); addListener(new DataManagerTrainingListener(dataManager)); } @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, + public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { super(mdp, dqn, conf.toLearningConfiguration()); } - public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, + public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, AsyncQLearningConfiguration conf) { super(mdp, dqn, conf); } @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, AsyncNStepQLConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } - public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, AsyncQLearningConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); } - public AsyncNStepQLearningDiscreteDense(MDP mdp, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf), conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index ef60c685f..0b8535f53 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -25,21 +25,21 @@ import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguratio import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. */ -public class AsyncNStepQLearningThreadDiscrete extends AsyncThreadDiscrete { +public class AsyncNStepQLearningThreadDiscrete extends AsyncThreadDiscrete { @Getter - final protected AsyncQLearningConfiguration conf; + final protected AsyncQLearningConfiguration configuration; @Getter final protected IAsyncGlobal asyncGlobal; @Getter @@ -47,17 +47,17 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn final private Random rnd; - public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, - AsyncQLearningConfiguration conf, + public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, + AsyncQLearningConfiguration configuration, TrainingListenerList listeners, int threadNumber, int deviceNum) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); - this.conf = conf; + this.configuration = configuration; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; rnd = Nd4j.getRandom(); - Long seed = conf.getSeed(); - if (seed != null) { + Long seed = configuration.getSeed(); + if(seed != null) { rnd.setSeed(seed + threadNumber); } @@ -65,13 +65,13 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn } public Policy getPolicy(IDQN nn) { - return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(), - rnd, conf.getMinEpsilon(), this); + return new EpsGreedy(new DQNPolicy(nn), getMdp(), configuration.getUpdateStart(), configuration.getEpsilonNbStep(), + rnd, configuration.getMinEpsilon(), this); } @Override protected UpdateAlgorithm buildUpdateAlgorithm() { int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); - return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), conf.getGamma()); + return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), configuration.getGamma()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index d12db5d67..b2e06dc9c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -32,10 +32,10 @@ import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.SyncLearning; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager.StatEntry; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index b2ad597d0..771650340 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -33,11 +33,11 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorith import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java index 98c690269..450d0e27e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java @@ -24,8 +24,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration; import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; @@ -34,59 +34,59 @@ import org.deeplearning4j.rl4j.util.IDataManager; * Specialized constructors for the Conv (pixels input) case * Specialized conf + provide additional type safety */ -public class QLearningDiscreteConv extends QLearningDiscrete { +public class QLearningDiscreteConv extends QLearningDiscrete { @Deprecated - public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, + public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, dqn, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } @Deprecated - public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, + public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame()); setHistoryProcessor(hpconf); } - public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, + public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); setHistoryProcessor(hpconf); } @Deprecated - public QLearningDiscreteConv(MDP mdp, DQNFactory factory, + public QLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } @Deprecated - public QLearningDiscreteConv(MDP mdp, DQNFactory factory, + public QLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } - public QLearningDiscreteConv(MDP mdp, DQNFactory factory, + public QLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } @Deprecated - public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, + public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); } @Deprecated - public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, + public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf); } - public QLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + public QLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java index 5b95cc84e..789e71b42 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java @@ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16. */ -public class QLearningDiscreteDense extends QLearningDiscrete { +public class QLearningDiscreteDense extends QLearningDiscrete { @Deprecated - public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf, - IDataManager dataManager) { + public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf, + IDataManager dataManager) { this(mdp, dqn, conf); addListener(new DataManagerTrainingListener(dataManager)); } @Deprecated - public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf) { + public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf) { super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep()); } - public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearningConfiguration conf) { + public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearningConfiguration conf) { super(mdp, dqn, conf, conf.getEpsilonNbStep()); } @Deprecated - public QLearningDiscreteDense(MDP mdp, DQNFactory factory, - QLearning.QLConfiguration conf, IDataManager dataManager) { + public QLearningDiscreteDense(MDP mdp, DQNFactory factory, + QLearning.QLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } @Deprecated - public QLearningDiscreteDense(MDP mdp, DQNFactory factory, + public QLearningDiscreteDense(MDP mdp, DQNFactory factory, QLearning.QLConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } - public QLearningDiscreteDense(MDP mdp, DQNFactory factory, + public QLearningDiscreteDense(MDP mdp, DQNFactory factory, QLearningConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } @Deprecated - public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, - QLearning.QLConfiguration conf, IDataManager dataManager) { + public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, + QLearning.QLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } @Deprecated - public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, + public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, QLearning.QLConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); } - public QLearningDiscreteDense(MDP mdp, DQNDenseNetworkConfiguration netConf, + public QLearningDiscreteDense(MDP mdp, DQNDenseNetworkConfiguration netConf, QLearningConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf), conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java index 94aa79b0b..8b33e54d0 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java @@ -4,8 +4,8 @@ import lombok.Getter; import lombok.Setter; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.space.ArrayObservationSpace; +import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.ObservationSpace; import java.util.Random; @@ -36,7 +36,7 @@ import java.util.Random; */ -public class CartpoleNative implements MDP { +public class CartpoleNative implements MDP { public enum KinematicsIntegrators { Euler, SemiImplicitEuler }; private static final int NUM_ACTIONS = 2; @@ -74,7 +74,7 @@ public class CartpoleNative implements MDP observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES }); + private ObservationSpace observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES }); public CartpoleNative() { rnd = new Random(); @@ -85,7 +85,7 @@ public class CartpoleNative implements MDP step(Integer action) { + public StepReply step(Integer action) { double force = action == ACTION_RIGHT ? forceMag : -forceMag; double cosTheta = Math.cos(theta); double sinTheta = Math.sin(theta); @@ -143,26 +143,12 @@ public class CartpoleNative implements MDP(new State(new double[] { x, xDot, theta, thetaDot }), reward, done, null); + return new StepReply<>(new Box(x, xDot, theta, thetaDot), reward, done, null); } @Override - public MDP newInstance() { + public MDP newInstance() { return new CartpoleNative(); } - public static class State implements Encodable { - - private final double[] state; - - State(double[] state) { - - this.state = state; - } - - @Override - public double[] toArray() { - return state; - } - } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java index 6fd96b7ea..a357eaeda 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java @@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy; import lombok.Value; import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16. @@ -31,4 +32,19 @@ public class HardToyState implements Encodable { public double[] toArray() { return values; } + + @Override + public boolean isSkipped() { + return false; + } + + @Override + public INDArray getData() { + return null; + } + + @Override + public Encodable dup() { + return null; + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java index 19b07b0b1..933332125 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java @@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.learning.NeuralNetFetchable; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.ArrayObservationSpace; +import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.ObservationSpace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -40,7 +41,6 @@ import org.nd4j.linalg.factory.Nd4j; public class SimpleToy implements MDP { final private int maxStep; - //TODO 10 steps toy (always +1 reward2 actions), toylong (1000 steps), toyhard (7 actions, +1 only if actiion = (step/100+step)%7, and toyStoch (like last but reward has 0.10 odd to be somewhere else). @Getter private DiscreteSpace actionSpace = new DiscreteSpace(2); @Getter diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java index 1c38cf384..6e41ea414 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java @@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy; import lombok.Value; import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. @@ -28,11 +29,24 @@ public class SimpleToyState implements Encodable { int i; int step; - @Override public double[] toArray() { double[] ar = new double[1]; ar[0] = (20 - i); return ar; } + @Override + public boolean isSkipped() { + return false; + } + + @Override + public INDArray getData() { + return null; + } + + @Override + public Encodable dup() { + return null; + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java index 0444aa32d..6603429cc 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java @@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * * @author Alexandre Boulanger */ -public class Observation { +public class Observation implements Encodable { /** * A singleton representing a skipped observation @@ -38,6 +38,11 @@ public class Observation { @Getter private final INDArray data; + @Override + public double[] toArray() { + return data.data().asDouble(); + } + public boolean isSkipped() { return data == null; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/EncodableToINDArrayTransform.java similarity index 64% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/EncodableToINDArrayTransform.java index a9214bbff..8be8c7ed9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/EncodableToINDArrayTransform.java @@ -1,41 +1,28 @@ -/******************************************************************************* - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ -package org.deeplearning4j.rl4j.observation.transform.legacy; - -import org.bytedeco.javacv.OpenCVFrameConverter; -import org.bytedeco.opencv.opencv_core.Mat; -import org.datavec.api.transform.Operation; -import org.datavec.image.data.ImageWritable; -import org.deeplearning4j.rl4j.space.Encodable; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.bytedeco.opencv.global.opencv_core.CV_32FC; - -public class EncodableToINDArrayTransform implements Operation { - - private final int[] shape; - - public EncodableToINDArrayTransform(int[] shape) { - this.shape = shape; - } - - @Override - public INDArray transform(Encodable encodable) { - return Nd4j.create(encodable.toArray()).reshape(shape); - } - -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K. K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation.transform; + +import org.datavec.api.transform.Operation; +import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; + +public class EncodableToINDArrayTransform implements Operation { + @Override + public INDArray transform(Encodable encodable) { + return encodable.getData(); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java index 133fbdb61..870b366ff 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java @@ -15,34 +15,32 @@ ******************************************************************************/ package org.deeplearning4j.rl4j.observation.transform.legacy; +import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.opencv.opencv_core.Mat; import org.datavec.api.transform.Operation; import org.datavec.image.data.ImageWritable; +import org.datavec.image.loader.NativeImageLoader; import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.bytedeco.opencv.global.opencv_core.CV_32FC; +import static org.bytedeco.opencv.global.opencv_core.CV_32FC3; +import static org.bytedeco.opencv.global.opencv_core.CV_32S; +import static org.bytedeco.opencv.global.opencv_core.CV_32SC; +import static org.bytedeco.opencv.global.opencv_core.CV_32SC3; +import static org.bytedeco.opencv.global.opencv_core.CV_64FC; +import static org.bytedeco.opencv.global.opencv_core.CV_8UC3; public class EncodableToImageWritableTransform implements Operation { - private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); - private final int height; - private final int width; - private final int colorChannels; - - public EncodableToImageWritableTransform(int height, int width, int colorChannels) { - this.height = height; - this.width = width; - this.colorChannels = colorChannels; - } + final static NativeImageLoader nativeImageLoader = new NativeImageLoader(); @Override public ImageWritable transform(Encodable encodable) { - INDArray indArray = Nd4j.create(encodable.toArray()).reshape(height, width, colorChannels); - Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer()); - return new ImageWritable(converter.convert(mat)); + return new ImageWritable(nativeImageLoader.asFrame(encodable.getData(), Frame.DEPTH_UBYTE)); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java index 3a48c128a..88615325d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java @@ -18,34 +18,31 @@ package org.deeplearning4j.rl4j.observation.transform.legacy; import org.datavec.api.transform.Operation; import org.datavec.image.data.ImageWritable; import org.datavec.image.loader.NativeImageLoader; -import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; import java.io.IOException; public class ImageWritableToINDArrayTransform implements Operation { - private final int height; - private final int width; - private final NativeImageLoader loader; - - public ImageWritableToINDArrayTransform(int height, int width) { - this.height = height; - this.width = width; - this.loader = new NativeImageLoader(height, width); - } + private final NativeImageLoader loader = new NativeImageLoader(); @Override public INDArray transform(ImageWritable imageWritable) { + + int height = imageWritable.getHeight(); + int width = imageWritable.getWidth(); + int channels = imageWritable.getFrame().imageChannels; + INDArray out = null; try { out = loader.asMatrix(imageWritable); } catch (IOException e) { e.printStackTrace(); } - out = out.reshape(1, height, width); + + // Convert back to uint8 and reshape to the number of channels in the image + out = out.reshape(channels, height, width); INDArray compressed = out.castTo(DataType.UINT8); return compressed; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java index e27d1134c..e8920bbdd 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java @@ -46,19 +46,20 @@ public class HistoryMergeTransform implements Operation, Res private final HistoryMergeElementStore historyMergeElementStore; private final HistoryMergeAssembler historyMergeAssembler; private final boolean shouldStoreCopy; - private final boolean isFirstDimenstionBatch; + private final boolean isFirstDimensionBatch; private HistoryMergeTransform(Builder builder) { this.historyMergeElementStore = builder.historyMergeElementStore; this.historyMergeAssembler = builder.historyMergeAssembler; this.shouldStoreCopy = builder.shouldStoreCopy; - this.isFirstDimenstionBatch = builder.isFirstDimenstionBatch; + this.isFirstDimensionBatch = builder.isFirstDimenstionBatch; } @Override public INDArray transform(INDArray input) { + INDArray element; - if(isFirstDimenstionBatch) { + if(isFirstDimensionBatch) { element = input.slice(0, 0); } else { @@ -132,9 +133,9 @@ public class HistoryMergeTransform implements Operation, Res return this; } - public HistoryMergeTransform build() { + public HistoryMergeTransform build(int frameStackLength) { if(historyMergeElementStore == null) { - historyMergeElementStore = new CircularFifoStore(); + historyMergeElementStore = new CircularFifoStore(frameStackLength); } if(historyMergeAssembler == null) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java index db1cbb2bd..5b00bba3c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java @@ -28,14 +28,9 @@ import org.nd4j.linalg.factory.Nd4j; * @author Alexandre Boulanger */ public class CircularFifoStore implements HistoryMergeElementStore { - private static final int DEFAULT_STORE_SIZE = 4; private final CircularFifoQueue queue; - public CircularFifoStore() { - this(DEFAULT_STORE_SIZE); - } - public CircularFifoStore(int size) { Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size); queue = new CircularFifoQueue<>(size); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java index e01456729..6824e75cb 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java @@ -20,8 +20,8 @@ import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph; import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate; import org.deeplearning4j.rl4j.network.ac.IActorCritic; -import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; @@ -35,7 +35,7 @@ import java.io.IOException; * the softmax output of the actor critic, but objects constructed * with a {@link Random} argument of null return the max only. */ -public class ACPolicy extends Policy { +public class ACPolicy extends Policy { final private IActorCritic actorCritic; Random rnd; @@ -48,18 +48,18 @@ public class ACPolicy extends Policy { this.rnd = rnd; } - public static ACPolicy load(String path) throws IOException { - return new ACPolicy(ActorCriticCompGraph.load(path)); + public static ACPolicy load(String path) throws IOException { + return new ACPolicy<>(ActorCriticCompGraph.load(path)); } - public static ACPolicy load(String path, Random rnd) throws IOException { - return new ACPolicy(ActorCriticCompGraph.load(path), rnd); + public static ACPolicy load(String path, Random rnd) throws IOException { + return new ACPolicy<>(ActorCriticCompGraph.load(path), rnd); } - public static ACPolicy load(String pathValue, String pathPolicy) throws IOException { - return new ACPolicy(ActorCriticSeparate.load(pathValue, pathPolicy)); + public static ACPolicy load(String pathValue, String pathPolicy) throws IOException { + return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy)); } - public static ACPolicy load(String pathValue, String pathPolicy, Random rnd) throws IOException { - return new ACPolicy(ActorCriticSeparate.load(pathValue, pathPolicy), rnd); + public static ACPolicy load(String pathValue, String pathPolicy, Random rnd) throws IOException { + return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd); } public IActorCritic getNeuralNet() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java index 7508655c3..6f2e63620 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java @@ -17,8 +17,8 @@ package org.deeplearning4j.rl4j.policy; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; @@ -30,7 +30,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp; * Boltzmann exploration is a stochastic policy wrt to the * exponential Q-values as evaluated by the dqn model. */ -public class BoltzmannQ extends Policy { +public class BoltzmannQ extends Policy { final private IDQN dqn; final private Random rnd; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java index e2982823d..ed591a1ff 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java @@ -20,8 +20,8 @@ import lombok.AllArgsConstructor; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.network.dqn.DQN; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; import java.io.IOException; @@ -35,12 +35,12 @@ import java.io.IOException; // FIXME: Should we rename this "GreedyPolicy"? @AllArgsConstructor -public class DQNPolicy extends Policy { +public class DQNPolicy extends Policy { final private IDQN dqn; - public static DQNPolicy load(String path) throws IOException { - return new DQNPolicy(DQN.load(path)); + public static DQNPolicy load(String path) throws IOException { + return new DQNPolicy<>(DQN.load(path)); } public IDQN getNeuralNet() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java index 4801c7b70..a7282f139 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java @@ -20,12 +20,11 @@ package org.deeplearning4j.rl4j.policy; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.rl4j.learning.IEpochTrainer; -import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; @@ -41,10 +40,10 @@ import org.nd4j.linalg.api.rng.Random; */ @AllArgsConstructor @Slf4j -public class EpsGreedy> extends Policy { +public class EpsGreedy> extends Policy { final private Policy policy; - final private MDP mdp; + final private MDP mdp; final private int updateStart; final private int epsilonNbStep; final private Random rnd; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index 4885e2c62..6a4146c94 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -22,9 +22,9 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; /** diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java index 981f35379..cc0a12e38 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java @@ -7,22 +7,22 @@ import org.datavec.image.transform.ColorConversionTransform; import org.datavec.image.transform.CropImageTransform; import org.datavec.image.transform.MultiImageTransform; import org.datavec.image.transform.ResizeImageTransform; +import org.datavec.image.transform.ShowImageTransform; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.transform.TransformProcess; import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; -import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform; import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform; import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform; import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.ObservationSpace; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; import java.util.HashMap; import java.util.Map; @@ -46,6 +46,7 @@ public class LegacyMDPWrapper wrappedMDP, IHistoryProcessor historyProcessor) { this.wrappedMDP = wrappedMDP; this.shape = wrappedMDP.getObservationSpace().getShape(); @@ -66,28 +67,33 @@ public class LegacyMDPWrapper channelsData = buildChannelsData(rawStepReply.getObservation()); Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone()); + return new StepReply(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); } @@ -161,12 +168,7 @@ public class LegacyMDPWrapper { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java index 5a91b71e4..e7e9fcd4c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java @@ -16,26 +16,21 @@ package org.deeplearning4j.rl4j.util; -import lombok.Getter; import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.BytePointer; -import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacv.FFmpegFrameRecorder; import org.bytedeco.javacv.Frame; -import org.bytedeco.javacv.OpenCVFrameConverter; -import org.bytedeco.opencv.global.opencv_core; -import org.bytedeco.opencv.global.opencv_imgproc; -import org.bytedeco.opencv.opencv_core.Mat; -import org.bytedeco.opencv.opencv_core.Rect; -import org.bytedeco.opencv.opencv_core.Size; -import org.opencv.imgproc.Imgproc; +import org.datavec.image.loader.NativeImageLoader; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; -import static org.bytedeco.ffmpeg.global.avcodec.*; -import static org.bytedeco.opencv.global.opencv_core.*; +import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_H264; +import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_MPEG4; +import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB0; +import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB24; +import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB8; /** - * VideoRecorder is used to create a video from a sequence of individual frames. If using 3 channels - * images, it expects B-G-R order. A RGB order can be used by calling isRGBOrder(true).
+ * VideoRecorder is used to create a video from a sequence of INDArray frames. INDArrays are assumed to be in CHW format where C=3 and pixels are RGB encoded
* Example:
*

  * {@code
@@ -45,11 +40,8 @@ import static org.bytedeco.opencv.global.opencv_core.*;
  *             .build();
  *         recorder.startRecording("myVideo.mp4");
  *         while(...) {
- *             byte[] data = new byte[160*100*3];
- *             // Todo: Fill data
- *             VideoRecorder.VideoFrame frame = recorder.createFrame(data);
- *             // Todo: Apply cropping or resizing to frame
- *             recorder.record(frame);
+ *             INDArray chwData = Nd4j.create()
+ *             recorder.record(chwData);
  *         }
  *         recorder.stopRecording();
  * }
@@ -60,16 +52,13 @@ import static org.bytedeco.opencv.global.opencv_core.*;
 @Slf4j
 public class VideoRecorder implements AutoCloseable {
 
-    public enum FrameInputTypes { BGR, RGB, Float }
+    private final NativeImageLoader nativeImageLoader = new NativeImageLoader();
 
     private final int height;
     private final int width;
-    private final int imageType;
-    private final OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat();
     private final int codec;
     private final double framerate;
     private final int videoQuality;
-    private final FrameInputTypes frameInputType;
 
     private FFmpegFrameRecorder fmpegFrameRecorder = null;
 
@@ -83,11 +72,9 @@ public class VideoRecorder implements AutoCloseable {
     private VideoRecorder(Builder builder) {
         this.height = builder.height;
         this.width = builder.width;
-        imageType = CV_8UC(builder.numChannels);
         codec = builder.codec;
         framerate = builder.frameRate;
         videoQuality = builder.videoQuality;
-        frameInputType = builder.frameInputType;
     }
 
     /**
@@ -119,59 +106,11 @@ public class VideoRecorder implements AutoCloseable {
 
     /**
      * Add a frame to the video
-     * @param frame the VideoFrame to add to the video
+     * @param imageArray the INDArray that contains the data to be recorded, the data must be in CHW format
      * @throws Exception
      */
-    public void record(VideoFrame frame) throws Exception {
-        Size size = frame.getMat().size();
-        if(size.height() != height || size.width() != width) {
-            throw new IllegalArgumentException(String.format("Wrong frame size. Got (%dh x %dw) expected (%dh x %dw)", size.height(), size.width(), height, width));
-        }
-        Frame cvFrame = openCVFrameConverter.convert(frame.getMat());
-        fmpegFrameRecorder.record(cvFrame);
-    }
-
-    /**
-     * Create a VideoFrame from a byte array.
-     * @param data A byte array. Expect the index to be of the form [(Y*Width + X) * NumChannels + channel]
-     * @return An instance of VideoFrame
-     */
-    public VideoFrame createFrame(byte[] data) {
-        return createFrame(new BytePointer(data));
-    }
-
-    /**
-     * Create a VideoFrame from a byte array with different height and width than the video
-     * the frame will need to be cropped or resized before being added to the video)
-     *
-     * @param data A byte array Expect the index to be of the form [(Y*customWidth + X) * NumChannels + channel]
-     * @param customHeight The actual height of the data
-     * @param customWidth The actual width of the data
-     * @return A VideoFrame instance
-     */
-    public VideoFrame createFrame(byte[] data, int customHeight, int customWidth) {
-        return createFrame(new BytePointer(data), customHeight, customWidth);
-    }
-
-    /**
-     * Create a VideoFrame from a Pointer (to use for example with a INDarray).
-     * @param data A Pointer (for example myINDArray.data().pointer())
-     * @return An instance of VideoFrame
-     */
-    public VideoFrame createFrame(Pointer data) {
-        return new VideoFrame(height, width, imageType, frameInputType, data);
-    }
-
-    /**
-     *  Create a VideoFrame from a Pointer with different height and width than the video
-     * the frame will need to be cropped or resized before being added to the video)
-     * @param data
-     * @param customHeight The actual height of the data
-     * @param customWidth The actual width of the data
-     * @return A VideoFrame instance
-     */
-    public VideoFrame createFrame(Pointer data, int customHeight, int customWidth) {
-        return new VideoFrame(customHeight, customWidth, imageType, frameInputType, data);
+    public void record(INDArray imageArray) throws Exception {
+        fmpegFrameRecorder.record(nativeImageLoader.asFrame(imageArray, Frame.DEPTH_UBYTE));
     }
 
     /**
@@ -192,69 +131,12 @@ public class VideoRecorder implements AutoCloseable {
         return new Builder(height, width);
     }
 
-    /**
-     * An individual frame for the video
-     */
-    public static class VideoFrame {
-
-        private final int height;
-        private final int width;
-        private final int imageType;
-        @Getter
-        private Mat mat;
-
-        private VideoFrame(int height, int width, int imageType, FrameInputTypes frameInputType, Pointer data) {
-            this.height = height;
-            this.width = width;
-            this.imageType = imageType;
-
-            switch(frameInputType) {
-                case RGB:
-                    Mat src = new Mat(height, width, imageType, data);
-                    mat = new Mat(height, width, imageType);
-                    opencv_imgproc.cvtColor(src, mat, Imgproc.COLOR_RGB2BGR);
-                    break;
-
-                case BGR:
-                    mat = new Mat(height, width, imageType, data);
-                    break;
-
-                case Float:
-                    Mat tmpMat = new Mat(height, width, CV_32FC(3), data);
-                    mat = new Mat(height, width, imageType);
-                    tmpMat.convertTo(mat, CV_8UC(3), 255.0, 0.0);
-            }
-        }
-
-        /**
-         * Crop the video to a specified size
-         * @param newHeight The new height of the frame
-         * @param newWidth The new width of the frame
-         * @param heightOffset The starting height offset in the uncropped frame
-         * @param widthOffset The starting weight offset in the uncropped frame
-         */
-        public void crop(int newHeight, int newWidth, int heightOffset, int widthOffset) {
-            mat = mat.apply(new Rect(widthOffset, heightOffset, newWidth, newHeight));
-        }
-
-        /**
-         * Resize the frame to a specified size
-         * @param newHeight The new height of the frame
-         * @param newWidth The new width of the frame
-         */
-        public void resize(int newHeight, int newWidth) {
-            mat = new Mat(newHeight, newWidth, imageType);
-        }
-    }
-
     /**
      * A builder class for the VideoRecorder
      */
     public static class Builder {
         private final int height;
         private final int width;
-        private int numChannels = 3;
-        private FrameInputTypes frameInputType = FrameInputTypes.BGR;
         private int codec = AV_CODEC_ID_H264;
         private double frameRate = 30.0;
         private int videoQuality = 30;
@@ -268,24 +150,6 @@ public class VideoRecorder implements AutoCloseable {
             this.width = width;
         }
 
-        /**
-         * Specify the number of channels. Default is 3
-         * @param numChannels
-         */
-        public Builder numChannels(int numChannels) {
-            this.numChannels = numChannels;
-            return this;
-        }
-
-        /**
-         * Tell the VideoRecorder what data it will receive (default is BGR)
-         * @param frameInputType (See {@link FrameInputTypes}}
-         */
-        public Builder frameInputType(FrameInputTypes frameInputType) {
-            this.frameInputType = frameInputType;
-            return this;
-        }
-
         /**
          * The codec to use for the video. Default is AV_CODEC_ID_H264
          * @param codec Code (see {@link org.bytedeco.ffmpeg.global.avcodec codec codes})
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
index 9499da99e..de9778a80 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
@@ -115,7 +115,7 @@ public class AsyncThreadDiscreteTest {
 
         asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm);
 
-        when(asyncThreadDiscrete.getConf()).thenReturn(mockAsyncConfiguration);
+        when(asyncThreadDiscrete.getConfiguration()).thenReturn(mockAsyncConfiguration);
         when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0);
         when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
         when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy);
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
index 117465de3..5b6afbc28 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
@@ -39,7 +39,6 @@ import static org.junit.Assert.assertEquals;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.clearInvocations;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
@@ -130,7 +129,7 @@ public class AsyncThreadTest {
 
         when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode);
         when(mockAsyncConfiguration.getNStep()).thenReturn(nstep);
-        when(thread.getConf()).thenReturn(mockAsyncConfiguration);
+        when(thread.getConfiguration()).thenReturn(mockAsyncConfiguration);
 
         // if we hit the max step count
         when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps);
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
index 82129e0df..e19af338b 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
@@ -18,24 +18,16 @@
 package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
 
 import org.deeplearning4j.gym.StepReply;
-import org.deeplearning4j.rl4j.experience.ExperienceHandler;
-import org.deeplearning4j.rl4j.experience.StateActionPair;
 import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
 import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
-import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
-import org.deeplearning4j.rl4j.learning.sync.Transition;
 import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
 import org.deeplearning4j.rl4j.mdp.MDP;
 import org.deeplearning4j.rl4j.network.dqn.IDQN;
+import org.deeplearning4j.rl4j.space.Encodable;
 import org.deeplearning4j.rl4j.observation.Observation;
 import org.deeplearning4j.rl4j.space.Box;
 import org.deeplearning4j.rl4j.space.DiscreteSpace;
-import org.deeplearning4j.rl4j.space.Encodable;
 import org.deeplearning4j.rl4j.space.ObservationSpace;
-import org.deeplearning4j.rl4j.support.*;
-import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
-import org.deeplearning4j.rl4j.util.IDataManager;
-import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -43,17 +35,17 @@ import org.mockito.Mock;
 import org.mockito.Mockito;
 import org.mockito.junit.MockitoJUnitRunner;
 import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.dataset.api.DataSet;
-import org.nd4j.linalg.api.rng.Random;
 import org.nd4j.linalg.factory.Nd4j;
 
-import java.util.ArrayList;
-import java.util.List;
-
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 @RunWith(MockitoJUnitRunner.class)
@@ -82,6 +74,7 @@ public class QLearningDiscreteTest {
     @Mock
     QLearningConfiguration mockQlearningConfiguration;
 
+    // HWC
     int[] observationShape = new int[]{3, 10, 10};
     int totalObservationSize = 1;
 
@@ -123,6 +116,7 @@ public class QLearningDiscreteTest {
         when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]);
         when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]);
         when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames);
+        when(mockHistoryConfiguration.getHistoryLength()).thenReturn(1);
         when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration);
 
         qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor);
@@ -148,7 +142,7 @@ public class QLearningDiscreteTest {
         Observation observation = new Observation(Nd4j.zeros(observationShape));
         when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
 
-        when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
+        when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Observation(Nd4j.zeros(observationShape)), 0, false, null));
 
         // Act
         QLearning.QLStepReturn stepReturn = qLearningDiscrete.trainStep(observation);
@@ -170,25 +164,26 @@ public class QLearningDiscreteTest {
         // Arrange
         mockTestContext(100,0,2,1.0, 10);
 
-        mockHistoryProcessor(2);
+        Observation skippedObservation = Observation.SkippedObservation;
+        Observation nextObservation = new Observation(Nd4j.zeros(observationShape));
 
-        // An example observation and 2 Q values output (2 actions)
-        Observation observation = new Observation(Nd4j.zeros(observationShape));
-        when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
-
-        when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
+        when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(nextObservation, 0, false, null));
 
         // Act
-        QLearning.QLStepReturn stepReturn = qLearningDiscrete.trainStep(observation);
+        QLearning.QLStepReturn stepReturn = qLearningDiscrete.trainStep(skippedObservation);
 
         // Assert
-        assertEquals(1.0, stepReturn.getMaxQ(), 1e-5);
+        assertEquals(Double.NaN, stepReturn.getMaxQ(), 1e-5);
 
         StepReply stepReply = stepReturn.getStepReply();
 
         assertEquals(0, stepReply.getReward(), 1e-5);
         assertFalse(stepReply.isDone());
-        assertTrue(stepReply.getObservation().isSkipped());
+        assertFalse(stepReply.getObservation().isSkipped());
+        assertEquals(0, qLearningDiscrete.getExperienceHandler().getTrainingBatchSize());
+
+        verify(mockDQN, never()).output(any(INDArray.class));
+
     }
 
     //TODO: there are much more test cases here that can be improved upon
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java
index 9c7a172bb..9126ea1fa 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java
@@ -17,7 +17,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .isFirstDimenstionBatch(false)
                 .elementStore(store)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 
         // Act
@@ -35,7 +35,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .isFirstDimenstionBatch(true)
                 .elementStore(store)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3);
 
         // Act
@@ -53,7 +53,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .isFirstDimenstionBatch(true)
                 .elementStore(store)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 
         // Act
@@ -70,7 +70,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .shouldStoreCopy(false)
                 .elementStore(store)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 
         // Act
@@ -87,7 +87,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .shouldStoreCopy(true)
                 .elementStore(store)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 
         // Act
@@ -107,7 +107,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .elementStore(store)
                 .assembler(assemble)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 
         // Act
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
index 403d3c91e..249304afb 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
@@ -252,8 +252,8 @@ public class PolicyTest {
         }
 
         @Override
-        protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) {
-            mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
+        protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) {
+            mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(skipFrame, historyLength));
             return super.refacInitMdp(mdpWrapper, hp);
         }
     }
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockEncodable.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockEncodable.java
deleted file mode 100644
index 436205b42..000000000
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockEncodable.java
+++ /dev/null
@@ -1,18 +0,0 @@
-package org.deeplearning4j.rl4j.support;
-
-import org.deeplearning4j.rl4j.space.Encodable;
-
-public class MockEncodable implements Encodable {
-
-    private final int value;
-
-    public MockEncodable(int value) {
-
-        this.value = value;
-    }
-
-    @Override
-    public double[] toArray() {
-        return new double[] { value };
-    }
-}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java
index bbed87624..61db6dd6e 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java
@@ -2,9 +2,10 @@ package org.deeplearning4j.rl4j.support;
 
 import org.deeplearning4j.gym.StepReply;
 import org.deeplearning4j.rl4j.mdp.MDP;
+import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform;
 import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
 import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
-import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
+import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
 import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
 import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
 import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore;
@@ -15,7 +16,7 @@ import org.nd4j.linalg.api.rng.Random;
 import java.util.ArrayList;
 import java.util.List;
 
-public class MockMDP implements MDP {
+public class MockMDP implements MDP {
 
     private final DiscreteSpace actionSpace;
     private final int stepsUntilDone;
@@ -55,11 +56,11 @@ public class MockMDP implements MDP {
     }
 
     @Override
-    public MockEncodable reset() {
+    public MockObservation reset() {
         ++resetCount;
         currentObsValue = 0;
         step = 0;
-        return new MockEncodable(currentObsValue++);
+        return new MockObservation(currentObsValue++);
     }
 
     @Override
@@ -68,10 +69,10 @@ public class MockMDP implements MDP {
     }
 
     @Override
-    public StepReply step(Integer action) {
+    public StepReply step(Integer action) {
         actions.add(action);
         ++step;
-        return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null);
+        return new StepReply<>(new MockObservation(currentObsValue), (double) currentObsValue++, isDone(), null);
     }
 
     @Override
@@ -84,14 +85,14 @@ public class MockMDP implements MDP {
         return null;
     }
 
-    public static TransformProcess buildTransformProcess(int[] shape, int skipFrame, int historyLength) {
+    public static TransformProcess buildTransformProcess(int skipFrame, int historyLength) {
         return TransformProcess.builder()
                 .filter(new UniformSkippingFilter(skipFrame))
-                .transform("data", new EncodableToINDArrayTransform(shape))
+                .transform("data", new EncodableToINDArrayTransform())
                 .transform("data", new SimpleNormalizationTransform(0.0, 255.0))
                 .transform("data", HistoryMergeTransform.builder()
                         .elementStore(new CircularFifoStore(historyLength))
-                        .build())
+                        .build(4))
                 .build("data");
     }
 
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java
new file mode 100644
index 000000000..70a3e76c6
--- /dev/null
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java
@@ -0,0 +1,51 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K. K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.rl4j.support;
+
+
+import org.deeplearning4j.rl4j.space.Encodable;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+public class MockObservation implements Encodable {
+
+    final INDArray data;
+
+    public MockObservation(int value) {
+        this.data = Nd4j.ones(1).mul(value);
+    }
+
+    @Override
+    public double[] toArray() {
+        return data.data().asDouble();
+    }
+
+    @Override
+    public boolean isSkipped() {
+        return false;
+    }
+
+    @Override
+    public INDArray getData() {
+        return data;
+    }
+
+    @Override
+    public Encodable dup() {
+        return null;
+    }
+}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java
index 2b9fd491b..8786f7d7d 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java
@@ -17,7 +17,7 @@ public class MockPolicy implements IPolicy {
     public List actionInputs = new ArrayList();
 
     @Override
-    public > double play(MDP mdp, IHistoryProcessor hp) {
+    public > double play(MDP mdp, IHistoryProcessor hp) {
         ++playCallCount;
         return 0;
     }
diff --git a/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java b/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java
index 4c331ad78..becce416f 100644
--- a/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java
+++ b/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java
@@ -28,6 +28,9 @@ import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
 import org.deeplearning4j.rl4j.space.DiscreteSpace;
 import org.deeplearning4j.rl4j.space.Encodable;
 import org.deeplearning4j.rl4j.space.ObservationSpace;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
 import vizdoom.*;
 
 import java.util.ArrayList;
@@ -155,7 +158,7 @@ abstract public class VizDoom implements MDP> implements MDP {
+public class GymEnv> implements MDP {
 
     public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn";
 
@@ -82,7 +80,7 @@ public class GymEnv> implements MDP {
     private PyObject locals;
 
     final protected DiscreteSpace actionSpace;
-    final protected ObservationSpace observationSpace;
+    final protected ObservationSpace observationSpace;
     @Getter
     final private String envId;
     @Getter
@@ -119,7 +117,7 @@ public class GymEnv> implements MDP {
             for (int i = 0; i < shape.length; i++) {
                 shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i));
             }
-            observationSpace = (ObservationSpace) new ArrayObservationSpace(shape);
+            observationSpace = (ObservationSpace) new ArrayObservationSpace(shape);
             Py_DecRef(shapeTuple);
 
             PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null);
@@ -140,7 +138,7 @@ public class GymEnv> implements MDP {
     }
 
     @Override
-    public ObservationSpace getObservationSpace() {
+    public ObservationSpace getObservationSpace() {
         return observationSpace;
     }
 
@@ -153,7 +151,7 @@ public class GymEnv> implements MDP {
     }
 
     @Override
-    public StepReply step(A action) {
+    public StepReply step(A action) {
         int gstate = PyGILState_Ensure();
         try {
             if (render) {
@@ -186,7 +184,7 @@ public class GymEnv> implements MDP {
     }
 
     @Override
-    public O reset() {
+    public OBSERVATION reset() {
         int gstate = PyGILState_Ensure();
         try {
             Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null));
@@ -201,7 +199,7 @@ public class GymEnv> implements MDP {
 
             double[] data = new double[(int)stateData.capacity()];
             stateData.get(data);
-            return (O) new Box(data);
+            return (OBSERVATION) new Box(data);
         } finally {
             PyGILState_Release(gstate);
         }
@@ -220,7 +218,7 @@ public class GymEnv> implements MDP {
     }
 
     @Override
-    public GymEnv newInstance() {
-        return new GymEnv(envId, render, monitor);
+    public GymEnv newInstance() {
+        return new GymEnv(envId, render, monitor);
     }
 }
diff --git a/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java b/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java
index 2196d7b31..4faf26b2b 100644
--- a/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java
+++ b/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java
@@ -40,8 +40,8 @@ public class GymEnvTest {
         assertEquals(false, mdp.isDone());
         Box o = (Box)mdp.reset();
         StepReply r = mdp.step(0);
-        assertEquals(4, o.toArray().length);
-        assertEquals(4, ((Box)r.getObservation()).toArray().length);
+        assertEquals(4, o.getData().shape()[0]);
+        assertEquals(4, ((Box)r.getObservation()).getData().shape()[0]);
         assertNotEquals(null, mdp.newInstance());
         mdp.close();
     }
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java
index 412976b27..91cec3d8b 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java
@@ -1,5 +1,5 @@
 /*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K. K.
  *
  * This program and the accompanying materials are made available under the
  * terms of the Apache License, Version 2.0 which is available at
@@ -16,33 +16,13 @@
 
 package org.deeplearning4j.malmo;
 
-import java.util.Arrays;
+import org.deeplearning4j.rl4j.space.Box;
+import org.nd4j.linalg.factory.Nd4j;
 
-import org.deeplearning4j.rl4j.space.Encodable;
+@Deprecated
+public class MalmoBox extends Box {
 
-/**
- * Encodable state as a simple value array similar to Gym Box model, but without a JSON constructor
- * @author howard-abrams (howard.abrams@ca.com) on 1/12/17.
- */
-public class MalmoBox implements Encodable {
-    double[] value;
-
-    /**
-     * Construct state from an array of doubles
-     * @param value state values
-     */
-    //TODO: If this constructor was added to "Box", we wouldn't need this class at all.
-    public MalmoBox(double... value) {
-        this.value = value;
-    }
-
-    @Override
-    public double[] toArray() {
-        return value;
-    }
-
-    @Override
-    public String toString() {
-        return Arrays.toString(value);
+    public MalmoBox(double... arr) {
+        super(arr);
     }
 }
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java
index c853fa362..d68de87ef 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java
@@ -19,6 +19,8 @@ package org.deeplearning4j.malmo;
 import java.util.Arrays;
 
 import com.microsoft.msr.malmo.WorldState;
+import org.deeplearning4j.rl4j.space.Box;
+import org.nd4j.linalg.api.ndarray.INDArray;
 
 /**
  * A Malmo consistency policy that ensures the both there is a reward and next observation has a different position that the previous one.
@@ -30,14 +32,14 @@ public class MalmoDescretePositionPolicy implements MalmoObservationPolicy {
 
     @Override
     public boolean isObservationConsistant(WorldState world_state, WorldState original_world_state) {
-        MalmoBox last_observation = observationSpace.getObservation(world_state);
-        MalmoBox old_observation = observationSpace.getObservation(original_world_state);
+        Box last_observation = observationSpace.getObservation(world_state);
+        Box old_observation = observationSpace.getObservation(original_world_state);
 
-        double[] newvalues = old_observation == null ? null : old_observation.toArray();
-        double[] oldvalues = last_observation == null ? null : last_observation.toArray();
+        INDArray newvalues = old_observation == null ? null : old_observation.getData();
+        INDArray oldvalues = last_observation == null ? null : last_observation.getData();
 
         return !(world_state.getObservations().isEmpty() || world_state.getRewards().isEmpty()
-                        || Arrays.equals(oldvalues, newvalues));
+                        || oldvalues.eq(newvalues).all());
     }
 
 }
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java
index b27412b99..b98db3650 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java
@@ -21,6 +21,7 @@ import java.nio.file.Paths;
 
 import org.deeplearning4j.gym.StepReply;
 import org.deeplearning4j.rl4j.mdp.MDP;
+import org.deeplearning4j.rl4j.space.Box;
 import org.deeplearning4j.rl4j.space.DiscreteSpace;
 
 import com.microsoft.msr.malmo.AgentHost;
@@ -34,6 +35,7 @@ import com.microsoft.msr.malmo.WorldState;
 import lombok.Setter;
 import lombok.Getter;
 
+import org.deeplearning4j.rl4j.space.Encodable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -233,7 +235,7 @@ public class MalmoEnv implements MDP {
             logger.info("Mission ended");
         }
 
-        return new StepReply(last_observation, getRewards(last_world_state), isDone(), null);
+        return new StepReply<>(last_observation, getRewards(last_world_state), isDone(), null);
     }
 
     private double getRewards(WorldState world_state) {
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java
index 61a0dddc7..cc140bee2 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java
@@ -16,6 +16,8 @@
 
 package org.deeplearning4j.malmo;
 
+import org.deeplearning4j.rl4j.space.Box;
+import org.deeplearning4j.rl4j.space.Encodable;
 import org.deeplearning4j.rl4j.space.ObservationSpace;
 
 import com.microsoft.msr.malmo.WorldState;
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java
index 00b7c4f7a..1595def55 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java
@@ -19,6 +19,7 @@ package org.deeplearning4j.malmo;
 
 import com.microsoft.msr.malmo.TimestampedStringVector;
 import com.microsoft.msr.malmo.WorldState;
+import org.deeplearning4j.rl4j.space.Box;
 import org.json.JSONArray;
 import org.json.JSONObject;
 import org.nd4j.linalg.api.ndarray.INDArray;
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java
index 52dc02918..4fbbb6cc2 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.malmo;
 
 import java.util.HashMap;
 
+import org.deeplearning4j.rl4j.space.Box;
 import org.nd4j.linalg.api.ndarray.INDArray;
 import org.nd4j.linalg.factory.Nd4j;
 
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java
index 50f710bf5..cf85059d8 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java
@@ -16,6 +16,7 @@
 
 package org.deeplearning4j.malmo;
 
+import org.deeplearning4j.rl4j.space.Box;
 import org.json.JSONObject;
 
 import org.nd4j.linalg.api.ndarray.INDArray;