RL4J: Sanitize Observation (#404)
* working on ALE image pipelines that appear to lose data * transformation pipeline for ALE has been broken for a while and needed some cleanup to make sure that openCV tooling for scene transforms was actually working. * allowing history length to be set and passed through to history merge transforms Signed-off-by: Bam4d <chrisbam4d@gmail.com> * native image loader is not thread-safe so should not be static Signed-off-by: Bam4d <chrisbam4d@gmail.com> * make sure the transformer for encoding observations that are not pixels converts corectly Signed-off-by: Bam4d <chrisbam4d@gmail.com> * Test fixes for ALE pixel observation shape Signed-off-by: Bam4d <chrisbam4d@gmail.com> * Fix compilation errors Signed-off-by: Samuel Audet <samuel.audet@gmail.com> * fixing some post-merge issues, and comments from PR Signed-off-by: Bam4d <chrisbam4d@gmail.com> Co-authored-by: Samuel Audet <samuel.audet@gmail.com>
This commit is contained in:
		
							parent
							
								
									75cc6e2ed7
								
							
						
					
					
						commit
						032b97912e
					
				| @ -25,9 +25,13 @@ import org.bytedeco.javacpp.IntPointer; | |||||||
| import org.deeplearning4j.gym.StepReply; | import org.deeplearning4j.gym.StepReply; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.space.ArrayObservationSpace; | import org.deeplearning4j.rl4j.space.ArrayObservationSpace; | ||||||
|  | import org.deeplearning4j.rl4j.space.Box; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.space.ObservationSpace; | import org.deeplearning4j.rl4j.space.ObservationSpace; | ||||||
|  | import org.nd4j.linalg.api.buffer.DataType; | ||||||
|  | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
|  | import org.nd4j.linalg.factory.Nd4j; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * @author saudet |  * @author saudet | ||||||
| @ -70,10 +74,14 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> { | |||||||
|         actions = new int[(int)a.limit()]; |         actions = new int[(int)a.limit()]; | ||||||
|         a.get(actions); |         a.get(actions); | ||||||
| 
 | 
 | ||||||
|  |         int height = (int)ale.getScreen().height(); | ||||||
|  |         int width = (int)(int)ale.getScreen().width(); | ||||||
|  | 
 | ||||||
|         discreteSpace = new DiscreteSpace(actions.length); |         discreteSpace = new DiscreteSpace(actions.length); | ||||||
|         int[] shape = {(int)ale.getScreen().height(), (int)ale.getScreen().width(), 3}; |         int[] shape = {3, height, width}; | ||||||
|         observationSpace = new ArrayObservationSpace<>(shape); |         observationSpace = new ArrayObservationSpace<>(shape); | ||||||
|         screenBuffer = new byte[shape[0] * shape[1] * shape[2]]; |         screenBuffer = new byte[shape[0] * shape[1] * shape[2]]; | ||||||
|  | 
 | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public void setupGame() { |     public void setupGame() { | ||||||
| @ -103,7 +111,7 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> { | |||||||
|     public GameScreen reset() { |     public GameScreen reset() { | ||||||
|         ale.reset_game(); |         ale.reset_game(); | ||||||
|         ale.getScreenRGB(screenBuffer); |         ale.getScreenRGB(screenBuffer); | ||||||
|         return new GameScreen(screenBuffer); |         return new GameScreen(observationSpace.getShape(), screenBuffer); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -115,7 +123,8 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> { | |||||||
|         double r = ale.act(actions[action]) * scaleFactor; |         double r = ale.act(actions[action]) * scaleFactor; | ||||||
|         log.info(ale.getEpisodeFrameNumber() + " " + r + " " + action + " "); |         log.info(ale.getEpisodeFrameNumber() + " " + r + " " + action + " "); | ||||||
|         ale.getScreenRGB(screenBuffer); |         ale.getScreenRGB(screenBuffer); | ||||||
|         return new StepReply(new GameScreen(screenBuffer), r, ale.game_over(), null); | 
 | ||||||
|  |         return new StepReply(new GameScreen(observationSpace.getShape(), screenBuffer), r, ale.game_over(), null); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public ObservationSpace<GameScreen> getObservationSpace() { |     public ObservationSpace<GameScreen> getObservationSpace() { | ||||||
| @ -140,17 +149,35 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public static class GameScreen implements Encodable { |     public static class GameScreen implements Encodable { | ||||||
|         double[] array; |  | ||||||
| 
 | 
 | ||||||
|         public GameScreen(byte[] screen) { |         final INDArray data; | ||||||
|             array = new double[screen.length]; |         public GameScreen(int[] shape, byte[] screen) { | ||||||
|             for (int i = 0; i < screen.length; i++) { | 
 | ||||||
|                 array[i] = (screen[i] & 0xFF) / 255.0; |             data = Nd4j.create(screen, new long[] {shape[1], shape[2], 3}, DataType.UINT8).permute(2,0,1); | ||||||
|             } |  | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         private GameScreen(INDArray toDup) { | ||||||
|  |             data = toDup.dup(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         @Override | ||||||
|         public double[] toArray() { |         public double[] toArray() { | ||||||
|             return array; |             return data.data().asDouble(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         @Override | ||||||
|  |         public boolean isSkipped() { | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         @Override | ||||||
|  |         public INDArray getData() { | ||||||
|  |             return data; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         @Override | ||||||
|  |         public GameScreen dup() { | ||||||
|  |             return new GameScreen(data); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -19,15 +19,15 @@ package org.deeplearning4j.gym; | |||||||
| import lombok.Value; | import lombok.Value; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * @param <T> type of observation |  * @param <OBSERVATION> type of observation | ||||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16. |  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16. | ||||||
|  * |  * | ||||||
|  *  StepReply is the container for the data returned after each step(action). |  *  StepReply is the container for the data returned after each step(action). | ||||||
|  */ |  */ | ||||||
| @Value | @Value | ||||||
| public class StepReply<T> { | public class StepReply<OBSERVATION> { | ||||||
| 
 | 
 | ||||||
|     T observation; |     OBSERVATION observation; | ||||||
|     double reward; |     double reward; | ||||||
|     boolean done; |     boolean done; | ||||||
|     Object info; |     Object info; | ||||||
|  | |||||||
| @ -32,7 +32,7 @@ import org.deeplearning4j.rl4j.space.ObservationSpace; | |||||||
|  * in a "functionnal manner" if step return a mdp |  * in a "functionnal manner" if step return a mdp | ||||||
|  * |  * | ||||||
|  */ |  */ | ||||||
| public interface MDP<OBSERVATION, ACTION, ACTION_SPACE extends ActionSpace<ACTION>> { | public interface MDP<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>> { | ||||||
| 
 | 
 | ||||||
|     ObservationSpace<OBSERVATION> getObservationSpace(); |     ObservationSpace<OBSERVATION> getObservationSpace(); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -16,6 +16,9 @@ | |||||||
| 
 | 
 | ||||||
| package org.deeplearning4j.rl4j.space; | package org.deeplearning4j.rl4j.space; | ||||||
| 
 | 
 | ||||||
|  | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
|  | import org.nd4j.linalg.factory.Nd4j; | ||||||
|  | 
 | ||||||
| /** | /** | ||||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. |  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. | ||||||
|  * |  * | ||||||
| @ -25,13 +28,37 @@ package org.deeplearning4j.rl4j.space; | |||||||
|  */ |  */ | ||||||
| public class Box implements Encodable { | public class Box implements Encodable { | ||||||
| 
 | 
 | ||||||
|     private final double[] array; |     private final INDArray data; | ||||||
| 
 | 
 | ||||||
|     public Box(double[] arr) { |     public Box(double... arr) { | ||||||
|         this.array = arr; |         this.data = Nd4j.create(arr); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     public Box(int[] shape, double... arr) { | ||||||
|  |         this.data = Nd4j.create(arr).reshape(shape); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private Box(INDArray toDup) { | ||||||
|  |         data = toDup.dup(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|     public double[] toArray() { |     public double[] toArray() { | ||||||
|         return array; |         return data.data().asDouble(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public boolean isSkipped() { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public INDArray getData() { | ||||||
|  |         return data; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public Encodable dup() { | ||||||
|  |         return new Box(data); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,5 +1,5 @@ | |||||||
| /******************************************************************************* | /******************************************************************************* | ||||||
|  * Copyright (c) 2015-2018 Skymind, Inc. |  * Copyright (c) 2020 Konduit K. K. | ||||||
|  * |  * | ||||||
|  * This program and the accompanying materials are made available under the |  * This program and the accompanying materials are made available under the | ||||||
|  * terms of the Apache License, Version 2.0 which is available at |  * terms of the Apache License, Version 2.0 which is available at | ||||||
| @ -16,17 +16,19 @@ | |||||||
| 
 | 
 | ||||||
| package org.deeplearning4j.rl4j.space; | package org.deeplearning4j.rl4j.space; | ||||||
| 
 | 
 | ||||||
| /** | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/19/16. | 
 | ||||||
|  *         Encodable is an interface that ensure that the state is convertible to a double array |  | ||||||
|  */ |  | ||||||
| public interface Encodable { | public interface Encodable { | ||||||
| 
 | 
 | ||||||
|     /** |     @Deprecated | ||||||
|      * $ |  | ||||||
|      * encodes all the information of an Observation in an array double and can be used as input of a DQN directly |  | ||||||
|      * |  | ||||||
|      * @return the encoded informations |  | ||||||
|      */ |  | ||||||
|     double[] toArray(); |     double[] toArray(); | ||||||
|  | 
 | ||||||
|  |     boolean isSkipped(); | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Any image data should be in CHW format. | ||||||
|  |      */ | ||||||
|  |     INDArray getData(); | ||||||
|  | 
 | ||||||
|  |     Encodable dup(); | ||||||
| } | } | ||||||
|  | |||||||
| @ -24,16 +24,17 @@ import org.nd4j.linalg.factory.Nd4j; | |||||||
|  * @author Alexandre Boulanger |  * @author Alexandre Boulanger | ||||||
|  */ |  */ | ||||||
| public class INDArrayHelper { | public class INDArrayHelper { | ||||||
|  | 
 | ||||||
|     /** |     /** | ||||||
|      * MultiLayerNetwork and ComputationGraph expect the first dimension to be the number of examples in the INDArray. |      * MultiLayerNetwork and ComputationGraph expects input data to be in NCHW in the case of pixels and NS in case of other data types. | ||||||
|      * In the case of RL4J, it must be 1. This method will return a INDArray with the correct shape. |  | ||||||
|      * |      * | ||||||
|      * @param source A INDArray |      * We must have either shape 2 (NK) or shape 4 (NCHW) | ||||||
|      * @return The source INDArray with the correct shape |  | ||||||
|      */ |      */ | ||||||
|     public static INDArray forceCorrectShape(INDArray source) { |     public static INDArray forceCorrectShape(INDArray source) { | ||||||
|  | 
 | ||||||
|         return source.shape()[0] == 1 && source.shape().length > 1 |         return source.shape()[0] == 1 && source.shape().length > 1 | ||||||
|                 ? source |                 ? source | ||||||
|                 : Nd4j.expandDims(source, 0); |                 : Nd4j.expandDims(source, 0); | ||||||
|  | 
 | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -46,7 +46,6 @@ public class HistoryProcessor implements IHistoryProcessor { | |||||||
| 
 | 
 | ||||||
|     @Getter |     @Getter | ||||||
|     final private Configuration conf; |     final private Configuration conf; | ||||||
|     final private OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat(); |  | ||||||
|     private CircularFifoQueue<INDArray> history; |     private CircularFifoQueue<INDArray> history; | ||||||
|     private VideoRecorder videoRecorder; |     private VideoRecorder videoRecorder; | ||||||
| 
 | 
 | ||||||
| @ -63,8 +62,7 @@ public class HistoryProcessor implements IHistoryProcessor { | |||||||
| 
 | 
 | ||||||
|     public void startMonitor(String filename, int[] shape) { |     public void startMonitor(String filename, int[] shape) { | ||||||
|         if(videoRecorder == null) { |         if(videoRecorder == null) { | ||||||
|             videoRecorder = VideoRecorder.builder(shape[0], shape[1]) |             videoRecorder = VideoRecorder.builder(shape[1], shape[2]) | ||||||
|                     .frameInputType(VideoRecorder.FrameInputTypes.Float) |  | ||||||
|                     .build(); |                     .build(); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| @ -89,14 +87,13 @@ public class HistoryProcessor implements IHistoryProcessor { | |||||||
|         return videoRecorder != null && videoRecorder.isRecording(); |         return videoRecorder != null && videoRecorder.isRecording(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public void record(INDArray raw) { |     public void record(INDArray pixelArray) { | ||||||
|         if(isMonitoring()) { |         if(isMonitoring()) { | ||||||
|             // before accessing the raw pointer, we need to make sure that array is actual on the host side |             // before accessing the raw pointer, we need to make sure that array is actual on the host side | ||||||
|             Nd4j.getAffinityManager().ensureLocation(raw, AffinityManager.Location.HOST); |             Nd4j.getAffinityManager().ensureLocation(pixelArray, AffinityManager.Location.HOST); | ||||||
| 
 | 
 | ||||||
|             VideoRecorder.VideoFrame frame = videoRecorder.createFrame(raw.data().pointer()); |  | ||||||
|             try { |             try { | ||||||
|                 videoRecorder.record(frame); |                 videoRecorder.record(pixelArray); | ||||||
|             } catch (Exception e) { |             } catch (Exception e) { | ||||||
|                 e.printStackTrace(); |                 e.printStackTrace(); | ||||||
|             } |             } | ||||||
|  | |||||||
| @ -64,7 +64,7 @@ public interface IHistoryProcessor { | |||||||
|         @Builder.Default int skipFrame = 4; |         @Builder.Default int skipFrame = 4; | ||||||
| 
 | 
 | ||||||
|         public int[] getShape() { |         public int[] getShape() { | ||||||
|             return new int[] {getHistoryLength(), getCroppingHeight(), getCroppingWidth()}; |             return new int[] {getHistoryLength(), getRescaledHeight(), getRescaledWidth()}; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.space.Encodable; | |||||||
|  * |  * | ||||||
|  * A common interface that any training method should implement |  * A common interface that any training method should implement | ||||||
|  */ |  */ | ||||||
| public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> { | public interface ILearning<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> { | ||||||
| 
 | 
 | ||||||
|     IPolicy<A> getPolicy(); |     IPolicy<A> getPolicy(); | ||||||
| 
 | 
 | ||||||
| @ -38,7 +38,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> { | |||||||
| 
 | 
 | ||||||
|     ILearningConfiguration getConfiguration(); |     ILearningConfiguration getConfiguration(); | ||||||
| 
 | 
 | ||||||
|     MDP<O, A, AS> getMdp(); |     MDP<OBSERVATION, A, AS> getMdp(); | ||||||
| 
 | 
 | ||||||
|     IHistoryProcessor getHistoryProcessor(); |     IHistoryProcessor getHistoryProcessor(); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -21,7 +21,6 @@ import lombok.Getter; | |||||||
| import lombok.Setter; | import lombok.Setter; | ||||||
| import lombok.Value; | import lombok.Value; | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; |  | ||||||
| import org.deeplearning4j.rl4j.network.NeuralNet; | import org.deeplearning4j.rl4j.network.NeuralNet; | ||||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | import org.deeplearning4j.rl4j.space.ActionSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| @ -38,8 +37,8 @@ import org.nd4j.linalg.factory.Nd4j; | |||||||
|  * |  * | ||||||
|  */ |  */ | ||||||
| @Slf4j | @Slf4j | ||||||
| public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> | public abstract class Learning<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> | ||||||
|                 implements ILearning<O, A, AS>, NeuralNetFetchable<NN> { |                 implements ILearning<OBSERVATION, A, AS>, NeuralNetFetchable<NN> { | ||||||
| 
 | 
 | ||||||
|     @Getter @Setter |     @Getter @Setter | ||||||
|     protected int stepCount = 0; |     protected int stepCount = 0; | ||||||
|  | |||||||
| @ -29,10 +29,10 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListener; | |||||||
| import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; | import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.NeuralNet; | import org.deeplearning4j.rl4j.network.NeuralNet; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.observation.Observation; | import org.deeplearning4j.rl4j.observation.Observation; | ||||||
| import org.deeplearning4j.rl4j.policy.IPolicy; | import org.deeplearning4j.rl4j.policy.IPolicy; | ||||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | import org.deeplearning4j.rl4j.space.ActionSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.deeplearning4j.rl4j.util.IDataManager; | import org.deeplearning4j.rl4j.util.IDataManager; | ||||||
| import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| @ -188,7 +188,7 @@ public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_ | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     private boolean handleTraining(RunContext context) { |     private boolean handleTraining(RunContext context) { | ||||||
|         int maxTrainSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpisodeStepCount); |         int maxTrainSteps = Math.min(getConfiguration().getNStep(), getConfiguration().getMaxEpochStep() - currentEpisodeStepCount); | ||||||
|         SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxTrainSteps); |         SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxTrainSteps); | ||||||
| 
 | 
 | ||||||
|         context.obs = subEpochReturn.getLastObs(); |         context.obs = subEpochReturn.getLastObs(); | ||||||
| @ -219,7 +219,7 @@ public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_ | |||||||
| 
 | 
 | ||||||
|     protected abstract IAsyncGlobal<NN> getAsyncGlobal(); |     protected abstract IAsyncGlobal<NN> getAsyncGlobal(); | ||||||
| 
 | 
 | ||||||
|     protected abstract IAsyncLearningConfiguration getConf(); |     protected abstract IAsyncLearningConfiguration getConfiguration(); | ||||||
| 
 | 
 | ||||||
|     protected abstract IPolicy<ACTION> getPolicy(NN net); |     protected abstract IPolicy<ACTION> getPolicy(NN net); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -24,29 +24,22 @@ import lombok.Setter; | |||||||
| import org.deeplearning4j.gym.StepReply; | import org.deeplearning4j.gym.StepReply; | ||||||
| import org.deeplearning4j.rl4j.experience.ExperienceHandler; | import org.deeplearning4j.rl4j.experience.ExperienceHandler; | ||||||
| import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; | import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; | ||||||
| import org.deeplearning4j.rl4j.experience.ExperienceHandler; |  | ||||||
| import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; |  | ||||||
| import org.deeplearning4j.rl4j.experience.StateActionPair; |  | ||||||
| import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | ||||||
| import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; | import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.NeuralNet; | import org.deeplearning4j.rl4j.network.NeuralNet; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.observation.Observation; | import org.deeplearning4j.rl4j.observation.Observation; | ||||||
| import org.deeplearning4j.rl4j.policy.IPolicy; | import org.deeplearning4j.rl4j.policy.IPolicy; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.util.Stack; |  | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. |  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. | ||||||
|  * <p> |  * <p> | ||||||
|  * Async Learning specialized for the Discrete Domain |  * Async Learning specialized for the Discrete Domain | ||||||
|  */ |  */ | ||||||
| public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet> | public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN extends NeuralNet> | ||||||
|         extends AsyncThread<O, Integer, DiscreteSpace, NN> { |         extends AsyncThread<OBSERVATION, Integer, DiscreteSpace, NN> { | ||||||
| 
 | 
 | ||||||
|     @Getter |     @Getter | ||||||
|     private NN current; |     private NN current; | ||||||
| @ -59,7 +52,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural | |||||||
|     private ExperienceHandler experienceHandler = new StateActionExperienceHandler(); |     private ExperienceHandler experienceHandler = new StateActionExperienceHandler(); | ||||||
| 
 | 
 | ||||||
|     public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, |     public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, | ||||||
|                                MDP<O, Integer, DiscreteSpace> mdp, |                                MDP<OBSERVATION, Integer, DiscreteSpace> mdp, | ||||||
|                                TrainingListenerList listeners, |                                TrainingListenerList listeners, | ||||||
|                                int threadNumber, |                                int threadNumber, | ||||||
|                                int deviceNum) { |                                int deviceNum) { | ||||||
| @ -112,7 +105,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural | |||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action); |             StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action); | ||||||
|             accuReward += stepReply.getReward() * getConf().getRewardFactor(); |             accuReward += stepReply.getReward() * getConfiguration().getRewardFactor(); | ||||||
| 
 | 
 | ||||||
|             if (!obs.isSkipped()) { |             if (!obs.isSkipped()) { | ||||||
|                 experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone()); |                 experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone()); | ||||||
| @ -126,7 +119,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural | |||||||
| 
 | 
 | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         boolean episodeComplete = getMdp().isDone() || getConf().getMaxEpochStep() == currentEpisodeStepCount; |         boolean episodeComplete = getMdp().isDone() || getConfiguration().getMaxEpochStep() == currentEpisodeStepCount; | ||||||
| 
 | 
 | ||||||
|         if (episodeComplete && experienceHandler.getTrainingBatchSize() != trainingSteps) { |         if (episodeComplete && experienceHandler.getTrainingBatchSize() != trainingSteps) { | ||||||
|             experienceHandler.setFinalObservation(obs); |             experienceHandler.setFinalObservation(obs); | ||||||
|  | |||||||
| @ -28,9 +28,9 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread; | |||||||
| import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; | import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.ac.IActorCritic; | import org.deeplearning4j.rl4j.network.ac.IActorCritic; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.policy.ACPolicy; | import org.deeplearning4j.rl4j.policy.ACPolicy; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.nd4j.linalg.api.rng.Random; | import org.nd4j.linalg.api.rng.Random; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| 
 | 
 | ||||||
| @ -41,19 +41,19 @@ import org.nd4j.linalg.factory.Nd4j; | |||||||
|  * All methods are fully implemented as described in the |  * All methods are fully implemented as described in the | ||||||
|  * https://arxiv.org/abs/1602.01783 paper. |  * https://arxiv.org/abs/1602.01783 paper. | ||||||
|  */ |  */ | ||||||
| public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, Integer, DiscreteSpace, IActorCritic> { | public abstract class A3CDiscrete<OBSERVATION extends Encodable> extends AsyncLearning<OBSERVATION, Integer, DiscreteSpace, IActorCritic> { | ||||||
| 
 | 
 | ||||||
|     @Getter |     @Getter | ||||||
|     final public A3CLearningConfiguration configuration; |     final public A3CLearningConfiguration configuration; | ||||||
|     @Getter |     @Getter | ||||||
|     final protected MDP<O, Integer, DiscreteSpace> mdp; |     final protected MDP<OBSERVATION, Integer, DiscreteSpace> mdp; | ||||||
|     final private IActorCritic iActorCritic; |     final private IActorCritic iActorCritic; | ||||||
|     @Getter |     @Getter | ||||||
|     final private AsyncGlobal asyncGlobal; |     final private AsyncGlobal asyncGlobal; | ||||||
|     @Getter |     @Getter | ||||||
|     final private ACPolicy<O> policy; |     final private ACPolicy<OBSERVATION> policy; | ||||||
| 
 | 
 | ||||||
|     public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) { |     public A3CDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) { | ||||||
|         this.iActorCritic = iActorCritic; |         this.iActorCritic = iActorCritic; | ||||||
|         this.mdp = mdp; |         this.mdp = mdp; | ||||||
|         this.configuration = conf; |         this.configuration = conf; | ||||||
|  | |||||||
| @ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph; | |||||||
| import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv; | import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv; | ||||||
| import org.deeplearning4j.rl4j.network.ac.IActorCritic; | import org.deeplearning4j.rl4j.network.ac.IActorCritic; | ||||||
| import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration; | import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; |  | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | ||||||
| import org.deeplearning4j.rl4j.util.IDataManager; | import org.deeplearning4j.rl4j.util.IDataManager; | ||||||
| 
 | 
 | ||||||
| @ -42,19 +42,19 @@ import org.deeplearning4j.rl4j.util.IDataManager; | |||||||
|  * first layers since they're essentially doing the same dimension |  * first layers since they're essentially doing the same dimension | ||||||
|  * reduction task |  * reduction task | ||||||
|  **/ |  **/ | ||||||
| public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> { | public class A3CDiscreteConv<OBSERVATION extends Encodable> extends A3CDiscrete<OBSERVATION> { | ||||||
| 
 | 
 | ||||||
|     final private HistoryProcessor.Configuration hpconf; |     final private HistoryProcessor.Configuration hpconf; | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, |     public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, | ||||||
|                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { |                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, actorCritic, hpconf, conf); |         this(mdp, actorCritic, hpconf, conf); | ||||||
|         addListener(new DataManagerTrainingListener(dataManager)); |         addListener(new DataManagerTrainingListener(dataManager)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, |     public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, | ||||||
|                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { |                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { | ||||||
| 
 | 
 | ||||||
|         super(mdp, IActorCritic, conf.toLearningConfiguration()); |         super(mdp, IActorCritic, conf.toLearningConfiguration()); | ||||||
| @ -62,7 +62,7 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> { | |||||||
|         setHistoryProcessor(hpconf); |         setHistoryProcessor(hpconf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, |     public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, | ||||||
|                            HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { |                            HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { | ||||||
|         super(mdp, IActorCritic, conf); |         super(mdp, IActorCritic, conf); | ||||||
|         this.hpconf = hpconf; |         this.hpconf = hpconf; | ||||||
| @ -70,35 +70,35 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, |     public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, | ||||||
|                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { |                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); |         this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, |     public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, | ||||||
|                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { |                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { | ||||||
|         this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); |         this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, |     public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, | ||||||
|                            HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { |                            HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { | ||||||
|         this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); |         this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, |     public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, | ||||||
|                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { |                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); |         this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, |     public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, | ||||||
|                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { |                            HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { | ||||||
|         this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf); |         this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticNetworkConfiguration netConf, |     public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticNetworkConfiguration netConf, | ||||||
|                            HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { |                            HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { | ||||||
|         this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf); |         this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -21,8 +21,8 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; | |||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.ac.*; | import org.deeplearning4j.rl4j.network.ac.*; | ||||||
| import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; | import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; |  | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | ||||||
| import org.deeplearning4j.rl4j.util.IDataManager; | import org.deeplearning4j.rl4j.util.IDataManager; | ||||||
| 
 | 
 | ||||||
| @ -34,74 +34,74 @@ import org.deeplearning4j.rl4j.util.IDataManager; | |||||||
|  * We use specifically the Separate version because |  * We use specifically the Separate version because | ||||||
|  * the model is too small to have enough benefit by sharing layers |  * the model is too small to have enough benefit by sharing layers | ||||||
|  */ |  */ | ||||||
| public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> { | public class A3CDiscreteDense<OBSERVATION extends Encodable> extends A3CDiscrete<OBSERVATION> { | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf, |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf, | ||||||
|                             IDataManager dataManager) { |                             IDataManager dataManager) { | ||||||
|         this(mdp, IActorCritic, conf); |         this(mdp, IActorCritic, conf); | ||||||
|         addListener(new DataManagerTrainingListener(dataManager)); |         addListener(new DataManagerTrainingListener(dataManager)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) { |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) { | ||||||
|         super(mdp, actorCritic, conf.toLearningConfiguration()); |         super(mdp, actorCritic, conf.toLearningConfiguration()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) { |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) { | ||||||
|         super(mdp, actorCritic, conf); |         super(mdp, actorCritic, conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, | ||||||
|                             A3CConfiguration conf, IDataManager dataManager) { |                             A3CConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, |         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, | ||||||
|                 dataManager); |                 dataManager); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, | ||||||
|                             A3CConfiguration conf) { |                             A3CConfiguration conf) { | ||||||
|         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); |         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, | ||||||
|                             A3CLearningConfiguration conf) { |                             A3CLearningConfiguration conf) { | ||||||
|         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); |         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, | ||||||
|                             ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, |                             ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, | ||||||
|                             IDataManager dataManager) { |                             IDataManager dataManager) { | ||||||
|         this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager); |         this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, | ||||||
|                             ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) { |                             ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) { | ||||||
|         this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf); |         this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, | ||||||
|                             ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) { |                             ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) { | ||||||
|         this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf); |         this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, | ||||||
|                             A3CConfiguration conf, IDataManager dataManager) { |                             A3CConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, |         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, | ||||||
|                 dataManager); |                 dataManager); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, | ||||||
|                             A3CConfiguration conf) { |                             A3CConfiguration conf) { | ||||||
|         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); |         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, |     public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, | ||||||
|                             A3CLearningConfiguration conf) { |                             A3CLearningConfiguration conf) { | ||||||
|         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); |         this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -23,23 +23,23 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; | |||||||
| import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; | import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.ac.IActorCritic; | import org.deeplearning4j.rl4j.network.ac.IActorCritic; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.policy.ACPolicy; | import org.deeplearning4j.rl4j.policy.ACPolicy; | ||||||
| import org.deeplearning4j.rl4j.policy.Policy; | import org.deeplearning4j.rl4j.policy.Policy; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.nd4j.linalg.api.rng.Random; | import org.nd4j.linalg.api.rng.Random; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| import org.nd4j.linalg.api.rng.Random; | import org.nd4j.linalg.api.rng.Random; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. |  * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. | ||||||
|  * |  * <p> | ||||||
|  * Local thread as described in the https://arxiv.org/abs/1602.01783 paper. |  * Local thread as described in the https://arxiv.org/abs/1602.01783 paper. | ||||||
|  */ |  */ | ||||||
| public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IActorCritic> { | public class A3CThreadDiscrete<OBSERVATION extends Encodable> extends AsyncThreadDiscrete<OBSERVATION, IActorCritic> { | ||||||
| 
 | 
 | ||||||
|     @Getter |     @Getter | ||||||
|     final protected A3CLearningConfiguration conf; |     final protected A3CLearningConfiguration configuration; | ||||||
|     @Getter |     @Getter | ||||||
|     final protected IAsyncGlobal<IActorCritic> asyncGlobal; |     final protected IAsyncGlobal<IActorCritic> asyncGlobal; | ||||||
|     @Getter |     @Getter | ||||||
| @ -47,15 +47,15 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete< | |||||||
| 
 | 
 | ||||||
|     final private Random rnd; |     final private Random rnd; | ||||||
| 
 | 
 | ||||||
|     public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal, |     public A3CThreadDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal, | ||||||
|                              A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners, |                              A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners, | ||||||
|                              int threadNumber) { |                              int threadNumber) { | ||||||
|         super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); |         super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); | ||||||
|         this.conf = a3cc; |         this.configuration = a3cc; | ||||||
|         this.asyncGlobal = asyncGlobal; |         this.asyncGlobal = asyncGlobal; | ||||||
|         this.threadNumber = threadNumber; |         this.threadNumber = threadNumber; | ||||||
| 
 | 
 | ||||||
|         Long seed = conf.getSeed(); |         Long seed = configuration.getSeed(); | ||||||
|         rnd = Nd4j.getRandom(); |         rnd = Nd4j.getRandom(); | ||||||
|         if (seed != null) { |         if (seed != null) { | ||||||
|             rnd.setSeed(seed + threadNumber); |             rnd.setSeed(seed + threadNumber); | ||||||
| @ -69,9 +69,12 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete< | |||||||
|         return new ACPolicy(net, rnd); |         return new ACPolicy(net, rnd); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     /** | ||||||
|  |      * calc the gradients based on the n-step rewards | ||||||
|  |      */ | ||||||
|     @Override |     @Override | ||||||
|     protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() { |     protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() { | ||||||
|         int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); |         int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); | ||||||
|         return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), conf.getGamma()); |         return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), configuration.getGamma()); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -28,26 +28,26 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread; | |||||||
| import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; | import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.policy.DQNPolicy; | import org.deeplearning4j.rl4j.policy.DQNPolicy; | ||||||
| import org.deeplearning4j.rl4j.policy.IPolicy; | import org.deeplearning4j.rl4j.policy.IPolicy; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. |  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. | ||||||
|  */ |  */ | ||||||
| public abstract class AsyncNStepQLearningDiscrete<O extends Encodable> | public abstract class AsyncNStepQLearningDiscrete<OBSERVATION extends Encodable> | ||||||
|         extends AsyncLearning<O, Integer, DiscreteSpace, IDQN> { |         extends AsyncLearning<OBSERVATION, Integer, DiscreteSpace, IDQN> { | ||||||
| 
 | 
 | ||||||
|     @Getter |     @Getter | ||||||
|     final public AsyncQLearningConfiguration configuration; |     final public AsyncQLearningConfiguration configuration; | ||||||
|     @Getter |     @Getter | ||||||
|     final private MDP<O, Integer, DiscreteSpace> mdp; |     final private MDP<OBSERVATION, Integer, DiscreteSpace> mdp; | ||||||
|     @Getter |     @Getter | ||||||
|     final private AsyncGlobal<IDQN> asyncGlobal; |     final private AsyncGlobal<IDQN> asyncGlobal; | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) { |     public AsyncNStepQLearningDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) { | ||||||
|         this.mdp = mdp; |         this.mdp = mdp; | ||||||
|         this.configuration = conf; |         this.configuration = conf; | ||||||
|         this.asyncGlobal = new AsyncGlobal<>(dqn, conf); |         this.asyncGlobal = new AsyncGlobal<>(dqn, conf); | ||||||
| @ -63,7 +63,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable> | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public IPolicy<Integer> getPolicy() { |     public IPolicy<Integer> getPolicy() { | ||||||
|         return new DQNPolicy<O>(getNeuralNet()); |         return new DQNPolicy<OBSERVATION>(getNeuralNet()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Data |     @Data | ||||||
|  | |||||||
| @ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration; | |||||||
| import org.deeplearning4j.rl4j.network.dqn.DQNFactory; | import org.deeplearning4j.rl4j.network.dqn.DQNFactory; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; | import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; |  | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | ||||||
| import org.deeplearning4j.rl4j.util.IDataManager; | import org.deeplearning4j.rl4j.util.IDataManager; | ||||||
| 
 | 
 | ||||||
| @ -35,17 +35,17 @@ import org.deeplearning4j.rl4j.util.IDataManager; | |||||||
|  * Specialized constructors for the Conv (pixels input) case |  * Specialized constructors for the Conv (pixels input) case | ||||||
|  * Specialized conf + provide additional type safety |  * Specialized conf + provide additional type safety | ||||||
|  */ |  */ | ||||||
| public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> { | public class AsyncNStepQLearningDiscreteConv<OBSERVATION extends Encodable> extends AsyncNStepQLearningDiscrete<OBSERVATION> { | ||||||
| 
 | 
 | ||||||
|     final private HistoryProcessor.Configuration hpconf; |     final private HistoryProcessor.Configuration hpconf; | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, |     public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, | ||||||
|                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { |                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, dqn, hpconf, conf); |         this(mdp, dqn, hpconf, conf); | ||||||
|         addListener(new DataManagerTrainingListener(dataManager)); |         addListener(new DataManagerTrainingListener(dataManager)); | ||||||
|     } |     } | ||||||
|     public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, |     public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, | ||||||
|                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { |                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { | ||||||
|         super(mdp, dqn, conf); |         super(mdp, dqn, conf); | ||||||
|         this.hpconf = hpconf; |         this.hpconf = hpconf; | ||||||
| @ -53,21 +53,21 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, |     public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory, | ||||||
|                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { |                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); |         this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); | ||||||
|     } |     } | ||||||
|     public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, |     public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory, | ||||||
|                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { |                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { | ||||||
|         this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); |         this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf, |     public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf, | ||||||
|                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { |                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); |         this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); | ||||||
|     } |     } | ||||||
|     public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf, |     public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf, | ||||||
|                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { |                                            HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { | ||||||
|         this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); |         this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio | |||||||
| import org.deeplearning4j.rl4j.network.dqn.DQNFactory; | import org.deeplearning4j.rl4j.network.dqn.DQNFactory; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; | import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; |  | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | ||||||
| import org.deeplearning4j.rl4j.util.IDataManager; | import org.deeplearning4j.rl4j.util.IDataManager; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16. |  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16. | ||||||
|  */ |  */ | ||||||
| public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> { | public class AsyncNStepQLearningDiscreteDense<OBSERVATION extends Encodable> extends AsyncNStepQLearningDiscrete<OBSERVATION> { | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, |     public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, | ||||||
|                                             AsyncNStepQLConfiguration conf, IDataManager dataManager) { |                                             AsyncNStepQLConfiguration conf, IDataManager dataManager) { | ||||||
|         super(mdp, dqn, conf.toLearningConfiguration()); |         super(mdp, dqn, conf.toLearningConfiguration()); | ||||||
|         addListener(new DataManagerTrainingListener(dataManager)); |         addListener(new DataManagerTrainingListener(dataManager)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, |     public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, | ||||||
|                                             AsyncNStepQLConfiguration conf) { |                                             AsyncNStepQLConfiguration conf) { | ||||||
|         super(mdp, dqn, conf.toLearningConfiguration()); |         super(mdp, dqn, conf.toLearningConfiguration()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, |     public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, | ||||||
|                                             AsyncQLearningConfiguration conf) { |                                             AsyncQLearningConfiguration conf) { | ||||||
|         super(mdp, dqn, conf); |         super(mdp, dqn, conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, |     public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory, | ||||||
|                                             AsyncNStepQLConfiguration conf, IDataManager dataManager) { |                                             AsyncNStepQLConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, |         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, | ||||||
|                 dataManager); |                 dataManager); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, |     public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory, | ||||||
|                                             AsyncNStepQLConfiguration conf) { |                                             AsyncNStepQLConfiguration conf) { | ||||||
|         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); |         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, |     public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory, | ||||||
|                                             AsyncQLearningConfiguration conf) { |                                             AsyncQLearningConfiguration conf) { | ||||||
|         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); |         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, |     public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, | ||||||
|                                             DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { |                                             DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); |         this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, |     public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, | ||||||
|                                             DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) { |                                             DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) { | ||||||
|         this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); |         this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, |     public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, | ||||||
|                                             DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) { |                                             DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) { | ||||||
|         this(mdp, new DQNFactoryStdDense(netConf), conf); |         this(mdp, new DQNFactoryStdDense(netConf), conf); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -25,21 +25,21 @@ import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguratio | |||||||
| import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; | import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.policy.DQNPolicy; | import org.deeplearning4j.rl4j.policy.DQNPolicy; | ||||||
| import org.deeplearning4j.rl4j.policy.EpsGreedy; | import org.deeplearning4j.rl4j.policy.EpsGreedy; | ||||||
| import org.deeplearning4j.rl4j.policy.Policy; | import org.deeplearning4j.rl4j.policy.Policy; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.nd4j.linalg.api.rng.Random; | import org.nd4j.linalg.api.rng.Random; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. |  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. | ||||||
|  */ |  */ | ||||||
| public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IDQN> { | public class AsyncNStepQLearningThreadDiscrete<OBSERVATION extends Encodable> extends AsyncThreadDiscrete<OBSERVATION, IDQN> { | ||||||
| 
 | 
 | ||||||
|     @Getter |     @Getter | ||||||
|     final protected AsyncQLearningConfiguration conf; |     final protected AsyncQLearningConfiguration configuration; | ||||||
|     @Getter |     @Getter | ||||||
|     final protected IAsyncGlobal<IDQN> asyncGlobal; |     final protected IAsyncGlobal<IDQN> asyncGlobal; | ||||||
|     @Getter |     @Getter | ||||||
| @ -47,16 +47,16 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn | |||||||
| 
 | 
 | ||||||
|     final private Random rnd; |     final private Random rnd; | ||||||
| 
 | 
 | ||||||
|     public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal, |     public AsyncNStepQLearningThreadDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal, | ||||||
|                                              AsyncQLearningConfiguration conf, |                                              AsyncQLearningConfiguration configuration, | ||||||
|                                              TrainingListenerList listeners, int threadNumber, int deviceNum) { |                                              TrainingListenerList listeners, int threadNumber, int deviceNum) { | ||||||
|         super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); |         super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); | ||||||
|         this.conf = conf; |         this.configuration = configuration; | ||||||
|         this.asyncGlobal = asyncGlobal; |         this.asyncGlobal = asyncGlobal; | ||||||
|         this.threadNumber = threadNumber; |         this.threadNumber = threadNumber; | ||||||
|         rnd = Nd4j.getRandom(); |         rnd = Nd4j.getRandom(); | ||||||
| 
 | 
 | ||||||
|         Long seed = conf.getSeed(); |         Long seed = configuration.getSeed(); | ||||||
|         if(seed != null) { |         if(seed != null) { | ||||||
|             rnd.setSeed(seed + threadNumber); |             rnd.setSeed(seed + threadNumber); | ||||||
|         } |         } | ||||||
| @ -65,13 +65,13 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public Policy<Integer> getPolicy(IDQN nn) { |     public Policy<Integer> getPolicy(IDQN nn) { | ||||||
|         return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(), |         return new EpsGreedy(new DQNPolicy(nn), getMdp(), configuration.getUpdateStart(), configuration.getEpsilonNbStep(), | ||||||
|                 rnd, conf.getMinEpsilon(), this); |                 rnd, configuration.getMinEpsilon(), this); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() { |     protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() { | ||||||
|         int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); |         int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); | ||||||
|         return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), conf.getGamma()); |         return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), configuration.getGamma()); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -32,10 +32,10 @@ import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; | |||||||
| import org.deeplearning4j.rl4j.learning.sync.SyncLearning; | import org.deeplearning4j.rl4j.learning.sync.SyncLearning; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.observation.Observation; | import org.deeplearning4j.rl4j.observation.Observation; | ||||||
| import org.deeplearning4j.rl4j.policy.EpsGreedy; | import org.deeplearning4j.rl4j.policy.EpsGreedy; | ||||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | import org.deeplearning4j.rl4j.space.ActionSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.deeplearning4j.rl4j.util.IDataManager.StatEntry; | import org.deeplearning4j.rl4j.util.IDataManager.StatEntry; | ||||||
| import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -33,11 +33,11 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorith | |||||||
| import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; | import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.observation.Observation; | import org.deeplearning4j.rl4j.observation.Observation; | ||||||
| import org.deeplearning4j.rl4j.policy.DQNPolicy; | import org.deeplearning4j.rl4j.policy.DQNPolicy; | ||||||
| import org.deeplearning4j.rl4j.policy.EpsGreedy; | import org.deeplearning4j.rl4j.policy.EpsGreedy; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.rng.Random; | import org.nd4j.linalg.api.rng.Random; | ||||||
|  | |||||||
| @ -24,8 +24,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration; | |||||||
| import org.deeplearning4j.rl4j.network.dqn.DQNFactory; | import org.deeplearning4j.rl4j.network.dqn.DQNFactory; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; | import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; |  | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | ||||||
| import org.deeplearning4j.rl4j.util.IDataManager; | import org.deeplearning4j.rl4j.util.IDataManager; | ||||||
| 
 | 
 | ||||||
| @ -34,59 +34,59 @@ import org.deeplearning4j.rl4j.util.IDataManager; | |||||||
|  * Specialized constructors for the Conv (pixels input) case |  * Specialized constructors for the Conv (pixels input) case | ||||||
|  * Specialized conf + provide additional type safety |  * Specialized conf + provide additional type safety | ||||||
|  */ |  */ | ||||||
| public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscrete<O> { | public class QLearningDiscreteConv<OBSERVATION extends Encodable> extends QLearningDiscrete<OBSERVATION> { | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, |     public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, | ||||||
|                                  QLConfiguration conf, IDataManager dataManager) { |                                  QLConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, dqn, hpconf, conf); |         this(mdp, dqn, hpconf, conf); | ||||||
|         addListener(new DataManagerTrainingListener(dataManager)); |         addListener(new DataManagerTrainingListener(dataManager)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, |     public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, | ||||||
|                                  QLConfiguration conf) { |                                  QLConfiguration conf) { | ||||||
|         super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame()); |         super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame()); | ||||||
|         setHistoryProcessor(hpconf); |         setHistoryProcessor(hpconf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, |     public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, | ||||||
|                                  QLearningConfiguration conf) { |                                  QLearningConfiguration conf) { | ||||||
|         super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); |         super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); | ||||||
|         setHistoryProcessor(hpconf); |         setHistoryProcessor(hpconf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, |     public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory, | ||||||
|                                  HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { |                                  HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); |         this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, |     public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory, | ||||||
|                                  HistoryProcessor.Configuration hpconf, QLConfiguration conf) { |                                  HistoryProcessor.Configuration hpconf, QLConfiguration conf) { | ||||||
|         this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); |         this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, |     public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory, | ||||||
|                                  HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { |                                  HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { | ||||||
|         this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); |         this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, |     public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, | ||||||
|                                  HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { |                                  HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); |         this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, |     public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, | ||||||
|                                  HistoryProcessor.Configuration hpconf, QLConfiguration conf) { |                                  HistoryProcessor.Configuration hpconf, QLConfiguration conf) { | ||||||
|         this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf); |         this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf, |     public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf, | ||||||
|                                  HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { |                                  HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { | ||||||
|         this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); |         this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio | |||||||
| import org.deeplearning4j.rl4j.network.dqn.DQNFactory; | import org.deeplearning4j.rl4j.network.dqn.DQNFactory; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; | import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; |  | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; | ||||||
| import org.deeplearning4j.rl4j.util.IDataManager; | import org.deeplearning4j.rl4j.util.IDataManager; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16. |  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16. | ||||||
|  */ |  */ | ||||||
| public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscrete<O> { | public class QLearningDiscreteDense<OBSERVATION extends Encodable> extends QLearningDiscrete<OBSERVATION> { | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf, |     public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf, | ||||||
|                                   IDataManager dataManager) { |                                   IDataManager dataManager) { | ||||||
|         this(mdp, dqn, conf); |         this(mdp, dqn, conf); | ||||||
|         addListener(new DataManagerTrainingListener(dataManager)); |         addListener(new DataManagerTrainingListener(dataManager)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) { |     public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) { | ||||||
|         super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep()); |         super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf) { |     public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf) { | ||||||
|         super(mdp, dqn, conf, conf.getEpsilonNbStep()); |         super(mdp, dqn, conf, conf.getEpsilonNbStep()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, |     public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory, | ||||||
|                                   QLearning.QLConfiguration conf, IDataManager dataManager) { |                                   QLearning.QLConfiguration conf, IDataManager dataManager) { | ||||||
|         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, |         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, | ||||||
|                         dataManager); |                         dataManager); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, |     public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory, | ||||||
|                                   QLearning.QLConfiguration conf) { |                                   QLearning.QLConfiguration conf) { | ||||||
|         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); |         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, |     public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory, | ||||||
|                                   QLearningConfiguration conf) { |                                   QLearningConfiguration conf) { | ||||||
|         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); |         this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf, |     public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf, | ||||||
|                                   QLearning.QLConfiguration conf, IDataManager dataManager) { |                                   QLearning.QLConfiguration conf, IDataManager dataManager) { | ||||||
| 
 | 
 | ||||||
|         this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); |         this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf, |     public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf, | ||||||
|                                   QLearning.QLConfiguration conf) { |                                   QLearning.QLConfiguration conf) { | ||||||
|         this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); |         this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNDenseNetworkConfiguration netConf, |     public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNDenseNetworkConfiguration netConf, | ||||||
|                                   QLearningConfiguration conf) { |                                   QLearningConfiguration conf) { | ||||||
|         this(mdp, new DQNFactoryStdDense(netConf), conf); |         this(mdp, new DQNFactoryStdDense(netConf), conf); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -4,8 +4,8 @@ import lombok.Getter; | |||||||
| import lombok.Setter; | import lombok.Setter; | ||||||
| import org.deeplearning4j.gym.StepReply; | import org.deeplearning4j.gym.StepReply; | ||||||
| import org.deeplearning4j.rl4j.space.ArrayObservationSpace; | import org.deeplearning4j.rl4j.space.ArrayObservationSpace; | ||||||
|  | import org.deeplearning4j.rl4j.space.Box; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.deeplearning4j.rl4j.space.ObservationSpace; | import org.deeplearning4j.rl4j.space.ObservationSpace; | ||||||
| 
 | 
 | ||||||
| import java.util.Random; | import java.util.Random; | ||||||
| @ -36,7 +36,7 @@ import java.util.Random; | |||||||
| 
 | 
 | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| public class CartpoleNative implements MDP<CartpoleNative.State, Integer, DiscreteSpace> { | public class CartpoleNative implements MDP<Box, Integer, DiscreteSpace> { | ||||||
|     public enum KinematicsIntegrators { Euler, SemiImplicitEuler }; |     public enum KinematicsIntegrators { Euler, SemiImplicitEuler }; | ||||||
| 
 | 
 | ||||||
|     private static final int NUM_ACTIONS = 2; |     private static final int NUM_ACTIONS = 2; | ||||||
| @ -74,7 +74,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre | |||||||
|     @Getter |     @Getter | ||||||
|     private DiscreteSpace actionSpace = new DiscreteSpace(NUM_ACTIONS); |     private DiscreteSpace actionSpace = new DiscreteSpace(NUM_ACTIONS); | ||||||
|     @Getter |     @Getter | ||||||
|     private ObservationSpace<CartpoleNative.State> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES }); |     private ObservationSpace<Box> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES }); | ||||||
| 
 | 
 | ||||||
|     public CartpoleNative() { |     public CartpoleNative() { | ||||||
|         rnd = new Random(); |         rnd = new Random(); | ||||||
| @ -85,7 +85,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public State reset() { |     public Box reset() { | ||||||
| 
 | 
 | ||||||
|         x = 0.1 * rnd.nextDouble() - 0.05; |         x = 0.1 * rnd.nextDouble() - 0.05; | ||||||
|         xDot = 0.1 * rnd.nextDouble() - 0.05; |         xDot = 0.1 * rnd.nextDouble() - 0.05; | ||||||
| @ -94,7 +94,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre | |||||||
|         stepsBeyondDone = null; |         stepsBeyondDone = null; | ||||||
|         done = false; |         done = false; | ||||||
| 
 | 
 | ||||||
|         return new State(new double[] { x, xDot, theta, thetaDot }); |         return new Box(x, xDot, theta, thetaDot); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
| @ -103,7 +103,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public StepReply<State> step(Integer action) { |     public StepReply<Box> step(Integer action) { | ||||||
|         double force = action == ACTION_RIGHT ? forceMag : -forceMag; |         double force = action == ACTION_RIGHT ? forceMag : -forceMag; | ||||||
|         double cosTheta = Math.cos(theta); |         double cosTheta = Math.cos(theta); | ||||||
|         double sinTheta = Math.sin(theta); |         double sinTheta = Math.sin(theta); | ||||||
| @ -143,26 +143,12 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre | |||||||
|             reward = 0; |             reward = 0; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         return new StepReply<>(new State(new double[] { x, xDot, theta, thetaDot }), reward, done, null); |         return new StepReply<>(new Box(x, xDot, theta, thetaDot), reward, done, null); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public MDP<State, Integer, DiscreteSpace> newInstance() { |     public MDP<Box, Integer, DiscreteSpace> newInstance() { | ||||||
|         return new CartpoleNative(); |         return new CartpoleNative(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public static class State  implements Encodable { |  | ||||||
| 
 |  | ||||||
|         private final double[] state; |  | ||||||
| 
 |  | ||||||
|         State(double[] state) { |  | ||||||
| 
 |  | ||||||
|             this.state = state; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         @Override |  | ||||||
|         public double[] toArray() { |  | ||||||
|             return state; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } | } | ||||||
|  | |||||||
| @ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy; | |||||||
| 
 | 
 | ||||||
| import lombok.Value; | import lombok.Value; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16. |  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16. | ||||||
| @ -31,4 +32,19 @@ public class HardToyState implements Encodable { | |||||||
|     public double[] toArray() { |     public double[] toArray() { | ||||||
|         return values; |         return values; | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public boolean isSkipped() { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public INDArray getData() { | ||||||
|  |         return null; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public Encodable dup() { | ||||||
|  |         return null; | ||||||
|  |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.learning.NeuralNetFetchable; | |||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
| import org.deeplearning4j.rl4j.space.ArrayObservationSpace; | import org.deeplearning4j.rl4j.space.ArrayObservationSpace; | ||||||
|  | import org.deeplearning4j.rl4j.space.Box; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.ObservationSpace; | import org.deeplearning4j.rl4j.space.ObservationSpace; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| @ -40,7 +41,6 @@ import org.nd4j.linalg.factory.Nd4j; | |||||||
| public class SimpleToy implements MDP<SimpleToyState, Integer, DiscreteSpace> { | public class SimpleToy implements MDP<SimpleToyState, Integer, DiscreteSpace> { | ||||||
| 
 | 
 | ||||||
|     final private int maxStep; |     final private int maxStep; | ||||||
|     //TODO 10 steps toy (always +1 reward2 actions), toylong (1000 steps), toyhard (7 actions, +1 only if actiion = (step/100+step)%7, and toyStoch (like last but reward has 0.10 odd to be somewhere else). |  | ||||||
|     @Getter |     @Getter | ||||||
|     private DiscreteSpace actionSpace = new DiscreteSpace(2); |     private DiscreteSpace actionSpace = new DiscreteSpace(2); | ||||||
|     @Getter |     @Getter | ||||||
|  | |||||||
| @ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy; | |||||||
| 
 | 
 | ||||||
| import lombok.Value; | import lombok.Value; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. |  * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. | ||||||
| @ -28,11 +29,24 @@ public class SimpleToyState implements Encodable { | |||||||
|     int i; |     int i; | ||||||
|     int step; |     int step; | ||||||
| 
 | 
 | ||||||
|     @Override |  | ||||||
|     public double[] toArray() { |     public double[] toArray() { | ||||||
|         double[] ar = new double[1]; |         double[] ar = new double[1]; | ||||||
|         ar[0] = (20 - i); |         ar[0] = (20 - i); | ||||||
|         return ar; |         return ar; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     @Override | ||||||
|  |     public boolean isSkipped() { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public INDArray getData() { | ||||||
|  |         return null; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public Encodable dup() { | ||||||
|  |         return null; | ||||||
|  |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; | |||||||
|  * |  * | ||||||
|  * @author Alexandre Boulanger |  * @author Alexandre Boulanger | ||||||
|  */ |  */ | ||||||
| public class Observation { | public class Observation implements Encodable { | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * A singleton representing a skipped observation |      * A singleton representing a skipped observation | ||||||
| @ -38,6 +38,11 @@ public class Observation { | |||||||
|     @Getter |     @Getter | ||||||
|     private final INDArray data; |     private final INDArray data; | ||||||
| 
 | 
 | ||||||
|  |     @Override | ||||||
|  |     public double[] toArray() { | ||||||
|  |         return data.data().asDouble(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     public boolean isSkipped() { |     public boolean isSkipped() { | ||||||
|         return data == null; |         return data == null; | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -13,29 +13,16 @@ | |||||||
|  * |  * | ||||||
|  * SPDX-License-Identifier: Apache-2.0 |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  ******************************************************************************/ |  ******************************************************************************/ | ||||||
| package org.deeplearning4j.rl4j.observation.transform.legacy; |  | ||||||
| 
 | 
 | ||||||
| import org.bytedeco.javacv.OpenCVFrameConverter; | package org.deeplearning4j.rl4j.observation.transform; | ||||||
| import org.bytedeco.opencv.opencv_core.Mat; | 
 | ||||||
| import org.datavec.api.transform.Operation; | import org.datavec.api.transform.Operation; | ||||||
| import org.datavec.image.data.ImageWritable; |  | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import static org.bytedeco.opencv.global.opencv_core.CV_32FC; |  | ||||||
| 
 | 
 | ||||||
| public class EncodableToINDArrayTransform implements Operation<Encodable, INDArray> { | public class EncodableToINDArrayTransform implements Operation<Encodable, INDArray> { | ||||||
| 
 |  | ||||||
|     private final int[] shape; |  | ||||||
| 
 |  | ||||||
|     public EncodableToINDArrayTransform(int[] shape) { |  | ||||||
|         this.shape = shape; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |     @Override | ||||||
|     public INDArray transform(Encodable encodable) { |     public INDArray transform(Encodable encodable) { | ||||||
|         return Nd4j.create(encodable.toArray()).reshape(shape); |         return encodable.getData(); | ||||||
|     } |     } | ||||||
| 
 |  | ||||||
| } | } | ||||||
| @ -15,34 +15,32 @@ | |||||||
|  ******************************************************************************/ |  ******************************************************************************/ | ||||||
| package org.deeplearning4j.rl4j.observation.transform.legacy; | package org.deeplearning4j.rl4j.observation.transform.legacy; | ||||||
| 
 | 
 | ||||||
|  | import org.bytedeco.javacv.Frame; | ||||||
| import org.bytedeco.javacv.OpenCVFrameConverter; | import org.bytedeco.javacv.OpenCVFrameConverter; | ||||||
| import org.bytedeco.opencv.opencv_core.Mat; | import org.bytedeco.opencv.opencv_core.Mat; | ||||||
| import org.datavec.api.transform.Operation; | import org.datavec.api.transform.Operation; | ||||||
| import org.datavec.image.data.ImageWritable; | import org.datavec.image.data.ImageWritable; | ||||||
|  | import org.datavec.image.loader.NativeImageLoader; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| 
 | 
 | ||||||
| import static org.bytedeco.opencv.global.opencv_core.CV_32FC; | import static org.bytedeco.opencv.global.opencv_core.CV_32FC; | ||||||
|  | import static org.bytedeco.opencv.global.opencv_core.CV_32FC3; | ||||||
|  | import static org.bytedeco.opencv.global.opencv_core.CV_32S; | ||||||
|  | import static org.bytedeco.opencv.global.opencv_core.CV_32SC; | ||||||
|  | import static org.bytedeco.opencv.global.opencv_core.CV_32SC3; | ||||||
|  | import static org.bytedeco.opencv.global.opencv_core.CV_64FC; | ||||||
|  | import static org.bytedeco.opencv.global.opencv_core.CV_8UC3; | ||||||
| 
 | 
 | ||||||
| public class EncodableToImageWritableTransform implements Operation<Encodable, ImageWritable> { | public class EncodableToImageWritableTransform implements Operation<Encodable, ImageWritable> { | ||||||
| 
 | 
 | ||||||
|     private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); |     final static NativeImageLoader nativeImageLoader = new NativeImageLoader(); | ||||||
|     private final int height; |  | ||||||
|     private final int width; |  | ||||||
|     private final int colorChannels; |  | ||||||
| 
 |  | ||||||
|     public EncodableToImageWritableTransform(int height, int width, int colorChannels) { |  | ||||||
|         this.height = height; |  | ||||||
|         this.width = width; |  | ||||||
|         this.colorChannels = colorChannels; |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public ImageWritable transform(Encodable encodable) { |     public ImageWritable transform(Encodable encodable) { | ||||||
|         INDArray indArray = Nd4j.create(encodable.toArray()).reshape(height, width, colorChannels); |         return new ImageWritable(nativeImageLoader.asFrame(encodable.getData(), Frame.DEPTH_UBYTE)); | ||||||
|         Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer()); |  | ||||||
|         return new ImageWritable(converter.convert(mat)); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  | |||||||
| @ -18,34 +18,31 @@ package org.deeplearning4j.rl4j.observation.transform.legacy; | |||||||
| import org.datavec.api.transform.Operation; | import org.datavec.api.transform.Operation; | ||||||
| import org.datavec.image.data.ImageWritable; | import org.datavec.image.data.ImageWritable; | ||||||
| import org.datavec.image.loader.NativeImageLoader; | import org.datavec.image.loader.NativeImageLoader; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.nd4j.linalg.api.buffer.DataType; | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 | 
 | ||||||
| import java.io.IOException; | import java.io.IOException; | ||||||
| 
 | 
 | ||||||
| public class ImageWritableToINDArrayTransform implements Operation<ImageWritable, INDArray> { | public class ImageWritableToINDArrayTransform implements Operation<ImageWritable, INDArray> { | ||||||
| 
 | 
 | ||||||
|     private final int height; |     private final NativeImageLoader loader = new NativeImageLoader(); | ||||||
|     private final int width; |  | ||||||
|     private final NativeImageLoader loader; |  | ||||||
| 
 |  | ||||||
|     public ImageWritableToINDArrayTransform(int height, int width) { |  | ||||||
|         this.height = height; |  | ||||||
|         this.width = width; |  | ||||||
|         this.loader = new NativeImageLoader(height, width); |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public INDArray transform(ImageWritable imageWritable) { |     public INDArray transform(ImageWritable imageWritable) { | ||||||
|  | 
 | ||||||
|  |         int height = imageWritable.getHeight(); | ||||||
|  |         int width = imageWritable.getWidth(); | ||||||
|  |         int channels = imageWritable.getFrame().imageChannels; | ||||||
|  | 
 | ||||||
|         INDArray out = null; |         INDArray out = null; | ||||||
|         try { |         try { | ||||||
|             out = loader.asMatrix(imageWritable); |             out = loader.asMatrix(imageWritable); | ||||||
|         } catch (IOException e) { |         } catch (IOException e) { | ||||||
|             e.printStackTrace(); |             e.printStackTrace(); | ||||||
|         } |         } | ||||||
|         out = out.reshape(1, height, width); | 
 | ||||||
|  |         // Convert back to uint8 and reshape to the number of channels in the image | ||||||
|  |         out = out.reshape(channels, height, width); | ||||||
|         INDArray compressed = out.castTo(DataType.UINT8); |         INDArray compressed = out.castTo(DataType.UINT8); | ||||||
|         return compressed; |         return compressed; | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -46,19 +46,20 @@ public class HistoryMergeTransform implements Operation<INDArray, INDArray>, Res | |||||||
|     private final HistoryMergeElementStore historyMergeElementStore; |     private final HistoryMergeElementStore historyMergeElementStore; | ||||||
|     private final HistoryMergeAssembler historyMergeAssembler; |     private final HistoryMergeAssembler historyMergeAssembler; | ||||||
|     private final boolean shouldStoreCopy; |     private final boolean shouldStoreCopy; | ||||||
|     private final boolean isFirstDimenstionBatch; |     private final boolean isFirstDimensionBatch; | ||||||
| 
 | 
 | ||||||
|     private HistoryMergeTransform(Builder builder) { |     private HistoryMergeTransform(Builder builder) { | ||||||
|         this.historyMergeElementStore = builder.historyMergeElementStore; |         this.historyMergeElementStore = builder.historyMergeElementStore; | ||||||
|         this.historyMergeAssembler = builder.historyMergeAssembler; |         this.historyMergeAssembler = builder.historyMergeAssembler; | ||||||
|         this.shouldStoreCopy = builder.shouldStoreCopy; |         this.shouldStoreCopy = builder.shouldStoreCopy; | ||||||
|         this.isFirstDimenstionBatch = builder.isFirstDimenstionBatch; |         this.isFirstDimensionBatch = builder.isFirstDimenstionBatch; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public INDArray transform(INDArray input) { |     public INDArray transform(INDArray input) { | ||||||
|  | 
 | ||||||
|         INDArray element; |         INDArray element; | ||||||
|         if(isFirstDimenstionBatch) { |         if(isFirstDimensionBatch) { | ||||||
|             element = input.slice(0, 0); |             element = input.slice(0, 0); | ||||||
|         } |         } | ||||||
|         else { |         else { | ||||||
| @ -132,9 +133,9 @@ public class HistoryMergeTransform implements Operation<INDArray, INDArray>, Res | |||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         public HistoryMergeTransform build() { |         public HistoryMergeTransform build(int frameStackLength) { | ||||||
|             if(historyMergeElementStore == null) { |             if(historyMergeElementStore == null) { | ||||||
|                 historyMergeElementStore = new CircularFifoStore(); |                 historyMergeElementStore = new CircularFifoStore(frameStackLength); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             if(historyMergeAssembler == null) { |             if(historyMergeAssembler == null) { | ||||||
|  | |||||||
| @ -28,14 +28,9 @@ import org.nd4j.linalg.factory.Nd4j; | |||||||
|  * @author Alexandre Boulanger |  * @author Alexandre Boulanger | ||||||
|  */ |  */ | ||||||
| public class CircularFifoStore implements HistoryMergeElementStore { | public class CircularFifoStore implements HistoryMergeElementStore { | ||||||
|     private static final int DEFAULT_STORE_SIZE = 4; |  | ||||||
| 
 | 
 | ||||||
|     private final CircularFifoQueue<INDArray> queue; |     private final CircularFifoQueue<INDArray> queue; | ||||||
| 
 | 
 | ||||||
|     public CircularFifoStore() { |  | ||||||
|         this(DEFAULT_STORE_SIZE); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public CircularFifoStore(int size) { |     public CircularFifoStore(int size) { | ||||||
|         Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size); |         Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size); | ||||||
|         queue = new CircularFifoQueue<>(size); |         queue = new CircularFifoQueue<>(size); | ||||||
|  | |||||||
| @ -20,8 +20,8 @@ import org.deeplearning4j.rl4j.learning.Learning; | |||||||
| import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph; | import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph; | ||||||
| import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate; | import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate; | ||||||
| import org.deeplearning4j.rl4j.network.ac.IActorCritic; | import org.deeplearning4j.rl4j.network.ac.IActorCritic; | ||||||
| import org.deeplearning4j.rl4j.observation.Observation; |  | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.deeplearning4j.rl4j.observation.Observation; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.rng.Random; | import org.nd4j.linalg.api.rng.Random; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| @ -35,7 +35,7 @@ import java.io.IOException; | |||||||
|  * the softmax output of the actor critic, but objects constructed |  * the softmax output of the actor critic, but objects constructed | ||||||
|  * with a {@link Random} argument of null return the max only. |  * with a {@link Random} argument of null return the max only. | ||||||
|  */ |  */ | ||||||
| public class ACPolicy<O extends Encodable> extends Policy<Integer> { | public class ACPolicy<OBSERVATION extends Encodable> extends Policy<Integer> { | ||||||
| 
 | 
 | ||||||
|     final private IActorCritic actorCritic; |     final private IActorCritic actorCritic; | ||||||
|     Random rnd; |     Random rnd; | ||||||
| @ -48,18 +48,18 @@ public class ACPolicy<O extends Encodable> extends Policy<Integer> { | |||||||
|         this.rnd = rnd; |         this.rnd = rnd; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public static <O extends Encodable> ACPolicy<O> load(String path) throws IOException { |     public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String path) throws IOException { | ||||||
|         return new ACPolicy<O>(ActorCriticCompGraph.load(path)); |         return new ACPolicy<>(ActorCriticCompGraph.load(path)); | ||||||
|     } |     } | ||||||
|     public static <O extends Encodable> ACPolicy<O> load(String path, Random rnd) throws IOException { |     public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String path, Random rnd) throws IOException { | ||||||
|         return new ACPolicy<O>(ActorCriticCompGraph.load(path), rnd); |         return new ACPolicy<>(ActorCriticCompGraph.load(path), rnd); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy) throws IOException { |     public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String pathValue, String pathPolicy) throws IOException { | ||||||
|         return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy)); |         return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy)); | ||||||
|     } |     } | ||||||
|     public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rnd) throws IOException { |     public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String pathValue, String pathPolicy, Random rnd) throws IOException { | ||||||
|         return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd); |         return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public IActorCritic getNeuralNet() { |     public IActorCritic getNeuralNet() { | ||||||
|  | |||||||
| @ -17,8 +17,8 @@ | |||||||
| package org.deeplearning4j.rl4j.policy; | package org.deeplearning4j.rl4j.policy; | ||||||
| 
 | 
 | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
| import org.deeplearning4j.rl4j.observation.Observation; |  | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.deeplearning4j.rl4j.observation.Observation; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.rng.Random; | import org.nd4j.linalg.api.rng.Random; | ||||||
| 
 | 
 | ||||||
| @ -30,7 +30,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp; | |||||||
|  * Boltzmann exploration is a stochastic policy wrt to the |  * Boltzmann exploration is a stochastic policy wrt to the | ||||||
|  * exponential Q-values as evaluated by the dqn model. |  * exponential Q-values as evaluated by the dqn model. | ||||||
|  */ |  */ | ||||||
| public class BoltzmannQ<O extends Encodable> extends Policy<Integer> { | public class BoltzmannQ<OBSERVATION extends Encodable> extends Policy<Integer> { | ||||||
| 
 | 
 | ||||||
|     final private IDQN dqn; |     final private IDQN dqn; | ||||||
|     final private Random rnd; |     final private Random rnd; | ||||||
|  | |||||||
| @ -20,8 +20,8 @@ import lombok.AllArgsConstructor; | |||||||
| import org.deeplearning4j.rl4j.learning.Learning; | import org.deeplearning4j.rl4j.learning.Learning; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.DQN; | import org.deeplearning4j.rl4j.network.dqn.DQN; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
| import org.deeplearning4j.rl4j.observation.Observation; |  | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.deeplearning4j.rl4j.observation.Observation; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| 
 | 
 | ||||||
| import java.io.IOException; | import java.io.IOException; | ||||||
| @ -35,12 +35,12 @@ import java.io.IOException; | |||||||
| 
 | 
 | ||||||
| // FIXME: Should we rename this "GreedyPolicy"? | // FIXME: Should we rename this "GreedyPolicy"? | ||||||
| @AllArgsConstructor | @AllArgsConstructor | ||||||
| public class DQNPolicy<O> extends Policy<Integer> { | public class DQNPolicy<OBSERVATION> extends Policy<Integer> { | ||||||
| 
 | 
 | ||||||
|     final private IDQN dqn; |     final private IDQN dqn; | ||||||
| 
 | 
 | ||||||
|     public static <O extends Encodable> DQNPolicy<O> load(String path) throws IOException { |     public static <OBSERVATION extends Encodable> DQNPolicy<OBSERVATION> load(String path) throws IOException { | ||||||
|         return new DQNPolicy<O>(DQN.load(path)); |         return new DQNPolicy<>(DQN.load(path)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public IDQN getNeuralNet() { |     public IDQN getNeuralNet() { | ||||||
|  | |||||||
| @ -20,12 +20,11 @@ package org.deeplearning4j.rl4j.policy; | |||||||
| import lombok.AllArgsConstructor; | import lombok.AllArgsConstructor; | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.deeplearning4j.rl4j.learning.IEpochTrainer; | import org.deeplearning4j.rl4j.learning.IEpochTrainer; | ||||||
| import org.deeplearning4j.rl4j.learning.ILearning; |  | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.NeuralNet; | import org.deeplearning4j.rl4j.network.NeuralNet; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.observation.Observation; | import org.deeplearning4j.rl4j.observation.Observation; | ||||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | import org.deeplearning4j.rl4j.space.ActionSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.rng.Random; | import org.nd4j.linalg.api.rng.Random; | ||||||
| 
 | 
 | ||||||
| @ -41,10 +40,10 @@ import org.nd4j.linalg.api.rng.Random; | |||||||
|  */ |  */ | ||||||
| @AllArgsConstructor | @AllArgsConstructor | ||||||
| @Slf4j | @Slf4j | ||||||
| public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> { | public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> { | ||||||
| 
 | 
 | ||||||
|     final private Policy<A> policy; |     final private Policy<A> policy; | ||||||
|     final private MDP<O, A, AS> mdp; |     final private MDP<OBSERVATION, A, AS> mdp; | ||||||
|     final private int updateStart; |     final private int updateStart; | ||||||
|     final private int epsilonNbStep; |     final private int epsilonNbStep; | ||||||
|     final private Random rnd; |     final private Random rnd; | ||||||
|  | |||||||
| @ -22,9 +22,9 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | |||||||
| import org.deeplearning4j.rl4j.learning.Learning; | import org.deeplearning4j.rl4j.learning.Learning; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.NeuralNet; | import org.deeplearning4j.rl4j.network.NeuralNet; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.observation.Observation; | import org.deeplearning4j.rl4j.observation.Observation; | ||||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | import org.deeplearning4j.rl4j.space.ActionSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  | |||||||
| @ -7,22 +7,22 @@ import org.datavec.image.transform.ColorConversionTransform; | |||||||
| import org.datavec.image.transform.CropImageTransform; | import org.datavec.image.transform.CropImageTransform; | ||||||
| import org.datavec.image.transform.MultiImageTransform; | import org.datavec.image.transform.MultiImageTransform; | ||||||
| import org.datavec.image.transform.ResizeImageTransform; | import org.datavec.image.transform.ResizeImageTransform; | ||||||
|  | import org.datavec.image.transform.ShowImageTransform; | ||||||
| import org.deeplearning4j.gym.StepReply; | import org.deeplearning4j.gym.StepReply; | ||||||
| import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
|  | import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.observation.Observation; | import org.deeplearning4j.rl4j.observation.Observation; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.TransformProcess; | import org.deeplearning4j.rl4j.observation.transform.TransformProcess; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; | import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform; |  | ||||||
| import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform; | import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform; | import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; | import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; | import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; | ||||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | import org.deeplearning4j.rl4j.space.ActionSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.deeplearning4j.rl4j.space.ObservationSpace; | import org.deeplearning4j.rl4j.space.ObservationSpace; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 | 
 | ||||||
| import java.util.HashMap; | import java.util.HashMap; | ||||||
| import java.util.Map; | import java.util.Map; | ||||||
| @ -46,6 +46,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio | |||||||
|     private int skipFrame = 1; |     private int skipFrame = 1; | ||||||
|     private int steps = 0; |     private int steps = 0; | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|     public LegacyMDPWrapper(MDP<OBSERVATION, A, AS> wrappedMDP, IHistoryProcessor historyProcessor) { |     public LegacyMDPWrapper(MDP<OBSERVATION, A, AS> wrappedMDP, IHistoryProcessor historyProcessor) { | ||||||
|         this.wrappedMDP = wrappedMDP; |         this.wrappedMDP = wrappedMDP; | ||||||
|         this.shape = wrappedMDP.getObservationSpace().getShape(); |         this.shape = wrappedMDP.getObservationSpace().getShape(); | ||||||
| @ -66,28 +67,33 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio | |||||||
| 
 | 
 | ||||||
|         if(historyProcessor != null && shape.length == 3) { |         if(historyProcessor != null && shape.length == 3) { | ||||||
|             int skipFrame = historyProcessor.getConf().getSkipFrame(); |             int skipFrame = historyProcessor.getConf().getSkipFrame(); | ||||||
|  |             int frameStackLength = historyProcessor.getConf().getHistoryLength(); | ||||||
| 
 | 
 | ||||||
|             int finalHeight = historyProcessor.getConf().getCroppingHeight(); |             int height = shape[1]; | ||||||
|             int finalWidth = historyProcessor.getConf().getCroppingWidth(); |             int width = shape[2]; | ||||||
|  | 
 | ||||||
|  |             int cropBottom = height - historyProcessor.getConf().getCroppingHeight(); | ||||||
|  |             int cropRight = width - historyProcessor.getConf().getCroppingWidth(); | ||||||
| 
 | 
 | ||||||
|             transformProcess = TransformProcess.builder() |             transformProcess = TransformProcess.builder() | ||||||
|                     .filter(new UniformSkippingFilter(skipFrame)) |                     .filter(new UniformSkippingFilter(skipFrame)) | ||||||
|                     .transform("data", new EncodableToImageWritableTransform(shape[0], shape[1], shape[2])) |                     .transform("data", new EncodableToImageWritableTransform()) | ||||||
|                     .transform("data", new MultiImageTransform( |                     .transform("data", new MultiImageTransform( | ||||||
|  |                             new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), cropBottom, cropRight), | ||||||
|                             new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()), |                             new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()), | ||||||
|                             new ColorConversionTransform(COLOR_BGR2GRAY), |                             new ColorConversionTransform(COLOR_BGR2GRAY) | ||||||
|                             new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), finalHeight, finalWidth) |                             //new ShowImageTransform("crop + resize + greyscale") | ||||||
|                     )) |                     )) | ||||||
|                     .transform("data", new ImageWritableToINDArrayTransform(finalHeight, finalWidth)) |                     .transform("data", new ImageWritableToINDArrayTransform()) | ||||||
|                     .transform("data", new SimpleNormalizationTransform(0.0, 255.0)) |                     .transform("data", new SimpleNormalizationTransform(0.0, 255.0)) | ||||||
|                     .transform("data", HistoryMergeTransform.builder() |                     .transform("data", HistoryMergeTransform.builder() | ||||||
|                             .isFirstDimenstionBatch(true) |                             .isFirstDimenstionBatch(true) | ||||||
|                             .build()) |                             .build(frameStackLength)) | ||||||
|                     .build("data"); |                     .build("data"); | ||||||
|         } |         } | ||||||
|         else { |         else { | ||||||
|             transformProcess = TransformProcess.builder() |             transformProcess = TransformProcess.builder() | ||||||
|                     .transform("data", new EncodableToINDArrayTransform(shape)) |                     .transform("data", new EncodableToINDArrayTransform()) | ||||||
|                     .build("data"); |                     .build("data"); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @ -127,6 +133,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio | |||||||
| 
 | 
 | ||||||
|         Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation()); |         Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation()); | ||||||
|         Observation observation =  transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone()); |         Observation observation =  transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone()); | ||||||
|  | 
 | ||||||
|         return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); |         return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -161,12 +168,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     private INDArray getInput(OBSERVATION obs) { |     private INDArray getInput(OBSERVATION obs) { | ||||||
|         INDArray arr = Nd4j.create(obs.toArray()); |         return obs.getData(); | ||||||
|         int[] shape = observationSpace.getShape(); |  | ||||||
|         if (shape.length == 1) |  | ||||||
|             return arr.reshape(new long[] {1, arr.length()}); |  | ||||||
|         else |  | ||||||
|             return arr.reshape(shape); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public static class WrapperObservationSpace implements ObservationSpace<Observation> { |     public static class WrapperObservationSpace implements ObservationSpace<Observation> { | ||||||
|  | |||||||
| @ -16,26 +16,21 @@ | |||||||
| 
 | 
 | ||||||
| package org.deeplearning4j.rl4j.util; | package org.deeplearning4j.rl4j.util; | ||||||
| 
 | 
 | ||||||
| import lombok.Getter; |  | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.bytedeco.javacpp.BytePointer; |  | ||||||
| import org.bytedeco.javacpp.Pointer; |  | ||||||
| import org.bytedeco.javacv.FFmpegFrameRecorder; | import org.bytedeco.javacv.FFmpegFrameRecorder; | ||||||
| import org.bytedeco.javacv.Frame; | import org.bytedeco.javacv.Frame; | ||||||
| import org.bytedeco.javacv.OpenCVFrameConverter; | import org.datavec.image.loader.NativeImageLoader; | ||||||
| import org.bytedeco.opencv.global.opencv_core; | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| import org.bytedeco.opencv.global.opencv_imgproc; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.bytedeco.opencv.opencv_core.Mat; |  | ||||||
| import org.bytedeco.opencv.opencv_core.Rect; |  | ||||||
| import org.bytedeco.opencv.opencv_core.Size; |  | ||||||
| import org.opencv.imgproc.Imgproc; |  | ||||||
| 
 | 
 | ||||||
| import static org.bytedeco.ffmpeg.global.avcodec.*; | import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_H264; | ||||||
| import static org.bytedeco.opencv.global.opencv_core.*; | import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_MPEG4; | ||||||
|  | import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB0; | ||||||
|  | import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB24; | ||||||
|  | import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB8; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * VideoRecorder is used to create a video from a sequence of individual frames. If using 3 channels |  * VideoRecorder is used to create a video from a sequence of INDArray frames. INDArrays are assumed to be in CHW format where C=3 and pixels are RGB encoded<br> | ||||||
|  * images, it expects B-G-R order. A RGB order can be used by calling isRGBOrder(true).<br> |  | ||||||
|  * Example:<br> |  * Example:<br> | ||||||
|  * <pre> |  * <pre> | ||||||
|  * {@code |  * {@code | ||||||
| @ -45,11 +40,8 @@ import static org.bytedeco.opencv.global.opencv_core.*; | |||||||
|  *             .build(); |  *             .build(); | ||||||
|  *         recorder.startRecording("myVideo.mp4"); |  *         recorder.startRecording("myVideo.mp4"); | ||||||
|  *         while(...) { |  *         while(...) { | ||||||
|  *             byte[] data = new byte[160*100*3]; |  *             INDArray chwData = Nd4j.create() | ||||||
|  *             // Todo: Fill data |  *             recorder.record(chwData); | ||||||
|  *             VideoRecorder.VideoFrame frame = recorder.createFrame(data); |  | ||||||
|  *             // Todo: Apply cropping or resizing to frame |  | ||||||
|  *             recorder.record(frame); |  | ||||||
|  *         } |  *         } | ||||||
|  *         recorder.stopRecording(); |  *         recorder.stopRecording(); | ||||||
|  * } |  * } | ||||||
| @ -60,16 +52,13 @@ import static org.bytedeco.opencv.global.opencv_core.*; | |||||||
| @Slf4j | @Slf4j | ||||||
| public class VideoRecorder implements AutoCloseable { | public class VideoRecorder implements AutoCloseable { | ||||||
| 
 | 
 | ||||||
|     public enum FrameInputTypes { BGR, RGB, Float } |     private final NativeImageLoader nativeImageLoader = new NativeImageLoader(); | ||||||
| 
 | 
 | ||||||
|     private final int height; |     private final int height; | ||||||
|     private final int width; |     private final int width; | ||||||
|     private final int imageType; |  | ||||||
|     private final OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat(); |  | ||||||
|     private final int codec; |     private final int codec; | ||||||
|     private final double framerate; |     private final double framerate; | ||||||
|     private final int videoQuality; |     private final int videoQuality; | ||||||
|     private final FrameInputTypes frameInputType; |  | ||||||
| 
 | 
 | ||||||
|     private FFmpegFrameRecorder fmpegFrameRecorder = null; |     private FFmpegFrameRecorder fmpegFrameRecorder = null; | ||||||
| 
 | 
 | ||||||
| @ -83,11 +72,9 @@ public class VideoRecorder implements AutoCloseable { | |||||||
|     private VideoRecorder(Builder builder) { |     private VideoRecorder(Builder builder) { | ||||||
|         this.height = builder.height; |         this.height = builder.height; | ||||||
|         this.width = builder.width; |         this.width = builder.width; | ||||||
|         imageType = CV_8UC(builder.numChannels); |  | ||||||
|         codec = builder.codec; |         codec = builder.codec; | ||||||
|         framerate = builder.frameRate; |         framerate = builder.frameRate; | ||||||
|         videoQuality = builder.videoQuality; |         videoQuality = builder.videoQuality; | ||||||
|         frameInputType = builder.frameInputType; |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
| @ -119,59 +106,11 @@ public class VideoRecorder implements AutoCloseable { | |||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * Add a frame to the video |      * Add a frame to the video | ||||||
|      * @param frame the VideoFrame to add to the video |      * @param imageArray the INDArray that contains the data to be recorded, the data must be in CHW format | ||||||
|      * @throws Exception |      * @throws Exception | ||||||
|      */ |      */ | ||||||
|     public void record(VideoFrame frame) throws Exception { |     public void record(INDArray imageArray) throws Exception { | ||||||
|         Size size = frame.getMat().size(); |         fmpegFrameRecorder.record(nativeImageLoader.asFrame(imageArray, Frame.DEPTH_UBYTE)); | ||||||
|         if(size.height() != height || size.width() != width) { |  | ||||||
|             throw new IllegalArgumentException(String.format("Wrong frame size. Got (%dh x %dw) expected (%dh x %dw)", size.height(), size.width(), height, width)); |  | ||||||
|         } |  | ||||||
|         Frame cvFrame = openCVFrameConverter.convert(frame.getMat()); |  | ||||||
|         fmpegFrameRecorder.record(cvFrame); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Create a VideoFrame from a byte array. |  | ||||||
|      * @param data A byte array. Expect the index to be of the form [(Y*Width + X) * NumChannels + channel] |  | ||||||
|      * @return An instance of VideoFrame |  | ||||||
|      */ |  | ||||||
|     public VideoFrame createFrame(byte[] data) { |  | ||||||
|         return createFrame(new BytePointer(data)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Create a VideoFrame from a byte array with different height and width than the video |  | ||||||
|      * the frame will need to be cropped or resized before being added to the video) |  | ||||||
|      * |  | ||||||
|      * @param data A byte array Expect the index to be of the form [(Y*customWidth + X) * NumChannels + channel] |  | ||||||
|      * @param customHeight The actual height of the data |  | ||||||
|      * @param customWidth The actual width of the data |  | ||||||
|      * @return A VideoFrame instance |  | ||||||
|      */ |  | ||||||
|     public VideoFrame createFrame(byte[] data, int customHeight, int customWidth) { |  | ||||||
|         return createFrame(new BytePointer(data), customHeight, customWidth); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Create a VideoFrame from a Pointer (to use for example with a INDarray). |  | ||||||
|      * @param data A Pointer (for example myINDArray.data().pointer()) |  | ||||||
|      * @return An instance of VideoFrame |  | ||||||
|      */ |  | ||||||
|     public VideoFrame createFrame(Pointer data) { |  | ||||||
|         return new VideoFrame(height, width, imageType, frameInputType, data); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      *  Create a VideoFrame from a Pointer with different height and width than the video |  | ||||||
|      * the frame will need to be cropped or resized before being added to the video) |  | ||||||
|      * @param data |  | ||||||
|      * @param customHeight The actual height of the data |  | ||||||
|      * @param customWidth The actual width of the data |  | ||||||
|      * @return A VideoFrame instance |  | ||||||
|      */ |  | ||||||
|     public VideoFrame createFrame(Pointer data, int customHeight, int customWidth) { |  | ||||||
|         return new VideoFrame(customHeight, customWidth, imageType, frameInputType, data); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
| @ -192,69 +131,12 @@ public class VideoRecorder implements AutoCloseable { | |||||||
|         return new Builder(height, width); |         return new Builder(height, width); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |  | ||||||
|      * An individual frame for the video |  | ||||||
|      */ |  | ||||||
|     public static class VideoFrame { |  | ||||||
| 
 |  | ||||||
|         private final int height; |  | ||||||
|         private final int width; |  | ||||||
|         private final int imageType; |  | ||||||
|         @Getter |  | ||||||
|         private Mat mat; |  | ||||||
| 
 |  | ||||||
|         private VideoFrame(int height, int width, int imageType, FrameInputTypes frameInputType, Pointer data) { |  | ||||||
|             this.height = height; |  | ||||||
|             this.width = width; |  | ||||||
|             this.imageType = imageType; |  | ||||||
| 
 |  | ||||||
|             switch(frameInputType) { |  | ||||||
|                 case RGB: |  | ||||||
|                     Mat src = new Mat(height, width, imageType, data); |  | ||||||
|                     mat = new Mat(height, width, imageType); |  | ||||||
|                     opencv_imgproc.cvtColor(src, mat, Imgproc.COLOR_RGB2BGR); |  | ||||||
|                     break; |  | ||||||
| 
 |  | ||||||
|                 case BGR: |  | ||||||
|                     mat = new Mat(height, width, imageType, data); |  | ||||||
|                     break; |  | ||||||
| 
 |  | ||||||
|                 case Float: |  | ||||||
|                     Mat tmpMat = new Mat(height, width, CV_32FC(3), data); |  | ||||||
|                     mat = new Mat(height, width, imageType); |  | ||||||
|                     tmpMat.convertTo(mat, CV_8UC(3), 255.0, 0.0); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         /** |  | ||||||
|          * Crop the video to a specified size |  | ||||||
|          * @param newHeight The new height of the frame |  | ||||||
|          * @param newWidth The new width of the frame |  | ||||||
|          * @param heightOffset The starting height offset in the uncropped frame |  | ||||||
|          * @param widthOffset The starting weight offset in the uncropped frame |  | ||||||
|          */ |  | ||||||
|         public void crop(int newHeight, int newWidth, int heightOffset, int widthOffset) { |  | ||||||
|             mat = mat.apply(new Rect(widthOffset, heightOffset, newWidth, newHeight)); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         /** |  | ||||||
|          * Resize the frame to a specified size |  | ||||||
|          * @param newHeight The new height of the frame |  | ||||||
|          * @param newWidth The new width of the frame |  | ||||||
|          */ |  | ||||||
|         public void resize(int newHeight, int newWidth) { |  | ||||||
|             mat = new Mat(newHeight, newWidth, imageType); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |     /** | ||||||
|      * A builder class for the VideoRecorder |      * A builder class for the VideoRecorder | ||||||
|      */ |      */ | ||||||
|     public static class Builder { |     public static class Builder { | ||||||
|         private final int height; |         private final int height; | ||||||
|         private final int width; |         private final int width; | ||||||
|         private int numChannels = 3; |  | ||||||
|         private FrameInputTypes frameInputType = FrameInputTypes.BGR; |  | ||||||
|         private int codec = AV_CODEC_ID_H264; |         private int codec = AV_CODEC_ID_H264; | ||||||
|         private double frameRate = 30.0; |         private double frameRate = 30.0; | ||||||
|         private int videoQuality = 30; |         private int videoQuality = 30; | ||||||
| @ -268,24 +150,6 @@ public class VideoRecorder implements AutoCloseable { | |||||||
|             this.width = width; |             this.width = width; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         /** |  | ||||||
|          * Specify the number of channels. Default is 3 |  | ||||||
|          * @param numChannels |  | ||||||
|          */ |  | ||||||
|         public Builder numChannels(int numChannels) { |  | ||||||
|             this.numChannels = numChannels; |  | ||||||
|             return this; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         /** |  | ||||||
|          * Tell the VideoRecorder what data it will receive (default is BGR) |  | ||||||
|          * @param frameInputType (See {@link FrameInputTypes}} |  | ||||||
|          */ |  | ||||||
|         public Builder frameInputType(FrameInputTypes frameInputType) { |  | ||||||
|             this.frameInputType = frameInputType; |  | ||||||
|             return this; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         /** |         /** | ||||||
|          * The codec to use for the video. Default is AV_CODEC_ID_H264 |          * The codec to use for the video. Default is AV_CODEC_ID_H264 | ||||||
|          * @param codec Code (see {@link org.bytedeco.ffmpeg.global.avcodec codec codes}) |          * @param codec Code (see {@link org.bytedeco.ffmpeg.global.avcodec codec codes}) | ||||||
|  | |||||||
| @ -115,7 +115,7 @@ public class AsyncThreadDiscreteTest { | |||||||
| 
 | 
 | ||||||
|         asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm); |         asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm); | ||||||
| 
 | 
 | ||||||
|         when(asyncThreadDiscrete.getConf()).thenReturn(mockAsyncConfiguration); |         when(asyncThreadDiscrete.getConfiguration()).thenReturn(mockAsyncConfiguration); | ||||||
|         when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0); |         when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0); | ||||||
|         when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal); |         when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal); | ||||||
|         when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy); |         when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy); | ||||||
|  | |||||||
| @ -39,7 +39,6 @@ import static org.junit.Assert.assertEquals; | |||||||
| import static org.mockito.ArgumentMatchers.any; | import static org.mockito.ArgumentMatchers.any; | ||||||
| import static org.mockito.ArgumentMatchers.anyInt; | import static org.mockito.ArgumentMatchers.anyInt; | ||||||
| import static org.mockito.ArgumentMatchers.eq; | import static org.mockito.ArgumentMatchers.eq; | ||||||
| import static org.mockito.Mockito.clearInvocations; |  | ||||||
| import static org.mockito.Mockito.doAnswer; | import static org.mockito.Mockito.doAnswer; | ||||||
| import static org.mockito.Mockito.mock; | import static org.mockito.Mockito.mock; | ||||||
| import static org.mockito.Mockito.times; | import static org.mockito.Mockito.times; | ||||||
| @ -130,7 +129,7 @@ public class AsyncThreadTest { | |||||||
| 
 | 
 | ||||||
|         when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode); |         when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode); | ||||||
|         when(mockAsyncConfiguration.getNStep()).thenReturn(nstep); |         when(mockAsyncConfiguration.getNStep()).thenReturn(nstep); | ||||||
|         when(thread.getConf()).thenReturn(mockAsyncConfiguration); |         when(thread.getConfiguration()).thenReturn(mockAsyncConfiguration); | ||||||
| 
 | 
 | ||||||
|         // if we hit the max step count |         // if we hit the max step count | ||||||
|         when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps); |         when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps); | ||||||
|  | |||||||
| @ -18,24 +18,16 @@ | |||||||
| package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; | package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; | ||||||
| 
 | 
 | ||||||
| import org.deeplearning4j.gym.StepReply; | import org.deeplearning4j.gym.StepReply; | ||||||
| import org.deeplearning4j.rl4j.experience.ExperienceHandler; |  | ||||||
| import org.deeplearning4j.rl4j.experience.StateActionPair; |  | ||||||
| import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | ||||||
| import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; | import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; | ||||||
| import org.deeplearning4j.rl4j.learning.sync.IExpReplay; |  | ||||||
| import org.deeplearning4j.rl4j.learning.sync.Transition; |  | ||||||
| import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; | import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.observation.Observation; | import org.deeplearning4j.rl4j.observation.Observation; | ||||||
| import org.deeplearning4j.rl4j.space.Box; | import org.deeplearning4j.rl4j.space.Box; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.deeplearning4j.rl4j.space.ObservationSpace; | import org.deeplearning4j.rl4j.space.ObservationSpace; | ||||||
| import org.deeplearning4j.rl4j.support.*; |  | ||||||
| import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; |  | ||||||
| import org.deeplearning4j.rl4j.util.IDataManager; |  | ||||||
| import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; |  | ||||||
| import org.junit.Before; | import org.junit.Before; | ||||||
| import org.junit.Test; | import org.junit.Test; | ||||||
| import org.junit.runner.RunWith; | import org.junit.runner.RunWith; | ||||||
| @ -43,17 +35,17 @@ import org.mockito.Mock; | |||||||
| import org.mockito.Mockito; | import org.mockito.Mockito; | ||||||
| import org.mockito.junit.MockitoJUnitRunner; | import org.mockito.junit.MockitoJUnitRunner; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.dataset.api.DataSet; |  | ||||||
| import org.nd4j.linalg.api.rng.Random; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| 
 | 
 | ||||||
| import java.util.ArrayList; | import static org.junit.Assert.assertEquals; | ||||||
| import java.util.List; | import static org.junit.Assert.assertFalse; | ||||||
| 
 | import static org.junit.Assert.assertTrue; | ||||||
| import static org.junit.Assert.*; | import static org.mockito.ArgumentMatchers.any; | ||||||
| import static org.mockito.ArgumentMatchers.anyInt; | import static org.mockito.ArgumentMatchers.anyInt; | ||||||
| import static org.mockito.ArgumentMatchers.eq; | import static org.mockito.ArgumentMatchers.eq; | ||||||
| import static org.mockito.Mockito.mock; | import static org.mockito.Mockito.mock; | ||||||
|  | import static org.mockito.Mockito.never; | ||||||
|  | import static org.mockito.Mockito.verify; | ||||||
| import static org.mockito.Mockito.when; | import static org.mockito.Mockito.when; | ||||||
| 
 | 
 | ||||||
| @RunWith(MockitoJUnitRunner.class) | @RunWith(MockitoJUnitRunner.class) | ||||||
| @ -82,6 +74,7 @@ public class QLearningDiscreteTest { | |||||||
|     @Mock |     @Mock | ||||||
|     QLearningConfiguration mockQlearningConfiguration; |     QLearningConfiguration mockQlearningConfiguration; | ||||||
| 
 | 
 | ||||||
|  |     // HWC | ||||||
|     int[] observationShape = new int[]{3, 10, 10}; |     int[] observationShape = new int[]{3, 10, 10}; | ||||||
|     int totalObservationSize = 1; |     int totalObservationSize = 1; | ||||||
| 
 | 
 | ||||||
| @ -123,6 +116,7 @@ public class QLearningDiscreteTest { | |||||||
|         when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]); |         when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]); | ||||||
|         when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]); |         when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]); | ||||||
|         when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames); |         when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames); | ||||||
|  |         when(mockHistoryConfiguration.getHistoryLength()).thenReturn(1); | ||||||
|         when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration); |         when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration); | ||||||
| 
 | 
 | ||||||
|         qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor); |         qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor); | ||||||
| @ -148,7 +142,7 @@ public class QLearningDiscreteTest { | |||||||
|         Observation observation = new Observation(Nd4j.zeros(observationShape)); |         Observation observation = new Observation(Nd4j.zeros(observationShape)); | ||||||
|         when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f})); |         when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f})); | ||||||
| 
 | 
 | ||||||
|         when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null)); |         when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Observation(Nd4j.zeros(observationShape)), 0, false, null)); | ||||||
| 
 | 
 | ||||||
|         // Act |         // Act | ||||||
|         QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation); |         QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation); | ||||||
| @ -170,25 +164,26 @@ public class QLearningDiscreteTest { | |||||||
|         // Arrange |         // Arrange | ||||||
|         mockTestContext(100,0,2,1.0, 10); |         mockTestContext(100,0,2,1.0, 10); | ||||||
| 
 | 
 | ||||||
|         mockHistoryProcessor(2); |         Observation skippedObservation = Observation.SkippedObservation; | ||||||
|  |         Observation nextObservation = new Observation(Nd4j.zeros(observationShape)); | ||||||
| 
 | 
 | ||||||
|         // An example observation and 2 Q values output (2 actions) |         when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(nextObservation, 0, false, null)); | ||||||
|         Observation observation = new Observation(Nd4j.zeros(observationShape)); |  | ||||||
|         when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f})); |  | ||||||
| 
 |  | ||||||
|         when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null)); |  | ||||||
| 
 | 
 | ||||||
|         // Act |         // Act | ||||||
|         QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation); |         QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(skippedObservation); | ||||||
| 
 | 
 | ||||||
|         // Assert |         // Assert | ||||||
|         assertEquals(1.0, stepReturn.getMaxQ(), 1e-5); |         assertEquals(Double.NaN, stepReturn.getMaxQ(), 1e-5); | ||||||
| 
 | 
 | ||||||
|         StepReply<Observation> stepReply = stepReturn.getStepReply(); |         StepReply<Observation> stepReply = stepReturn.getStepReply(); | ||||||
| 
 | 
 | ||||||
|         assertEquals(0, stepReply.getReward(), 1e-5); |         assertEquals(0, stepReply.getReward(), 1e-5); | ||||||
|         assertFalse(stepReply.isDone()); |         assertFalse(stepReply.isDone()); | ||||||
|         assertTrue(stepReply.getObservation().isSkipped()); |         assertFalse(stepReply.getObservation().isSkipped()); | ||||||
|  |         assertEquals(0, qLearningDiscrete.getExperienceHandler().getTrainingBatchSize()); | ||||||
|  | 
 | ||||||
|  |         verify(mockDQN, never()).output(any(INDArray.class)); | ||||||
|  | 
 | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     //TODO: there are much more test cases here that can be improved upon |     //TODO: there are much more test cases here that can be improved upon | ||||||
|  | |||||||
| @ -17,7 +17,7 @@ public class HistoryMergeTransformTest { | |||||||
|         HistoryMergeTransform sut = HistoryMergeTransform.builder() |         HistoryMergeTransform sut = HistoryMergeTransform.builder() | ||||||
|                 .isFirstDimenstionBatch(false) |                 .isFirstDimenstionBatch(false) | ||||||
|                 .elementStore(store) |                 .elementStore(store) | ||||||
|                 .build(); |                 .build(4); | ||||||
|         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); |         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); | ||||||
| 
 | 
 | ||||||
|         // Act |         // Act | ||||||
| @ -35,7 +35,7 @@ public class HistoryMergeTransformTest { | |||||||
|         HistoryMergeTransform sut = HistoryMergeTransform.builder() |         HistoryMergeTransform sut = HistoryMergeTransform.builder() | ||||||
|                 .isFirstDimenstionBatch(true) |                 .isFirstDimenstionBatch(true) | ||||||
|                 .elementStore(store) |                 .elementStore(store) | ||||||
|                 .build(); |                 .build(4); | ||||||
|         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3); |         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3); | ||||||
| 
 | 
 | ||||||
|         // Act |         // Act | ||||||
| @ -53,7 +53,7 @@ public class HistoryMergeTransformTest { | |||||||
|         HistoryMergeTransform sut = HistoryMergeTransform.builder() |         HistoryMergeTransform sut = HistoryMergeTransform.builder() | ||||||
|                 .isFirstDimenstionBatch(true) |                 .isFirstDimenstionBatch(true) | ||||||
|                 .elementStore(store) |                 .elementStore(store) | ||||||
|                 .build(); |                 .build(4); | ||||||
|         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); |         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); | ||||||
| 
 | 
 | ||||||
|         // Act |         // Act | ||||||
| @ -70,7 +70,7 @@ public class HistoryMergeTransformTest { | |||||||
|         HistoryMergeTransform sut = HistoryMergeTransform.builder() |         HistoryMergeTransform sut = HistoryMergeTransform.builder() | ||||||
|                 .shouldStoreCopy(false) |                 .shouldStoreCopy(false) | ||||||
|                 .elementStore(store) |                 .elementStore(store) | ||||||
|                 .build(); |                 .build(4); | ||||||
|         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); |         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); | ||||||
| 
 | 
 | ||||||
|         // Act |         // Act | ||||||
| @ -87,7 +87,7 @@ public class HistoryMergeTransformTest { | |||||||
|         HistoryMergeTransform sut = HistoryMergeTransform.builder() |         HistoryMergeTransform sut = HistoryMergeTransform.builder() | ||||||
|                 .shouldStoreCopy(true) |                 .shouldStoreCopy(true) | ||||||
|                 .elementStore(store) |                 .elementStore(store) | ||||||
|                 .build(); |                 .build(4); | ||||||
|         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); |         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); | ||||||
| 
 | 
 | ||||||
|         // Act |         // Act | ||||||
| @ -107,7 +107,7 @@ public class HistoryMergeTransformTest { | |||||||
|         HistoryMergeTransform sut = HistoryMergeTransform.builder() |         HistoryMergeTransform sut = HistoryMergeTransform.builder() | ||||||
|                 .elementStore(store) |                 .elementStore(store) | ||||||
|                 .assembler(assemble) |                 .assembler(assemble) | ||||||
|                 .build(); |                 .build(4); | ||||||
|         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); |         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); | ||||||
| 
 | 
 | ||||||
|         // Act |         // Act | ||||||
|  | |||||||
| @ -252,8 +252,8 @@ public class PolicyTest { | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         @Override |         @Override | ||||||
|         protected <O extends Encodable, AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, Integer, AS> mdpWrapper, IHistoryProcessor hp) { |         protected <MockObservation extends Encodable, AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockObservation, Integer, AS> mdpWrapper, IHistoryProcessor hp) { | ||||||
|             mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength)); |             mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(skipFrame, historyLength)); | ||||||
|             return super.refacInitMdp(mdpWrapper, hp); |             return super.refacInitMdp(mdpWrapper, hp); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -1,18 +0,0 @@ | |||||||
| package org.deeplearning4j.rl4j.support; |  | ||||||
| 
 |  | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| 
 |  | ||||||
| public class MockEncodable implements Encodable { |  | ||||||
| 
 |  | ||||||
|     private final int value; |  | ||||||
| 
 |  | ||||||
|     public MockEncodable(int value) { |  | ||||||
| 
 |  | ||||||
|         this.value = value; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public double[] toArray() { |  | ||||||
|         return new double[] { value }; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -2,9 +2,10 @@ package org.deeplearning4j.rl4j.support; | |||||||
| 
 | 
 | ||||||
| import org.deeplearning4j.gym.StepReply; | import org.deeplearning4j.gym.StepReply; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
|  | import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.TransformProcess; | import org.deeplearning4j.rl4j.observation.transform.TransformProcess; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; | import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform; | import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; | import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; | import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; | ||||||
| import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore; | import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore; | ||||||
| @ -15,7 +16,7 @@ import org.nd4j.linalg.api.rng.Random; | |||||||
| import java.util.ArrayList; | import java.util.ArrayList; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| 
 | 
 | ||||||
| public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> { | public class MockMDP implements MDP<MockObservation, Integer, DiscreteSpace> { | ||||||
| 
 | 
 | ||||||
|     private final DiscreteSpace actionSpace; |     private final DiscreteSpace actionSpace; | ||||||
|     private final int stepsUntilDone; |     private final int stepsUntilDone; | ||||||
| @ -55,11 +56,11 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public MockEncodable reset() { |     public MockObservation reset() { | ||||||
|         ++resetCount; |         ++resetCount; | ||||||
|         currentObsValue = 0; |         currentObsValue = 0; | ||||||
|         step = 0; |         step = 0; | ||||||
|         return new MockEncodable(currentObsValue++); |         return new MockObservation(currentObsValue++); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
| @ -68,10 +69,10 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public StepReply<MockEncodable> step(Integer action) { |     public StepReply<MockObservation> step(Integer action) { | ||||||
|         actions.add(action); |         actions.add(action); | ||||||
|         ++step; |         ++step; | ||||||
|         return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null); |         return new StepReply<>(new MockObservation(currentObsValue), (double) currentObsValue++, isDone(), null); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
| @ -84,14 +85,14 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> { | |||||||
|         return null; |         return null; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public static TransformProcess buildTransformProcess(int[] shape, int skipFrame, int historyLength) { |     public static TransformProcess buildTransformProcess(int skipFrame, int historyLength) { | ||||||
|         return TransformProcess.builder() |         return TransformProcess.builder() | ||||||
|                 .filter(new UniformSkippingFilter(skipFrame)) |                 .filter(new UniformSkippingFilter(skipFrame)) | ||||||
|                 .transform("data", new EncodableToINDArrayTransform(shape)) |                 .transform("data", new EncodableToINDArrayTransform()) | ||||||
|                 .transform("data", new SimpleNormalizationTransform(0.0, 255.0)) |                 .transform("data", new SimpleNormalizationTransform(0.0, 255.0)) | ||||||
|                 .transform("data", HistoryMergeTransform.builder() |                 .transform("data", HistoryMergeTransform.builder() | ||||||
|                         .elementStore(new CircularFifoStore(historyLength)) |                         .elementStore(new CircularFifoStore(historyLength)) | ||||||
|                         .build()) |                         .build(4)) | ||||||
|                 .build("data"); |                 .build("data"); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -0,0 +1,51 @@ | |||||||
|  | /******************************************************************************* | ||||||
|  |  * Copyright (c) 2020 Konduit K. K. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0. | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | package org.deeplearning4j.rl4j.support; | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
|  | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
|  | import org.nd4j.linalg.factory.Nd4j; | ||||||
|  | 
 | ||||||
|  | public class MockObservation implements Encodable { | ||||||
|  | 
 | ||||||
|  |     final INDArray data; | ||||||
|  | 
 | ||||||
|  |     public MockObservation(int value) { | ||||||
|  |         this.data = Nd4j.ones(1).mul(value); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public double[] toArray() { | ||||||
|  |         return data.data().asDouble(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public boolean isSkipped() { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public INDArray getData() { | ||||||
|  |         return data; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public Encodable dup() { | ||||||
|  |         return null; | ||||||
|  |     } | ||||||
|  | } | ||||||
| @ -17,7 +17,7 @@ public class MockPolicy implements IPolicy<Integer> { | |||||||
|     public List<INDArray> actionInputs = new ArrayList<INDArray>(); |     public List<INDArray> actionInputs = new ArrayList<INDArray>(); | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public <O extends Encodable, AS extends ActionSpace<Integer>> double play(MDP<O, Integer, AS> mdp, IHistoryProcessor hp) { |     public <MockObservation extends Encodable, AS extends ActionSpace<Integer>> double play(MDP<MockObservation, Integer, AS> mdp, IHistoryProcessor hp) { | ||||||
|         ++playCallCount; |         ++playCallCount; | ||||||
|         return 0; |         return 0; | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -28,6 +28,9 @@ import org.deeplearning4j.rl4j.space.ArrayObservationSpace; | |||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.space.ObservationSpace; | import org.deeplearning4j.rl4j.space.ObservationSpace; | ||||||
|  | import org.nd4j.linalg.api.buffer.DataType; | ||||||
|  | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
|  | import org.nd4j.linalg.factory.Nd4j; | ||||||
| import vizdoom.*; | import vizdoom.*; | ||||||
| 
 | 
 | ||||||
| import java.util.ArrayList; | import java.util.ArrayList; | ||||||
| @ -155,7 +158,7 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre | |||||||
|                         + Pointer.formatBytes(Pointer.totalPhysicalBytes())); |                         + Pointer.formatBytes(Pointer.totalPhysicalBytes())); | ||||||
| 
 | 
 | ||||||
|         game.newEpisode(); |         game.newEpisode(); | ||||||
|         return new GameScreen(game.getState().screenBuffer); |         return new GameScreen(observationSpace.getShape(), game.getState().screenBuffer); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -168,7 +171,7 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre | |||||||
| 
 | 
 | ||||||
|         double r = game.makeAction(actions.get(action)) * scaleFactor; |         double r = game.makeAction(actions.get(action)) * scaleFactor; | ||||||
|         log.info(game.getEpisodeTime() + " " + r + " " + action + " "); |         log.info(game.getEpisodeTime() + " " + r + " " + action + " "); | ||||||
|         return new StepReply(new GameScreen(game.isEpisodeFinished() |         return new StepReply(new GameScreen(observationSpace.getShape(), game.isEpisodeFinished() | ||||||
|                 ? new byte[game.getScreenSize()] |                 ? new byte[game.getScreenSize()] | ||||||
|                 : game.getState().screenBuffer), r, game.isEpisodeFinished(), null); |                 : game.getState().screenBuffer), r, game.isEpisodeFinished(), null); | ||||||
| 
 | 
 | ||||||
| @ -201,18 +204,34 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre | |||||||
| 
 | 
 | ||||||
|     public static class GameScreen implements Encodable { |     public static class GameScreen implements Encodable { | ||||||
| 
 | 
 | ||||||
|  |         final INDArray data; | ||||||
|  |         public GameScreen(int[] shape, byte[] screen) { | ||||||
| 
 | 
 | ||||||
|         double[] array; |             data = Nd4j.create(screen, new long[] {shape[1], shape[2], 3}, DataType.UINT8).permute(2,0,1); | ||||||
| 
 |  | ||||||
|         public GameScreen(byte[] screen) { |  | ||||||
|             array = new double[screen.length]; |  | ||||||
|             for (int i = 0; i < screen.length; i++) { |  | ||||||
|                 array[i] = (screen[i] & 0xFF) / 255.0; |  | ||||||
|             } |  | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         private GameScreen(INDArray toDup) { | ||||||
|  |             data = toDup.dup(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         @Override | ||||||
|         public double[] toArray() { |         public double[] toArray() { | ||||||
|             return array; |             return data.data().asDouble(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         @Override | ||||||
|  |         public boolean isSkipped() { | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         @Override | ||||||
|  |         public INDArray getData() { | ||||||
|  |             return data; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         @Override | ||||||
|  |         public GameScreen dup() { | ||||||
|  |             return new GameScreen(data); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -19,8 +19,6 @@ package org.deeplearning4j.rl4j.mdp.gym; | |||||||
| 
 | 
 | ||||||
| import java.io.IOException; | import java.io.IOException; | ||||||
| import lombok.Getter; | import lombok.Getter; | ||||||
| import lombok.Setter; |  | ||||||
| import lombok.Value; |  | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.bytedeco.javacpp.DoublePointer; | import org.bytedeco.javacpp.DoublePointer; | ||||||
| import org.bytedeco.javacpp.Pointer; | import org.bytedeco.javacpp.Pointer; | ||||||
| @ -31,8 +29,8 @@ import org.deeplearning4j.rl4j.space.ArrayObservationSpace; | |||||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | import org.deeplearning4j.rl4j.space.ActionSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Box; | import org.deeplearning4j.rl4j.space.Box; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; |  | ||||||
| import org.deeplearning4j.rl4j.space.HighLowDiscrete; | import org.deeplearning4j.rl4j.space.HighLowDiscrete; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.space.ObservationSpace; | import org.deeplearning4j.rl4j.space.ObservationSpace; | ||||||
| 
 | 
 | ||||||
| import org.bytedeco.cpython.*; | import org.bytedeco.cpython.*; | ||||||
| @ -47,7 +45,7 @@ import static org.bytedeco.numpy.global.numpy.*; | |||||||
|  * @author saudet |  * @author saudet | ||||||
|  */ |  */ | ||||||
| @Slf4j | @Slf4j | ||||||
| public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { | public class GymEnv<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> implements MDP<OBSERVATION, A, AS> { | ||||||
| 
 | 
 | ||||||
|     public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn"; |     public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn"; | ||||||
| 
 | 
 | ||||||
| @ -82,7 +80,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { | |||||||
|     private PyObject locals; |     private PyObject locals; | ||||||
| 
 | 
 | ||||||
|     final protected DiscreteSpace actionSpace; |     final protected DiscreteSpace actionSpace; | ||||||
|     final protected ObservationSpace<O> observationSpace; |     final protected ObservationSpace<OBSERVATION> observationSpace; | ||||||
|     @Getter |     @Getter | ||||||
|     final private String envId; |     final private String envId; | ||||||
|     @Getter |     @Getter | ||||||
| @ -119,7 +117,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { | |||||||
|             for (int i = 0; i < shape.length; i++) { |             for (int i = 0; i < shape.length; i++) { | ||||||
|                 shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i)); |                 shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i)); | ||||||
|             } |             } | ||||||
|             observationSpace = (ObservationSpace<O>) new ArrayObservationSpace<Box>(shape); |             observationSpace = (ObservationSpace<OBSERVATION>) new ArrayObservationSpace<Box>(shape); | ||||||
|             Py_DecRef(shapeTuple); |             Py_DecRef(shapeTuple); | ||||||
| 
 | 
 | ||||||
|             PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null); |             PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null); | ||||||
| @ -140,7 +138,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public ObservationSpace<O> getObservationSpace() { |     public ObservationSpace<OBSERVATION> getObservationSpace() { | ||||||
|         return observationSpace; |         return observationSpace; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -153,7 +151,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public StepReply<O> step(A action) { |     public StepReply<OBSERVATION> step(A action) { | ||||||
|         int gstate = PyGILState_Ensure(); |         int gstate = PyGILState_Ensure(); | ||||||
|         try { |         try { | ||||||
|             if (render) { |             if (render) { | ||||||
| @ -186,7 +184,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public O reset() { |     public OBSERVATION reset() { | ||||||
|         int gstate = PyGILState_Ensure(); |         int gstate = PyGILState_Ensure(); | ||||||
|         try { |         try { | ||||||
|             Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null)); |             Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null)); | ||||||
| @ -201,7 +199,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { | |||||||
| 
 | 
 | ||||||
|             double[] data = new double[(int)stateData.capacity()]; |             double[] data = new double[(int)stateData.capacity()]; | ||||||
|             stateData.get(data); |             stateData.get(data); | ||||||
|             return (O) new Box(data); |             return (OBSERVATION) new Box(data); | ||||||
|         } finally { |         } finally { | ||||||
|             PyGILState_Release(gstate); |             PyGILState_Release(gstate); | ||||||
|         } |         } | ||||||
| @ -220,7 +218,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public GymEnv<O, A, AS> newInstance() { |     public GymEnv<OBSERVATION, A, AS> newInstance() { | ||||||
|         return new GymEnv<O, A, AS>(envId, render, monitor); |         return new GymEnv<OBSERVATION, A, AS>(envId, render, monitor); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -40,8 +40,8 @@ public class GymEnvTest { | |||||||
|         assertEquals(false, mdp.isDone()); |         assertEquals(false, mdp.isDone()); | ||||||
|         Box o = (Box)mdp.reset(); |         Box o = (Box)mdp.reset(); | ||||||
|         StepReply r = mdp.step(0); |         StepReply r = mdp.step(0); | ||||||
|         assertEquals(4, o.toArray().length); |         assertEquals(4, o.getData().shape()[0]); | ||||||
|         assertEquals(4, ((Box)r.getObservation()).toArray().length); |         assertEquals(4, ((Box)r.getObservation()).getData().shape()[0]); | ||||||
|         assertNotEquals(null, mdp.newInstance()); |         assertNotEquals(null, mdp.newInstance()); | ||||||
|         mdp.close(); |         mdp.close(); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -1,5 +1,5 @@ | |||||||
| /******************************************************************************* | /******************************************************************************* | ||||||
|  * Copyright (c) 2015-2018 Skymind, Inc. |  * Copyright (c) 2020 Konduit K. K. | ||||||
|  * |  * | ||||||
|  * This program and the accompanying materials are made available under the |  * This program and the accompanying materials are made available under the | ||||||
|  * terms of the Apache License, Version 2.0 which is available at |  * terms of the Apache License, Version 2.0 which is available at | ||||||
| @ -16,33 +16,13 @@ | |||||||
| 
 | 
 | ||||||
| package org.deeplearning4j.malmo; | package org.deeplearning4j.malmo; | ||||||
| 
 | 
 | ||||||
| import java.util.Arrays; | import org.deeplearning4j.rl4j.space.Box; | ||||||
|  | import org.nd4j.linalg.factory.Nd4j; | ||||||
| 
 | 
 | ||||||
| import org.deeplearning4j.rl4j.space.Encodable; | @Deprecated | ||||||
|  | public class MalmoBox extends Box { | ||||||
| 
 | 
 | ||||||
| /** |     public MalmoBox(double... arr) { | ||||||
|  * Encodable state as a simple value array similar to Gym Box model, but without a JSON constructor |         super(arr); | ||||||
|  * @author howard-abrams (howard.abrams@ca.com) on 1/12/17. |  | ||||||
|  */ |  | ||||||
| public class MalmoBox implements Encodable { |  | ||||||
|     double[] value; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Construct state from an array of doubles |  | ||||||
|      * @param value state values |  | ||||||
|      */ |  | ||||||
|     //TODO: If this constructor was added to "Box", we wouldn't need this class at all. |  | ||||||
|     public MalmoBox(double... value) { |  | ||||||
|         this.value = value; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public double[] toArray() { |  | ||||||
|         return value; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String toString() { |  | ||||||
|         return Arrays.toString(value); |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -19,6 +19,8 @@ package org.deeplearning4j.malmo; | |||||||
| import java.util.Arrays; | import java.util.Arrays; | ||||||
| 
 | 
 | ||||||
| import com.microsoft.msr.malmo.WorldState; | import com.microsoft.msr.malmo.WorldState; | ||||||
|  | import org.deeplearning4j.rl4j.space.Box; | ||||||
|  | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * A Malmo consistency policy that ensures the both there is a reward and next observation has a different position that the previous one. |  * A Malmo consistency policy that ensures the both there is a reward and next observation has a different position that the previous one. | ||||||
| @ -30,14 +32,14 @@ public class MalmoDescretePositionPolicy implements MalmoObservationPolicy { | |||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public boolean isObservationConsistant(WorldState world_state, WorldState original_world_state) { |     public boolean isObservationConsistant(WorldState world_state, WorldState original_world_state) { | ||||||
|         MalmoBox last_observation = observationSpace.getObservation(world_state); |         Box last_observation = observationSpace.getObservation(world_state); | ||||||
|         MalmoBox old_observation = observationSpace.getObservation(original_world_state); |         Box old_observation = observationSpace.getObservation(original_world_state); | ||||||
| 
 | 
 | ||||||
|         double[] newvalues = old_observation == null ? null : old_observation.toArray(); |         INDArray newvalues = old_observation == null ? null : old_observation.getData(); | ||||||
|         double[] oldvalues = last_observation == null ? null : last_observation.toArray(); |         INDArray oldvalues = last_observation == null ? null : last_observation.getData(); | ||||||
| 
 | 
 | ||||||
|         return !(world_state.getObservations().isEmpty() || world_state.getRewards().isEmpty() |         return !(world_state.getObservations().isEmpty() || world_state.getRewards().isEmpty() | ||||||
|                         || Arrays.equals(oldvalues, newvalues)); |                         || oldvalues.eq(newvalues).all()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  | |||||||
| @ -21,6 +21,7 @@ import java.nio.file.Paths; | |||||||
| 
 | 
 | ||||||
| import org.deeplearning4j.gym.StepReply; | import org.deeplearning4j.gym.StepReply; | ||||||
| import org.deeplearning4j.rl4j.mdp.MDP; | import org.deeplearning4j.rl4j.mdp.MDP; | ||||||
|  | import org.deeplearning4j.rl4j.space.Box; | ||||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||||
| 
 | 
 | ||||||
| import com.microsoft.msr.malmo.AgentHost; | import com.microsoft.msr.malmo.AgentHost; | ||||||
| @ -34,6 +35,7 @@ import com.microsoft.msr.malmo.WorldState; | |||||||
| import lombok.Setter; | import lombok.Setter; | ||||||
| import lombok.Getter; | import lombok.Getter; | ||||||
| 
 | 
 | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.slf4j.Logger; | import org.slf4j.Logger; | ||||||
| import org.slf4j.LoggerFactory; | import org.slf4j.LoggerFactory; | ||||||
| 
 | 
 | ||||||
| @ -233,7 +235,7 @@ public class MalmoEnv implements MDP<MalmoBox, Integer, DiscreteSpace> { | |||||||
|             logger.info("Mission ended"); |             logger.info("Mission ended"); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         return new StepReply<MalmoBox>(last_observation, getRewards(last_world_state), isDone(), null); |         return new StepReply<>(last_observation, getRewards(last_world_state), isDone(), null); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     private double getRewards(WorldState world_state) { |     private double getRewards(WorldState world_state) { | ||||||
|  | |||||||
| @ -16,6 +16,8 @@ | |||||||
| 
 | 
 | ||||||
| package org.deeplearning4j.malmo; | package org.deeplearning4j.malmo; | ||||||
| 
 | 
 | ||||||
|  | import org.deeplearning4j.rl4j.space.Box; | ||||||
|  | import org.deeplearning4j.rl4j.space.Encodable; | ||||||
| import org.deeplearning4j.rl4j.space.ObservationSpace; | import org.deeplearning4j.rl4j.space.ObservationSpace; | ||||||
| 
 | 
 | ||||||
| import com.microsoft.msr.malmo.WorldState; | import com.microsoft.msr.malmo.WorldState; | ||||||
|  | |||||||
| @ -19,6 +19,7 @@ package org.deeplearning4j.malmo; | |||||||
| 
 | 
 | ||||||
| import com.microsoft.msr.malmo.TimestampedStringVector; | import com.microsoft.msr.malmo.TimestampedStringVector; | ||||||
| import com.microsoft.msr.malmo.WorldState; | import com.microsoft.msr.malmo.WorldState; | ||||||
|  | import org.deeplearning4j.rl4j.space.Box; | ||||||
| import org.json.JSONArray; | import org.json.JSONArray; | ||||||
| import org.json.JSONObject; | import org.json.JSONObject; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
|  | |||||||
| @ -18,6 +18,7 @@ package org.deeplearning4j.malmo; | |||||||
| 
 | 
 | ||||||
| import java.util.HashMap; | import java.util.HashMap; | ||||||
| 
 | 
 | ||||||
|  | import org.deeplearning4j.rl4j.space.Box; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -16,6 +16,7 @@ | |||||||
| 
 | 
 | ||||||
| package org.deeplearning4j.malmo; | package org.deeplearning4j.malmo; | ||||||
| 
 | 
 | ||||||
|  | import org.deeplearning4j.rl4j.space.Box; | ||||||
| import org.json.JSONObject; | import org.json.JSONObject; | ||||||
| 
 | 
 | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user