RL4J: Sanitize Observation (#404)
* working on ALE image pipelines that appear to lose data * transformation pipeline for ALE has been broken for a while and needed some cleanup to make sure that openCV tooling for scene transforms was actually working. * allowing history length to be set and passed through to history merge transforms Signed-off-by: Bam4d <chrisbam4d@gmail.com> * native image loader is not thread-safe so should not be static Signed-off-by: Bam4d <chrisbam4d@gmail.com> * make sure the transformer for encoding observations that are not pixels converts corectly Signed-off-by: Bam4d <chrisbam4d@gmail.com> * Test fixes for ALE pixel observation shape Signed-off-by: Bam4d <chrisbam4d@gmail.com> * Fix compilation errors Signed-off-by: Samuel Audet <samuel.audet@gmail.com> * fixing some post-merge issues, and comments from PR Signed-off-by: Bam4d <chrisbam4d@gmail.com> Co-authored-by: Samuel Audet <samuel.audet@gmail.com>master
parent
75cc6e2ed7
commit
032b97912e
|
@ -25,9 +25,13 @@ import org.bytedeco.javacpp.IntPointer;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
|
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author saudet
|
* @author saudet
|
||||||
|
@ -70,10 +74,14 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
|
||||||
actions = new int[(int)a.limit()];
|
actions = new int[(int)a.limit()];
|
||||||
a.get(actions);
|
a.get(actions);
|
||||||
|
|
||||||
|
int height = (int)ale.getScreen().height();
|
||||||
|
int width = (int)(int)ale.getScreen().width();
|
||||||
|
|
||||||
discreteSpace = new DiscreteSpace(actions.length);
|
discreteSpace = new DiscreteSpace(actions.length);
|
||||||
int[] shape = {(int)ale.getScreen().height(), (int)ale.getScreen().width(), 3};
|
int[] shape = {3, height, width};
|
||||||
observationSpace = new ArrayObservationSpace<>(shape);
|
observationSpace = new ArrayObservationSpace<>(shape);
|
||||||
screenBuffer = new byte[shape[0] * shape[1] * shape[2]];
|
screenBuffer = new byte[shape[0] * shape[1] * shape[2]];
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setupGame() {
|
public void setupGame() {
|
||||||
|
@ -103,7 +111,7 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
|
||||||
public GameScreen reset() {
|
public GameScreen reset() {
|
||||||
ale.reset_game();
|
ale.reset_game();
|
||||||
ale.getScreenRGB(screenBuffer);
|
ale.getScreenRGB(screenBuffer);
|
||||||
return new GameScreen(screenBuffer);
|
return new GameScreen(observationSpace.getShape(), screenBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -115,7 +123,8 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
|
||||||
double r = ale.act(actions[action]) * scaleFactor;
|
double r = ale.act(actions[action]) * scaleFactor;
|
||||||
log.info(ale.getEpisodeFrameNumber() + " " + r + " " + action + " ");
|
log.info(ale.getEpisodeFrameNumber() + " " + r + " " + action + " ");
|
||||||
ale.getScreenRGB(screenBuffer);
|
ale.getScreenRGB(screenBuffer);
|
||||||
return new StepReply(new GameScreen(screenBuffer), r, ale.game_over(), null);
|
|
||||||
|
return new StepReply(new GameScreen(observationSpace.getShape(), screenBuffer), r, ale.game_over(), null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ObservationSpace<GameScreen> getObservationSpace() {
|
public ObservationSpace<GameScreen> getObservationSpace() {
|
||||||
|
@ -140,17 +149,35 @@ public class ALEMDP implements MDP<ALEMDP.GameScreen, Integer, DiscreteSpace> {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class GameScreen implements Encodable {
|
public static class GameScreen implements Encodable {
|
||||||
double[] array;
|
|
||||||
|
|
||||||
public GameScreen(byte[] screen) {
|
final INDArray data;
|
||||||
array = new double[screen.length];
|
public GameScreen(int[] shape, byte[] screen) {
|
||||||
for (int i = 0; i < screen.length; i++) {
|
|
||||||
array[i] = (screen[i] & 0xFF) / 255.0;
|
data = Nd4j.create(screen, new long[] {shape[1], shape[2], 3}, DataType.UINT8).permute(2,0,1);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private GameScreen(INDArray toDup) {
|
||||||
|
data = toDup.dup();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public double[] toArray() {
|
public double[] toArray() {
|
||||||
return array;
|
return data.data().asDouble();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSkipped() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getData() {
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public GameScreen dup() {
|
||||||
|
return new GameScreen(data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,15 +19,15 @@ package org.deeplearning4j.gym;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param <T> type of observation
|
* @param <OBSERVATION> type of observation
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16.
|
||||||
*
|
*
|
||||||
* StepReply is the container for the data returned after each step(action).
|
* StepReply is the container for the data returned after each step(action).
|
||||||
*/
|
*/
|
||||||
@Value
|
@Value
|
||||||
public class StepReply<T> {
|
public class StepReply<OBSERVATION> {
|
||||||
|
|
||||||
T observation;
|
OBSERVATION observation;
|
||||||
double reward;
|
double reward;
|
||||||
boolean done;
|
boolean done;
|
||||||
Object info;
|
Object info;
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
* in a "functionnal manner" if step return a mdp
|
* in a "functionnal manner" if step return a mdp
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public interface MDP<OBSERVATION, ACTION, ACTION_SPACE extends ActionSpace<ACTION>> {
|
public interface MDP<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>> {
|
||||||
|
|
||||||
ObservationSpace<OBSERVATION> getObservationSpace();
|
ObservationSpace<OBSERVATION> getObservationSpace();
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.space;
|
package org.deeplearning4j.rl4j.space;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
|
||||||
*
|
*
|
||||||
|
@ -25,13 +28,37 @@ package org.deeplearning4j.rl4j.space;
|
||||||
*/
|
*/
|
||||||
public class Box implements Encodable {
|
public class Box implements Encodable {
|
||||||
|
|
||||||
private final double[] array;
|
private final INDArray data;
|
||||||
|
|
||||||
public Box(double[] arr) {
|
public Box(double... arr) {
|
||||||
this.array = arr;
|
this.data = Nd4j.create(arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Box(int[] shape, double... arr) {
|
||||||
|
this.data = Nd4j.create(arr).reshape(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Box(INDArray toDup) {
|
||||||
|
data = toDup.dup();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public double[] toArray() {
|
public double[] toArray() {
|
||||||
return array;
|
return data.data().asDouble();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSkipped() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getData() {
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Encodable dup() {
|
||||||
|
return new Box(data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2020 Konduit K. K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,17 +16,19 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.space;
|
package org.deeplearning4j.rl4j.space;
|
||||||
|
|
||||||
/**
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/19/16.
|
|
||||||
* Encodable is an interface that ensure that the state is convertible to a double array
|
|
||||||
*/
|
|
||||||
public interface Encodable {
|
public interface Encodable {
|
||||||
|
|
||||||
/**
|
@Deprecated
|
||||||
* $
|
|
||||||
* encodes all the information of an Observation in an array double and can be used as input of a DQN directly
|
|
||||||
*
|
|
||||||
* @return the encoded informations
|
|
||||||
*/
|
|
||||||
double[] toArray();
|
double[] toArray();
|
||||||
|
|
||||||
|
boolean isSkipped();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Any image data should be in CHW format.
|
||||||
|
*/
|
||||||
|
INDArray getData();
|
||||||
|
|
||||||
|
Encodable dup();
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,16 +24,17 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
public class INDArrayHelper {
|
public class INDArrayHelper {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* MultiLayerNetwork and ComputationGraph expect the first dimension to be the number of examples in the INDArray.
|
* MultiLayerNetwork and ComputationGraph expects input data to be in NCHW in the case of pixels and NS in case of other data types.
|
||||||
* In the case of RL4J, it must be 1. This method will return a INDArray with the correct shape.
|
|
||||||
*
|
*
|
||||||
* @param source A INDArray
|
* We must have either shape 2 (NK) or shape 4 (NCHW)
|
||||||
* @return The source INDArray with the correct shape
|
|
||||||
*/
|
*/
|
||||||
public static INDArray forceCorrectShape(INDArray source) {
|
public static INDArray forceCorrectShape(INDArray source) {
|
||||||
|
|
||||||
return source.shape()[0] == 1 && source.shape().length > 1
|
return source.shape()[0] == 1 && source.shape().length > 1
|
||||||
? source
|
? source
|
||||||
: Nd4j.expandDims(source, 0);
|
: Nd4j.expandDims(source, 0);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,6 @@ public class HistoryProcessor implements IHistoryProcessor {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final private Configuration conf;
|
final private Configuration conf;
|
||||||
final private OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat();
|
|
||||||
private CircularFifoQueue<INDArray> history;
|
private CircularFifoQueue<INDArray> history;
|
||||||
private VideoRecorder videoRecorder;
|
private VideoRecorder videoRecorder;
|
||||||
|
|
||||||
|
@ -63,8 +62,7 @@ public class HistoryProcessor implements IHistoryProcessor {
|
||||||
|
|
||||||
public void startMonitor(String filename, int[] shape) {
|
public void startMonitor(String filename, int[] shape) {
|
||||||
if(videoRecorder == null) {
|
if(videoRecorder == null) {
|
||||||
videoRecorder = VideoRecorder.builder(shape[0], shape[1])
|
videoRecorder = VideoRecorder.builder(shape[1], shape[2])
|
||||||
.frameInputType(VideoRecorder.FrameInputTypes.Float)
|
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,14 +87,13 @@ public class HistoryProcessor implements IHistoryProcessor {
|
||||||
return videoRecorder != null && videoRecorder.isRecording();
|
return videoRecorder != null && videoRecorder.isRecording();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void record(INDArray raw) {
|
public void record(INDArray pixelArray) {
|
||||||
if(isMonitoring()) {
|
if(isMonitoring()) {
|
||||||
// before accessing the raw pointer, we need to make sure that array is actual on the host side
|
// before accessing the raw pointer, we need to make sure that array is actual on the host side
|
||||||
Nd4j.getAffinityManager().ensureLocation(raw, AffinityManager.Location.HOST);
|
Nd4j.getAffinityManager().ensureLocation(pixelArray, AffinityManager.Location.HOST);
|
||||||
|
|
||||||
VideoRecorder.VideoFrame frame = videoRecorder.createFrame(raw.data().pointer());
|
|
||||||
try {
|
try {
|
||||||
videoRecorder.record(frame);
|
videoRecorder.record(pixelArray);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,7 +64,7 @@ public interface IHistoryProcessor {
|
||||||
@Builder.Default int skipFrame = 4;
|
@Builder.Default int skipFrame = 4;
|
||||||
|
|
||||||
public int[] getShape() {
|
public int[] getShape() {
|
||||||
return new int[] {getHistoryLength(), getCroppingHeight(), getCroppingWidth()};
|
return new int[] {getHistoryLength(), getRescaledHeight(), getRescaledWidth()};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
*
|
*
|
||||||
* A common interface that any training method should implement
|
* A common interface that any training method should implement
|
||||||
*/
|
*/
|
||||||
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
|
public interface ILearning<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> {
|
||||||
|
|
||||||
IPolicy<A> getPolicy();
|
IPolicy<A> getPolicy();
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
|
||||||
|
|
||||||
ILearningConfiguration getConfiguration();
|
ILearningConfiguration getConfiguration();
|
||||||
|
|
||||||
MDP<O, A, AS> getMdp();
|
MDP<OBSERVATION, A, AS> getMdp();
|
||||||
|
|
||||||
IHistoryProcessor getHistoryProcessor();
|
IHistoryProcessor getHistoryProcessor();
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
@ -38,8 +37,8 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class Learning<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
|
implements ILearning<OBSERVATION, A, AS>, NeuralNetFetchable<NN> {
|
||||||
|
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
protected int stepCount = 0;
|
protected int stepCount = 0;
|
||||||
|
|
|
@ -29,10 +29,10 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -188,7 +188,7 @@ public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean handleTraining(RunContext context) {
|
private boolean handleTraining(RunContext context) {
|
||||||
int maxTrainSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpisodeStepCount);
|
int maxTrainSteps = Math.min(getConfiguration().getNStep(), getConfiguration().getMaxEpochStep() - currentEpisodeStepCount);
|
||||||
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxTrainSteps);
|
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxTrainSteps);
|
||||||
|
|
||||||
context.obs = subEpochReturn.getLastObs();
|
context.obs = subEpochReturn.getLastObs();
|
||||||
|
@ -219,7 +219,7 @@ public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_
|
||||||
|
|
||||||
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
||||||
|
|
||||||
protected abstract IAsyncLearningConfiguration getConf();
|
protected abstract IAsyncLearningConfiguration getConfiguration();
|
||||||
|
|
||||||
protected abstract IPolicy<ACTION> getPolicy(NN net);
|
protected abstract IPolicy<ACTION> getPolicy(NN net);
|
||||||
|
|
||||||
|
|
|
@ -24,29 +24,22 @@ import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
|
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
|
||||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
|
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.Stack;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
* <p>
|
* <p>
|
||||||
* Async Learning specialized for the Discrete Domain
|
* Async Learning specialized for the Discrete Domain
|
||||||
*/
|
*/
|
||||||
public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet>
|
public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN extends NeuralNet>
|
||||||
extends AsyncThread<O, Integer, DiscreteSpace, NN> {
|
extends AsyncThread<OBSERVATION, Integer, DiscreteSpace, NN> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private NN current;
|
private NN current;
|
||||||
|
@ -59,7 +52,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
||||||
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
|
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
|
||||||
|
|
||||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
|
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
|
||||||
MDP<O, Integer, DiscreteSpace> mdp,
|
MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
|
||||||
TrainingListenerList listeners,
|
TrainingListenerList listeners,
|
||||||
int threadNumber,
|
int threadNumber,
|
||||||
int deviceNum) {
|
int deviceNum) {
|
||||||
|
@ -112,7 +105,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
||||||
}
|
}
|
||||||
|
|
||||||
StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action);
|
StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action);
|
||||||
accuReward += stepReply.getReward() * getConf().getRewardFactor();
|
accuReward += stepReply.getReward() * getConfiguration().getRewardFactor();
|
||||||
|
|
||||||
if (!obs.isSkipped()) {
|
if (!obs.isSkipped()) {
|
||||||
experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone());
|
experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone());
|
||||||
|
@ -126,7 +119,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean episodeComplete = getMdp().isDone() || getConf().getMaxEpochStep() == currentEpisodeStepCount;
|
boolean episodeComplete = getMdp().isDone() || getConfiguration().getMaxEpochStep() == currentEpisodeStepCount;
|
||||||
|
|
||||||
if (episodeComplete && experienceHandler.getTrainingBatchSize() != trainingSteps) {
|
if (episodeComplete && experienceHandler.getTrainingBatchSize() != trainingSteps) {
|
||||||
experienceHandler.setFinalObservation(obs);
|
experienceHandler.setFinalObservation(obs);
|
||||||
|
|
|
@ -28,9 +28,9 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -41,19 +41,19 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
* All methods are fully implemented as described in the
|
* All methods are fully implemented as described in the
|
||||||
* https://arxiv.org/abs/1602.01783 paper.
|
* https://arxiv.org/abs/1602.01783 paper.
|
||||||
*/
|
*/
|
||||||
public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, Integer, DiscreteSpace, IActorCritic> {
|
public abstract class A3CDiscrete<OBSERVATION extends Encodable> extends AsyncLearning<OBSERVATION, Integer, DiscreteSpace, IActorCritic> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final public A3CLearningConfiguration configuration;
|
final public A3CLearningConfiguration configuration;
|
||||||
@Getter
|
@Getter
|
||||||
final protected MDP<O, Integer, DiscreteSpace> mdp;
|
final protected MDP<OBSERVATION, Integer, DiscreteSpace> mdp;
|
||||||
final private IActorCritic iActorCritic;
|
final private IActorCritic iActorCritic;
|
||||||
@Getter
|
@Getter
|
||||||
final private AsyncGlobal asyncGlobal;
|
final private AsyncGlobal asyncGlobal;
|
||||||
@Getter
|
@Getter
|
||||||
final private ACPolicy<O> policy;
|
final private ACPolicy<OBSERVATION> policy;
|
||||||
|
|
||||||
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) {
|
public A3CDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) {
|
||||||
this.iActorCritic = iActorCritic;
|
this.iActorCritic = iActorCritic;
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
|
|
|
@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph;
|
||||||
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
|
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration;
|
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
|
@ -42,19 +42,19 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
* first layers since they're essentially doing the same dimension
|
* first layers since they're essentially doing the same dimension
|
||||||
* reduction task
|
* reduction task
|
||||||
**/
|
**/
|
||||||
public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
public class A3CDiscreteConv<OBSERVATION extends Encodable> extends A3CDiscrete<OBSERVATION> {
|
||||||
|
|
||||||
final private HistoryProcessor.Configuration hpconf;
|
final private HistoryProcessor.Configuration hpconf;
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
|
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, actorCritic, hpconf, conf);
|
this(mdp, actorCritic, hpconf, conf);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||||
|
|
||||||
super(mdp, IActorCritic, conf.toLearningConfiguration());
|
super(mdp, IActorCritic, conf.toLearningConfiguration());
|
||||||
|
@ -62,7 +62,7 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
||||||
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
|
||||||
super(mdp, IActorCritic, conf);
|
super(mdp, IActorCritic, conf);
|
||||||
this.hpconf = hpconf;
|
this.hpconf = hpconf;
|
||||||
|
@ -70,35 +70,35 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||||
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
|
||||||
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
|
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||||
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
|
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticNetworkConfiguration netConf,
|
public A3CDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticNetworkConfiguration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
|
||||||
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
|
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,8 +21,8 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.ac.*;
|
import org.deeplearning4j.rl4j.network.ac.*;
|
||||||
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
|
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
|
@ -34,74 +34,74 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
* We use specifically the Separate version because
|
* We use specifically the Separate version because
|
||||||
* the model is too small to have enough benefit by sharing layers
|
* the model is too small to have enough benefit by sharing layers
|
||||||
*/
|
*/
|
||||||
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
|
public class A3CDiscreteDense<OBSERVATION extends Encodable> extends A3CDiscrete<OBSERVATION> {
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
|
||||||
IDataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
this(mdp, IActorCritic, conf);
|
this(mdp, IActorCritic, conf);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
|
||||||
super(mdp, actorCritic, conf.toLearningConfiguration());
|
super(mdp, actorCritic, conf.toLearningConfiguration());
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) {
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) {
|
||||||
super(mdp, actorCritic, conf);
|
super(mdp, actorCritic, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||||
A3CConfiguration conf, IDataManager dataManager) {
|
A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||||
A3CConfiguration conf) {
|
A3CConfiguration conf) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||||
A3CLearningConfiguration conf) {
|
A3CLearningConfiguration conf) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
|
||||||
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
|
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
|
||||||
IDataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
|
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
|
||||||
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
|
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
|
||||||
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf);
|
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
|
||||||
ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) {
|
ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) {
|
||||||
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
|
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
A3CConfiguration conf, IDataManager dataManager) {
|
A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
A3CConfiguration conf) {
|
A3CConfiguration conf) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
A3CLearningConfiguration conf) {
|
A3CLearningConfiguration conf) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,23 +23,23 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
||||||
*
|
* <p>
|
||||||
* Local thread as described in the https://arxiv.org/abs/1602.01783 paper.
|
* Local thread as described in the https://arxiv.org/abs/1602.01783 paper.
|
||||||
*/
|
*/
|
||||||
public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IActorCritic> {
|
public class A3CThreadDiscrete<OBSERVATION extends Encodable> extends AsyncThreadDiscrete<OBSERVATION, IActorCritic> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final protected A3CLearningConfiguration conf;
|
final protected A3CLearningConfiguration configuration;
|
||||||
@Getter
|
@Getter
|
||||||
final protected IAsyncGlobal<IActorCritic> asyncGlobal;
|
final protected IAsyncGlobal<IActorCritic> asyncGlobal;
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -47,17 +47,17 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
|
|
||||||
final private Random rnd;
|
final private Random rnd;
|
||||||
|
|
||||||
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal,
|
public A3CThreadDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal,
|
||||||
A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
|
A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
|
||||||
int threadNumber) {
|
int threadNumber) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
this.conf = a3cc;
|
this.configuration = a3cc;
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
|
|
||||||
Long seed = conf.getSeed();
|
Long seed = configuration.getSeed();
|
||||||
rnd = Nd4j.getRandom();
|
rnd = Nd4j.getRandom();
|
||||||
if(seed != null) {
|
if (seed != null) {
|
||||||
rnd.setSeed(seed + threadNumber);
|
rnd.setSeed(seed + threadNumber);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,9 +69,12 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
return new ACPolicy(net, rnd);
|
return new ACPolicy(net, rnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* calc the gradients based on the n-step rewards
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
|
protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
|
||||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
||||||
return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), conf.getGamma());
|
return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), configuration.getGamma());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,26 +28,26 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
*/
|
*/
|
||||||
public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
public abstract class AsyncNStepQLearningDiscrete<OBSERVATION extends Encodable>
|
||||||
extends AsyncLearning<O, Integer, DiscreteSpace, IDQN> {
|
extends AsyncLearning<OBSERVATION, Integer, DiscreteSpace, IDQN> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final public AsyncQLearningConfiguration configuration;
|
final public AsyncQLearningConfiguration configuration;
|
||||||
@Getter
|
@Getter
|
||||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
final private MDP<OBSERVATION, Integer, DiscreteSpace> mdp;
|
||||||
@Getter
|
@Getter
|
||||||
final private AsyncGlobal<IDQN> asyncGlobal;
|
final private AsyncGlobal<IDQN> asyncGlobal;
|
||||||
|
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
|
public AsyncNStepQLearningDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
||||||
|
@ -63,7 +63,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
}
|
}
|
||||||
|
|
||||||
public IPolicy<Integer> getPolicy() {
|
public IPolicy<Integer> getPolicy() {
|
||||||
return new DQNPolicy<O>(getNeuralNet());
|
return new DQNPolicy<OBSERVATION>(getNeuralNet());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
|
@ -35,17 +35,17 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
* Specialized constructors for the Conv (pixels input) case
|
* Specialized constructors for the Conv (pixels input) case
|
||||||
* Specialized conf + provide additional type safety
|
* Specialized conf + provide additional type safety
|
||||||
*/
|
*/
|
||||||
public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
|
public class AsyncNStepQLearningDiscreteConv<OBSERVATION extends Encodable> extends AsyncNStepQLearningDiscrete<OBSERVATION> {
|
||||||
|
|
||||||
final private HistoryProcessor.Configuration hpconf;
|
final private HistoryProcessor.Configuration hpconf;
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, dqn, hpconf, conf);
|
this(mdp, dqn, hpconf, conf);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
|
||||||
super(mdp, dqn, conf);
|
super(mdp, dqn, conf);
|
||||||
this.hpconf = hpconf;
|
this.hpconf = hpconf;
|
||||||
|
@ -53,21 +53,21 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
|
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
|
public AsyncNStepQLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
|
||||||
*/
|
*/
|
||||||
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
|
public class AsyncNStepQLearningDiscreteDense<OBSERVATION extends Encodable> extends AsyncNStepQLearningDiscrete<OBSERVATION> {
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
super(mdp, dqn, conf.toLearningConfiguration());
|
super(mdp, dqn, conf.toLearningConfiguration());
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
AsyncNStepQLConfiguration conf) {
|
AsyncNStepQLConfiguration conf) {
|
||||||
super(mdp, dqn, conf.toLearningConfiguration());
|
super(mdp, dqn, conf.toLearningConfiguration());
|
||||||
}
|
}
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
AsyncQLearningConfiguration conf) {
|
AsyncQLearningConfiguration conf) {
|
||||||
super(mdp, dqn, conf);
|
super(mdp, dqn, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
AsyncNStepQLConfiguration conf) {
|
AsyncNStepQLConfiguration conf) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
AsyncQLearningConfiguration conf) {
|
AsyncQLearningConfiguration conf) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
|
||||||
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
|
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
|
||||||
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
|
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
|
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public AsyncNStepQLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
|
||||||
DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) {
|
DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,21 +25,21 @@ import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguratio
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
*/
|
*/
|
||||||
public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IDQN> {
|
public class AsyncNStepQLearningThreadDiscrete<OBSERVATION extends Encodable> extends AsyncThreadDiscrete<OBSERVATION, IDQN> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final protected AsyncQLearningConfiguration conf;
|
final protected AsyncQLearningConfiguration configuration;
|
||||||
@Getter
|
@Getter
|
||||||
final protected IAsyncGlobal<IDQN> asyncGlobal;
|
final protected IAsyncGlobal<IDQN> asyncGlobal;
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -47,17 +47,17 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
|
|
||||||
final private Random rnd;
|
final private Random rnd;
|
||||||
|
|
||||||
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
public AsyncNStepQLearningThreadDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
||||||
AsyncQLearningConfiguration conf,
|
AsyncQLearningConfiguration configuration,
|
||||||
TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
this.conf = conf;
|
this.configuration = configuration;
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
rnd = Nd4j.getRandom();
|
rnd = Nd4j.getRandom();
|
||||||
|
|
||||||
Long seed = conf.getSeed();
|
Long seed = configuration.getSeed();
|
||||||
if (seed != null) {
|
if(seed != null) {
|
||||||
rnd.setSeed(seed + threadNumber);
|
rnd.setSeed(seed + threadNumber);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,13 +65,13 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
}
|
}
|
||||||
|
|
||||||
public Policy<Integer> getPolicy(IDQN nn) {
|
public Policy<Integer> getPolicy(IDQN nn) {
|
||||||
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
|
return new EpsGreedy(new DQNPolicy(nn), getMdp(), configuration.getUpdateStart(), configuration.getEpsilonNbStep(),
|
||||||
rnd, conf.getMinEpsilon(), this);
|
rnd, configuration.getMinEpsilon(), this);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
|
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
|
||||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
||||||
return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), conf.getGamma());
|
return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), configuration.getGamma());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,10 +32,10 @@ import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
|
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
|
|
||||||
|
|
|
@ -33,11 +33,11 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorith
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
|
@ -24,8 +24,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
|
@ -34,59 +34,59 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
* Specialized constructors for the Conv (pixels input) case
|
* Specialized constructors for the Conv (pixels input) case
|
||||||
* Specialized conf + provide additional type safety
|
* Specialized conf + provide additional type safety
|
||||||
*/
|
*/
|
||||||
public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscrete<O> {
|
public class QLearningDiscreteConv<OBSERVATION extends Encodable> extends QLearningDiscrete<OBSERVATION> {
|
||||||
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||||
QLConfiguration conf, IDataManager dataManager) {
|
QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, dqn, hpconf, conf);
|
this(mdp, dqn, hpconf, conf);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||||
QLConfiguration conf) {
|
QLConfiguration conf) {
|
||||||
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||||
QLearningConfiguration conf) {
|
QLearningConfiguration conf) {
|
||||||
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
|
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
|
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
|
public QLearningDiscreteConv(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
|
||||||
*/
|
*/
|
||||||
public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscrete<O> {
|
public class QLearningDiscreteDense<OBSERVATION extends Encodable> extends QLearningDiscrete<OBSERVATION> {
|
||||||
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
|
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
|
||||||
IDataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
this(mdp, dqn, conf);
|
this(mdp, dqn, conf);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
|
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
|
||||||
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep());
|
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep());
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf) {
|
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf) {
|
||||||
super(mdp, dqn, conf, conf.getEpsilonNbStep());
|
super(mdp, dqn, conf, conf.getEpsilonNbStep());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
QLearning.QLConfiguration conf) {
|
QLearning.QLConfiguration conf) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
QLearningConfiguration conf) {
|
QLearningConfiguration conf) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
||||||
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
||||||
|
|
||||||
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
|
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
||||||
QLearning.QLConfiguration conf) {
|
QLearning.QLConfiguration conf) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
|
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNDenseNetworkConfiguration netConf,
|
public QLearningDiscreteDense(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, DQNDenseNetworkConfiguration netConf,
|
||||||
QLearningConfiguration conf) {
|
QLearningConfiguration conf) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,8 +4,8 @@ import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
|
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
@ -36,7 +36,7 @@ import java.util.Random;
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
public class CartpoleNative implements MDP<CartpoleNative.State, Integer, DiscreteSpace> {
|
public class CartpoleNative implements MDP<Box, Integer, DiscreteSpace> {
|
||||||
public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
|
public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
|
||||||
|
|
||||||
private static final int NUM_ACTIONS = 2;
|
private static final int NUM_ACTIONS = 2;
|
||||||
|
@ -74,7 +74,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
||||||
@Getter
|
@Getter
|
||||||
private DiscreteSpace actionSpace = new DiscreteSpace(NUM_ACTIONS);
|
private DiscreteSpace actionSpace = new DiscreteSpace(NUM_ACTIONS);
|
||||||
@Getter
|
@Getter
|
||||||
private ObservationSpace<CartpoleNative.State> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES });
|
private ObservationSpace<Box> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES });
|
||||||
|
|
||||||
public CartpoleNative() {
|
public CartpoleNative() {
|
||||||
rnd = new Random();
|
rnd = new Random();
|
||||||
|
@ -85,7 +85,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public State reset() {
|
public Box reset() {
|
||||||
|
|
||||||
x = 0.1 * rnd.nextDouble() - 0.05;
|
x = 0.1 * rnd.nextDouble() - 0.05;
|
||||||
xDot = 0.1 * rnd.nextDouble() - 0.05;
|
xDot = 0.1 * rnd.nextDouble() - 0.05;
|
||||||
|
@ -94,7 +94,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
||||||
stepsBeyondDone = null;
|
stepsBeyondDone = null;
|
||||||
done = false;
|
done = false;
|
||||||
|
|
||||||
return new State(new double[] { x, xDot, theta, thetaDot });
|
return new Box(x, xDot, theta, thetaDot);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -103,7 +103,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public StepReply<State> step(Integer action) {
|
public StepReply<Box> step(Integer action) {
|
||||||
double force = action == ACTION_RIGHT ? forceMag : -forceMag;
|
double force = action == ACTION_RIGHT ? forceMag : -forceMag;
|
||||||
double cosTheta = Math.cos(theta);
|
double cosTheta = Math.cos(theta);
|
||||||
double sinTheta = Math.sin(theta);
|
double sinTheta = Math.sin(theta);
|
||||||
|
@ -143,26 +143,12 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
||||||
reward = 0;
|
reward = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return new StepReply<>(new State(new double[] { x, xDot, theta, thetaDot }), reward, done, null);
|
return new StepReply<>(new Box(x, xDot, theta, thetaDot), reward, done, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MDP<State, Integer, DiscreteSpace> newInstance() {
|
public MDP<Box, Integer, DiscreteSpace> newInstance() {
|
||||||
return new CartpoleNative();
|
return new CartpoleNative();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class State implements Encodable {
|
|
||||||
|
|
||||||
private final double[] state;
|
|
||||||
|
|
||||||
State(double[] state) {
|
|
||||||
|
|
||||||
this.state = state;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] toArray() {
|
|
||||||
return state;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy;
|
||||||
|
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
|
||||||
|
@ -31,4 +32,19 @@ public class HardToyState implements Encodable {
|
||||||
public double[] toArray() {
|
public double[] toArray() {
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSkipped() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getData() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Encodable dup() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
|
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -40,7 +41,6 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
public class SimpleToy implements MDP<SimpleToyState, Integer, DiscreteSpace> {
|
public class SimpleToy implements MDP<SimpleToyState, Integer, DiscreteSpace> {
|
||||||
|
|
||||||
final private int maxStep;
|
final private int maxStep;
|
||||||
//TODO 10 steps toy (always +1 reward2 actions), toylong (1000 steps), toyhard (7 actions, +1 only if actiion = (step/100+step)%7, and toyStoch (like last but reward has 0.10 odd to be somewhere else).
|
|
||||||
@Getter
|
@Getter
|
||||||
private DiscreteSpace actionSpace = new DiscreteSpace(2);
|
private DiscreteSpace actionSpace = new DiscreteSpace(2);
|
||||||
@Getter
|
@Getter
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy;
|
||||||
|
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
||||||
|
@ -28,11 +29,24 @@ public class SimpleToyState implements Encodable {
|
||||||
int i;
|
int i;
|
||||||
int step;
|
int step;
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] toArray() {
|
public double[] toArray() {
|
||||||
double[] ar = new double[1];
|
double[] ar = new double[1];
|
||||||
ar[0] = (20 - i);
|
ar[0] = (20 - i);
|
||||||
return ar;
|
return ar;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSkipped() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getData() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Encodable dup() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
*
|
*
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
public class Observation {
|
public class Observation implements Encodable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A singleton representing a skipped observation
|
* A singleton representing a skipped observation
|
||||||
|
@ -38,6 +38,11 @@ public class Observation {
|
||||||
@Getter
|
@Getter
|
||||||
private final INDArray data;
|
private final INDArray data;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double[] toArray() {
|
||||||
|
return data.data().asDouble();
|
||||||
|
}
|
||||||
|
|
||||||
public boolean isSkipped() {
|
public boolean isSkipped() {
|
||||||
return data == null;
|
return data == null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,41 +1,28 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2020 Konduit K.K.
|
* Copyright (c) 2020 Konduit K. K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* License for the specific language governing permissions and limitations
|
* License for the specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.deeplearning4j.rl4j.observation.transform.legacy;
|
|
||||||
|
package org.deeplearning4j.rl4j.observation.transform;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
|
||||||
import org.bytedeco.opencv.opencv_core.Mat;
|
import org.datavec.api.transform.Operation;
|
||||||
import org.datavec.api.transform.Operation;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
public class EncodableToINDArrayTransform implements Operation<Encodable, INDArray> {
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
@Override
|
||||||
|
public INDArray transform(Encodable encodable) {
|
||||||
import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
|
return encodable.getData();
|
||||||
|
}
|
||||||
public class EncodableToINDArrayTransform implements Operation<Encodable, INDArray> {
|
}
|
||||||
|
|
||||||
private final int[] shape;
|
|
||||||
|
|
||||||
public EncodableToINDArrayTransform(int[] shape) {
|
|
||||||
this.shape = shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray transform(Encodable encodable) {
|
|
||||||
return Nd4j.create(encodable.toArray()).reshape(shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -15,34 +15,32 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.deeplearning4j.rl4j.observation.transform.legacy;
|
package org.deeplearning4j.rl4j.observation.transform.legacy;
|
||||||
|
|
||||||
|
import org.bytedeco.javacv.Frame;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
import org.bytedeco.opencv.opencv_core.Mat;
|
import org.bytedeco.opencv.opencv_core.Mat;
|
||||||
import org.datavec.api.transform.Operation;
|
import org.datavec.api.transform.Operation;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
|
import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
|
||||||
|
import static org.bytedeco.opencv.global.opencv_core.CV_32FC3;
|
||||||
|
import static org.bytedeco.opencv.global.opencv_core.CV_32S;
|
||||||
|
import static org.bytedeco.opencv.global.opencv_core.CV_32SC;
|
||||||
|
import static org.bytedeco.opencv.global.opencv_core.CV_32SC3;
|
||||||
|
import static org.bytedeco.opencv.global.opencv_core.CV_64FC;
|
||||||
|
import static org.bytedeco.opencv.global.opencv_core.CV_8UC3;
|
||||||
|
|
||||||
public class EncodableToImageWritableTransform implements Operation<Encodable, ImageWritable> {
|
public class EncodableToImageWritableTransform implements Operation<Encodable, ImageWritable> {
|
||||||
|
|
||||||
private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
|
final static NativeImageLoader nativeImageLoader = new NativeImageLoader();
|
||||||
private final int height;
|
|
||||||
private final int width;
|
|
||||||
private final int colorChannels;
|
|
||||||
|
|
||||||
public EncodableToImageWritableTransform(int height, int width, int colorChannels) {
|
|
||||||
this.height = height;
|
|
||||||
this.width = width;
|
|
||||||
this.colorChannels = colorChannels;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ImageWritable transform(Encodable encodable) {
|
public ImageWritable transform(Encodable encodable) {
|
||||||
INDArray indArray = Nd4j.create(encodable.toArray()).reshape(height, width, colorChannels);
|
return new ImageWritable(nativeImageLoader.asFrame(encodable.getData(), Frame.DEPTH_UBYTE));
|
||||||
Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer());
|
|
||||||
return new ImageWritable(converter.convert(mat));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,34 +18,31 @@ package org.deeplearning4j.rl4j.observation.transform.legacy;
|
||||||
import org.datavec.api.transform.Operation;
|
import org.datavec.api.transform.Operation;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.datavec.image.loader.NativeImageLoader;
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
public class ImageWritableToINDArrayTransform implements Operation<ImageWritable, INDArray> {
|
public class ImageWritableToINDArrayTransform implements Operation<ImageWritable, INDArray> {
|
||||||
|
|
||||||
private final int height;
|
private final NativeImageLoader loader = new NativeImageLoader();
|
||||||
private final int width;
|
|
||||||
private final NativeImageLoader loader;
|
|
||||||
|
|
||||||
public ImageWritableToINDArrayTransform(int height, int width) {
|
|
||||||
this.height = height;
|
|
||||||
this.width = width;
|
|
||||||
this.loader = new NativeImageLoader(height, width);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray transform(ImageWritable imageWritable) {
|
public INDArray transform(ImageWritable imageWritable) {
|
||||||
|
|
||||||
|
int height = imageWritable.getHeight();
|
||||||
|
int width = imageWritable.getWidth();
|
||||||
|
int channels = imageWritable.getFrame().imageChannels;
|
||||||
|
|
||||||
INDArray out = null;
|
INDArray out = null;
|
||||||
try {
|
try {
|
||||||
out = loader.asMatrix(imageWritable);
|
out = loader.asMatrix(imageWritable);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
}
|
}
|
||||||
out = out.reshape(1, height, width);
|
|
||||||
|
// Convert back to uint8 and reshape to the number of channels in the image
|
||||||
|
out = out.reshape(channels, height, width);
|
||||||
INDArray compressed = out.castTo(DataType.UINT8);
|
INDArray compressed = out.castTo(DataType.UINT8);
|
||||||
return compressed;
|
return compressed;
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,19 +46,20 @@ public class HistoryMergeTransform implements Operation<INDArray, INDArray>, Res
|
||||||
private final HistoryMergeElementStore historyMergeElementStore;
|
private final HistoryMergeElementStore historyMergeElementStore;
|
||||||
private final HistoryMergeAssembler historyMergeAssembler;
|
private final HistoryMergeAssembler historyMergeAssembler;
|
||||||
private final boolean shouldStoreCopy;
|
private final boolean shouldStoreCopy;
|
||||||
private final boolean isFirstDimenstionBatch;
|
private final boolean isFirstDimensionBatch;
|
||||||
|
|
||||||
private HistoryMergeTransform(Builder builder) {
|
private HistoryMergeTransform(Builder builder) {
|
||||||
this.historyMergeElementStore = builder.historyMergeElementStore;
|
this.historyMergeElementStore = builder.historyMergeElementStore;
|
||||||
this.historyMergeAssembler = builder.historyMergeAssembler;
|
this.historyMergeAssembler = builder.historyMergeAssembler;
|
||||||
this.shouldStoreCopy = builder.shouldStoreCopy;
|
this.shouldStoreCopy = builder.shouldStoreCopy;
|
||||||
this.isFirstDimenstionBatch = builder.isFirstDimenstionBatch;
|
this.isFirstDimensionBatch = builder.isFirstDimenstionBatch;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray transform(INDArray input) {
|
public INDArray transform(INDArray input) {
|
||||||
|
|
||||||
INDArray element;
|
INDArray element;
|
||||||
if(isFirstDimenstionBatch) {
|
if(isFirstDimensionBatch) {
|
||||||
element = input.slice(0, 0);
|
element = input.slice(0, 0);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -132,9 +133,9 @@ public class HistoryMergeTransform implements Operation<INDArray, INDArray>, Res
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public HistoryMergeTransform build() {
|
public HistoryMergeTransform build(int frameStackLength) {
|
||||||
if(historyMergeElementStore == null) {
|
if(historyMergeElementStore == null) {
|
||||||
historyMergeElementStore = new CircularFifoStore();
|
historyMergeElementStore = new CircularFifoStore(frameStackLength);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(historyMergeAssembler == null) {
|
if(historyMergeAssembler == null) {
|
||||||
|
|
|
@ -28,14 +28,9 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
public class CircularFifoStore implements HistoryMergeElementStore {
|
public class CircularFifoStore implements HistoryMergeElementStore {
|
||||||
private static final int DEFAULT_STORE_SIZE = 4;
|
|
||||||
|
|
||||||
private final CircularFifoQueue<INDArray> queue;
|
private final CircularFifoQueue<INDArray> queue;
|
||||||
|
|
||||||
public CircularFifoStore() {
|
|
||||||
this(DEFAULT_STORE_SIZE);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CircularFifoStore(int size) {
|
public CircularFifoStore(int size) {
|
||||||
Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size);
|
Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size);
|
||||||
queue = new CircularFifoQueue<>(size);
|
queue = new CircularFifoQueue<>(size);
|
||||||
|
|
|
@ -20,8 +20,8 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
|
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
|
||||||
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
|
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -35,7 +35,7 @@ import java.io.IOException;
|
||||||
* the softmax output of the actor critic, but objects constructed
|
* the softmax output of the actor critic, but objects constructed
|
||||||
* with a {@link Random} argument of null return the max only.
|
* with a {@link Random} argument of null return the max only.
|
||||||
*/
|
*/
|
||||||
public class ACPolicy<O extends Encodable> extends Policy<Integer> {
|
public class ACPolicy<OBSERVATION extends Encodable> extends Policy<Integer> {
|
||||||
|
|
||||||
final private IActorCritic actorCritic;
|
final private IActorCritic actorCritic;
|
||||||
Random rnd;
|
Random rnd;
|
||||||
|
@ -48,18 +48,18 @@ public class ACPolicy<O extends Encodable> extends Policy<Integer> {
|
||||||
this.rnd = rnd;
|
this.rnd = rnd;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <O extends Encodable> ACPolicy<O> load(String path) throws IOException {
|
public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String path) throws IOException {
|
||||||
return new ACPolicy<O>(ActorCriticCompGraph.load(path));
|
return new ACPolicy<>(ActorCriticCompGraph.load(path));
|
||||||
}
|
}
|
||||||
public static <O extends Encodable> ACPolicy<O> load(String path, Random rnd) throws IOException {
|
public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String path, Random rnd) throws IOException {
|
||||||
return new ACPolicy<O>(ActorCriticCompGraph.load(path), rnd);
|
return new ACPolicy<>(ActorCriticCompGraph.load(path), rnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy) throws IOException {
|
public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String pathValue, String pathPolicy) throws IOException {
|
||||||
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy));
|
return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy));
|
||||||
}
|
}
|
||||||
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rnd) throws IOException {
|
public static <OBSERVATION extends Encodable> ACPolicy<OBSERVATION> load(String pathValue, String pathPolicy, Random rnd) throws IOException {
|
||||||
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd);
|
return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
public IActorCritic getNeuralNet() {
|
public IActorCritic getNeuralNet() {
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
package org.deeplearning4j.rl4j.policy;
|
package org.deeplearning4j.rl4j.policy;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp;
|
||||||
* Boltzmann exploration is a stochastic policy wrt to the
|
* Boltzmann exploration is a stochastic policy wrt to the
|
||||||
* exponential Q-values as evaluated by the dqn model.
|
* exponential Q-values as evaluated by the dqn model.
|
||||||
*/
|
*/
|
||||||
public class BoltzmannQ<O extends Encodable> extends Policy<Integer> {
|
public class BoltzmannQ<OBSERVATION extends Encodable> extends Policy<Integer> {
|
||||||
|
|
||||||
final private IDQN dqn;
|
final private IDQN dqn;
|
||||||
final private Random rnd;
|
final private Random rnd;
|
||||||
|
|
|
@ -20,8 +20,8 @@ import lombok.AllArgsConstructor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQN;
|
import org.deeplearning4j.rl4j.network.dqn.DQN;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -35,12 +35,12 @@ import java.io.IOException;
|
||||||
|
|
||||||
// FIXME: Should we rename this "GreedyPolicy"?
|
// FIXME: Should we rename this "GreedyPolicy"?
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class DQNPolicy<O> extends Policy<Integer> {
|
public class DQNPolicy<OBSERVATION> extends Policy<Integer> {
|
||||||
|
|
||||||
final private IDQN dqn;
|
final private IDQN dqn;
|
||||||
|
|
||||||
public static <O extends Encodable> DQNPolicy<O> load(String path) throws IOException {
|
public static <OBSERVATION extends Encodable> DQNPolicy<OBSERVATION> load(String path) throws IOException {
|
||||||
return new DQNPolicy<O>(DQN.load(path));
|
return new DQNPolicy<>(DQN.load(path));
|
||||||
}
|
}
|
||||||
|
|
||||||
public IDQN getNeuralNet() {
|
public IDQN getNeuralNet() {
|
||||||
|
|
|
@ -20,12 +20,11 @@ package org.deeplearning4j.rl4j.policy;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
|
@ -41,10 +40,10 @@ import org.nd4j.linalg.api.rng.Random;
|
||||||
*/
|
*/
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
|
public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
|
||||||
|
|
||||||
final private Policy<A> policy;
|
final private Policy<A> policy;
|
||||||
final private MDP<O, A, AS> mdp;
|
final private MDP<OBSERVATION, A, AS> mdp;
|
||||||
final private int updateStart;
|
final private int updateStart;
|
||||||
final private int epsilonNbStep;
|
final private int epsilonNbStep;
|
||||||
final private Random rnd;
|
final private Random rnd;
|
||||||
|
|
|
@ -22,9 +22,9 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -7,22 +7,22 @@ import org.datavec.image.transform.ColorConversionTransform;
|
||||||
import org.datavec.image.transform.CropImageTransform;
|
import org.datavec.image.transform.CropImageTransform;
|
||||||
import org.datavec.image.transform.MultiImageTransform;
|
import org.datavec.image.transform.MultiImageTransform;
|
||||||
import org.datavec.image.transform.ResizeImageTransform;
|
import org.datavec.image.transform.ResizeImageTransform;
|
||||||
|
import org.datavec.image.transform.ShowImageTransform;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
|
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
|
|
||||||
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
|
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform;
|
import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
|
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
|
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -46,6 +46,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
|
||||||
private int skipFrame = 1;
|
private int skipFrame = 1;
|
||||||
private int steps = 0;
|
private int steps = 0;
|
||||||
|
|
||||||
|
|
||||||
public LegacyMDPWrapper(MDP<OBSERVATION, A, AS> wrappedMDP, IHistoryProcessor historyProcessor) {
|
public LegacyMDPWrapper(MDP<OBSERVATION, A, AS> wrappedMDP, IHistoryProcessor historyProcessor) {
|
||||||
this.wrappedMDP = wrappedMDP;
|
this.wrappedMDP = wrappedMDP;
|
||||||
this.shape = wrappedMDP.getObservationSpace().getShape();
|
this.shape = wrappedMDP.getObservationSpace().getShape();
|
||||||
|
@ -66,28 +67,33 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
|
||||||
|
|
||||||
if(historyProcessor != null && shape.length == 3) {
|
if(historyProcessor != null && shape.length == 3) {
|
||||||
int skipFrame = historyProcessor.getConf().getSkipFrame();
|
int skipFrame = historyProcessor.getConf().getSkipFrame();
|
||||||
|
int frameStackLength = historyProcessor.getConf().getHistoryLength();
|
||||||
|
|
||||||
int finalHeight = historyProcessor.getConf().getCroppingHeight();
|
int height = shape[1];
|
||||||
int finalWidth = historyProcessor.getConf().getCroppingWidth();
|
int width = shape[2];
|
||||||
|
|
||||||
|
int cropBottom = height - historyProcessor.getConf().getCroppingHeight();
|
||||||
|
int cropRight = width - historyProcessor.getConf().getCroppingWidth();
|
||||||
|
|
||||||
transformProcess = TransformProcess.builder()
|
transformProcess = TransformProcess.builder()
|
||||||
.filter(new UniformSkippingFilter(skipFrame))
|
.filter(new UniformSkippingFilter(skipFrame))
|
||||||
.transform("data", new EncodableToImageWritableTransform(shape[0], shape[1], shape[2]))
|
.transform("data", new EncodableToImageWritableTransform())
|
||||||
.transform("data", new MultiImageTransform(
|
.transform("data", new MultiImageTransform(
|
||||||
|
new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), cropBottom, cropRight),
|
||||||
new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()),
|
new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()),
|
||||||
new ColorConversionTransform(COLOR_BGR2GRAY),
|
new ColorConversionTransform(COLOR_BGR2GRAY)
|
||||||
new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), finalHeight, finalWidth)
|
//new ShowImageTransform("crop + resize + greyscale")
|
||||||
))
|
))
|
||||||
.transform("data", new ImageWritableToINDArrayTransform(finalHeight, finalWidth))
|
.transform("data", new ImageWritableToINDArrayTransform())
|
||||||
.transform("data", new SimpleNormalizationTransform(0.0, 255.0))
|
.transform("data", new SimpleNormalizationTransform(0.0, 255.0))
|
||||||
.transform("data", HistoryMergeTransform.builder()
|
.transform("data", HistoryMergeTransform.builder()
|
||||||
.isFirstDimenstionBatch(true)
|
.isFirstDimenstionBatch(true)
|
||||||
.build())
|
.build(frameStackLength))
|
||||||
.build("data");
|
.build("data");
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
transformProcess = TransformProcess.builder()
|
transformProcess = TransformProcess.builder()
|
||||||
.transform("data", new EncodableToINDArrayTransform(shape))
|
.transform("data", new EncodableToINDArrayTransform())
|
||||||
.build("data");
|
.build("data");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -127,6 +133,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
|
||||||
|
|
||||||
Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation());
|
Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation());
|
||||||
Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone());
|
Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone());
|
||||||
|
|
||||||
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
|
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,12 +168,7 @@ public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends Actio
|
||||||
}
|
}
|
||||||
|
|
||||||
private INDArray getInput(OBSERVATION obs) {
|
private INDArray getInput(OBSERVATION obs) {
|
||||||
INDArray arr = Nd4j.create(obs.toArray());
|
return obs.getData();
|
||||||
int[] shape = observationSpace.getShape();
|
|
||||||
if (shape.length == 1)
|
|
||||||
return arr.reshape(new long[] {1, arr.length()});
|
|
||||||
else
|
|
||||||
return arr.reshape(shape);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class WrapperObservationSpace implements ObservationSpace<Observation> {
|
public static class WrapperObservationSpace implements ObservationSpace<Observation> {
|
||||||
|
|
|
@ -16,26 +16,21 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.util;
|
package org.deeplearning4j.rl4j.util;
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.bytedeco.javacpp.BytePointer;
|
|
||||||
import org.bytedeco.javacpp.Pointer;
|
|
||||||
import org.bytedeco.javacv.FFmpegFrameRecorder;
|
import org.bytedeco.javacv.FFmpegFrameRecorder;
|
||||||
import org.bytedeco.javacv.Frame;
|
import org.bytedeco.javacv.Frame;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
import org.bytedeco.opencv.global.opencv_core;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.bytedeco.opencv.global.opencv_imgproc;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.bytedeco.opencv.opencv_core.Mat;
|
|
||||||
import org.bytedeco.opencv.opencv_core.Rect;
|
|
||||||
import org.bytedeco.opencv.opencv_core.Size;
|
|
||||||
import org.opencv.imgproc.Imgproc;
|
|
||||||
|
|
||||||
import static org.bytedeco.ffmpeg.global.avcodec.*;
|
import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_H264;
|
||||||
import static org.bytedeco.opencv.global.opencv_core.*;
|
import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_MPEG4;
|
||||||
|
import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB0;
|
||||||
|
import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB24;
|
||||||
|
import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB8;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* VideoRecorder is used to create a video from a sequence of individual frames. If using 3 channels
|
* VideoRecorder is used to create a video from a sequence of INDArray frames. INDArrays are assumed to be in CHW format where C=3 and pixels are RGB encoded<br>
|
||||||
* images, it expects B-G-R order. A RGB order can be used by calling isRGBOrder(true).<br>
|
|
||||||
* Example:<br>
|
* Example:<br>
|
||||||
* <pre>
|
* <pre>
|
||||||
* {@code
|
* {@code
|
||||||
|
@ -45,11 +40,8 @@ import static org.bytedeco.opencv.global.opencv_core.*;
|
||||||
* .build();
|
* .build();
|
||||||
* recorder.startRecording("myVideo.mp4");
|
* recorder.startRecording("myVideo.mp4");
|
||||||
* while(...) {
|
* while(...) {
|
||||||
* byte[] data = new byte[160*100*3];
|
* INDArray chwData = Nd4j.create()
|
||||||
* // Todo: Fill data
|
* recorder.record(chwData);
|
||||||
* VideoRecorder.VideoFrame frame = recorder.createFrame(data);
|
|
||||||
* // Todo: Apply cropping or resizing to frame
|
|
||||||
* recorder.record(frame);
|
|
||||||
* }
|
* }
|
||||||
* recorder.stopRecording();
|
* recorder.stopRecording();
|
||||||
* }
|
* }
|
||||||
|
@ -60,16 +52,13 @@ import static org.bytedeco.opencv.global.opencv_core.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class VideoRecorder implements AutoCloseable {
|
public class VideoRecorder implements AutoCloseable {
|
||||||
|
|
||||||
public enum FrameInputTypes { BGR, RGB, Float }
|
private final NativeImageLoader nativeImageLoader = new NativeImageLoader();
|
||||||
|
|
||||||
private final int height;
|
private final int height;
|
||||||
private final int width;
|
private final int width;
|
||||||
private final int imageType;
|
|
||||||
private final OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat();
|
|
||||||
private final int codec;
|
private final int codec;
|
||||||
private final double framerate;
|
private final double framerate;
|
||||||
private final int videoQuality;
|
private final int videoQuality;
|
||||||
private final FrameInputTypes frameInputType;
|
|
||||||
|
|
||||||
private FFmpegFrameRecorder fmpegFrameRecorder = null;
|
private FFmpegFrameRecorder fmpegFrameRecorder = null;
|
||||||
|
|
||||||
|
@ -83,11 +72,9 @@ public class VideoRecorder implements AutoCloseable {
|
||||||
private VideoRecorder(Builder builder) {
|
private VideoRecorder(Builder builder) {
|
||||||
this.height = builder.height;
|
this.height = builder.height;
|
||||||
this.width = builder.width;
|
this.width = builder.width;
|
||||||
imageType = CV_8UC(builder.numChannels);
|
|
||||||
codec = builder.codec;
|
codec = builder.codec;
|
||||||
framerate = builder.frameRate;
|
framerate = builder.frameRate;
|
||||||
videoQuality = builder.videoQuality;
|
videoQuality = builder.videoQuality;
|
||||||
frameInputType = builder.frameInputType;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -119,59 +106,11 @@ public class VideoRecorder implements AutoCloseable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a frame to the video
|
* Add a frame to the video
|
||||||
* @param frame the VideoFrame to add to the video
|
* @param imageArray the INDArray that contains the data to be recorded, the data must be in CHW format
|
||||||
* @throws Exception
|
* @throws Exception
|
||||||
*/
|
*/
|
||||||
public void record(VideoFrame frame) throws Exception {
|
public void record(INDArray imageArray) throws Exception {
|
||||||
Size size = frame.getMat().size();
|
fmpegFrameRecorder.record(nativeImageLoader.asFrame(imageArray, Frame.DEPTH_UBYTE));
|
||||||
if(size.height() != height || size.width() != width) {
|
|
||||||
throw new IllegalArgumentException(String.format("Wrong frame size. Got (%dh x %dw) expected (%dh x %dw)", size.height(), size.width(), height, width));
|
|
||||||
}
|
|
||||||
Frame cvFrame = openCVFrameConverter.convert(frame.getMat());
|
|
||||||
fmpegFrameRecorder.record(cvFrame);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a VideoFrame from a byte array.
|
|
||||||
* @param data A byte array. Expect the index to be of the form [(Y*Width + X) * NumChannels + channel]
|
|
||||||
* @return An instance of VideoFrame
|
|
||||||
*/
|
|
||||||
public VideoFrame createFrame(byte[] data) {
|
|
||||||
return createFrame(new BytePointer(data));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a VideoFrame from a byte array with different height and width than the video
|
|
||||||
* the frame will need to be cropped or resized before being added to the video)
|
|
||||||
*
|
|
||||||
* @param data A byte array Expect the index to be of the form [(Y*customWidth + X) * NumChannels + channel]
|
|
||||||
* @param customHeight The actual height of the data
|
|
||||||
* @param customWidth The actual width of the data
|
|
||||||
* @return A VideoFrame instance
|
|
||||||
*/
|
|
||||||
public VideoFrame createFrame(byte[] data, int customHeight, int customWidth) {
|
|
||||||
return createFrame(new BytePointer(data), customHeight, customWidth);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a VideoFrame from a Pointer (to use for example with a INDarray).
|
|
||||||
* @param data A Pointer (for example myINDArray.data().pointer())
|
|
||||||
* @return An instance of VideoFrame
|
|
||||||
*/
|
|
||||||
public VideoFrame createFrame(Pointer data) {
|
|
||||||
return new VideoFrame(height, width, imageType, frameInputType, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a VideoFrame from a Pointer with different height and width than the video
|
|
||||||
* the frame will need to be cropped or resized before being added to the video)
|
|
||||||
* @param data
|
|
||||||
* @param customHeight The actual height of the data
|
|
||||||
* @param customWidth The actual width of the data
|
|
||||||
* @return A VideoFrame instance
|
|
||||||
*/
|
|
||||||
public VideoFrame createFrame(Pointer data, int customHeight, int customWidth) {
|
|
||||||
return new VideoFrame(customHeight, customWidth, imageType, frameInputType, data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -192,69 +131,12 @@ public class VideoRecorder implements AutoCloseable {
|
||||||
return new Builder(height, width);
|
return new Builder(height, width);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* An individual frame for the video
|
|
||||||
*/
|
|
||||||
public static class VideoFrame {
|
|
||||||
|
|
||||||
private final int height;
|
|
||||||
private final int width;
|
|
||||||
private final int imageType;
|
|
||||||
@Getter
|
|
||||||
private Mat mat;
|
|
||||||
|
|
||||||
private VideoFrame(int height, int width, int imageType, FrameInputTypes frameInputType, Pointer data) {
|
|
||||||
this.height = height;
|
|
||||||
this.width = width;
|
|
||||||
this.imageType = imageType;
|
|
||||||
|
|
||||||
switch(frameInputType) {
|
|
||||||
case RGB:
|
|
||||||
Mat src = new Mat(height, width, imageType, data);
|
|
||||||
mat = new Mat(height, width, imageType);
|
|
||||||
opencv_imgproc.cvtColor(src, mat, Imgproc.COLOR_RGB2BGR);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BGR:
|
|
||||||
mat = new Mat(height, width, imageType, data);
|
|
||||||
break;
|
|
||||||
|
|
||||||
case Float:
|
|
||||||
Mat tmpMat = new Mat(height, width, CV_32FC(3), data);
|
|
||||||
mat = new Mat(height, width, imageType);
|
|
||||||
tmpMat.convertTo(mat, CV_8UC(3), 255.0, 0.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Crop the video to a specified size
|
|
||||||
* @param newHeight The new height of the frame
|
|
||||||
* @param newWidth The new width of the frame
|
|
||||||
* @param heightOffset The starting height offset in the uncropped frame
|
|
||||||
* @param widthOffset The starting weight offset in the uncropped frame
|
|
||||||
*/
|
|
||||||
public void crop(int newHeight, int newWidth, int heightOffset, int widthOffset) {
|
|
||||||
mat = mat.apply(new Rect(widthOffset, heightOffset, newWidth, newHeight));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Resize the frame to a specified size
|
|
||||||
* @param newHeight The new height of the frame
|
|
||||||
* @param newWidth The new width of the frame
|
|
||||||
*/
|
|
||||||
public void resize(int newHeight, int newWidth) {
|
|
||||||
mat = new Mat(newHeight, newWidth, imageType);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A builder class for the VideoRecorder
|
* A builder class for the VideoRecorder
|
||||||
*/
|
*/
|
||||||
public static class Builder {
|
public static class Builder {
|
||||||
private final int height;
|
private final int height;
|
||||||
private final int width;
|
private final int width;
|
||||||
private int numChannels = 3;
|
|
||||||
private FrameInputTypes frameInputType = FrameInputTypes.BGR;
|
|
||||||
private int codec = AV_CODEC_ID_H264;
|
private int codec = AV_CODEC_ID_H264;
|
||||||
private double frameRate = 30.0;
|
private double frameRate = 30.0;
|
||||||
private int videoQuality = 30;
|
private int videoQuality = 30;
|
||||||
|
@ -268,24 +150,6 @@ public class VideoRecorder implements AutoCloseable {
|
||||||
this.width = width;
|
this.width = width;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Specify the number of channels. Default is 3
|
|
||||||
* @param numChannels
|
|
||||||
*/
|
|
||||||
public Builder numChannels(int numChannels) {
|
|
||||||
this.numChannels = numChannels;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Tell the VideoRecorder what data it will receive (default is BGR)
|
|
||||||
* @param frameInputType (See {@link FrameInputTypes}}
|
|
||||||
*/
|
|
||||||
public Builder frameInputType(FrameInputTypes frameInputType) {
|
|
||||||
this.frameInputType = frameInputType;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The codec to use for the video. Default is AV_CODEC_ID_H264
|
* The codec to use for the video. Default is AV_CODEC_ID_H264
|
||||||
* @param codec Code (see {@link org.bytedeco.ffmpeg.global.avcodec codec codes})
|
* @param codec Code (see {@link org.bytedeco.ffmpeg.global.avcodec codec codes})
|
||||||
|
|
|
@ -115,7 +115,7 @@ public class AsyncThreadDiscreteTest {
|
||||||
|
|
||||||
asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm);
|
asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm);
|
||||||
|
|
||||||
when(asyncThreadDiscrete.getConf()).thenReturn(mockAsyncConfiguration);
|
when(asyncThreadDiscrete.getConfiguration()).thenReturn(mockAsyncConfiguration);
|
||||||
when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0);
|
when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0);
|
||||||
when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
|
when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
|
||||||
when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy);
|
when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy);
|
||||||
|
|
|
@ -39,7 +39,6 @@ import static org.junit.Assert.assertEquals;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.ArgumentMatchers.anyInt;
|
import static org.mockito.ArgumentMatchers.anyInt;
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.mockito.Mockito.clearInvocations;
|
|
||||||
import static org.mockito.Mockito.doAnswer;
|
import static org.mockito.Mockito.doAnswer;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
|
@ -130,7 +129,7 @@ public class AsyncThreadTest {
|
||||||
|
|
||||||
when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode);
|
when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode);
|
||||||
when(mockAsyncConfiguration.getNStep()).thenReturn(nstep);
|
when(mockAsyncConfiguration.getNStep()).thenReturn(nstep);
|
||||||
when(thread.getConf()).thenReturn(mockAsyncConfiguration);
|
when(thread.getConfiguration()).thenReturn(mockAsyncConfiguration);
|
||||||
|
|
||||||
// if we hit the max step count
|
// if we hit the max step count
|
||||||
when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps);
|
when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps);
|
||||||
|
|
|
@ -18,24 +18,16 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||||
|
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.Box;
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
|
@ -43,17 +35,17 @@ import org.mockito.Mock;
|
||||||
import org.mockito.Mockito;
|
import org.mockito.Mockito;
|
||||||
import org.mockito.junit.MockitoJUnitRunner;
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import static org.junit.Assert.assertEquals;
|
||||||
import java.util.List;
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
import static org.junit.Assert.*;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.ArgumentMatchers.anyInt;
|
import static org.mockito.ArgumentMatchers.anyInt;
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
@RunWith(MockitoJUnitRunner.class)
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
|
@ -82,6 +74,7 @@ public class QLearningDiscreteTest {
|
||||||
@Mock
|
@Mock
|
||||||
QLearningConfiguration mockQlearningConfiguration;
|
QLearningConfiguration mockQlearningConfiguration;
|
||||||
|
|
||||||
|
// HWC
|
||||||
int[] observationShape = new int[]{3, 10, 10};
|
int[] observationShape = new int[]{3, 10, 10};
|
||||||
int totalObservationSize = 1;
|
int totalObservationSize = 1;
|
||||||
|
|
||||||
|
@ -123,6 +116,7 @@ public class QLearningDiscreteTest {
|
||||||
when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]);
|
when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]);
|
||||||
when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]);
|
when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]);
|
||||||
when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames);
|
when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames);
|
||||||
|
when(mockHistoryConfiguration.getHistoryLength()).thenReturn(1);
|
||||||
when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration);
|
when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration);
|
||||||
|
|
||||||
qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor);
|
qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor);
|
||||||
|
@ -148,7 +142,7 @@ public class QLearningDiscreteTest {
|
||||||
Observation observation = new Observation(Nd4j.zeros(observationShape));
|
Observation observation = new Observation(Nd4j.zeros(observationShape));
|
||||||
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
|
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
|
||||||
|
|
||||||
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
|
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Observation(Nd4j.zeros(observationShape)), 0, false, null));
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
|
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
|
||||||
|
@ -170,25 +164,26 @@ public class QLearningDiscreteTest {
|
||||||
// Arrange
|
// Arrange
|
||||||
mockTestContext(100,0,2,1.0, 10);
|
mockTestContext(100,0,2,1.0, 10);
|
||||||
|
|
||||||
mockHistoryProcessor(2);
|
Observation skippedObservation = Observation.SkippedObservation;
|
||||||
|
Observation nextObservation = new Observation(Nd4j.zeros(observationShape));
|
||||||
|
|
||||||
// An example observation and 2 Q values output (2 actions)
|
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(nextObservation, 0, false, null));
|
||||||
Observation observation = new Observation(Nd4j.zeros(observationShape));
|
|
||||||
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
|
|
||||||
|
|
||||||
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
|
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(skippedObservation);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(1.0, stepReturn.getMaxQ(), 1e-5);
|
assertEquals(Double.NaN, stepReturn.getMaxQ(), 1e-5);
|
||||||
|
|
||||||
StepReply<Observation> stepReply = stepReturn.getStepReply();
|
StepReply<Observation> stepReply = stepReturn.getStepReply();
|
||||||
|
|
||||||
assertEquals(0, stepReply.getReward(), 1e-5);
|
assertEquals(0, stepReply.getReward(), 1e-5);
|
||||||
assertFalse(stepReply.isDone());
|
assertFalse(stepReply.isDone());
|
||||||
assertTrue(stepReply.getObservation().isSkipped());
|
assertFalse(stepReply.getObservation().isSkipped());
|
||||||
|
assertEquals(0, qLearningDiscrete.getExperienceHandler().getTrainingBatchSize());
|
||||||
|
|
||||||
|
verify(mockDQN, never()).output(any(INDArray.class));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: there are much more test cases here that can be improved upon
|
//TODO: there are much more test cases here that can be improved upon
|
||||||
|
|
|
@ -17,7 +17,7 @@ public class HistoryMergeTransformTest {
|
||||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||||
.isFirstDimenstionBatch(false)
|
.isFirstDimenstionBatch(false)
|
||||||
.elementStore(store)
|
.elementStore(store)
|
||||||
.build();
|
.build(4);
|
||||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
|
@ -35,7 +35,7 @@ public class HistoryMergeTransformTest {
|
||||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||||
.isFirstDimenstionBatch(true)
|
.isFirstDimenstionBatch(true)
|
||||||
.elementStore(store)
|
.elementStore(store)
|
||||||
.build();
|
.build(4);
|
||||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3);
|
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
|
@ -53,7 +53,7 @@ public class HistoryMergeTransformTest {
|
||||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||||
.isFirstDimenstionBatch(true)
|
.isFirstDimenstionBatch(true)
|
||||||
.elementStore(store)
|
.elementStore(store)
|
||||||
.build();
|
.build(4);
|
||||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
|
@ -70,7 +70,7 @@ public class HistoryMergeTransformTest {
|
||||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||||
.shouldStoreCopy(false)
|
.shouldStoreCopy(false)
|
||||||
.elementStore(store)
|
.elementStore(store)
|
||||||
.build();
|
.build(4);
|
||||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
|
@ -87,7 +87,7 @@ public class HistoryMergeTransformTest {
|
||||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||||
.shouldStoreCopy(true)
|
.shouldStoreCopy(true)
|
||||||
.elementStore(store)
|
.elementStore(store)
|
||||||
.build();
|
.build(4);
|
||||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
|
@ -107,7 +107,7 @@ public class HistoryMergeTransformTest {
|
||||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||||
.elementStore(store)
|
.elementStore(store)
|
||||||
.assembler(assemble)
|
.assembler(assemble)
|
||||||
.build();
|
.build(4);
|
||||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
|
|
|
@ -252,8 +252,8 @@ public class PolicyTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected <O extends Encodable, AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
|
protected <MockObservation extends Encodable, AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockObservation, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
|
||||||
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
|
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(skipFrame, historyLength));
|
||||||
return super.refacInitMdp(mdpWrapper, hp);
|
return super.refacInitMdp(mdpWrapper, hp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,18 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.support;
|
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
|
|
||||||
public class MockEncodable implements Encodable {
|
|
||||||
|
|
||||||
private final int value;
|
|
||||||
|
|
||||||
public MockEncodable(int value) {
|
|
||||||
|
|
||||||
this.value = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] toArray() {
|
|
||||||
return new double[] { value };
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -2,9 +2,10 @@ package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
|
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
|
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
|
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
|
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore;
|
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore;
|
||||||
|
@ -15,7 +16,7 @@ import org.nd4j.linalg.api.rng.Random;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
public class MockMDP implements MDP<MockObservation, Integer, DiscreteSpace> {
|
||||||
|
|
||||||
private final DiscreteSpace actionSpace;
|
private final DiscreteSpace actionSpace;
|
||||||
private final int stepsUntilDone;
|
private final int stepsUntilDone;
|
||||||
|
@ -55,11 +56,11 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MockEncodable reset() {
|
public MockObservation reset() {
|
||||||
++resetCount;
|
++resetCount;
|
||||||
currentObsValue = 0;
|
currentObsValue = 0;
|
||||||
step = 0;
|
step = 0;
|
||||||
return new MockEncodable(currentObsValue++);
|
return new MockObservation(currentObsValue++);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -68,10 +69,10 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public StepReply<MockEncodable> step(Integer action) {
|
public StepReply<MockObservation> step(Integer action) {
|
||||||
actions.add(action);
|
actions.add(action);
|
||||||
++step;
|
++step;
|
||||||
return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null);
|
return new StepReply<>(new MockObservation(currentObsValue), (double) currentObsValue++, isDone(), null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -84,14 +85,14 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static TransformProcess buildTransformProcess(int[] shape, int skipFrame, int historyLength) {
|
public static TransformProcess buildTransformProcess(int skipFrame, int historyLength) {
|
||||||
return TransformProcess.builder()
|
return TransformProcess.builder()
|
||||||
.filter(new UniformSkippingFilter(skipFrame))
|
.filter(new UniformSkippingFilter(skipFrame))
|
||||||
.transform("data", new EncodableToINDArrayTransform(shape))
|
.transform("data", new EncodableToINDArrayTransform())
|
||||||
.transform("data", new SimpleNormalizationTransform(0.0, 255.0))
|
.transform("data", new SimpleNormalizationTransform(0.0, 255.0))
|
||||||
.transform("data", HistoryMergeTransform.builder()
|
.transform("data", HistoryMergeTransform.builder()
|
||||||
.elementStore(new CircularFifoStore(historyLength))
|
.elementStore(new CircularFifoStore(historyLength))
|
||||||
.build())
|
.build(4))
|
||||||
.build("data");
|
.build("data");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K. K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
public class MockObservation implements Encodable {
|
||||||
|
|
||||||
|
final INDArray data;
|
||||||
|
|
||||||
|
public MockObservation(int value) {
|
||||||
|
this.data = Nd4j.ones(1).mul(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double[] toArray() {
|
||||||
|
return data.data().asDouble();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSkipped() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getData() {
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Encodable dup() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,7 +17,7 @@ public class MockPolicy implements IPolicy<Integer> {
|
||||||
public List<INDArray> actionInputs = new ArrayList<INDArray>();
|
public List<INDArray> actionInputs = new ArrayList<INDArray>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <O extends Encodable, AS extends ActionSpace<Integer>> double play(MDP<O, Integer, AS> mdp, IHistoryProcessor hp) {
|
public <MockObservation extends Encodable, AS extends ActionSpace<Integer>> double play(MDP<MockObservation, Integer, AS> mdp, IHistoryProcessor hp) {
|
||||||
++playCallCount;
|
++playCallCount;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,9 @@ import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import vizdoom.*;
|
import vizdoom.*;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -155,7 +158,7 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre
|
||||||
+ Pointer.formatBytes(Pointer.totalPhysicalBytes()));
|
+ Pointer.formatBytes(Pointer.totalPhysicalBytes()));
|
||||||
|
|
||||||
game.newEpisode();
|
game.newEpisode();
|
||||||
return new GameScreen(game.getState().screenBuffer);
|
return new GameScreen(observationSpace.getShape(), game.getState().screenBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -168,7 +171,7 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre
|
||||||
|
|
||||||
double r = game.makeAction(actions.get(action)) * scaleFactor;
|
double r = game.makeAction(actions.get(action)) * scaleFactor;
|
||||||
log.info(game.getEpisodeTime() + " " + r + " " + action + " ");
|
log.info(game.getEpisodeTime() + " " + r + " " + action + " ");
|
||||||
return new StepReply(new GameScreen(game.isEpisodeFinished()
|
return new StepReply(new GameScreen(observationSpace.getShape(), game.isEpisodeFinished()
|
||||||
? new byte[game.getScreenSize()]
|
? new byte[game.getScreenSize()]
|
||||||
: game.getState().screenBuffer), r, game.isEpisodeFinished(), null);
|
: game.getState().screenBuffer), r, game.isEpisodeFinished(), null);
|
||||||
|
|
||||||
|
@ -201,18 +204,34 @@ abstract public class VizDoom implements MDP<VizDoom.GameScreen, Integer, Discre
|
||||||
|
|
||||||
public static class GameScreen implements Encodable {
|
public static class GameScreen implements Encodable {
|
||||||
|
|
||||||
|
final INDArray data;
|
||||||
|
public GameScreen(int[] shape, byte[] screen) {
|
||||||
|
|
||||||
double[] array;
|
data = Nd4j.create(screen, new long[] {shape[1], shape[2], 3}, DataType.UINT8).permute(2,0,1);
|
||||||
|
|
||||||
public GameScreen(byte[] screen) {
|
|
||||||
array = new double[screen.length];
|
|
||||||
for (int i = 0; i < screen.length; i++) {
|
|
||||||
array[i] = (screen[i] & 0xFF) / 255.0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private GameScreen(INDArray toDup) {
|
||||||
|
data = toDup.dup();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public double[] toArray() {
|
public double[] toArray() {
|
||||||
return array;
|
return data.data().asDouble();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSkipped() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getData() {
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public GameScreen dup() {
|
||||||
|
return new GameScreen(data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,6 @@ package org.deeplearning4j.rl4j.mdp.gym;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.Value;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.bytedeco.javacpp.DoublePointer;
|
import org.bytedeco.javacpp.DoublePointer;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
@ -31,8 +29,8 @@ import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Box;
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.space.HighLowDiscrete;
|
import org.deeplearning4j.rl4j.space.HighLowDiscrete;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
|
||||||
import org.bytedeco.cpython.*;
|
import org.bytedeco.cpython.*;
|
||||||
|
@ -47,7 +45,7 @@ import static org.bytedeco.numpy.global.numpy.*;
|
||||||
* @author saudet
|
* @author saudet
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
public class GymEnv<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> implements MDP<OBSERVATION, A, AS> {
|
||||||
|
|
||||||
public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn";
|
public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn";
|
||||||
|
|
||||||
|
@ -82,7 +80,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
||||||
private PyObject locals;
|
private PyObject locals;
|
||||||
|
|
||||||
final protected DiscreteSpace actionSpace;
|
final protected DiscreteSpace actionSpace;
|
||||||
final protected ObservationSpace<O> observationSpace;
|
final protected ObservationSpace<OBSERVATION> observationSpace;
|
||||||
@Getter
|
@Getter
|
||||||
final private String envId;
|
final private String envId;
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -119,7 +117,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
||||||
for (int i = 0; i < shape.length; i++) {
|
for (int i = 0; i < shape.length; i++) {
|
||||||
shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i));
|
shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i));
|
||||||
}
|
}
|
||||||
observationSpace = (ObservationSpace<O>) new ArrayObservationSpace<Box>(shape);
|
observationSpace = (ObservationSpace<OBSERVATION>) new ArrayObservationSpace<Box>(shape);
|
||||||
Py_DecRef(shapeTuple);
|
Py_DecRef(shapeTuple);
|
||||||
|
|
||||||
PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null);
|
PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null);
|
||||||
|
@ -140,7 +138,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ObservationSpace<O> getObservationSpace() {
|
public ObservationSpace<OBSERVATION> getObservationSpace() {
|
||||||
return observationSpace;
|
return observationSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,7 +151,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public StepReply<O> step(A action) {
|
public StepReply<OBSERVATION> step(A action) {
|
||||||
int gstate = PyGILState_Ensure();
|
int gstate = PyGILState_Ensure();
|
||||||
try {
|
try {
|
||||||
if (render) {
|
if (render) {
|
||||||
|
@ -186,7 +184,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public O reset() {
|
public OBSERVATION reset() {
|
||||||
int gstate = PyGILState_Ensure();
|
int gstate = PyGILState_Ensure();
|
||||||
try {
|
try {
|
||||||
Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null));
|
Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null));
|
||||||
|
@ -201,7 +199,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
||||||
|
|
||||||
double[] data = new double[(int)stateData.capacity()];
|
double[] data = new double[(int)stateData.capacity()];
|
||||||
stateData.get(data);
|
stateData.get(data);
|
||||||
return (O) new Box(data);
|
return (OBSERVATION) new Box(data);
|
||||||
} finally {
|
} finally {
|
||||||
PyGILState_Release(gstate);
|
PyGILState_Release(gstate);
|
||||||
}
|
}
|
||||||
|
@ -220,7 +218,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public GymEnv<O, A, AS> newInstance() {
|
public GymEnv<OBSERVATION, A, AS> newInstance() {
|
||||||
return new GymEnv<O, A, AS>(envId, render, monitor);
|
return new GymEnv<OBSERVATION, A, AS>(envId, render, monitor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,8 +40,8 @@ public class GymEnvTest {
|
||||||
assertEquals(false, mdp.isDone());
|
assertEquals(false, mdp.isDone());
|
||||||
Box o = (Box)mdp.reset();
|
Box o = (Box)mdp.reset();
|
||||||
StepReply r = mdp.step(0);
|
StepReply r = mdp.step(0);
|
||||||
assertEquals(4, o.toArray().length);
|
assertEquals(4, o.getData().shape()[0]);
|
||||||
assertEquals(4, ((Box)r.getObservation()).toArray().length);
|
assertEquals(4, ((Box)r.getObservation()).getData().shape()[0]);
|
||||||
assertNotEquals(null, mdp.newInstance());
|
assertNotEquals(null, mdp.newInstance());
|
||||||
mdp.close();
|
mdp.close();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2020 Konduit K. K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,33 +16,13 @@
|
||||||
|
|
||||||
package org.deeplearning4j.malmo;
|
package org.deeplearning4j.malmo;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
@Deprecated
|
||||||
|
public class MalmoBox extends Box {
|
||||||
|
|
||||||
/**
|
public MalmoBox(double... arr) {
|
||||||
* Encodable state as a simple value array similar to Gym Box model, but without a JSON constructor
|
super(arr);
|
||||||
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17.
|
|
||||||
*/
|
|
||||||
public class MalmoBox implements Encodable {
|
|
||||||
double[] value;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Construct state from an array of doubles
|
|
||||||
* @param value state values
|
|
||||||
*/
|
|
||||||
//TODO: If this constructor was added to "Box", we wouldn't need this class at all.
|
|
||||||
public MalmoBox(double... value) {
|
|
||||||
this.value = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] toArray() {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return Arrays.toString(value);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,8 @@ package org.deeplearning4j.malmo;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
import com.microsoft.msr.malmo.WorldState;
|
import com.microsoft.msr.malmo.WorldState;
|
||||||
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A Malmo consistency policy that ensures the both there is a reward and next observation has a different position that the previous one.
|
* A Malmo consistency policy that ensures the both there is a reward and next observation has a different position that the previous one.
|
||||||
|
@ -30,14 +32,14 @@ public class MalmoDescretePositionPolicy implements MalmoObservationPolicy {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isObservationConsistant(WorldState world_state, WorldState original_world_state) {
|
public boolean isObservationConsistant(WorldState world_state, WorldState original_world_state) {
|
||||||
MalmoBox last_observation = observationSpace.getObservation(world_state);
|
Box last_observation = observationSpace.getObservation(world_state);
|
||||||
MalmoBox old_observation = observationSpace.getObservation(original_world_state);
|
Box old_observation = observationSpace.getObservation(original_world_state);
|
||||||
|
|
||||||
double[] newvalues = old_observation == null ? null : old_observation.toArray();
|
INDArray newvalues = old_observation == null ? null : old_observation.getData();
|
||||||
double[] oldvalues = last_observation == null ? null : last_observation.toArray();
|
INDArray oldvalues = last_observation == null ? null : last_observation.getData();
|
||||||
|
|
||||||
return !(world_state.getObservations().isEmpty() || world_state.getRewards().isEmpty()
|
return !(world_state.getObservations().isEmpty() || world_state.getRewards().isEmpty()
|
||||||
|| Arrays.equals(oldvalues, newvalues));
|
|| oldvalues.eq(newvalues).all());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import java.nio.file.Paths;
|
||||||
|
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
|
|
||||||
import com.microsoft.msr.malmo.AgentHost;
|
import com.microsoft.msr.malmo.AgentHost;
|
||||||
|
@ -34,6 +35,7 @@ import com.microsoft.msr.malmo.WorldState;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
@ -233,7 +235,7 @@ public class MalmoEnv implements MDP<MalmoBox, Integer, DiscreteSpace> {
|
||||||
logger.info("Mission ended");
|
logger.info("Mission ended");
|
||||||
}
|
}
|
||||||
|
|
||||||
return new StepReply<MalmoBox>(last_observation, getRewards(last_world_state), isDone(), null);
|
return new StepReply<>(last_observation, getRewards(last_world_state), isDone(), null);
|
||||||
}
|
}
|
||||||
|
|
||||||
private double getRewards(WorldState world_state) {
|
private double getRewards(WorldState world_state) {
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.malmo;
|
package org.deeplearning4j.malmo;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
|
||||||
import com.microsoft.msr.malmo.WorldState;
|
import com.microsoft.msr.malmo.WorldState;
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.malmo;
|
||||||
|
|
||||||
import com.microsoft.msr.malmo.TimestampedStringVector;
|
import com.microsoft.msr.malmo.TimestampedStringVector;
|
||||||
import com.microsoft.msr.malmo.WorldState;
|
import com.microsoft.msr.malmo.WorldState;
|
||||||
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.json.JSONArray;
|
import org.json.JSONArray;
|
||||||
import org.json.JSONObject;
|
import org.json.JSONObject;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.malmo;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.malmo;
|
package org.deeplearning4j.malmo;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.json.JSONObject;
|
import org.json.JSONObject;
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
Loading…
Reference in New Issue