From 1a35ebec2ed7209e7e732246dd505c94d482a136 Mon Sep 17 00:00:00 2001 From: Chris Bamford Date: Mon, 6 Apr 2020 04:36:12 +0100 Subject: [PATCH] RL4J: Add Backwardly Compatible Builder patterns (#326) * Starting to switch configs of RL algorithms to use more fluent builder patterns. Many parameter choices in different algorithms default to SOTA and only be changed in specific cases Signed-off-by: Bam4d * remove personal gpu-build file Signed-off-by: Bam4d * refactored out configurations so they are heirarchical and re-usable, this is a step towards having a plug-and-play framework for different algorithms * backwardly compatible configurations * adding documentation to new configuration classes Signed-off-by: Bam4d * private access modifiers are better suited here Signed-off-by: Bam4d * RL4j does not compile without java 8 due to previous updates fixing null pointers when listener arrays are empty Signed-off-by: Bam4d * fixing copyright headers Signed-off-by: Bam4d * uncomment logging line Signed-off-by: Bam4d * fixing default value for learningUpdateFrequency fixing test failure due to #352 Signed-off-by: Bam4d Co-authored-by: Bam4d --- .../api/transform/split/RandomSplit.java | 1 + rl4j/rl4j-core/pom.xml | 12 ++ .../rl4j/learning/EpochStepCounter.java | 16 +++ .../rl4j/learning/ILearning.java | 16 +-- .../rl4j/learning/async/AsyncGlobal.java | 36 +++--- .../rl4j/learning/async/AsyncLearning.java | 30 +++-- .../rl4j/learning/async/AsyncThread.java | 8 +- .../learning/async/AsyncThreadDiscrete.java | 6 +- .../async/a3c/discrete/A3CDiscrete.java | 44 +++++-- .../async/a3c/discrete/A3CDiscreteConv.java | 44 +++++-- .../async/a3c/discrete/A3CDiscreteDense.java | 83 +++++++----- .../async/a3c/discrete/A3CThreadDiscrete.java | 13 +- .../discrete/AsyncNStepQLearningDiscrete.java | 56 +++++--- .../AsyncNStepQLearningDiscreteConv.java | 39 +++--- .../AsyncNStepQLearningDiscreteDense.java | 57 ++++++--- .../AsyncNStepQLearningThreadDiscrete.java | 28 ++-- .../A3CLearningConfiguration.java | 46 +++++++ .../AsyncQLearningConfiguration.java | 42 ++++++ .../IAsyncLearningConfiguration.java | 28 ++++ .../ILearningConfiguration.java} | 30 +---- .../configuration/LearningConfiguration.java | 59 +++++++++ .../configuration/QLearningConfiguration.java | 79 ++++++++++++ .../rl4j/learning/sync/SyncLearning.java | 5 +- .../learning/sync/qlearning/QLearning.java | 48 +++++-- .../qlearning/discrete/QLearningDiscrete.java | 27 ++-- .../discrete/QLearningDiscreteConv.java | 35 ++++- .../discrete/QLearningDiscreteDense.java | 28 +++- .../rl4j/network/ac/ActorCriticCompGraph.java | 11 +- .../ActorCriticFactoryCompGraphStdConv.java | 29 ++++- .../ActorCriticFactoryCompGraphStdDense.java | 38 ++---- .../ActorCriticFactorySeparateStdDense.java | 83 +++++++----- .../ActorCriticDenseNetworkConfiguration.java | 42 ++++++ .../ActorCriticNetworkConfiguration.java | 37 ++++++ .../DQNDenseNetworkConfiguration.java | 40 ++++++ .../configuration/NetworkConfiguration.java | 58 +++++++++ .../rl4j/network/dqn/DQNFactoryStdConv.java | 26 +++- .../rl4j/network/dqn/DQNFactoryStdDense.java | 63 ++++++--- .../deeplearning4j/rl4j/policy/EpsGreedy.java | 13 +- .../deeplearning4j/rl4j/util/DataManager.java | 28 +++- .../rl4j/learning/HistoryProcessorTest.java | 8 +- .../learning/async/AsyncLearningTest.java | 33 ++++- .../async/AsyncThreadDiscreteTest.java | 22 +++- .../rl4j/learning/async/AsyncThreadTest.java | 19 ++- .../a3c/discrete/A3CThreadDiscreteTest.java | 26 +++- ...AsyncNStepQLearningThreadDiscreteTest.java | 23 +++- .../rl4j/learning/sync/SyncLearningTest.java | 34 ++++- ...t.java => QLearningConfigurationTest.java} | 31 ++--- .../discrete/QLearningDiscreteTest.java | 120 +++++++++++------- .../rl4j/network/ac/ActorCriticTest.java | 49 +++---- .../rl4j/network/dqn/DQNTest.java | 20 +-- .../transform/TransformProcessTest.java | 2 +- .../rl4j/policy/PolicyTest.java | 37 ++++-- .../rl4j/support/MockAsyncConfiguration.java | 31 +++-- .../util/DataManagerTrainingListenerTest.java | 20 ++- .../malmo/MalmoObservationSpaceGrid.java | 24 ++-- 55 files changed, 1388 insertions(+), 495 deletions(-) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/{async/AsyncConfiguration.java => configuration/ILearningConfiguration.java} (61%) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java rename rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/{QLConfigurationTest.java => QLearningConfigurationTest.java} (52%) diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/split/RandomSplit.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/split/RandomSplit.java index 290e26873..fe4718f48 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/split/RandomSplit.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/split/RandomSplit.java @@ -16,6 +16,7 @@ package org.datavec.api.transform.split; + import lombok.AllArgsConstructor; import lombok.Data; diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml index c08615250..a93ea6345 100644 --- a/rl4j/rl4j-core/pom.xml +++ b/rl4j/rl4j-core/pom.xml @@ -20,6 +20,18 @@ xmlns="http://maven.apache.org/POM/4.0.0" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 + + + + org.apache.maven.plugins + maven-compiler-plugin + + 8 + 8 + + + + rl4j diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java index 746a71396..533209ed7 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java @@ -1,3 +1,19 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning; public interface EpochStepCounter { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java index d151f093b..43ed508b0 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,10 +17,10 @@ package org.deeplearning4j.rl4j.learning; +import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16. @@ -34,21 +35,12 @@ public interface ILearning> { int getStepCounter(); - LConfiguration getConfiguration(); + ILearningConfiguration getConfiguration(); MDP getMdp(); IHistoryProcessor getHistoryProcessor(); - interface LConfiguration { - Integer getSeed(); - - int getMaxEpochStep(); - - int getMaxStep(); - - double getGamma(); - } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java index 5501a29e1..01c519b57 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,6 +20,8 @@ package org.deeplearning4j.rl4j.learning.async; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.network.NeuralNet; import org.nd4j.linalg.primitives.Pair; @@ -27,28 +30,26 @@ import java.util.concurrent.atomic.AtomicInteger; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. - * + *

* In the original paper, the authors uses Asynchronous * Gradient Descent: Hogwild! It is a way to apply gradients * and modify a model in a lock-free manner. - * + *

* As a way to implement this with dl4j, it is unfortunately * necessary at the time of writing to apply the gradient * (update the parameters) on a single separate global thread. - * + *

* This Central thread for Asynchronous Method of reinforcement learning * enqueue the gradients coming from the different threads and update its * model and target. Those neurals nets are then synced by the other threads. - * + *

* The benefits of this thread is that the updater is "shared" between all thread * we have a single updater which is the single updater of the model contained here - * + *

* This is similar to RMSProp with shared g and momentum - * + *

* When Hogwild! is implemented, this could be replaced by a simple data * structure - * - * */ @Slf4j public class AsyncGlobal extends Thread implements IAsyncGlobal { @@ -56,7 +57,7 @@ public class AsyncGlobal extends Thread implements IAsyncG @Getter final private NN current; final private ConcurrentLinkedQueue> queue; - final private AsyncConfiguration a3cc; + final private IAsyncLearningConfiguration configuration; private final IAsyncLearning learning; @Getter private AtomicInteger T = new AtomicInteger(0); @@ -65,20 +66,20 @@ public class AsyncGlobal extends Thread implements IAsyncG @Getter private boolean running = true; - public AsyncGlobal(NN initial, AsyncConfiguration a3cc, IAsyncLearning learning) { + public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration, IAsyncLearning learning) { this.current = initial; target = (NN) initial.clone(); - this.a3cc = a3cc; + this.configuration = configuration; this.learning = learning; queue = new ConcurrentLinkedQueue<>(); } public boolean isTrainingComplete() { - return T.get() >= a3cc.getMaxStep(); + return T.get() >= configuration.getMaxStep(); } public void enqueue(Gradient[] gradient, Integer nstep) { - if(running && !isTrainingComplete()) { + if (running && !isTrainingComplete()) { queue.add(new Pair<>(gradient, nstep)); } } @@ -94,9 +95,8 @@ public class AsyncGlobal extends Thread implements IAsyncG synchronized (this) { current.applyGradient(gradient, pair.getSecond()); } - if (a3cc.getTargetDqnUpdateFreq() != -1 - && T.get() / a3cc.getTargetDqnUpdateFreq() > (T.get() - pair.getSecond()) - / a3cc.getTargetDqnUpdateFreq()) { + if (configuration.getLearnerUpdateFrequency() != -1 && T.get() / configuration.getLearnerUpdateFrequency() > (T.get() - pair.getSecond()) + / configuration.getLearnerUpdateFrequency()) { log.info("TARGET UPDATE at T = " + T.get()); synchronized (this) { target.copy(current); @@ -111,7 +111,7 @@ public class AsyncGlobal extends Thread implements IAsyncG * Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too. */ public void terminate() { - if(running) { + if (running) { running = false; queue.clear(); learning.terminate(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java index 994ec9cb0..1c3c83972 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -21,14 +22,17 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.listener.*; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; +import org.deeplearning4j.rl4j.learning.listener.TrainingListener; +import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.factory.Nd4j; /** - * The entry point for async training. This class will start a number ({@link AsyncConfiguration#getNumThread() + * The entry point for async training. This class will start a number ({@link AsyncQLearningConfiguration#getNumThreads() * configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals * (see setProgressEventInterval(int)) * @@ -37,8 +41,8 @@ import org.nd4j.linalg.factory.Nd4j; */ @Slf4j public abstract class AsyncLearning, NN extends NeuralNet> - extends Learning - implements IAsyncLearning { + extends Learning + implements IAsyncLearning { private Thread monitorThread = null; @@ -56,9 +60,10 @@ public abstract class AsyncLearning, NN extends Ne } private void handleTraining(RunContext context) { - int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep); + int maxSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpochStep); SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); context.obs = subEpochReturn.getLastObs(); @@ -197,7 +199,7 @@ public abstract class AsyncThread, NN extends Ne protected abstract IAsyncGlobal getAsyncGlobal(); - protected abstract AsyncConfiguration getConf(); + protected abstract IAsyncLearningConfiguration getConf(); protected abstract IPolicy getPolicy(NN net); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index 27d49c366..a72abfa62 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -1,5 +1,7 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -112,7 +114,7 @@ public abstract class AsyncThreadDiscrete rewards.add(new MiniTrans(obs.getData(), null, null, 0)); else { INDArray[] output = null; - if (getConf().getTargetDqnUpdateFreq() == -1) + if (getConf().getLearnerUpdateFrequency() == -1) output = current.outputAll(obs.getData()); else synchronized (getAsyncGlobal()) { output = getAsyncGlobal().getTarget().outputAll(obs.getData()); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java index 81308ba5a..0608ec5cc 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,11 +17,15 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete; -import lombok.*; -import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncLearning; import org.deeplearning4j.rl4j.learning.async.AsyncThread; +import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.policy.ACPolicy; @@ -32,15 +37,14 @@ import org.nd4j.linalg.factory.Nd4j; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. * Training for A3C in the Discrete Domain - * + *

* All methods are fully implemented as described in the * https://arxiv.org/abs/1602.01783 paper. - * */ public abstract class A3CDiscrete extends AsyncLearning { @Getter - final public A3CConfiguration configuration; + final public A3CLearningConfiguration configuration; @Getter final protected MDP mdp; final private IActorCritic iActorCritic; @@ -49,15 +53,15 @@ public abstract class A3CDiscrete extends AsyncLearning policy; - public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CConfiguration conf) { + public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) { this.iActorCritic = iActorCritic; this.mdp = mdp; this.configuration = conf; asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this); - Integer seed = conf.getSeed(); + Long seed = conf.getSeed(); Random rnd = Nd4j.getRandom(); - if(seed != null) { + if (seed != null) { rnd.setSeed(seed); } @@ -65,7 +69,7 @@ public abstract class A3CDiscrete extends AsyncLearning extends AsyncLearning extends AsyncLearning * Training for A3C in the Discrete Domain - * + *

* Specialized constructors for the Conv (pixels input) case * Specialized conf + provide additional type safety - * + *

* It uses CompGraph because there is benefit to combine the * first layers since they're essentially doing the same dimension * reduction task - * **/ public class A3CDiscreteConv extends A3CDiscrete { @@ -46,12 +48,22 @@ public class A3CDiscreteConv extends A3CDiscrete { @Deprecated public A3CDiscreteConv(MDP mdp, IActorCritic actorCritic, - HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { + HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, actorCritic, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } + + @Deprecated public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { + + super(mdp, IActorCritic, conf.toLearningConfiguration()); + this.hpconf = hpconf; + setHistoryProcessor(hpconf); + } + + public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, + HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { super(mdp, IActorCritic, conf); this.hpconf = hpconf; setHistoryProcessor(hpconf); @@ -59,21 +71,35 @@ public class A3CDiscreteConv extends A3CDiscrete { @Deprecated public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, - HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { + HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } + + @Deprecated public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, + HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { + this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); + } + @Deprecated public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, - HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { - this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager); + HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { + this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); } + + @Deprecated public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { + this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf); + } + + public A3CDiscreteConv(MDP mdp, ActorCriticNetworkConfiguration netConf, + HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java index 16b8151df..74332bf3a 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,8 +17,10 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete; +import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.*; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; @@ -25,67 +28,81 @@ import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16. - * + *

* Training for A3C in the Discrete Domain - * + *

* We use specifically the Separate version because * the model is too small to have enough benefit by sharing layers - * */ public class A3CDiscreteDense extends A3CDiscrete { @Deprecated public A3CDiscreteDense(MDP mdp, IActorCritic IActorCritic, A3CConfiguration conf, - IDataManager dataManager) { + IDataManager dataManager) { this(mdp, IActorCritic, conf); addListener(new DataManagerTrainingListener(dataManager)); } + + @Deprecated public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CConfiguration conf) { + super(mdp, actorCritic, conf.toLearningConfiguration()); + } + + public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) { super(mdp, actorCritic, conf); } @Deprecated public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, - A3CConfiguration conf, IDataManager dataManager) { + A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, - dataManager); + dataManager); } + + @Deprecated public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, A3CConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } - @Deprecated - public A3CDiscreteDense(MDP mdp, - ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, - IDataManager dataManager) { - this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager); - } - public A3CDiscreteDense(MDP mdp, - ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) { - this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf); - } - - @Deprecated - public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, - A3CConfiguration conf, IDataManager dataManager) { - this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, - dataManager); - } - public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, - A3CConfiguration conf) { + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, + A3CLearningConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } @Deprecated public A3CDiscreteDense(MDP mdp, - ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf, - IDataManager dataManager) { - this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager); - } - public A3CDiscreteDense(MDP mdp, - ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf) { - this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf); + ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, + IDataManager dataManager) { + this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } + @Deprecated + public A3CDiscreteDense(MDP mdp, + ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) { + this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf); + } + + public A3CDiscreteDense(MDP mdp, + ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) { + this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf); + } + + @Deprecated + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + A3CConfiguration conf, IDataManager dataManager) { + this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, + dataManager); + } + + @Deprecated + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + A3CConfiguration conf) { + this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } + + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + A3CLearningConfiguration conf) { + this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index 22b3894b2..c2a16d6b4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,10 +20,10 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete; import lombok.Getter; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.IActorCritic; @@ -31,9 +32,9 @@ import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.linalg.api.rng.Random; import java.util.Stack; @@ -45,7 +46,7 @@ import java.util.Stack; public class A3CThreadDiscrete extends AsyncThreadDiscrete { @Getter - final protected A3CDiscrete.A3CConfiguration conf; + final protected A3CLearningConfiguration conf; @Getter final protected IAsyncGlobal asyncGlobal; @Getter @@ -54,14 +55,14 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< final private Random rnd; public A3CThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, - A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners, + A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners, int threadNumber) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); this.conf = a3cc; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; - Integer seed = conf.getSeed(); + Long seed = conf.getSeed(); rnd = Nd4j.getRandom(); if(seed != null) { rnd.setSeed(seed + threadNumber); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index c18de9e10..9a8049f6f 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -1,49 +1,53 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. * - * SPDX-License-Identifier: Apache-2.0 + * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; -import lombok.*; -import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncLearning; import org.deeplearning4j.rl4j.learning.async.AsyncThread; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.nd4j.linalg.factory.Nd4j; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. */ public abstract class AsyncNStepQLearningDiscrete - extends AsyncLearning { + extends AsyncLearning { @Getter - final public AsyncNStepQLConfiguration configuration; + final public AsyncQLearningConfiguration configuration; @Getter final private MDP mdp; @Getter final private AsyncGlobal asyncGlobal; - public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { + public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncQLearningConfiguration conf) { this.mdp = mdp; this.configuration = conf; this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this); @@ -62,12 +66,11 @@ public abstract class AsyncNStepQLearningDiscrete return new DQNPolicy(getNeuralNet()); } - @Data @AllArgsConstructor @Builder @EqualsAndHashCode(callSuper = false) - public static class AsyncNStepQLConfiguration implements AsyncConfiguration { + public static class AsyncNStepQLConfiguration { Integer seed; int maxEpochStep; @@ -82,5 +85,22 @@ public abstract class AsyncNStepQLearningDiscrete float minEpsilon; int epsilonNbStep; + public AsyncQLearningConfiguration toLearningConfiguration() { + return AsyncQLearningConfiguration.builder() + .seed(new Long(seed)) + .maxEpochStep(maxEpochStep) + .maxStep(maxStep) + .numThreads(numThread) + .nStep(nstep) + .targetDqnUpdateFreq(targetDqnUpdateFreq) + .updateStart(updateStart) + .rewardFactor(rewardFactor) + .gamma(gamma) + .errorClamp(errorClamp) + .minEpsilon(minEpsilon) + .build(); + } + } + } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java index 83274b7f6..f92b704b6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java @@ -1,24 +1,27 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. * - * SPDX-License-Identifier: Apache-2.0 + * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import org.deeplearning4j.rl4j.learning.HistoryProcessor; import org.deeplearning4j.rl4j.learning.async.AsyncThread; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration; import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.IDQN; @@ -38,12 +41,12 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN @Deprecated public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, dqn, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { super(mdp, dqn, conf); this.hpconf = hpconf; setHistoryProcessor(hpconf); @@ -51,21 +54,21 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN @Deprecated public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } @Deprecated - public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { + public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); } - public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { + public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java index b58e15902..b6216e849 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java @@ -1,22 +1,26 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. * - * SPDX-License-Identifier: Apache-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration; import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; @@ -32,35 +36,56 @@ public class AsyncNStepQLearningDiscreteDense extends Async @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, - AsyncNStepQLConfiguration conf, IDataManager dataManager) { - super(mdp, dqn, conf); + AsyncNStepQLConfiguration conf, IDataManager dataManager) { + super(mdp, dqn, conf.toLearningConfiguration()); addListener(new DataManagerTrainingListener(dataManager)); } + @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { + super(mdp, dqn, conf.toLearningConfiguration()); + } + + public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, + AsyncQLearningConfiguration conf) { super(mdp, dqn, conf); } @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, - AsyncNStepQLConfiguration conf, IDataManager dataManager) { + AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, - dataManager); + dataManager); } + + @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, AsyncNStepQLConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, + AsyncQLearningConfiguration conf) { + this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } + @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, - DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { - this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); + DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { + this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } + + @Deprecated public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) { + this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); + } + + public AsyncNStepQLearningDiscreteDense(MDP mdp, + DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf), conf); } + } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index f8c470269..71199efaf 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -1,17 +1,18 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. * - * SPDX-License-Identifier: Apache-2.0 + * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; @@ -22,6 +23,7 @@ import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; @@ -42,7 +44,7 @@ import java.util.Stack; public class AsyncNStepQLearningThreadDiscrete extends AsyncThreadDiscrete { @Getter - final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf; + final protected AsyncQLearningConfiguration conf; @Getter final protected IAsyncGlobal asyncGlobal; @Getter @@ -51,7 +53,7 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn final private Random rnd; public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, - AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, + AsyncQLearningConfiguration conf, TrainingListenerList listeners, int threadNumber, int deviceNum) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); this.conf = conf; @@ -59,7 +61,7 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn this.threadNumber = threadNumber; rnd = Nd4j.getRandom(); - Integer seed = conf.getSeed(); + Long seed = conf.getSeed(); if(seed != null) { rnd.setSeed(seed + threadNumber); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java new file mode 100644 index 000000000..226fe4419 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java @@ -0,0 +1,46 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@EqualsAndHashCode(callSuper = true) +public class A3CLearningConfiguration extends LearningConfiguration implements IAsyncLearningConfiguration { + + /** + * The number of asynchronous threads to use to generate gradients + */ + private final int numThreads; + + /** + * The number of steps to calculate gradients over + */ + private final int nStep; + + /** + * The frequency of async training iterations to update the target network. + * + * If this is set to -1 then the target network is updated after every training iteration + */ + @Builder.Default + private int learnerUpdateFrequency = -1; +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java new file mode 100644 index 000000000..a60903e59 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.configuration; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.SuperBuilder; + + +@Data +@SuperBuilder +@EqualsAndHashCode(callSuper = true) +public class AsyncQLearningConfiguration extends QLearningConfiguration implements IAsyncLearningConfiguration { + + /** + * The number of asynchronous threads to use to generate experience data + */ + private final int numThreads; + + /** + * The number of steps in each training interations + */ + private final int nStep; + + public int getLearnerUpdateFrequency() { + return getTargetDqnUpdateFreq(); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java new file mode 100644 index 000000000..1e7cf3f2e --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.configuration; + +public interface IAsyncLearningConfiguration extends ILearningConfiguration { + + int getNumThreads(); + + int getNStep(); + + int getLearnerUpdateFrequency(); + + int getMaxStep(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java similarity index 61% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java index 0727db475..7ae215087 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,36 +14,16 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.rl4j.learning.async; +package org.deeplearning4j.rl4j.learning.configuration; -import org.deeplearning4j.rl4j.learning.ILearning; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/23/16. - * - * Interface configuration for all training method that inherit - * from AsyncLearning - */ -public interface AsyncConfiguration extends ILearning.LConfiguration { - - Integer getSeed(); +public interface ILearningConfiguration { + Long getSeed(); int getMaxEpochStep(); int getMaxStep(); - int getNumThread(); - - int getNstep(); - - int getTargetDqnUpdateFreq(); - - int getUpdateStart(); - - double getRewardFactor(); - double getGamma(); - double getErrorClamp(); - + double getRewardFactor(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java new file mode 100644 index 000000000..d1567e619 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@NoArgsConstructor +public class LearningConfiguration implements ILearningConfiguration { + + /** + * Seed value used for training + */ + @Builder.Default + private Long seed = System.currentTimeMillis(); + + /** + * The maximum number of steps in each episode + */ + @Builder.Default + private int maxEpochStep = 200; + + /** + * The maximum number of steps to train for + */ + @Builder.Default + private int maxStep = 150000; + + /** + * Gamma parameter used for discounted rewards + */ + @Builder.Default + private double gamma = 0.99; + + /** + * Scaling parameter for rewards + */ + @Builder.Default + private double rewardFactor = 1.0; + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java new file mode 100644 index 000000000..26ac57f0c --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java @@ -0,0 +1,79 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@NoArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class QLearningConfiguration extends LearningConfiguration { + + /** + * The maximum size of the experience replay buffer + */ + @Builder.Default + private int expRepMaxSize = 150000; + + /** + * The batch size of experience for each training iteration + */ + @Builder.Default + private int batchSize = 32; + + /** + * How many steps between target network updates + */ + @Builder.Default + private int targetDqnUpdateFreq = 100; + + /** + * The number of steps to initially wait for until samplling batches from experience replay buffer + */ + @Builder.Default + private int updateStart = 10; + + /** + * Prevent the new Q-Value from being farther than errorClamp away from the previous value. Double.NaN will result in no clamping + */ + @Builder.Default + private double errorClamp = 1.0; + + /** + * The minimum probability for random exploration action during episilon-greedy annealing + */ + @Builder.Default + private double minEpsilon = 0.1f; + + /** + * The number of steps to anneal epsilon to its minimum value. + */ + @Builder.Default + private int epsilonNbStep = 10000; + + /** + * Whether to use the double DQN algorithm + */ + @Builder.Default + private boolean doubleDQN = false; + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java index 22d936fcf..c42756145 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -63,7 +64,7 @@ public abstract class SyncLearning, NN extends N /** * This method will train the model

* The training stop when:
- * - the number of steps reaches the maximum defined in the configuration (see {@link LConfiguration#getMaxStep() LConfiguration.getMaxStep()})
+ * - the number of steps reaches the maximum defined in the configuration (see {@link ILearningConfiguration#getMaxStep() LConfiguration.getMaxStep()})
* OR
* - a listener explicitly stops it
*

diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index 0757043f0..40704d4e9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,10 +19,19 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; -import lombok.*; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import lombok.Value; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.EpochStepCounter; +import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.ExpReplay; import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.SyncLearning; @@ -59,15 +69,15 @@ public abstract class QLearning getLegacyMDPWrapper(); - public QLearning(QLConfiguration conf) { + public QLearning(QLearningConfiguration conf) { this(conf, getSeededRandom(conf.getSeed())); } - public QLearning(QLConfiguration conf, Random random) { + public QLearning(QLearningConfiguration conf, Random random) { expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random); } - private static Random getSeededRandom(Integer seed) { + private static Random getSeededRandom(Long seed) { Random rnd = Nd4j.getRandom(); if(seed != null) { rnd.setSeed(seed); @@ -95,7 +105,7 @@ public abstract class QLearning scores; - float epsilon; + double epsilon; double startQ; double meanQ; } @@ -213,12 +223,14 @@ public abstract class QLearning * DQN or Deep Q-Learning in the Discrete domain - * + *

* http://arxiv.org/abs/1312.5602 - * */ public abstract class QLearningDiscrete extends QLearning { @Getter - final private QLConfiguration configuration; + final private QLearningConfiguration configuration; private final LegacyMDPWrapper mdp; @Getter private DQNPolicy policy; @@ -78,16 +79,15 @@ public abstract class QLearningDiscrete extends QLearning mdp, IDQN dqn, QLConfiguration conf, - int epsilonNbStep) { + public QLearningDiscrete(MDP mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep) { this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed())); } - public QLearningDiscrete(MDP mdp, IDQN dqn, QLConfiguration conf, + public QLearningDiscrete(MDP mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep, Random random) { super(conf); this.configuration = conf; - this.mdp = new LegacyMDPWrapper(mdp, null, this); + this.mdp = new LegacyMDPWrapper<>(mdp, null, this); qNetwork = dqn; targetQNetwork = dqn.clone(); policy = new DQNPolicy(getQNetwork()); @@ -125,6 +125,7 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning extends QLearning extends QLearning extends QLearningDiscret @Deprecated public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, - QLConfiguration conf, IDataManager dataManager) { + QLConfiguration conf, IDataManager dataManager) { this(mdp, dqn, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } + + @Deprecated public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { + super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame()); + setHistoryProcessor(hpconf); + } + + public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, + QLearningConfiguration conf) { super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); setHistoryProcessor(hpconf); } @Deprecated public QLearningDiscreteConv(MDP mdp, DQNFactory factory, - HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { + HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } + + @Deprecated public QLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } + public QLearningDiscreteConv(MDP mdp, DQNFactory factory, + HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { + this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); + } + @Deprecated public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, - HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { - this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); + HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { + this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); } + + @Deprecated public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { + this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf); + } + + public QLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java index ef69ea6fb..5b95cc84e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,8 +17,10 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration; import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; @@ -38,7 +41,13 @@ public class QLearningDiscreteDense extends QLearningDiscre this(mdp, dqn, conf); addListener(new DataManagerTrainingListener(dataManager)); } + + @Deprecated public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf) { + super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep()); + } + + public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearningConfiguration conf) { super(mdp, dqn, conf, conf.getEpsilonNbStep()); } @@ -48,18 +57,33 @@ public class QLearningDiscreteDense extends QLearningDiscre this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } + + @Deprecated public QLearningDiscreteDense(MDP mdp, DQNFactory factory, QLearning.QLConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } + public QLearningDiscreteDense(MDP mdp, DQNFactory factory, + QLearningConfiguration conf) { + this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); + } + @Deprecated public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, QLearning.QLConfiguration conf, IDataManager dataManager) { - this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); + + this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } + + @Deprecated public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, QLearning.QLConfiguration conf) { + this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); + } + + public QLearningDiscreteDense(MDP mdp, DQNDenseNetworkConfiguration netConf, + QLearningConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf), conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java index 274606ed9..63438bb74 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -36,7 +37,7 @@ import java.util.Collection; * * Standard implementation of ActorCriticCompGraph */ -public class ActorCriticCompGraph implements IActorCritic { +public class ActorCriticCompGraph implements IActorCritic { final protected ComputationGraph cg; @Getter @@ -73,13 +74,13 @@ public class ActorCriticCompGraph implements IA } } - public NN clone() { - NN nn = (NN)new ActorCriticCompGraph(cg.clone()); + public ActorCriticCompGraph clone() { + ActorCriticCompGraph nn = new ActorCriticCompGraph(cg.clone()); nn.cg.setListeners(cg.getListeners()); return nn; } - public void copy(NN from) { + public void copy(ActorCriticCompGraph from) { cg.setParams(from.cg.params()); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java index bdadd2969..eaccf2a10 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -31,12 +32,16 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration.ActorCriticNetworkConfigurationBuilder; import org.deeplearning4j.rl4j.util.Constants; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.Arrays; + /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16. * @@ -45,8 +50,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; @Value public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCompGraph { - - Configuration conf; + ActorCriticNetworkConfiguration conf; public ActorCriticCompGraph buildActorCritic(int shapeInputs[], int numOutputs) { @@ -109,16 +113,33 @@ public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCom return new ActorCriticCompGraph(model); } - @AllArgsConstructor @Builder @Value + @Deprecated public static class Configuration { double l2; IUpdater updater; TrainingListener[] listeners; boolean useLSTM; + + /** + * Converts the deprecated Configuration to the new NetworkConfiguration format + */ + public ActorCriticNetworkConfiguration toNetworkConfiguration() { + ActorCriticNetworkConfigurationBuilder builder = ActorCriticNetworkConfiguration.builder() + .l2(l2) + .updater(updater) + .useLSTM(useLSTM); + + if (listeners != null) { + builder.listeners(Arrays.asList(listeners)); + } + + return builder.build(); + + } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java index 7c9e3e21b..0d9dae3c6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,8 +17,6 @@ package org.deeplearning4j.rl4j.network.ac; -import lombok.AllArgsConstructor; -import lombok.Builder; import lombok.Value; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; @@ -29,12 +28,11 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; import org.deeplearning4j.rl4j.util.Constants; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; /** @@ -45,7 +43,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; @Value public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCompGraph { - Configuration conf; + ActorCriticDenseNetworkConfiguration conf; public ActorCriticCompGraph buildActorCritic(int[] numInputs, int numOutputs) { int nIn = 1; @@ -65,27 +63,27 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo "input"); - for (int i = 1; i < conf.getNumLayer(); i++) { + for (int i = 1; i < conf.getNumLayers(); i++) { confB.addLayer(i + "", new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) .activation(Activation.RELU).build(), (i - 1) + ""); } if (conf.isUseLSTM()) { - confB.addLayer(getConf().getNumLayer() + "", new LSTM.Builder().activation(Activation.TANH) - .nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayer() - 1) + ""); + confB.addLayer(getConf().getNumLayers() + "", new LSTM.Builder().activation(Activation.TANH) + .nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayers() - 1) + ""); confB.addLayer("value", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) - .nOut(1).build(), getConf().getNumLayer() + ""); + .nOut(1).build(), getConf().getNumLayers() + ""); confB.addLayer("softmax", new RnnOutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX) - .nOut(numOutputs).build(), getConf().getNumLayer() + ""); + .nOut(numOutputs).build(), getConf().getNumLayers() + ""); } else { confB.addLayer("value", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) - .nOut(1).build(), (getConf().getNumLayer() - 1) + ""); + .nOut(1).build(), (getConf().getNumLayers() - 1) + ""); confB.addLayer("softmax", new OutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX) - .nOut(numOutputs).build(), (getConf().getNumLayer() - 1) + ""); + .nOut(numOutputs).build(), (getConf().getNumLayers() - 1) + ""); } confB.setOutputs("value", "softmax"); @@ -103,18 +101,4 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo return new ActorCriticCompGraph(model); } - @AllArgsConstructor - @Builder - @Value - public static class Configuration { - - int numLayer; - int numHiddenNodes; - double l2; - IUpdater updater; - TrainingListener[] listeners; - boolean useLSTM; - } - - } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java index a55e351c0..4ac557096 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -31,21 +32,24 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; + +import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration.ActorCriticDenseNetworkConfigurationBuilder; import org.deeplearning4j.rl4j.util.Constants; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.Arrays; + /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16. - * - * */ @Value public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySeparate { - Configuration conf; + ActorCriticDenseNetworkConfiguration conf; public ActorCriticSeparate buildActorCritic(int[] numInputs, int numOutputs) { int nIn = 1; @@ -53,27 +57,27 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep nIn *= i; } NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) - .weightInit(WeightInit.XAVIER) - .l2(conf.getL2()) - .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) + .weightInit(WeightInit.XAVIER) + .l2(conf.getL2()) + .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes()) + .activation(Activation.RELU).build()); - for (int i = 1; i < conf.getNumLayer(); i++) { + for (int i = 1; i < conf.getNumLayers(); i++) { confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .activation(Activation.RELU).build()); } if (conf.isUseLSTM()) { - confB.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build()); + confB.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build()); - confB.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) - .nIn(conf.getNumHiddenNodes()).nOut(1).build()); + confB.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) + .nIn(conf.getNumHiddenNodes()).nOut(1).build()); } else { - confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) - .nIn(conf.getNumHiddenNodes()).nOut(1).build()); + confB.layer(conf.getNumLayers(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) + .nIn(conf.getNumHiddenNodes()).nOut(1).build()); } confB.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn)); @@ -87,28 +91,28 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep } NeuralNetConfiguration.ListBuilder confB2 = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) - .weightInit(WeightInit.XAVIER) - //.regularization(true) - //.l2(conf.getL2()) - .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) + .weightInit(WeightInit.XAVIER) + //.regularization(true) + //.l2(conf.getL2()) + .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes()) + .activation(Activation.RELU).build()); - for (int i = 1; i < conf.getNumLayer(); i++) { + for (int i = 1; i < conf.getNumLayers(); i++) { confB2.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .activation(Activation.RELU).build()); } if (conf.isUseLSTM()) { - confB2.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build()); + confB2.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build()); - confB2.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss()) - .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); + confB2.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss()) + .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); } else { - confB2.layer(conf.getNumLayer(), new OutputLayer.Builder(new ActorCriticLoss()) - .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); + confB2.layer(conf.getNumLayers(), new OutputLayer.Builder(new ActorCriticLoss()) + .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); } confB2.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn)); @@ -128,6 +132,7 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep @AllArgsConstructor @Value @Builder + @Deprecated public static class Configuration { int numLayer; @@ -136,6 +141,22 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep IUpdater updater; TrainingListener[] listeners; boolean useLSTM; + + public ActorCriticDenseNetworkConfiguration toNetworkConfiguration() { + ActorCriticDenseNetworkConfigurationBuilder builder = ActorCriticDenseNetworkConfiguration.builder() + .numHiddenNodes(numHiddenNodes) + .numLayers(numLayer) + .l2(l2) + .updater(updater) + .useLSTM(useLSTM); + + if (listeners != null) { + builder.listeners(Arrays.asList(listeners)); + } + + return builder.build(); + + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java new file mode 100644 index 000000000..e85ec6356 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.network.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@EqualsAndHashCode(callSuper = true) +public class ActorCriticDenseNetworkConfiguration extends ActorCriticNetworkConfiguration { + + /** + * The number of layers in the dense network + */ + @Builder.Default + private int numLayers = 3; + + /** + * The number of hidden neurons in each layer + */ + @Builder.Default + private int numHiddenNodes = 100; + + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java new file mode 100644 index 000000000..c043f458e --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java @@ -0,0 +1,37 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.network.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class ActorCriticNetworkConfiguration extends NetworkConfiguration { + + /** + * Whether or not to add an LSTM layer to the network. + */ + @Builder.Default + private boolean useLSTM = false; + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java new file mode 100644 index 000000000..452cb83c2 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.network.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@EqualsAndHashCode(callSuper = true) +public class DQNDenseNetworkConfiguration extends NetworkConfiguration { + + /** + * The number of layers in the dense network + */ + @Builder.Default + private int numLayers = 3; + + /** + * The number of hidden neurons in each layer + */ + @Builder.Default + private int numHiddenNodes = 100; +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java new file mode 100644 index 000000000..c77c379a2 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.network.configuration; + +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.Singular; +import lombok.experimental.SuperBuilder; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.learning.config.IUpdater; + +import java.util.List; + + +@Data +@SuperBuilder +@NoArgsConstructor +public class NetworkConfiguration { + + /** + * The learning rate of the network + */ + @Builder.Default + private double learningRate = 0.01; + + /** + * L2 regularization on the network + */ + @Builder.Default + private double l2 = 0.0; + + /** + * The network's gradient update algorithm + */ + private IUpdater updater; + + /** + * Training listeners attached to the network + */ + @Singular + private List listeners; + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java index ec09d1c1c..077bbf1ce 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -30,12 +31,15 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration; import org.deeplearning4j.rl4j.util.Constants; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.Arrays; + /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16. */ @@ -43,7 +47,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; public class DQNFactoryStdConv implements DQNFactory { - Configuration conf; + NetworkConfiguration conf; public DQN buildDQN(int shapeInputs[], int numOutputs) { @@ -80,7 +84,6 @@ public class DQNFactoryStdConv implements DQNFactory { return new DQN(model); } - @AllArgsConstructor @Builder @Value @@ -90,6 +93,23 @@ public class DQNFactoryStdConv implements DQNFactory { double l2; IUpdater updater; TrainingListener[] listeners; + + /** + * Converts the deprecated Configuration to the new NetworkConfiguration format + */ + public NetworkConfiguration toNetworkConfiguration() { + NetworkConfiguration.NetworkConfigurationBuilder builder = NetworkConfiguration.builder() + .learningRate(learningRate) + .l2(l2) + .updater(updater); + + if (listeners != null) { + builder.listeners(Arrays.asList(listeners)); + } + + return builder.build(); + + } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java index 323ca7ecb..ebe730b4d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,12 +29,16 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration; +import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration.DQNDenseNetworkConfigurationBuilder; import org.deeplearning4j.rl4j.util.Constants; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.Arrays; + /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16. */ @@ -41,32 +46,41 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; @Value public class DQNFactoryStdDense implements DQNFactory { - - Configuration conf; + DQNDenseNetworkConfiguration conf; public DQN buildDQN(int[] numInputs, int numOutputs) { int nIn = 1; + for (int i : numInputs) { nIn *= i; } + NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - //.updater(Updater.NESTEROVS).momentum(0.9) - //.updater(Updater.RMSPROP).rho(conf.getRmsDecay())//.rmsDecay(conf.getRmsDecay()) - .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) - .weightInit(WeightInit.XAVIER) - .l2(conf.getL2()) - .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) + .weightInit(WeightInit.XAVIER) + .l2(conf.getL2()) + .list() + .layer(0, + new DenseLayer.Builder() + .nIn(nIn) + .nOut(conf.getNumHiddenNodes()) + .activation(Activation.RELU).build() + ); - for (int i = 1; i < conf.getNumLayer(); i++) { + for (int i = 1; i < conf.getNumLayers(); i++) { confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) - .activation(Activation.RELU).build()); + .activation(Activation.RELU).build()); } - confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) - .nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); + confB.layer(conf.getNumLayers(), + new OutputLayer.Builder(LossFunctions.LossFunction.MSE) + .activation(Activation.IDENTITY) + .nIn(conf.getNumHiddenNodes()) + .nOut(numOutputs) + .build() + ); MultiLayerConfiguration mlnconf = confB.build(); @@ -83,6 +97,7 @@ public class DQNFactoryStdDense implements DQNFactory { @AllArgsConstructor @Value @Builder + @Deprecated public static class Configuration { int numLayer; @@ -90,7 +105,23 @@ public class DQNFactoryStdDense implements DQNFactory { double l2; IUpdater updater; TrainingListener[] listeners; + + /** + * Converts the deprecated Configuration to the new NetworkConfiguration format + */ + public DQNDenseNetworkConfiguration toNetworkConfiguration() { + DQNDenseNetworkConfigurationBuilder builder = DQNDenseNetworkConfiguration.builder() + .numHiddenNodes(numHiddenNodes) + .numLayers(numLayer) + .l2(l2) + .updater(updater); + + if (listeners != null) { + builder.listeners(Arrays.asList(listeners)); + } + + return builder.build(); + } } - } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java index 3ed375084..3454a37e6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -46,7 +47,7 @@ public class EpsGreedy> extends Policy { final private int updateStart; final private int epsilonNbStep; final private Random rnd; - final private float minEpsilon; + final private double minEpsilon; final private IEpochTrainer learning; public NeuralNet getNeuralNet() { @@ -55,10 +56,10 @@ public class EpsGreedy> extends Policy { public A nextAction(INDArray input) { - float ep = getEpsilon(); + double ep = getEpsilon(); if (learning.getStepCounter() % 500 == 1) log.info("EP: " + ep + " " + learning.getStepCounter()); - if (rnd.nextFloat() > ep) + if (rnd.nextDouble() > ep) return policy.nextAction(input); else return mdp.getActionSpace().randomAction(); @@ -68,7 +69,7 @@ public class EpsGreedy> extends Policy { return this.nextAction(observation.getData()); } - public float getEpsilon() { - return Math.min(1f, Math.max(minEpsilon, 1f - (learning.getStepCounter() - updateStart) * 1f / epsilonNbStep)); + public double getEpsilon() { + return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCounter() - updateStart) * 1.0 / epsilonNbStep)); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java index b639efdaa..bffafdb76 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -22,17 +23,30 @@ import lombok.Builder; import lombok.Getter; import lombok.Value; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.rl4j.learning.NeuralNetFetchable; -import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.rl4j.learning.ILearning; -import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.NeuralNetFetchable; +import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.dqn.DQN; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.util.ModelSerializer; +import org.nd4j.linalg.primitives.Pair; -import java.io.*; -import java.nio.file.*; +import java.io.BufferedOutputStream; +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.nio.file.StandardOpenOption; import java.util.zip.ZipEntry; import java.util.zip.ZipFile; import java.util.zip.ZipOutputStream; @@ -304,7 +318,7 @@ public class DataManager implements IDataManager { public static class Info { String trainingName; String mdpName; - ILearning.LConfiguration conf; + ILearningConfiguration conf; int stepCounter; long millisTime; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java index 26ec0708f..8718d252d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,14 +17,11 @@ package org.deeplearning4j.rl4j.learning; -import java.util.Arrays; import org.junit.Test; -import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; /** * @@ -32,7 +30,7 @@ import static org.junit.Assert.assertTrue; public class HistoryProcessorTest { @Test - public void testHistoryProcessor() throws Exception { + public void testHistoryProcessor() { HistoryProcessor.Configuration conf = HistoryProcessor.Configuration.builder() .croppingHeight(2).croppingWidth(2).rescaledHeight(2).rescaledWidth(2).build(); IHistoryProcessor hp = new HistoryProcessor(conf); @@ -43,8 +41,6 @@ public class HistoryProcessorTest { hp.add(a); INDArray[] h = hp.getHistory(); assertEquals(4, h.length); -// System.out.println(Arrays.toString(a.shape())); -// System.out.println(Arrays.toString(h[0].shape())); assertEquals( 1, h[0].shape()[0]); assertEquals(a.shape()[0], h[0].shape()[1]); assertEquals(a.shape()[1], h[0].shape()[2]); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java index 2302117d2..f2941feef 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java @@ -1,9 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.async; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.support.*; +import org.deeplearning4j.rl4j.support.MockAsyncConfiguration; +import org.deeplearning4j.rl4j.support.MockAsyncGlobal; +import org.deeplearning4j.rl4j.support.MockEncodable; +import org.deeplearning4j.rl4j.support.MockNeuralNet; +import org.deeplearning4j.rl4j.support.MockPolicy; +import org.deeplearning4j.rl4j.support.MockTrainingListener; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -68,7 +91,7 @@ public class AsyncLearningTest { public static class TestContext { - MockAsyncConfiguration config = new MockAsyncConfiguration(1, 11, 0, 0, 0, 0,0, 0, 0, 0); + MockAsyncConfiguration config = new MockAsyncConfiguration(1L, 11, 0, 0, 0, 0,0, 0, 0, 0); public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(); public final MockPolicy policy = new MockPolicy(); public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy); @@ -82,11 +105,11 @@ public class AsyncLearningTest { } public static class TestAsyncLearning extends AsyncLearning { - private final AsyncConfiguration conf; + private final IAsyncLearningConfiguration conf; private final IAsyncGlobal asyncGlobal; private final IPolicy policy; - public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy policy) { + public TestAsyncLearning(IAsyncLearningConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy policy) { this.conf = conf; this.asyncGlobal = asyncGlobal; this.policy = policy; @@ -98,7 +121,7 @@ public class AsyncLearningTest { } @Override - public AsyncConfiguration getConfiguration() { + public IAsyncLearningConfiguration getConfiguration() { return conf; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index bc396502f..72f374db5 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -1,7 +1,25 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.async; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; @@ -32,7 +50,7 @@ public class AsyncThreadDiscreteTest { MockMDP mdpMock = new MockMDP(observationSpace); TrainingListenerList listeners = new TrainingListenerList(); MockPolicy policyMock = new MockPolicy(); - MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0); + MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5); TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); @@ -173,7 +191,7 @@ public class AsyncThreadDiscreteTest { } @Override - protected AsyncConfiguration getConf() { + protected IAsyncLearningConfiguration getConf() { return config; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index 3dea25936..ff29960f1 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -3,12 +3,20 @@ package org.deeplearning4j.rl4j.learning.async; import lombok.AllArgsConstructor; import lombok.Getter; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.support.*; +import org.deeplearning4j.rl4j.support.MockAsyncConfiguration; +import org.deeplearning4j.rl4j.support.MockAsyncGlobal; +import org.deeplearning4j.rl4j.support.MockEncodable; +import org.deeplearning4j.rl4j.support.MockHistoryProcessor; +import org.deeplearning4j.rl4j.support.MockMDP; +import org.deeplearning4j.rl4j.support.MockNeuralNet; +import org.deeplearning4j.rl4j.support.MockObservationSpace; +import org.deeplearning4j.rl4j.support.MockTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; @@ -16,7 +24,6 @@ import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; public class AsyncThreadTest { @@ -126,7 +133,7 @@ public class AsyncThreadTest { public final MockNeuralNet neuralNet = new MockNeuralNet(); public final MockObservationSpace observationSpace = new MockObservationSpace(); public final MockMDP mdp = new MockMDP(observationSpace); - public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0); + public final MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 10, 0, 0, 0, 0, 0, 0, 10, 0); public final TrainingListenerList listeners = new TrainingListenerList(); public final MockTrainingListener listener = new MockTrainingListener(); public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); @@ -149,11 +156,11 @@ public class AsyncThreadTest { private final MockAsyncGlobal asyncGlobal; private final MockNeuralNet neuralNet; - private final AsyncConfiguration conf; + private final IAsyncLearningConfiguration conf; private final List trainSubEpochParams = new ArrayList(); - public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) { + public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, IAsyncLearningConfiguration conf, TrainingListenerList listeners) { super(asyncGlobal, mdp, listeners, threadNumber, 0); this.asyncGlobal = asyncGlobal; @@ -184,7 +191,7 @@ public class AsyncThreadTest { } @Override - protected AsyncConfiguration getConf() { + protected IAsyncLearningConfiguration getConf() { return conf; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java index ef7fec7d0..b812a5582 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java @@ -1,11 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.async.a3c.discrete; import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.async.MiniTrans; -import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete; -import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningThreadDiscrete; +import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.support.*; @@ -31,7 +47,7 @@ public class A3CThreadDiscreteTest { double gamma = 0.9; MockObservationSpace observationSpace = new MockObservationSpace(); MockMDP mdpMock = new MockMDP(observationSpace); - A3CDiscrete.A3CConfiguration config = new A3CDiscrete.A3CConfiguration(0, 0, 0, 0, 0, 0, 0, gamma, 0); + A3CLearningConfiguration config = A3CLearningConfiguration.builder().gamma(0.9).build(); MockActorCritic actorCriticMock = new MockActorCritic(); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(actorCriticMock); @@ -54,9 +70,9 @@ public class A3CThreadDiscreteTest { Nd4j.zeros(5) }; output[0].putScalar(i, outputs[i]); - minitransList.push(new MiniTrans(obs, i, output, rewards[i])); + minitransList.push(new MiniTrans<>(obs, i, output, rewards[i])); } - minitransList.push(new MiniTrans(null, 0, null, 4.0)); // The special batch-ending MiniTrans + minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans // Act sut.calcGradient(actorCriticMock, minitransList); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java index d105419df..2a8c5b832 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java @@ -1,7 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2015-2020 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.support.*; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; @@ -19,7 +36,7 @@ public class AsyncNStepQLearningThreadDiscreteTest { double gamma = 0.9; MockObservationSpace observationSpace = new MockObservationSpace(); MockMDP mdpMock = new MockMDP(observationSpace); - AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration config = new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(0, 0, 0, 0, 0, 0, 0, 0, gamma, 0, 0, 0); + AsyncQLearningConfiguration config = AsyncQLearningConfiguration.builder().gamma(gamma).build(); MockDQN dqnMock = new MockDQN(); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock); @@ -42,9 +59,9 @@ public class AsyncNStepQLearningThreadDiscreteTest { Nd4j.zeros(5) }; output[0].putScalar(i, outputs[i]); - minitransList.push(new MiniTrans(obs, i, output, rewards[i])); + minitransList.push(new MiniTrans<>(obs, i, output, rewards[i])); } - minitransList.push(new MiniTrans(null, 0, null, 4.0)); // The special batch-ending MiniTrans + minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans // Act sut.calcGradient(dqnMock, minitransList); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java index 79be025b5..22e4be3f6 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java @@ -1,6 +1,26 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.sync; import lombok.Getter; +import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.LearningConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; import org.deeplearning4j.rl4j.mdp.MDP; @@ -17,7 +37,7 @@ public class SyncLearningTest { @Test public void when_training_expect_listenersToBeCalled() { // Arrange - QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); @@ -34,7 +54,7 @@ public class SyncLearningTest { @Test public void when_trainingStartCanContinueFalse_expect_trainingStopped() { // Arrange - QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); @@ -52,7 +72,7 @@ public class SyncLearningTest { @Test public void when_newEpochCanContinueFalse_expect_trainingStopped() { // Arrange - QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); @@ -70,7 +90,7 @@ public class SyncLearningTest { @Test public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() { // Arrange - QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); + LearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build(); MockTrainingListener listener = new MockTrainingListener(); MockSyncLearning sut = new MockSyncLearning(lconfig); sut.addListener(listener); @@ -87,12 +107,12 @@ public class SyncLearningTest { public static class MockSyncLearning extends SyncLearning { - private final LConfiguration conf; + private final ILearningConfiguration conf; @Getter private int currentEpochStep = 0; - public MockSyncLearning(LConfiguration conf) { + public MockSyncLearning(ILearningConfiguration conf) { this.conf = conf; } @@ -119,7 +139,7 @@ public class SyncLearningTest { } @Override - public LConfiguration getConfiguration() { + public ILearningConfiguration getConfiguration() { return conf; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLConfigurationTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java similarity index 52% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLConfigurationTest.java rename to rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java index b12866ed2..d7d9bf072 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLConfigurationTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -17,36 +18,24 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning; import com.fasterxml.jackson.databind.ObjectMapper; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -public class QLConfigurationTest { +public class QLearningConfigurationTest { @Rule public ExpectedException thrown = ExpectedException.none(); @Test public void serialize() throws Exception { ObjectMapper mapper = new ObjectMapper(); - QLearning.QLConfiguration qlConfiguration = - new QLearning.QLConfiguration( - 123, //Random seed - 200, //Max step By epoch - 8000, //Max step - 150000, //Max size of experience replay - 32, //size of batches - 500, //target update (hard) - 10, //num step noop warmup - 0.01, //reward scaling - 0.99, //gamma - 1.0, //td error clipping - 0.1f, //min epsilon - 10000, //num step for eps greedy anneal - true //double DQN - ); + + QLearningConfiguration qLearningConfiguration = QLearningConfiguration.builder() + .build(); // Should not throw.. - String json = mapper.writeValueAsString(qlConfiguration); - QLearning.QLConfiguration cnf = mapper.readValue(json, QLearning.QLConfiguration.class); + String json = mapper.writeValueAsString(qLearningConfiguration); + QLearningConfiguration cnf = mapper.readValue(json, QLearningConfiguration.class); } -} \ No newline at end of file +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 58aaab297..fe8dd6acc 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -1,6 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; @@ -27,7 +45,7 @@ public class QLearningDiscreteTest { // Arrange MockObservationSpace observationSpace = new MockObservationSpace(); MockDQN dqn = new MockDQN(); - MockRandom random = new MockRandom(new double[] { + MockRandom random = new MockRandom(new double[]{ 0.7309677600860596, 0.8314409852027893, 0.2405363917350769, @@ -36,14 +54,26 @@ public class QLearningDiscreteTest { 0.3090505599975586, 0.5504369735717773, 0.11700659990310669 - }, - new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); + }, + new int[]{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}); MockMDP mdp = new MockMDP(observationSpace, random); int initStepCount = 8; - QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000, - initStepCount, 1.0, 0, 0, 0, 0, true); + QLearningConfiguration conf = QLearningConfiguration.builder() + .seed(0L) + .maxEpochStep(24) + .maxStep(0) + .expRepMaxSize(5).batchSize(1).targetDqnUpdateFreq(1000) + .updateStart(initStepCount) + .rewardFactor(1.0) + .gamma(0) + .errorClamp(0) + .minEpsilon(0) + .epsilonNbStep(0) + .doubleDQN(true) + .build(); + MockDataManager dataManager = new MockDataManager(false); MockExpReplay expReplay = new MockExpReplay(); TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random); @@ -58,9 +88,9 @@ public class QLearningDiscreteTest { // Assert // HistoryProcessor calls - double[] expectedRecords = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0 }; + double[] expectedRecords = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; assertEquals(expectedRecords.length, hp.recordCalls.size()); - for(int i = 0; i < expectedRecords.length; ++i) { + for (int i = 0; i < expectedRecords.length; ++i) { assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); } @@ -72,59 +102,59 @@ public class QLearningDiscreteTest { assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001); assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001); assertEquals(14, dqn.outputParams.size()); - double[][] expectedDQNOutput = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, - new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, - new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, - new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, - new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, - new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, - new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, + double[][] expectedDQNOutput = new double[][]{ + new double[]{0.0, 2.0, 4.0, 6.0, 8.0}, + new double[]{2.0, 4.0, 6.0, 8.0, 10.0}, + new double[]{2.0, 4.0, 6.0, 8.0, 10.0}, + new double[]{4.0, 6.0, 8.0, 10.0, 12.0}, + new double[]{6.0, 8.0, 10.0, 12.0, 14.0}, + new double[]{6.0, 8.0, 10.0, 12.0, 14.0}, + new double[]{8.0, 10.0, 12.0, 14.0, 16.0}, + new double[]{8.0, 10.0, 12.0, 14.0, 16.0}, + new double[]{10.0, 12.0, 14.0, 16.0, 18.0}, + new double[]{10.0, 12.0, 14.0, 16.0, 18.0}, + new double[]{12.0, 14.0, 16.0, 18.0, 20.0}, + new double[]{12.0, 14.0, 16.0, 18.0, 20.0}, + new double[]{14.0, 16.0, 18.0, 20.0, 22.0}, + new double[]{14.0, 16.0, 18.0, 20.0, 22.0}, }; - for(int i = 0; i < expectedDQNOutput.length; ++i) { + for (int i = 0; i < expectedDQNOutput.length; ++i) { INDArray outputParam = dqn.outputParams.get(i); assertEquals(5, outputParam.shape()[1]); assertEquals(1, outputParam.shape()[2]); double[] expectedRow = expectedDQNOutput[i]; - for(int j = 0; j < expectedRow.length; ++j) { - assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001); + for (int j = 0; j < expectedRow.length; ++j) { + assertEquals("row: " + i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001); } } // MDP calls - assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray()); + assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray()); // ExpReplay calls - double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 }; - int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 }; - double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; - double[][] expectedTrObservations = new double[][] { - new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, - new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, - new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, - new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, - new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, - new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, - new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, - new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, + double[] expectedTrRewards = new double[]{9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0}; + int[] expectedTrActions = new int[]{1, 4, 2, 4, 4, 4, 4, 4}; + double[] expectedTrNextObservation = new double[]{2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0}; + double[][] expectedTrObservations = new double[][]{ + new double[]{0.0, 2.0, 4.0, 6.0, 8.0}, + new double[]{2.0, 4.0, 6.0, 8.0, 10.0}, + new double[]{4.0, 6.0, 8.0, 10.0, 12.0}, + new double[]{6.0, 8.0, 10.0, 12.0, 14.0}, + new double[]{8.0, 10.0, 12.0, 14.0, 16.0}, + new double[]{10.0, 12.0, 14.0, 16.0, 18.0}, + new double[]{12.0, 14.0, 16.0, 18.0, 20.0}, + new double[]{14.0, 16.0, 18.0, 20.0, 22.0}, }; assertEquals(expectedTrObservations.length, expReplay.transitions.size()); - for(int i = 0; i < expectedTrRewards.length; ++i) { + for (int i = 0; i < expectedTrRewards.length; ++i) { Transition tr = expReplay.transitions.get(i); assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001); assertEquals(expectedTrActions[i], tr.getAction()); assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001); - for(int j = 0; j < expectedTrObservations[i].length; ++j) { - assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001); + for (int j = 0; j < expectedTrObservations[i].length; ++j) { + assertEquals("row: " + i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001); } } @@ -132,12 +162,12 @@ public class QLearningDiscreteTest { assertEquals(initStepCount + 16, result.getStepCounter()); assertEquals(300.0, result.getReward(), 0.00001); assertTrue(dqn.hasBeenReset); - assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset); + assertTrue(((MockDQN) sut.getTargetQNetwork()).hasBeenReset); } public static class TestQLearningDiscrete extends QLearningDiscrete { public TestQLearningDiscrete(MDP mdp, IDQN dqn, - QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay, + QLearningConfiguration conf, IDataManager dataManager, MockExpReplay expReplay, int epsilonNbStep, Random rnd) { super(mdp, dqn, conf, epsilonNbStep, rnd); addListener(new DataManagerTrainingListener(dataManager)); @@ -146,10 +176,10 @@ public class QLearningDiscreteTest { @Override protected DataSet setTarget(ArrayList> transitions) { - return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); + return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[]{123.0}), Nd4j.create(new double[]{234.0})); } - public void setExpReplay(IExpReplay exp){ + public void setExpReplay(IExpReplay exp) { this.expReplay = exp; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java index c43c26d50..821863054 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,7 @@ package org.deeplearning4j.rl4j.network.ac; +import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; import org.junit.Test; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.ndarray.INDArray; @@ -29,30 +31,31 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** - * * @author saudet */ public class ActorCriticTest { - public static ActorCriticFactorySeparateStdDense.Configuration NET_CONF = - new ActorCriticFactorySeparateStdDense.Configuration( - 4, //number of layers - 32, //number of hidden nodes - 0.001, //l2 regularization - new RmsProp(0.0005), null, false - ); + public static ActorCriticDenseNetworkConfiguration NET_CONF = + ActorCriticDenseNetworkConfiguration.builder() + .numLayers(4) + .numHiddenNodes(32) + .l2(0.001) + .updater(new RmsProp(0.0005)) + .useLSTM(false) + .build(); - public static ActorCriticFactoryCompGraphStdDense.Configuration NET_CONF_CG = - new ActorCriticFactoryCompGraphStdDense.Configuration( - 2, //number of layers - 128, //number of hidden nodes - 0.00001, //l2 regularization - new RmsProp(0.005), null, true - ); + public static ActorCriticDenseNetworkConfiguration NET_CONF_CG = + ActorCriticDenseNetworkConfiguration.builder() + .numLayers(2) + .numHiddenNodes(128) + .l2(0.00001) + .updater(new RmsProp(0.005)) + .useLSTM(true) + .build(); @Test public void testModelLoadSave() throws IOException { - ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[] {7}, 5); + ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[]{7}, 5); File fileValue = File.createTempFile("rl4j-value-", ".model"); File filePolicy = File.createTempFile("rl4j-policy-", ".model"); @@ -63,7 +66,7 @@ public class ActorCriticTest { assertEquals(acs.valueNet, acs2.valueNet); assertEquals(acs.policyNet, acs2.policyNet); - ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[] {37}, 43); + ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[]{37}, 43); File file = File.createTempFile("rl4j-cg-", ".model"); accg.save(file.getAbsolutePath()); @@ -83,15 +86,15 @@ public class ActorCriticTest { for (double i = eps; i < n; i++) { for (double j = eps; j < n; j++) { - INDArray labels = Nd4j.create(new double[] {i / n, 1 - i / n}, new long[]{1,2}); - INDArray output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2}); + INDArray labels = Nd4j.create(new double[]{i / n, 1 - i / n}, new long[]{1, 2}); + INDArray output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2}); INDArray gradient = loss.computeGradient(labels, output, activation, null); - output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2}); + output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2}); double score = loss.computeScore(labels, output, activation, null, false); - INDArray output1 = Nd4j.create(new double[] {j / n + eps, 1 - j / n}, new long[]{1,2}); + INDArray output1 = Nd4j.create(new double[]{j / n + eps, 1 - j / n}, new long[]{1, 2}); double score1 = loss.computeScore(labels, output1, activation, null, false); - INDArray output2 = Nd4j.create(new double[] {j / n, 1 - j / n + eps}, new long[]{1,2}); + INDArray output2 = Nd4j.create(new double[]{j / n, 1 - j / n + eps}, new long[]{1, 2}); double score2 = loss.computeScore(labels, output2, activation, null, false); double gradient1 = (score1 - score) / eps; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java index 3f68b8f3c..a9997ec0c 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,7 @@ package org.deeplearning4j.rl4j.network.dqn; +import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration; import org.junit.Test; import org.nd4j.linalg.learning.config.RmsProp; @@ -25,22 +27,20 @@ import java.io.IOException; import static org.junit.Assert.assertEquals; /** - * * @author saudet */ public class DQNTest { - public static DQNFactoryStdDense.Configuration NET_CONF = - new DQNFactoryStdDense.Configuration( - 3, //number of layers - 16, //number of hidden nodes - 0.001, //l2 regularization - new RmsProp(0.0005), null - ); + private static DQNDenseNetworkConfiguration NET_CONF = + DQNDenseNetworkConfiguration.builder().numLayers(3) + .numHiddenNodes(16) + .l2(0.001) + .updater(new RmsProp(0.0005)) + .build(); @Test public void testModelLoadSave() throws IOException { - DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[] {42}, 13); + DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[]{42}, 13); File file = File.createTempFile("rl4j-dqn-", ".model"); dqn.save(file.getAbsolutePath()); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java index fe79bdfc7..3f5e761a6 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java @@ -128,7 +128,7 @@ public class TransformProcessTest { // Assert assertFalse(result.isSkipped()); - assertEquals(1, result.getData().shape().length); + assertEquals(2, result.getData().shape().length); assertEquals(1, result.getData().shape()[0]); assertEquals(-10.0, result.getData().getDouble(0), 0.00001); } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index 0707e16ab..0dc16df09 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -24,16 +25,18 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; -import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest; -import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.support.*; +import org.deeplearning4j.rl4j.support.MockDQN; +import org.deeplearning4j.rl4j.support.MockEncodable; +import org.deeplearning4j.rl4j.support.MockHistoryProcessor; +import org.deeplearning4j.rl4j.support.MockMDP; +import org.deeplearning4j.rl4j.support.MockNeuralNet; +import org.deeplearning4j.rl4j.support.MockObservationSpace; +import org.deeplearning4j.rl4j.support.MockRandom; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.junit.Test; import org.nd4j.linalg.activations.Activation; @@ -43,8 +46,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.IOException; import java.io.OutputStream; -import java.util.ArrayList; -import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -186,8 +187,22 @@ public class PolicyTest { new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); MockMDP mdp = new MockMDP(observationSpace, 30, random); - QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, - 0, 1.0, 0, 0, 0, 0, true); + QLearningConfiguration conf = QLearningConfiguration.builder() + .seed(0L) + .maxEpochStep(0) + .maxStep(0) + .expRepMaxSize(5) + .batchSize(1) + .targetDqnUpdateFreq(0) + .updateStart(0) + .rewardFactor(1.0) + .gamma(0) + .errorClamp(0) + .minEpsilon(0) + .epsilonNbStep(0) + .doubleDQN(true) + .build(); + MockNeuralNet nnMock = new MockNeuralNet(); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java index 56581cc0d..08689b032 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java @@ -1,22 +1,37 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.support; import lombok.AllArgsConstructor; -import lombok.Getter; import lombok.Value; -import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; -@AllArgsConstructor @Value -public class MockAsyncConfiguration implements AsyncConfiguration { +@AllArgsConstructor +public class MockAsyncConfiguration implements IAsyncLearningConfiguration { - private Integer seed; + private Long seed; private int maxEpochStep; private int maxStep; - private int numThread; - private int nstep; - private int targetDqnUpdateFreq; private int updateStart; private double rewardFactor; private double gamma; private double errorClamp; + private int numThreads; + private int nStep; + private int learnerUpdateFrequency; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java index a3a5598d4..3a2d5230a 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java @@ -1,3 +1,20 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + package org.deeplearning4j.rl4j.util; import lombok.Getter; @@ -5,6 +22,7 @@ import lombok.Setter; import org.deeplearning4j.rl4j.learning.IEpochTrainer; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListener; import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; import org.deeplearning4j.rl4j.mdp.MDP; @@ -162,7 +180,7 @@ public class DataManagerTrainingListenerTest { } @Override - public LConfiguration getConfiguration() { + public ILearningConfiguration getConfiguration() { return null; } diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java index 7400657ef..00b7c4f7a 100644 --- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java +++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java @@ -1,5 +1,6 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,19 +17,18 @@ package org.deeplearning4j.malmo; -import java.util.HashMap; - +import com.microsoft.msr.malmo.TimestampedStringVector; +import com.microsoft.msr.malmo.WorldState; import org.json.JSONArray; import org.json.JSONObject; - import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import com.microsoft.msr.malmo.TimestampedStringVector; -import com.microsoft.msr.malmo.WorldState; +import java.util.HashMap; /** * Observation space that contains a grid of Minecraft blocks + * * @author howard-abrams (howard.abrams@ca.com) on 1/12/17. */ public class MalmoObservationSpaceGrid extends MalmoObservationSpace { @@ -43,11 +43,11 @@ public class MalmoObservationSpaceGrid extends MalmoObservationSpace { /** * Construct observation space from a array of blocks policy should distinguish between. - * - * @param name Name given to Grid element in mission specification - * @param xSize total x size of grid - * @param ySize total y size of grid - * @param zSize total z size of grid + * + * @param name Name given to Grid element in mission specification + * @param xSize total x size of grid + * @param ySize total y size of grid + * @param zSize total z size of grid * @param blocks Array of block names to distinguish between. Supports combination of individual strings and/or arrays of strings to map multiple block types to a single observation value. If not specified, it will dynamically map block names to integers - however, because these will be mapped as they are seen, different missions may have different mappings! */ public MalmoObservationSpaceGrid(String name, int xSize, int ySize, int zSize, Object... blocks) { @@ -78,7 +78,7 @@ public class MalmoObservationSpaceGrid extends MalmoObservationSpace { @Override public int[] getShape() { - return new int[] {totalSize}; + return new int[]{totalSize}; } @Override