From 65ef0622ffb73021a543493279b829817f4189df Mon Sep 17 00:00:00 2001 From: Kohei Tamura Date: Tue, 17 Dec 2019 12:37:07 +0900 Subject: [PATCH 1/5] Update Japanese translation for Deeplearning4J UI (#8525) Signed-off-by: k-tamura --- .../src/main/resources/dl4j_i18n/train.ja | 4 ++-- .../src/main/resources/dl4j_i18n/train.model.ja | 8 ++++---- .../src/main/resources/dl4j_i18n/train.system.ja | 14 +++++++------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.ja b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.ja index 03d09991f..6833d4ff9 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.ja +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.ja @@ -4,5 +4,5 @@ train.nav.model=モデル train.nav.system=システム train.nav.userguide=ユーザーガイド train.nav.language=言語 -train.session.label=Session -train.session.worker.label=Worker +train.session.label=セッション +train.session.worker.label=ワーカー diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.model.ja b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.model.ja index 00992aed7..2995dcec1 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.model.ja +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.model.ja @@ -5,9 +5,9 @@ train.model.lrChart.title=パラメータ学習率 train.model.lrChart.titleShort=学習率 train.model.paramHistChart.title=レイヤーパラメータヒストグラム train.model.updateHistChart.title=レイヤー更新ヒストグラム -train.model.meanmag.btn.ratio=Ratio -train.model.meanmag.btn.param=Param -train.model.meanmag.btn.update=Updates +train.model.meanmag.btn.ratio=比率 +train.model.meanmag.btn.param=パラメータ +train.model.meanmag.btn.update=更新 train.model.layerinfotable.layerName=レイヤー名 train.model.layerinfotable.layerType=レイヤータイプ train.model.layerinfotable.layerNIn=入力サイズ @@ -19,4 +19,4 @@ train.model.layerinfotable.layerUpdater=更新の方法 train.model.layerinfotable.layerSubsamplingPoolingType=プーリングタイプ train.model.layerinfotable.layerCnnKernel=カーネルサイズ train.model.layerinfotable.layerCnnStride=ストライド -train.model.layerinfotable.layerCnnPadding=パディング \ No newline at end of file +train.model.layerinfotable.layerCnnPadding=パディング diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.system.ja b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.system.ja index 78ff9936f..d28c961fa 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.system.ja +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/resources/dl4j_i18n/train.system.ja @@ -1,10 +1,10 @@ train.system.title=システム詳細 -train.system.selectMachine=Select Machine +train.system.selectMachine=マシンを選択 train.system.chart.memoryShort=メモリ -train.system.chart.systemMemoryTitle=JVM and Off-Heap Memory Utilization -train.system.chart.gpuMemoryTitle=GPU Memory Utilization -train.system.chart.key.jvm=JVM Memory -train.system.chart.key.offHeap=Off Heap Memory +train.system.chart.systemMemoryTitle=JVMとオフヒープのメモリ使用率 +train.system.chart.gpuMemoryTitle=GPUメモリ使用率 +train.system.chart.key.jvm=JVMメモリ +train.system.chart.key.offHeap=オフヒープメモリ train.system.hwTable.title=ハードウェアの情報 train.system.hwTable.jvmCurrent=JVM現在メモリ train.system.hwTable.jvmMax=JVM最大メモリ @@ -13,7 +13,7 @@ train.system.hwTable.offHeapMax=オフヒープ最大メモリ train.system.hwTable.jvmProcs=JVM使用可能プロセッサ train.system.hwTable.computeDevices=計算デバイス数 train.system.hwTable.deviceMemory=デバイスメモリ -train.system.hwTable.deviceName=Device Name +train.system.hwTable.deviceName=デバイス名 train.system.swTable.title=ソフトウェアの情報 train.system.swTable.hostname=ホスト名 train.system.swTable.os=OSの種類 @@ -21,4 +21,4 @@ train.system.swTable.osArch=OSのアーキテクチャ train.system.swTable.jvmName=JVM名 train.system.swTable.jvmVersion=JVMバージョン train.system.swTable.nd4jBackend=ND4Jバックエンド -train.system.swTable.nd4jDataType=ND4Jデータ型 \ No newline at end of file +train.system.swTable.nd4jDataType=ND4Jデータ型 From de3975f0885047579b13a0d644e847d7a1c41392 Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Wed, 18 Dec 2019 02:27:05 -0500 Subject: [PATCH 2/5] RL4J: Remove processing done on observations in Policy & Async (#8471) * Removed processing from Policy.play() and fixed missing resets Signed-off-by: unknown * Adjusted unit test to check if DQNs have been reset Signed-off-by: unknown * Fixed a couple of problems, added and updated unit tests Signed-off-by: unknown * Removed processing from AsyncThreadDiscrete Signed-off-by: unknown * Fixed a few problems Signed-off-by: unknown --- .../rl4j/learning/async/AsyncGlobal.java | 13 +- .../rl4j/learning/async/AsyncLearning.java | 21 +- .../rl4j/learning/async/AsyncThread.java | 133 +++++++++---- .../learning/async/AsyncThreadDiscrete.java | 74 ++----- .../rl4j/learning/async/IAsyncLearning.java | 21 ++ .../async/a3c/discrete/A3CDiscrete.java | 2 +- .../async/a3c/discrete/A3CThreadDiscrete.java | 5 +- .../discrete/AsyncNStepQLearningDiscrete.java | 2 +- .../learning/sync/qlearning/QLearning.java | 3 + .../qlearning/discrete/QLearningDiscrete.java | 2 +- .../rl4j/observation/Observation.java | 19 +- .../deeplearning4j/rl4j/policy/ACPolicy.java | 6 + .../rl4j/policy/BoltzmannQ.java | 6 + .../deeplearning4j/rl4j/policy/DQNPolicy.java | 6 + .../deeplearning4j/rl4j/policy/IPolicy.java | 2 + .../deeplearning4j/rl4j/policy/Policy.java | 105 ++++++---- .../rl4j/util/LegacyMDPWrapper.java | 66 +++++-- .../learning/async/AsyncLearningTest.java | 2 +- .../async/AsyncThreadDiscreteTest.java | 132 +++++++++---- .../rl4j/learning/async/AsyncThreadTest.java | 100 +++++----- .../a3c/discrete/A3CThreadDiscreteTest.java | 181 ++++++++++++++++++ ...AsyncNStepQLearningThreadDiscreteTest.java | 81 ++++++++ .../discrete/QLearningDiscreteTest.java | 10 +- .../rl4j/policy/PolicyTest.java | 6 + .../rl4j/support/MockAsyncGlobal.java | 20 +- .../deeplearning4j/rl4j/support/MockDQN.java | 10 +- .../rl4j/support/MockNeuralNet.java | 2 +- .../rl4j/support/MockPolicy.java | 8 +- .../rl4j/support/MockTrainingListener.java | 12 ++ 29 files changed, 779 insertions(+), 271 deletions(-) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java index fb07baf1e..5501a29e1 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java @@ -57,6 +57,7 @@ public class AsyncGlobal extends Thread implements IAsyncG final private NN current; final private ConcurrentLinkedQueue> queue; final private AsyncConfiguration a3cc; + private final IAsyncLearning learning; @Getter private AtomicInteger T = new AtomicInteger(0); @Getter @@ -64,10 +65,11 @@ public class AsyncGlobal extends Thread implements IAsyncG @Getter private boolean running = true; - public AsyncGlobal(NN initial, AsyncConfiguration a3cc) { + public AsyncGlobal(NN initial, AsyncConfiguration a3cc, IAsyncLearning learning) { this.current = initial; target = (NN) initial.clone(); this.a3cc = a3cc; + this.learning = learning; queue = new ConcurrentLinkedQueue<>(); } @@ -106,11 +108,14 @@ public class AsyncGlobal extends Thread implements IAsyncG } /** - * Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded. + * Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too. */ public void terminate() { - running = false; - queue.clear(); + if(running) { + running = false; + queue.clear(); + learning.terminate(); + } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java index 0835bf692..994ec9cb0 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java @@ -37,7 +37,10 @@ import org.nd4j.linalg.factory.Nd4j; */ @Slf4j public abstract class AsyncLearning, NN extends NeuralNet> - extends Learning { + extends Learning + implements IAsyncLearning { + + private Thread monitorThread = null; @Getter(AccessLevel.PROTECTED) private final TrainingListenerList listeners = new TrainingListenerList(); @@ -126,6 +129,7 @@ public abstract class AsyncLearning, NN extends NeuralNet> +public abstract class AsyncThread, NN extends NeuralNet> extends Thread implements StepCountable, IEpochTrainer { @Getter @@ -54,26 +58,35 @@ public abstract class AsyncThread mdp; @Getter @Setter private IHistoryProcessor historyProcessor; + private boolean isEpochStarted = false; + private final LegacyMDPWrapper mdp; + private final TrainingListenerList listeners; public AsyncThread(IAsyncGlobal asyncGlobal, MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) { - this.mdp = mdp; + this.mdp = new LegacyMDPWrapper(mdp, null, this); this.listeners = listeners; this.threadNumber = threadNumber; this.deviceNum = deviceNum; } + public MDP getMdp() { + return mdp.getWrappedMDP(); + } + protected LegacyMDPWrapper getLegacyMDPWrapper() { + return mdp; + } + public void setHistoryProcessor(IHistoryProcessor.Configuration conf) { - historyProcessor = new HistoryProcessor(conf); + setHistoryProcessor(new HistoryProcessor(conf)); } public void setHistoryProcessor(IHistoryProcessor historyProcessor) { this.historyProcessor = historyProcessor; + mdp.setHistoryProcessor(historyProcessor); } protected void postEpoch() { @@ -109,61 +122,63 @@ public abstract class AsyncThread context = new RunContext<>(); - Nd4j.getAffinityManager().unsafeSetDevice(deviceNum); + try { + RunContext context = new RunContext(); + Nd4j.getAffinityManager().unsafeSetDevice(deviceNum); - log.info("ThreadNum-" + threadNumber + " Started!"); - - boolean canContinue = initWork(context); - if (canContinue) { + log.info("ThreadNum-" + threadNumber + " Started!"); while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) { - handleTraining(context); - if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) { - canContinue = finishEpoch(context) && startNewEpoch(context); + if (!isEpochStarted) { + boolean canContinue = startNewEpoch(context); if (!canContinue) { break; } } + + handleTraining(context); + + if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) { + boolean canContinue = finishEpoch(context); + if (!canContinue) { + break; + } + + ++epochCounter; + } } } - terminateWork(); + finally { + terminateWork(); + } } - private void initNewEpoch(RunContext context) { - getCurrent().reset(); - Learning.InitMdp initMdp = Learning.initMdp(getMdp(), historyProcessor); - - context.obs = initMdp.getLastObs(); - context.rewards = initMdp.getReward(); - context.epochElapsedSteps = initMdp.getSteps(); - } - - private void handleTraining(RunContext context) { + private void handleTraining(RunContext context) { int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps); - SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); + SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); context.obs = subEpochReturn.getLastObs(); - stepCounter += subEpochReturn.getSteps(); context.epochElapsedSteps += subEpochReturn.getSteps(); context.rewards += subEpochReturn.getReward(); context.score = subEpochReturn.getScore(); } - private boolean initWork(RunContext context) { - initNewEpoch(context); - preEpoch(); - return listeners.notifyNewEpoch(this); - } - private boolean startNewEpoch(RunContext context) { - initNewEpoch(context); - epochCounter++; + getCurrent().reset(); + Learning.InitMdp initMdp = refacInitMdp(); + + context.obs = initMdp.getLastObs(); + context.rewards = initMdp.getReward(); + context.epochElapsedSteps = initMdp.getSteps(); + + isEpochStarted = true; preEpoch(); + return listeners.notifyNewEpoch(this); } private boolean finishEpoch(RunContext context) { + isEpochStarted = false; postEpoch(); IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score); @@ -173,8 +188,10 @@ public abstract class AsyncThread getPolicy(NN net); - protected abstract SubEpochReturn trainSubEpoch(O obs, int nstep); + protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep); + + private Learning.InitMdp refacInitMdp() { + LegacyMDPWrapper mdp = getLegacyMDPWrapper(); + IHistoryProcessor hp = getHistoryProcessor(); + + Observation observation = mdp.reset(); + + int step = 0; + double reward = 0; + + boolean isHistoryProcessor = hp != null; + + int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; + int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; + + while (step < requiredFrame && !mdp.isDone()) { + + A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP + + StepReply stepReply = mdp.step(action); + reward += stepReply.getReward(); + observation = stepReply.getObservation(); + + step++; + + } + + return new Learning.InitMdp(step, observation, reward); + + } + + public void incrementStep() { + ++stepCounter; + } @AllArgsConstructor @Value - public static class SubEpochReturn { + public static class SubEpochReturn { int steps; - O lastObs; + Observation lastObs; double reward; double score; } @@ -206,8 +257,8 @@ public abstract class AsyncThread { - private O obs; + private static class RunContext { + private Observation obs; private double rewards; private int epochElapsedSteps; private double score; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index 8b8bc2861..6b0078883 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -25,9 +25,11 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.ArrayUtil; @@ -40,7 +42,7 @@ import java.util.Stack; * Async Learning specialized for the Discrete Domain * */ -public abstract class AsyncThreadDiscrete +public abstract class AsyncThreadDiscrete extends AsyncThread { @Getter @@ -61,14 +63,14 @@ public abstract class AsyncThreadDiscrete trainSubEpoch(O sObs, int nstep) { + public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) { synchronized (getAsyncGlobal()) { current.copy(getAsyncGlobal().getCurrent()); } Stack> rewards = new Stack<>(); - O obs = sObs; + Observation obs = sObs; IPolicy policy = getPolicy(current); Integer action; @@ -81,93 +83,53 @@ public abstract class AsyncThreadDiscrete stepReply = getMdp().step(action); + StepReply stepReply = getLegacyMDPWrapper().step(action); accuReward += stepReply.getReward() * getConf().getRewardFactor(); //if it's not a skipped frame, you can do a step of training if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) { - obs = stepReply.getObservation(); - if (hstack == null) { - hstack = processHistory(input); - } - INDArray[] output = current.outputAll(hstack); - rewards.add(new MiniTrans(hstack, action, output, accuReward)); + INDArray[] output = current.outputAll(obs.getData()); + rewards.add(new MiniTrans(obs.getData(), action, output, accuReward)); accuReward = 0; } + obs = stepReply.getObservation(); + reward += stepReply.getReward(); i++; + incrementStep(); lastAction = action; } //a bit of a trick usable because of how the stack is treated to init R - INDArray input = Learning.getInput(getMdp(), obs); - INDArray hstack = processHistory(input); - - if (hp != null) { - hp.record(input); - } + // FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored. if (getMdp().isDone() && i < nstep * skipFrame) - rewards.add(new MiniTrans(hstack, null, null, 0)); + rewards.add(new MiniTrans(obs.getData(), null, null, 0)); else { INDArray[] output = null; if (getConf().getTargetDqnUpdateFreq() == -1) - output = current.outputAll(hstack); + output = current.outputAll(obs.getData()); else synchronized (getAsyncGlobal()) { - output = getAsyncGlobal().getTarget().outputAll(hstack); + output = getAsyncGlobal().getTarget().outputAll(obs.getData()); } double maxQ = Nd4j.max(output[0]).getDouble(0); - rewards.add(new MiniTrans(hstack, null, output, maxQ)); + rewards.add(new MiniTrans(obs.getData(), null, output, maxQ)); } getAsyncGlobal().enqueue(calcGradient(current, rewards), i); - return new SubEpochReturn(i, obs, reward, current.getLatestScore()); - } - - protected INDArray processHistory(INDArray input) { - IHistoryProcessor hp = getHistoryProcessor(); - INDArray[] history; - if (hp != null) { - hp.add(input); - history = hp.getHistory(); - } else - history = new INDArray[] {input}; - //concat the history into a single INDArray input - INDArray hstack = Transition.concat(history); - if (hp != null) { - hstack.muli(1.0 / hp.getScale()); - } - - if (getCurrent().isRecurrent()) { - //flatten everything for the RNN - hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()), 1)); - } else { - //if input is not 2d, you have to append that the batch is 1 length high - if (hstack.shape().length > 2) - hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()))); - } - - return hstack; + return new SubEpochReturn(i, obs, reward, current.getLatestScore()); } public abstract Gradient[] calcGradient(NN nn, Stack> rewards); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java new file mode 100644 index 000000000..6bae9fddf --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java @@ -0,0 +1,21 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.async; + +public interface IAsyncLearning { + void terminate(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java index 52fa3932b..81308ba5a 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java @@ -53,7 +53,7 @@ public abstract class A3CDiscrete extends AsyncLearning(iActorCritic, conf); + asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this); Integer seed = conf.getSeed(); Random rnd = Nd4j.getRandom(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index d71fa95a0..22b3894b2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; +import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.MiniTrans; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; @@ -46,13 +47,13 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< @Getter final protected A3CDiscrete.A3CConfiguration conf; @Getter - final protected AsyncGlobal asyncGlobal; + final protected IAsyncGlobal asyncGlobal; @Getter final protected int threadNumber; final private Random rnd; - public A3CThreadDiscrete(MDP mdp, AsyncGlobal asyncGlobal, + public A3CThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners, int threadNumber) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index 0c9ff057f..c18de9e10 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -46,7 +46,7 @@ public abstract class AsyncNStepQLearningDiscrete public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { this.mdp = mdp; this.configuration = conf; - this.asyncGlobal = new AsyncGlobal<>(dqn, conf); + this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this); } @Override diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index 098aefeaa..bfd23ef5b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -150,6 +150,9 @@ public abstract class QLearning refacInitMdp() { + getQNetwork().reset(); + getTargetQNetwork().reset(); + LegacyMDPWrapper mdp = getLegacyMDPWrapper(); IHistoryProcessor hp = getHistoryProcessor(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index 4d089d1ee..9da30ccef 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -46,7 +46,7 @@ import java.util.ArrayList; * * DQN or Deep Q-Learning in the Discrete domain * - * https://arxiv.org/abs/1312.5602 + * http://arxiv.org/abs/1312.5602 * */ public abstract class QLearningDiscrete extends QLearning { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java index 7ca63baaf..197a8f744 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java @@ -29,7 +29,15 @@ public class Observation { private final DataSet data; public Observation(INDArray[] data) { - this(new org.nd4j.linalg.dataset.DataSet(Nd4j.concat(0, data), null)); + this(data, false); + } + + public Observation(INDArray[] data, boolean shouldReshape) { + INDArray features = Nd4j.concat(0, data); + if(shouldReshape) { + features = reshape(features); + } + this.data = new org.nd4j.linalg.dataset.DataSet(features, null); } // FIXME: Remove -- only used in unit tests @@ -37,6 +45,15 @@ public class Observation { this.data = new org.nd4j.linalg.dataset.DataSet(data, null); } + private INDArray reshape(INDArray source) { + long[] shape = source.shape(); + long[] nshape = new long[shape.length + 1]; + nshape[0] = 1; + System.arraycopy(shape, 0, nshape, 1, shape.length); + + return source.reshape(nshape); + } + private Observation(DataSet data) { this.data = data; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java index 09e396ac4..61ba70825 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java @@ -20,6 +20,7 @@ import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph; import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate; import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; @@ -65,6 +66,11 @@ public class ACPolicy extends Policy { return actorCritic; } + @Override + public Integer nextAction(Observation obs) { + return nextAction(obs.getData()); + } + public Integer nextAction(INDArray input) { INDArray output = actorCritic.outputAll(input)[1]; if (rnd == null) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java index bff1a782c..cf2b60f41 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java @@ -17,6 +17,7 @@ package org.deeplearning4j.rl4j.policy; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; @@ -43,6 +44,11 @@ public class BoltzmannQ extends Policy { return dqn; } + @Override + public Integer nextAction(Observation obs) { + return nextAction(obs.getData()); + } + public Integer nextAction(INDArray input) { INDArray output = dqn.output(input); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java index 70562b0f1..c7ef91665 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java @@ -20,6 +20,7 @@ import lombok.AllArgsConstructor; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.network.dqn.DQN; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; @@ -44,6 +45,11 @@ public class DQNPolicy extends Policy { return dqn; } + @Override + public Integer nextAction(Observation obs) { + return nextAction(obs.getData()); + } + public Integer nextAction(INDArray input) { INDArray output = dqn.output(input); return Learning.getMaxAction(output); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java index 885fa36a2..0bef2a757 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java @@ -2,6 +2,7 @@ package org.deeplearning4j.rl4j.policy; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; @@ -9,4 +10,5 @@ import org.nd4j.linalg.api.ndarray.INDArray; public interface IPolicy { > double play(MDP mdp, IHistoryProcessor hp); A nextAction(INDArray input); + A nextAction(Observation observation); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index 84ef26f25..97a11b99c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -16,15 +16,21 @@ package org.deeplearning4j.rl4j.policy; +import lombok.Getter; +import lombok.Setter; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.HistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.learning.StepCountable; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.util.ArrayUtil; @@ -39,7 +45,7 @@ public abstract class Policy implements IPolicy { public abstract NeuralNet getNeuralNet(); - public abstract A nextAction(INDArray input); + public abstract A nextAction(Observation obs); public > double play(MDP mdp) { return play(mdp, (IHistoryProcessor)null); @@ -51,66 +57,81 @@ public abstract class Policy implements IPolicy { @Override public > double play(MDP mdp, IHistoryProcessor hp) { + RefacStepCountable stepCountable = new RefacStepCountable(); + LegacyMDPWrapper mdpWrapper = new LegacyMDPWrapper(mdp, hp, stepCountable); + boolean isHistoryProcessor = hp != null; int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; - getNeuralNet().reset(); - Learning.InitMdp initMdp = Learning.initMdp(mdp, hp); - O obs = initMdp.getLastObs(); + Learning.InitMdp initMdp = refacInitMdp(mdpWrapper, hp); + Observation obs = initMdp.getLastObs(); double reward = initMdp.getReward(); - A lastAction = mdp.getActionSpace().noOp(); + A lastAction = mdpWrapper.getActionSpace().noOp(); A action; - int step = initMdp.getSteps(); - INDArray[] history = null; + stepCountable.setStepCounter(initMdp.getSteps()); - INDArray input = Learning.getInput(mdp, obs); + while (!mdpWrapper.isDone()) { - while (!mdp.isDone()) { - - if (step % skipFrame != 0) { + if (stepCountable.getStepCounter() % skipFrame != 0) { action = lastAction; } else { - - if (history == null) { - if (isHistoryProcessor) { - hp.add(input); - history = hp.getHistory(); - } else - history = new INDArray[] {input}; - } - INDArray hstack = Transition.concat(history); - if (isHistoryProcessor) { - hstack.muli(1.0 / hp.getScale()); - } - if (getNeuralNet().isRecurrent()) { - //flatten everything for the RNN - hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()), 1)); - } else { - if (hstack.shape().length > 2) - hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()))); - } - action = nextAction(hstack); + action = nextAction(obs); } + lastAction = action; - StepReply stepReply = mdp.step(action); + StepReply stepReply = mdpWrapper.step(action); reward += stepReply.getReward(); - input = Learning.getInput(mdp, stepReply.getObservation()); - if (isHistoryProcessor) { - hp.record(input); - hp.add(input); - } - - history = isHistoryProcessor ? hp.getHistory() - : new INDArray[] {Learning.getInput(mdp, stepReply.getObservation())}; - step++; + obs = stepReply.getObservation(); + stepCountable.increment(); } - return reward; } + private > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) { + getNeuralNet().reset(); + Observation observation = mdpWrapper.reset(); + + int step = 0; + double reward = 0; + + boolean isHistoryProcessor = hp != null; + + int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; + int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; + + while (step < requiredFrame && !mdpWrapper.isDone()) { + + A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP + + StepReply stepReply = mdpWrapper.step(action); + reward += stepReply.getReward(); + observation = stepReply.getObservation(); + + step++; + + } + + return new Learning.InitMdp(step, observation, reward); + } + + private class RefacStepCountable implements StepCountable { + + @Getter + @Setter + private int stepCounter = 0; + + public void increment() { + ++stepCounter; + } + + @Override + public int getStepCounter() { + return 0; + } + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java index efbf29603..221409040 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java @@ -4,6 +4,7 @@ import lombok.Getter; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.learning.StepCountable; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; @@ -12,21 +13,53 @@ import org.deeplearning4j.rl4j.space.ObservationSpace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -public class LegacyMDPWrapper> implements MDP { +public class LegacyMDPWrapper> implements MDP { @Getter private final MDP wrappedMDP; @Getter private final WrapperObservationSpace observationSpace; private final ILearning learning; + private IHistoryProcessor historyProcessor; + private final StepCountable stepCountable; private int skipFrame; private int step = 0; public LegacyMDPWrapper(MDP wrappedMDP, ILearning learning) { + this(wrappedMDP, learning, null, null); + } + + public LegacyMDPWrapper(MDP wrappedMDP, IHistoryProcessor historyProcessor, StepCountable stepCountable) { + this(wrappedMDP, null, historyProcessor, stepCountable); + } + + private LegacyMDPWrapper(MDP wrappedMDP, ILearning learning, IHistoryProcessor historyProcessor, StepCountable stepCountable) { this.wrappedMDP = wrappedMDP; this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape()); this.learning = learning; + this.historyProcessor = historyProcessor; + this.stepCountable = stepCountable; + } + + private IHistoryProcessor getHistoryProcessor() { + if(historyProcessor != null) { + return historyProcessor; + } + + return learning.getHistoryProcessor(); + } + + public void setHistoryProcessor(IHistoryProcessor historyProcessor) { + this.historyProcessor = historyProcessor; + } + + private int getStep() { + if(stepCountable != null) { + return stepCountable.getStepCounter(); + } + + return learning.getStepCounter(); } @Override @@ -38,13 +71,12 @@ public class LegacyMDPWrapper> public Observation reset() { INDArray rawObservation = getInput(wrappedMDP.reset()); - IHistoryProcessor historyProcessor = learning.getHistoryProcessor(); + IHistoryProcessor historyProcessor = getHistoryProcessor(); if(historyProcessor != null) { - historyProcessor.record(rawObservation.dup()); - rawObservation.muli(1.0 / historyProcessor.getScale()); + historyProcessor.record(rawObservation); } - Observation observation = new Observation(new INDArray[] { rawObservation }); + Observation observation = new Observation(new INDArray[] { rawObservation }, false); if(historyProcessor != null) { skipFrame = historyProcessor.getConf().getSkipFrame(); @@ -55,14 +87,9 @@ public class LegacyMDPWrapper> return observation; } - @Override - public void close() { - wrappedMDP.close(); - } - @Override public StepReply step(A a) { - IHistoryProcessor historyProcessor = learning.getHistoryProcessor(); + IHistoryProcessor historyProcessor = getHistoryProcessor(); StepReply rawStepReply = wrappedMDP.step(a); INDArray rawObservation = getInput(rawStepReply.getObservation()); @@ -71,11 +98,10 @@ public class LegacyMDPWrapper> int requiredFrame = 0; if(historyProcessor != null) { - historyProcessor.record(rawObservation.dup()); - rawObservation.muli(1.0 / historyProcessor.getScale()); + historyProcessor.record(rawObservation); requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1); - if ((learning.getStepCounter() % skipFrame == 0 && step >= requiredFrame) + if ((getStep() % skipFrame == 0 && step >= requiredFrame) || (step % skipFrame == 0 && step < requiredFrame )){ historyProcessor.add(rawObservation); } @@ -83,15 +109,21 @@ public class LegacyMDPWrapper> Observation observation; if(historyProcessor != null && step >= requiredFrame) { - observation = new Observation(historyProcessor.getHistory()); + observation = new Observation(historyProcessor.getHistory(), true); + observation.getData().muli(1.0 / historyProcessor.getScale()); } else { - observation = new Observation(new INDArray[] { rawObservation }); + observation = new Observation(new INDArray[] { rawObservation }, false); } return new StepReply(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); } + @Override + public void close() { + wrappedMDP.close(); + } + @Override public boolean isDone() { return wrappedMDP.isDone(); @@ -103,7 +135,7 @@ public class LegacyMDPWrapper> } private INDArray getInput(O obs) { - INDArray arr = Nd4j.create(obs.toArray()); + INDArray arr = Nd4j.create(((Encodable)obs).toArray()); int[] shape = observationSpace.getShape(); if (shape.length == 1) return arr.reshape(new long[] {1, arr.length()}); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java index 8de9d864a..2302117d2 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java @@ -72,7 +72,7 @@ public class AsyncLearningTest { public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(); public final MockPolicy policy = new MockPolicy(); public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy); - public final MockTrainingListener listener = new MockTrainingListener(); + public final MockTrainingListener listener = new MockTrainingListener(asyncGlobal); public TestContext() { sut.addListener(listener); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index e3658b8dd..b94541159 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -2,16 +2,17 @@ package org.deeplearning4j.rl4j.learning.async; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; -import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.support.*; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; +import java.util.ArrayList; +import java.util.List; import java.util.Stack; import static org.junit.Assert.assertEquals; @@ -21,37 +22,51 @@ public class AsyncThreadDiscreteTest { @Test public void refac_AsyncThreadDiscrete_trainSubEpoch() { // Arrange + int numEpochs = 1; MockNeuralNet nnMock = new MockNeuralNet(); + IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock); + asyncGlobalMock.setMaxLoops(hpConf.getSkipFrame() * numEpochs); MockObservationSpace observationSpace = new MockObservationSpace(); MockMDP mdpMock = new MockMDP(observationSpace); TrainingListenerList listeners = new TrainingListenerList(); MockPolicy policyMock = new MockPolicy(); - MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 0, 5,0, 0, 0, 0); - IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); - MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); + MockAsyncConfiguration config = new MockAsyncConfiguration(5, 16, 0, 0, 2, 5,0, 0, 0, 0); TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); - MockEncodable obs = new MockEncodable(123); - - hpMock.add(Learning.getInput(mdpMock, new MockEncodable(1))); - hpMock.add(Learning.getInput(mdpMock, new MockEncodable(2))); - hpMock.add(Learning.getInput(mdpMock, new MockEncodable(3))); - hpMock.add(Learning.getInput(mdpMock, new MockEncodable(4))); - hpMock.add(Learning.getInput(mdpMock, new MockEncodable(5))); // Act - AsyncThread.SubEpochReturn result = sut.trainSubEpoch(obs, 2); + sut.run(); // Assert - assertEquals(4, result.getSteps()); - assertEquals(6.0, result.getReward(), 0.00001); - assertEquals(0.0, result.getScore(), 0.00001); - assertEquals(3.0, result.getLastObs().toArray()[0], 0.00001); - assertEquals(1, asyncGlobalMock.enqueueCallCount); + assertEquals(2, sut.trainSubEpochResults.size()); + double[][] expectedLastObservations = new double[][] { + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, + new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, + }; + double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 }; + for(int i = 0; i < 2; ++i) { + AsyncThread.SubEpochReturn result = sut.trainSubEpochResults.get(i); + assertEquals(4, result.getSteps()); + assertEquals(expectedSubEpochReturnRewards[i], result.getReward(), 0.00001); + assertEquals(0.0, result.getScore(), 0.00001); + + double[] expectedLastObservation = expectedLastObservations[i]; + assertEquals(expectedLastObservation.length, result.getLastObs().getData().shape()[1]); + for(int j = 0; j < expectedLastObservation.length; ++j) { + assertEquals(expectedLastObservation[j], 255.0 * result.getLastObs().getData().getDouble(j), 0.00001); + } + } + assertEquals(2, asyncGlobalMock.enqueueCallCount); // HistoryProcessor - assertEquals(10, hpMock.addCalls.size()); - double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 }; + double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 }; + assertEquals(expectedAddValues.length, hpMock.addCalls.size()); + for(int i = 0; i < expectedAddValues.length; ++i) { + assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001); + } + + double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, }; assertEquals(expectedRecordValues.length, hpMock.recordCalls.size()); for(int i = 0; i < expectedRecordValues.length; ++i) { assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001); @@ -59,49 +74,89 @@ public class AsyncThreadDiscreteTest { // Policy double[][] expectedPolicyInputs = new double[][] { - new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 }, - new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 }, - new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 }, - new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 }, + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, + new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, }; assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size()); for(int i = 0; i < expectedPolicyInputs.length; ++i) { double[] expectedRow = expectedPolicyInputs[i]; INDArray input = policyMock.actionInputs.get(i); - assertEquals(expectedRow.length, input.shape()[0]); + assertEquals(expectedRow.length, input.shape()[1]); for(int j = 0; j < expectedRow.length; ++j) { assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001); } } // NeuralNetwork - assertEquals(1, nnMock.copyCallCount); + assertEquals(2, nnMock.copyCallCount); double[][] expectedNNInputs = new double[][] { - new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 }, - new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 }, - new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 }, - new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 }, - new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 }, + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: This one comes from the computation of output of the last minitrans + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, + new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, + new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, // FIXME: This one comes from the computation of output of the last minitrans }; assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size()); for(int i = 0; i < expectedNNInputs.length; ++i) { double[] expectedRow = expectedNNInputs[i]; INDArray input = nnMock.outputAllInputs.get(i); - assertEquals(expectedRow.length, input.shape()[0]); + assertEquals(expectedRow.length, input.shape()[1]); for(int j = 0; j < expectedRow.length; ++j) { assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001); } } + int arrayIdx = 0; + double[][][] expectedMinitransObs = new double[][][] { + new double[][] { + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: The last minitrans contains the next observation + }, + new double[][] { + new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, + new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, + new double[] { 8.0, 9.0, 11.0, 13.0, 15 }, // FIXME: The last minitrans contains the next observation + } + }; + double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 }; + double[] expectedRewards = new double[] { 0.0, 0.0, 3.0, 0.0, 0.0, 6.0 }; + + assertEquals(2, sut.rewards.size()); + for(int rewardIdx = 0; rewardIdx < 2; ++rewardIdx) { + Stack> miniTransStack = sut.rewards.get(rewardIdx); + + for (int i = 0; i < expectedMinitransObs[rewardIdx].length; ++i) { + MiniTrans minitrans = miniTransStack.get(i); + + // Observation + double[] expectedRow = expectedMinitransObs[rewardIdx][i]; + INDArray realRewards = minitrans.getObs(); + assertEquals(expectedRow.length, realRewards.shape()[1]); + for (int j = 0; j < expectedRow.length; ++j) { + assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * realRewards.getDouble(j), 0.00001); + } + + assertEquals(expectedOutputs[arrayIdx], minitrans.getOutput()[0].getDouble(0), 0.00001); + assertEquals(expectedRewards[arrayIdx], minitrans.getReward(), 0.00001); + ++arrayIdx; + } + } } public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete { - private final IAsyncGlobal asyncGlobal; + private final MockAsyncGlobal asyncGlobal; private final MockPolicy policy; private final MockAsyncConfiguration config; - public TestAsyncThreadDiscrete(IAsyncGlobal asyncGlobal, MDP mdp, + public final List trainSubEpochResults = new ArrayList(); + public final List>> rewards = new ArrayList>>(); + + public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy, MockAsyncConfiguration config, IHistoryProcessor hp) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); @@ -113,6 +168,7 @@ public class AsyncThreadDiscreteTest { @Override public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack> rewards) { + this.rewards.add(rewards); return new Gradient[0]; } @@ -130,5 +186,13 @@ public class AsyncThreadDiscreteTest { protected IPolicy getPolicy(MockNeuralNet net) { return policy; } + + @Override + public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) { + asyncGlobal.increaseCurrentLoop(); + SubEpochReturn result = super.trainSubEpoch(sObs, nstep); + trainSubEpochResults.add(result); + return result; + } } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index 0a590a1e5..377b32175 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -6,11 +6,14 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.Policy; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; import java.util.List; @@ -23,107 +26,100 @@ public class AsyncThreadTest { @Test public void when_newEpochStarted_expect_neuralNetworkReset() { // Arrange - TestContext context = new TestContext(); - context.listener.setRemainingOnNewEpochCallCount(5); + int numberOfEpochs = 5; + TestContext context = new TestContext(numberOfEpochs); // Act context.sut.run(); // Assert - assertEquals(6, context.neuralNet.resetCallCount); + assertEquals(numberOfEpochs, context.neuralNet.resetCallCount); } @Test public void when_onNewEpochReturnsStop_expect_threadStopped() { // Arrange - TestContext context = new TestContext(); - context.listener.setRemainingOnNewEpochCallCount(1); + int stopAfterNumCalls = 1; + TestContext context = new TestContext(100000); + context.listener.setRemainingOnNewEpochCallCount(stopAfterNumCalls); // Act context.sut.run(); // Assert - assertEquals(2, context.listener.onNewEpochCallCount); - assertEquals(1, context.listener.onEpochTrainingResultCallCount); + assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: The call that returns stop is counted + assertEquals(stopAfterNumCalls, context.listener.onEpochTrainingResultCallCount); } @Test public void when_epochTrainingResultReturnsStop_expect_threadStopped() { // Arrange - TestContext context = new TestContext(); - context.listener.setRemainingOnEpochTrainingResult(1); + int stopAfterNumCalls = 1; + TestContext context = new TestContext(100000); + context.listener.setRemainingOnEpochTrainingResult(stopAfterNumCalls); // Act context.sut.run(); // Assert - assertEquals(2, context.listener.onNewEpochCallCount); - assertEquals(2, context.listener.onEpochTrainingResultCallCount); + assertEquals(stopAfterNumCalls + 1, context.listener.onEpochTrainingResultCallCount); // +1: The call that returns stop is counted + assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: onNewEpoch is called on the epoch that onEpochTrainingResult() will stop } @Test public void when_run_expect_preAndPostEpochCalled() { // Arrange - TestContext context = new TestContext(); + int numberOfEpochs = 5; + TestContext context = new TestContext(numberOfEpochs); // Act context.sut.run(); // Assert - assertEquals(6, context.sut.preEpochCallCount); - assertEquals(6, context.sut.postEpochCallCount); + assertEquals(numberOfEpochs, context.sut.preEpochCallCount); + assertEquals(numberOfEpochs, context.sut.postEpochCallCount); } @Test public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() { // Arrange - TestContext context = new TestContext(); + int numberOfEpochs = 5; + TestContext context = new TestContext(numberOfEpochs); // Act context.sut.run(); // Assert - assertEquals(5, context.listener.statEntries.size()); + assertEquals(numberOfEpochs, context.listener.statEntries.size()); int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 }; - for(int i = 0; i < 5; ++i) { + double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init + + 1.0; // Reward from trainSubEpoch() + for(int i = 0; i < numberOfEpochs; ++i) { IDataManager.StatEntry statEntry = context.listener.statEntries.get(i); assertEquals(expectedStepCounter[i], statEntry.getStepCounter()); assertEquals(i, statEntry.getEpochCounter()); - assertEquals(38.0, statEntry.getReward(), 0.0001); + assertEquals(expectedReward, statEntry.getReward(), 0.0001); } } - @Test - public void when_run_expect_NeuralNetIsResetAtInitAndEveryEpoch() { - // Arrange - TestContext context = new TestContext(); - - // Act - context.sut.run(); - - // Assert - assertEquals(6, context.neuralNet.resetCallCount); - } - @Test public void when_run_expect_trainSubEpochCalled() { // Arrange - TestContext context = new TestContext(); + int numberOfEpochs = 5; + TestContext context = new TestContext(numberOfEpochs); // Act context.sut.run(); // Assert - assertEquals(10, context.sut.trainSubEpochParams.size()); - for(int i = 0; i < 10; ++i) { + assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size()); + double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }; + for(int i = 0; i < context.sut.getEpochCounter(); ++i) { MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i); - if(i % 2 == 0) { - assertEquals(2, params.nstep); - assertEquals(8.0, params.obs.toArray()[0], 0.00001); - } - else { - assertEquals(1, params.nstep); - assertNull(params.obs); + assertEquals(2, params.nstep); + assertEquals(expectedObservation.length, params.obs.getData().shape()[1]); + for(int j = 0; j < expectedObservation.length; ++j){ + assertEquals(expectedObservation[j], 255.0 * params.obs.getData().getDouble(j), 0.00001); } } } @@ -136,30 +132,30 @@ public class AsyncThreadTest { public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0); public final TrainingListenerList listeners = new TrainingListenerList(); public final MockTrainingListener listener = new MockTrainingListener(); - private final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf); public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners); - public TestContext() { - asyncGlobal.setMaxLoops(10); + public TestContext(int numEpochs) { + asyncGlobal.setMaxLoops(numEpochs); listeners.add(listener); sut.setHistoryProcessor(historyProcessor); } } - public static class MockAsyncThread extends AsyncThread { + public static class MockAsyncThread extends AsyncThread { public int preEpochCallCount = 0; public int postEpochCallCount = 0; - private final IAsyncGlobal asyncGlobal; + private final MockAsyncGlobal asyncGlobal; private final MockNeuralNet neuralNet; private final AsyncConfiguration conf; private final List trainSubEpochParams = new ArrayList(); - public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) { + public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) { super(asyncGlobal, mdp, listeners, threadNumber, 0); this.asyncGlobal = asyncGlobal; @@ -180,7 +176,7 @@ public class AsyncThreadTest { } @Override - protected NeuralNet getCurrent() { + protected MockNeuralNet getCurrent() { return neuralNet; } @@ -195,20 +191,22 @@ public class AsyncThreadTest { } @Override - protected Policy getPolicy(NeuralNet net) { + protected Policy getPolicy(MockNeuralNet net) { return null; } @Override - protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) { + protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) { + asyncGlobal.increaseCurrentLoop(); trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep)); - return new SubEpochReturn(1, null, 1.0, 1.0); + setStepCounter(getStepCounter() + nstep); + return new SubEpochReturn(nstep, null, 1.0, 1.0); } @AllArgsConstructor @Getter public static class TrainSubEpochParams { - Encodable obs; + Observation obs; int nstep; } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java new file mode 100644 index 000000000..ef7fec7d0 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java @@ -0,0 +1,181 @@ +package org.deeplearning4j.rl4j.learning.async.a3c.discrete; + +import org.deeplearning4j.nn.api.NeuralNetwork; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete; +import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningThreadDiscrete; +import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.support.*; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Stack; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class A3CThreadDiscreteTest { + + @Test + public void refac_calcGradient() { + // Arrange + double gamma = 0.9; + MockObservationSpace observationSpace = new MockObservationSpace(); + MockMDP mdpMock = new MockMDP(observationSpace); + A3CDiscrete.A3CConfiguration config = new A3CDiscrete.A3CConfiguration(0, 0, 0, 0, 0, 0, 0, gamma, 0); + MockActorCritic actorCriticMock = new MockActorCritic(); + IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(actorCriticMock); + A3CThreadDiscrete sut = new A3CThreadDiscrete(mdpMock, asyncGlobalMock, config, 0, null, 0); + MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); + sut.setHistoryProcessor(hpMock); + + double[][] minitransObs = new double[][] { + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + }; + double[] outputs = new double[] { 1.0, 2.0, 3.0 }; + double[] rewards = new double[] { 0.0, 0.0, 3.0 }; + + Stack> minitransList = new Stack>(); + for(int i = 0; i < 3; ++i) { + INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1); + INDArray[] output = new INDArray[] { + Nd4j.zeros(5) + }; + output[0].putScalar(i, outputs[i]); + minitransList.push(new MiniTrans(obs, i, output, rewards[i])); + } + minitransList.push(new MiniTrans(null, 0, null, 4.0)); // The special batch-ending MiniTrans + + // Act + sut.calcGradient(actorCriticMock, minitransList); + + // Assert + assertEquals(1, actorCriticMock.gradientParams.size()); + INDArray input = actorCriticMock.gradientParams.get(0).getFirst(); + INDArray[] labels = actorCriticMock.gradientParams.get(0).getSecond(); + + assertEquals(minitransObs.length, input.shape()[0]); + for(int i = 0; i < minitransObs.length; ++i) { + double[] expectedRow = minitransObs[i]; + assertEquals(expectedRow.length, input.shape()[1]); + for(int j = 0; j < expectedRow.length; ++j) { + assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001); + } + } + + double latestReward = (gamma * 4.0) + 3.0; + double[] expectedLabels0 = new double[] { gamma * gamma * latestReward, gamma * latestReward, latestReward }; + for(int i = 0; i < expectedLabels0.length; ++i) { + assertEquals(expectedLabels0[i], labels[0].getDouble(i), 0.00001); + } + double[][] expectedLabels1 = new double[][] { + new double[] { 4.346, 0.0, 0.0, 0.0, 0.0 }, + new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 }, + new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 }, + }; + + assertArrayEquals(new long[] { expectedLabels0.length, 1 }, labels[0].shape()); + + for(int i = 0; i < expectedLabels1.length; ++i) { + double[] expectedRow = expectedLabels1[i]; + assertEquals(expectedRow.length, labels[1].shape()[1]); + for(int j = 0; j < expectedRow.length; ++j) { + assertEquals(expectedRow[j], labels[1].getDouble(i, j), 0.00001); + } + } + + } + + public class MockActorCritic implements IActorCritic { + + public final List> gradientParams = new ArrayList<>(); + + @Override + public NeuralNetwork[] getNeuralNetworks() { + return new NeuralNetwork[0]; + } + + @Override + public boolean isRecurrent() { + return false; + } + + @Override + public void reset() { + + } + + @Override + public void fit(INDArray input, INDArray[] labels) { + + } + + @Override + public INDArray[] outputAll(INDArray batch) { + return new INDArray[0]; + } + + @Override + public IActorCritic clone() { + return this; + } + + @Override + public void copy(NeuralNet from) { + + } + + @Override + public void copy(IActorCritic from) { + + } + + @Override + public Gradient[] gradient(INDArray input, INDArray[] labels) { + gradientParams.add(new Pair(input, labels)); + return new Gradient[0]; + } + + @Override + public void applyGradient(Gradient[] gradient, int batchSize) { + + } + + @Override + public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException { + + } + + @Override + public void save(String pathValue, String pathPolicy) throws IOException { + + } + + @Override + public double getLatestScore() { + return 0; + } + + @Override + public void save(OutputStream os) throws IOException { + + } + + @Override + public void save(String filename) throws IOException { + + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java new file mode 100644 index 000000000..d105419df --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java @@ -0,0 +1,81 @@ +package org.deeplearning4j.rl4j.learning.async.nstep.discrete; + +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.async.MiniTrans; +import org.deeplearning4j.rl4j.support.*; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Stack; + +import static org.junit.Assert.assertEquals; + +public class AsyncNStepQLearningThreadDiscreteTest { + + @Test + public void refac_calcGradient() { + // Arrange + double gamma = 0.9; + MockObservationSpace observationSpace = new MockObservationSpace(); + MockMDP mdpMock = new MockMDP(observationSpace); + AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration config = new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(0, 0, 0, 0, 0, 0, 0, 0, gamma, 0, 0, 0); + MockDQN dqnMock = new MockDQN(); + IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); + MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock); + AsyncNStepQLearningThreadDiscrete sut = new AsyncNStepQLearningThreadDiscrete(mdpMock, asyncGlobalMock, config, null, 0, 0); + MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf); + sut.setHistoryProcessor(hpMock); + + double[][] minitransObs = new double[][] { + new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, + new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, + new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, + }; + double[] outputs = new double[] { 1.0, 2.0, 3.0 }; + double[] rewards = new double[] { 0.0, 0.0, 3.0 }; + + Stack> minitransList = new Stack>(); + for(int i = 0; i < 3; ++i) { + INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1); + INDArray[] output = new INDArray[] { + Nd4j.zeros(5) + }; + output[0].putScalar(i, outputs[i]); + minitransList.push(new MiniTrans(obs, i, output, rewards[i])); + } + minitransList.push(new MiniTrans(null, 0, null, 4.0)); // The special batch-ending MiniTrans + + // Act + sut.calcGradient(dqnMock, minitransList); + + // Assert + assertEquals(1, dqnMock.gradientParams.size()); + INDArray input = dqnMock.gradientParams.get(0).getFirst(); + INDArray labels = dqnMock.gradientParams.get(0).getSecond(); + + assertEquals(minitransObs.length, input.shape()[0]); + for(int i = 0; i < minitransObs.length; ++i) { + double[] expectedRow = minitransObs[i]; + assertEquals(expectedRow.length, input.shape()[1]); + for(int j = 0; j < expectedRow.length; ++j) { + assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001); + } + } + + double latestReward = (gamma * 4.0) + 3.0; + double[][] expectedLabels = new double[][] { + new double[] { gamma * gamma * latestReward, 0.0, 0.0, 0.0, 0.0 }, + new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 }, + new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 }, + }; + assertEquals(minitransObs.length, labels.shape()[0]); + for(int i = 0; i < minitransObs.length; ++i) { + double[] expectedRow = expectedLabels[i]; + assertEquals(expectedRow.length, labels.shape()[1]); + for(int j = 0; j < expectedRow.length; ++j) { + assertEquals(expectedRow[j], labels.getDouble(i, j), 0.00001); + } + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 59c28551b..22e1ba49d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -63,7 +63,7 @@ public class QLearningDiscreteTest { double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 }; assertEquals(expectedAdds.length, hp.addCalls.size()); for(int i = 0; i < expectedAdds.length; ++i) { - assertEquals(expectedAdds[i], 255.0 * hp.addCalls.get(i).getDouble(0), 0.0001); + assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001); } assertEquals(0, hp.startMonitorCallCount); assertEquals(0, hp.stopMonitorCallCount); @@ -92,8 +92,8 @@ public class QLearningDiscreteTest { for(int i = 0; i < expectedDQNOutput.length; ++i) { INDArray outputParam = dqn.outputParams.get(i); - assertEquals(5, outputParam.shape()[0]); - assertEquals(1, outputParam.shape()[1]); + assertEquals(5, outputParam.shape()[1]); + assertEquals(1, outputParam.shape()[2]); double[] expectedRow = expectedDQNOutput[i]; for(int j = 0; j < expectedRow.length; ++j) { @@ -124,13 +124,15 @@ public class QLearningDiscreteTest { assertEquals(expectedTrActions[i], tr.getAction()); assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001); for(int j = 0; j < expectedTrObservations[i].length; ++j) { - assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(j, 0), 0.0001); + assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001); } } // trainEpoch result assertEquals(16, result.getStepCounter()); assertEquals(300.0, result.getReward(), 0.00001); + assertTrue(dqn.hasBeenReset); + assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset); } public static class TestQLearningDiscrete extends QLearningDiscrete { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index ffb3680bb..6fa2d06d4 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -28,6 +28,7 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscret import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.support.*; @@ -229,6 +230,11 @@ public class PolicyTest { return neuralNet; } + @Override + public Integer nextAction(Observation obs) { + return nextAction(obs.getData()); + } + @Override public Integer nextAction(INDArray input) { return (int)input.getDouble(0); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java index 34a2078f0..33dc82314 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncGlobal.java @@ -8,9 +8,10 @@ import org.deeplearning4j.rl4j.network.NeuralNet; import java.util.concurrent.atomic.AtomicInteger; -public class MockAsyncGlobal implements IAsyncGlobal { +public class MockAsyncGlobal implements IAsyncGlobal { - private final NeuralNet current; + @Getter + private final NN current; public boolean hasBeenStarted = false; public boolean hasBeenTerminated = false; @@ -27,7 +28,7 @@ public class MockAsyncGlobal implements IAsyncGlobal { this(null); } - public MockAsyncGlobal(NeuralNet current) { + public MockAsyncGlobal(NN current) { maxLoops = Integer.MAX_VALUE; numLoopsStopRunning = Integer.MAX_VALUE; this.current = current; @@ -45,7 +46,7 @@ public class MockAsyncGlobal implements IAsyncGlobal { @Override public boolean isTrainingComplete() { - return ++currentLoop > maxLoops; + return currentLoop >= maxLoops; } @Override @@ -59,12 +60,7 @@ public class MockAsyncGlobal implements IAsyncGlobal { } @Override - public NeuralNet getCurrent() { - return current; - } - - @Override - public NeuralNet getTarget() { + public NN getTarget() { return current; } @@ -72,4 +68,8 @@ public class MockAsyncGlobal implements IAsyncGlobal { public void enqueue(Gradient[] gradient, Integer nstep) { ++enqueueCallCount; } + + public void increaseCurrentLoop() { + ++currentLoop; + } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java index 680f9a653..28d7f3914 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java @@ -15,8 +15,10 @@ import java.util.List; public class MockDQN implements IDQN { + public boolean hasBeenReset = false; public final List outputParams = new ArrayList<>(); public final List> fitParams = new ArrayList<>(); + public final List> gradientParams = new ArrayList<>(); @Override public NeuralNetwork[] getNeuralNetworks() { @@ -30,7 +32,7 @@ public class MockDQN implements IDQN { @Override public void reset() { - + hasBeenReset = true; } @Override @@ -61,7 +63,10 @@ public class MockDQN implements IDQN { @Override public IDQN clone() { - return null; + MockDQN clone = new MockDQN(); + clone.hasBeenReset = hasBeenReset; + + return clone; } @Override @@ -76,6 +81,7 @@ public class MockDQN implements IDQN { @Override public Gradient[] gradient(INDArray input, INDArray label) { + gradientParams.add(new Pair(input, label)); return new Gradient[0]; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java index 6d542934b..a5d7c5f3e 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java @@ -35,7 +35,7 @@ public class MockNeuralNet implements NeuralNet { @Override public INDArray[] outputAll(INDArray batch) { outputAllInputs.add(batch); - return new INDArray[] { Nd4j.create(new double[] { 1.0 }) }; + return new INDArray[] { Nd4j.create(new double[] { outputAllInputs.size() }) }; } @Override diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java index 82adc65b7..4c4f100e9 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java @@ -2,6 +2,7 @@ package org.deeplearning4j.rl4j.support; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -23,6 +24,11 @@ public class MockPolicy implements IPolicy { @Override public Integer nextAction(INDArray input) { actionInputs.add(input); - return null; + return input.getInt(0, 0, 0); + } + + @Override + public Integer nextAction(Observation observation) { + return nextAction(observation.getData()); } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java index 97bf5cc28..d4e696248 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockTrainingListener.java @@ -11,6 +11,7 @@ import java.util.List; public class MockTrainingListener implements TrainingListener { + private final MockAsyncGlobal asyncGlobal; public int onTrainingStartCallCount = 0; public int onTrainingEndCallCount = 0; public int onNewEpochCallCount = 0; @@ -28,6 +29,14 @@ public class MockTrainingListener implements TrainingListener { public final List statEntries = new ArrayList<>(); + public MockTrainingListener() { + this(null); + } + + public MockTrainingListener(MockAsyncGlobal asyncGlobal) { + this.asyncGlobal = asyncGlobal; + } + @Override public ListenerResponse onTrainingStart() { @@ -55,6 +64,9 @@ public class MockTrainingListener implements TrainingListener { public ListenerResponse onTrainingProgress(ILearning learning) { ++onTrainingProgressCallCount; --remainingonTrainingProgressCallCount; + if(asyncGlobal != null) { + asyncGlobal.increaseCurrentLoop(); + } return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE; } From e718cc659b33d8e8a83d03fa1c2459ae9546ba15 Mon Sep 17 00:00:00 2001 From: Max Pumperla Date: Wed, 18 Dec 2019 15:56:03 +0100 Subject: [PATCH 3/5] python version bump --- pydatavec/setup.py | 2 +- pydl4j/setup.py | 36 ++++++++++++++++++------------------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pydatavec/setup.py b/pydatavec/setup.py index cfe14da33..bf203d3f6 100644 --- a/pydatavec/setup.py +++ b/pydatavec/setup.py @@ -22,7 +22,7 @@ from setuptools import find_packages setup(name='pydatavec', - version='0.1.1', + version='0.1.2', description='Python interface for DataVec', long_description='Python interface for DataVec', diff --git a/pydl4j/setup.py b/pydl4j/setup.py index 9ef751fcd..f8598be1a 100644 --- a/pydl4j/setup.py +++ b/pydl4j/setup.py @@ -1,27 +1,27 @@ -################################################################################ -# Copyright (c) 2015-2019 Skymind, Inc. -# -# This program and the accompanying materials are made available under the -# terms of the Apache License, Version 2.0 which is available at -# https://www.apache.org/licenses/LICENSE-2.0. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -# -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - +################################################################################ +# Copyright (c) 2015-2019 Skymind, Inc. +# +# This program and the accompanying materials are made available under the +# terms of the Apache License, Version 2.0 which is available at +# https://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + from setuptools import setup from setuptools import find_packages setup( name='pydl4j', - version='0.1.3', + version='0.1.4', packages=find_packages(), - install_requires=['Cython', 'jnius', 'requests', + install_requires=['Cython', 'pyjnius', 'requests', 'click', 'argcomplete', 'python-dateutil'], extras_require={ 'tests': ['pytest', 'pytest-pep8', 'pytest-cov'] From 4ffef95a2c511117a2c50687b141aee9499ad6e3 Mon Sep 17 00:00:00 2001 From: Max Pumperla Date: Wed, 18 Dec 2019 16:00:30 +0100 Subject: [PATCH 4/5] increase --- pydl4j/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydl4j/setup.py b/pydl4j/setup.py index f8598be1a..f60437c84 100644 --- a/pydl4j/setup.py +++ b/pydl4j/setup.py @@ -19,7 +19,7 @@ from setuptools import find_packages setup( name='pydl4j', - version='0.1.4', + version='0.1.5', packages=find_packages(), install_requires=['Cython', 'pyjnius', 'requests', 'click', 'argcomplete', 'python-dateutil'], From 9edbefdc67f6174564fd667dfe20f88c2cdfd178 Mon Sep 17 00:00:00 2001 From: Samuel Audet Date: Mon, 20 Jan 2020 17:13:57 +0900 Subject: [PATCH 5/5] RL4J: Replace gym-java-client with JavaCPP (#8595) * RL4J: Replace gym-java-client with JavaCPP Signed-off-by: Samuel Audet --- README.md | 1 - gym-java-client/.gitignore | 4 - gym-java-client/LICENSE.txt | 201 ---------- gym-java-client/README.md | 108 ------ gym-java-client/contrib/formatter.xml | 354 ------------------ gym-java-client/pom.xml | 337 ----------------- .../java/org/deeplearning4j/gym/Client.java | 200 ---------- .../org/deeplearning4j/gym/ClientFactory.java | 90 ----- .../org/deeplearning4j/gym/ClientUtils.java | 75 ---- .../rl4j/space/GymObservationSpace.java | 82 ---- .../deeplearning4j/gym/test/ClientTest.java | 140 ------- .../gym/test/JSONObjectMatcher.java | 46 --- pom.xml | 22 +- rl4j/README.md | 10 +- rl4j/pom.xml | 50 +++ rl4j/rl4j-ale/pom.xml | 9 + rl4j/rl4j-api/pom.xml | 14 +- .../org/deeplearning4j/gym/StepReply.java | 3 +- .../rl4j/space/ActionSpace.java | 0 .../rl4j/space/ArrayObservationSpace.java | 0 .../org/deeplearning4j/rl4j/space/Box.java | 12 +- .../rl4j/space/DiscreteSpace.java | 0 .../deeplearning4j/rl4j/space/Encodable.java | 0 .../rl4j/space/HighLowDiscrete.java | 12 +- .../rl4j/space/ObservationSpace.java | 0 rl4j/rl4j-core/pom.xml | 26 +- .../rl4j/mdp/toy/HardDeteministicToy.java | 3 +- .../rl4j/mdp/toy/SimpleToy.java | 3 +- .../deeplearning4j/rl4j/util/DataManager.java | 49 ++- .../rl4j/util/LegacyMDPWrapper.java | 5 +- rl4j/rl4j-doom/pom.xml | 9 + rl4j/rl4j-gym/pom.xml | 15 +- .../deeplearning4j/rl4j/mdp/gym/GymEnv.java | 169 +++++++-- .../rl4j/mdp/gym/GymEnvTest.java | 48 +++ rl4j/rl4j-malmo/pom.xml | 14 + 35 files changed, 358 insertions(+), 1753 deletions(-) delete mode 100644 gym-java-client/.gitignore delete mode 100644 gym-java-client/LICENSE.txt delete mode 100644 gym-java-client/README.md delete mode 100644 gym-java-client/contrib/formatter.xml delete mode 100644 gym-java-client/pom.xml delete mode 100644 gym-java-client/src/main/java/org/deeplearning4j/gym/Client.java delete mode 100644 gym-java-client/src/main/java/org/deeplearning4j/gym/ClientFactory.java delete mode 100644 gym-java-client/src/main/java/org/deeplearning4j/gym/ClientUtils.java delete mode 100644 gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/GymObservationSpace.java delete mode 100644 gym-java-client/src/test/java/org/deeplearning4j/gym/test/ClientTest.java delete mode 100644 gym-java-client/src/test/java/org/deeplearning4j/gym/test/JSONObjectMatcher.java rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/gym/StepReply.java (95%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java (100%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java (100%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/Box.java (83%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java (100%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java (100%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java (82%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java (100%) create mode 100644 rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java diff --git a/README.md b/README.md index 2b69f7981..6a3d206c7 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,6 @@ Welcome to the new monorepo of Deeplearning4j that contains the source code for * https://github.com/eclipse/deeplearning4j/tree/master/datavec * https://github.com/eclipse/deeplearning4j/tree/master/arbiter * https://github.com/eclipse/deeplearning4j/tree/master/nd4s - * https://github.com/eclipse/deeplearning4j/tree/master/gym-java-client * https://github.com/eclipse/deeplearning4j/tree/master/rl4j * https://github.com/eclipse/deeplearning4j/tree/master/scalnet * https://github.com/eclipse/deeplearning4j/tree/master/pydl4j diff --git a/gym-java-client/.gitignore b/gym-java-client/.gitignore deleted file mode 100644 index 089ca5bb0..000000000 --- a/gym-java-client/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -target/ -.idea/ -*.iml -*-git.properties diff --git a/gym-java-client/LICENSE.txt b/gym-java-client/LICENSE.txt deleted file mode 100644 index 5c304d1a4..000000000 --- a/gym-java-client/LICENSE.txt +++ /dev/null @@ -1,201 +0,0 @@ -Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/gym-java-client/README.md b/gym-java-client/README.md deleted file mode 100644 index 3ae9aa408..000000000 --- a/gym-java-client/README.md +++ /dev/null @@ -1,108 +0,0 @@ -# gym-java-client - -A java http client for [gym-http-api](https://github.com/openai/gym-http-api). - -Note: If you are encountering errors as reported in [issue #13](https://github.com/deeplearning4j/gym-java-client/issues/13), please execute the following command before launching `python gym_http_server.py`: - -```bash -$ sudo sysctl -w net.ipv4.tcp_tw_recycle=1 -``` - -# Quickstart - -To create a new Client, use the ClientFactory. If the url is not localhost:5000, provide it as a second argument - -```java -Client client = ClientFactory.build("CartPole-v0"); -``` - -"CartPole-v0" is the name of the gym environment. - -The type parameters of a client are the Observation type, the Action type, the Observation Space type and the ActionSpace type. - -It is a bit cumbersome to both declare an ActionSpace and an Action since an ActionSpace knows what type is an Action but unfortunately java does't support type member and path dependant types. - -Here we use Box and BoxSpace for the environment and Integer and Discrete Space because it is how [CartPole-v0](https://gym.openai.com/envs/CartPole-v0) is specified. - -The methods nomenclature follows closely the api interface of [gym-http-api](https://github.com/openai/gym-http-api#api-specification), O is Observation an A is Action: - -```java -//Static methods - -/** - * @param url url of the server - * @return set of all environments running on the server at the url - */ -public static Set listAll(String url); - -/** - * Shutdown the server at the url - * - * @param url url of the server - */ -public static void serverShutdown(String url); - - - -//Methods accessible from a Client -/** - * @return set of all environments running on the same server than this client - */ -public Set listAll(); - -/** - * Step the environment by one action - * - * @param action action to step the environment with - * @return the StepReply containing the next observation, the reward, if it is a terminal state and optional information. - */ -public StepReply step(A action); -/** - * Reset the state of the environment and return an initial observation. - * - * @return initial observation - */ -public O reset(); - -/** - * Start monitoring. - * - * @param directory path to directory in which store the monitoring file - * @param force clear out existing training data from this directory (by deleting every file prefixed with "openaigym.") - * @param resume retain the training data already in this directory, which will be merged with our new data - */ -public void monitorStart(String directory, boolean force, boolean resume); - -/** - * Flush all monitor data to disk - */ -public void monitorClose(); - -/** - * Upload monitoring data to OpenAI servers. - * - * @param trainingDir directory that contains the monitoring data - * @param apiKey personal OpenAI API key - * @param algorithmId an arbitrary string indicating the paricular version of the algorithm (including choices of parameters) you are running. - **/ -public void upload(String trainingDir, String apiKey, String algorithmId); - -/** - * Upload monitoring data to OpenAI servers. - * - * @param trainingDir directory that contains the monitoring data - * @param apiKey personal OpenAI API key - */ -public void upload(String trainingDir, String apiKey); - - -/** - * Shutdown the server at the same url than this client - */ -public void serverShutdown() - -``` - -## TODO - -* Add all ObservationSpace and ActionSpace when they will be available. diff --git a/gym-java-client/contrib/formatter.xml b/gym-java-client/contrib/formatter.xml deleted file mode 100644 index 1d0adbbe1..000000000 --- a/gym-java-client/contrib/formatter.xml +++ /dev/null @@ -1,354 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/gym-java-client/pom.xml b/gym-java-client/pom.xml deleted file mode 100644 index fc82d01b2..000000000 --- a/gym-java-client/pom.xml +++ /dev/null @@ -1,337 +0,0 @@ - - - - - - - - org.deeplearning4j - deeplearning4j - 1.0.0-SNAPSHOT - - - 4.0.0 - - org.deeplearning4j - gym-java-client - - gym-java-client - A Java client for Open AI's Reinforcement Learning Gym - - - - Apache License, Version 2.0 - http://www.apache.org/licenses/LICENSE-2.0.txt - repo - - - - - - rubenfiszel - Ruben Fiszel - ruben.fiszel@epfl.ch - - - - - nd4j-native - - - - - org.nd4j - ${nd4j.backend} - ${nd4j.version} - - - commons-codec - commons-codec - ${commons-codec.version} - - - org.apache.httpcomponents - httpclient - ${httpclient.version} - - - org.apache.httpcomponents - httpcore - ${httpcore.version} - - - org.apache.httpcomponents - httpmime - ${httpmime.version} - - - com.mashape.unirest - unirest-java - ${unirest.version} - - - org.objenesis - objenesis - ${objenesis.version} - - - org.mockito - mockito-core - ${mockito.version} - test - - - org.ow2.asm - asm - ${asm.version} - - - cglib - cglib - 3.1 - test - - - junit - junit - ${junit.version} - test - - - org.powermock - powermock-api-mockito2 - 1.7.3 - test - - - org.powermock - powermock-module-junit4 - 1.7.3 - test - - - org.slf4j - slf4j-api - ${slf4j.version} - - - ch.qos.logback - logback-classic - ${logback.version} - test - - - ch.qos.logback - logback-core - ${logback.version} - test - - - org.projectlombok - lombok - ${lombok.version} - provided - - - - - - - maven-source-plugin - ${maven-source-plugin.version} - - - attach-sources - - jar - - - - - - maven-surefire-plugin - ${maven-surefire-plugin.version} - - - true - false - - - - maven-javadoc-plugin - ${maven-javadoc-plugin.version} - - -Xdoclint:none - - - - attach-javadocs - - jar - - - - - - com.lewisd - lint-maven-plugin - 0.0.11 - - true - - DuplicateDep - RedundantDepVersion - RedundantPluginVersion - - - - - - pom-lint - validate - - check - - - - - - net.revelc.code.formatter - formatter-maven-plugin - 2.0.0 - - ${session.executionRootDirectory}/contrib/formatter.xml - - - - - pl.project13.maven - git-commit-id-plugin - ${maven-git-commit-plugin.version} - - - - revision - - generate-resources - - - - true - - ${project.basedir}/target/generated-sources/src/main/resources/ai/skymind/${project.groupId}-${project.artifactId}-git.properties - - - true - - - - - - org.codehaus.mojo - build-helper-maven-plugin - ${maven-build-helper-plugin.version} - - - add-resource - generate-resources - - add-resource - - - - - - ${project.basedir}/target/generated-sources/src/main/resources - - - - - - - - - - - - org.eclipse.m2e - lifecycle-mapping - 1.0.0 - - - - - - com.lewisd - lint-maven-plugin - [0.0.11,) - - check - - - - - - - - - - - - - - - - - - maven-surefire-report-plugin - 2.19.1 - - - - org.codehaus.mojo - cobertura-maven-plugin - 2.7 - - - - - - - nd4j-backend - - - libnd4j.cuda - - - - nd4j-cuda-${libnd4j.cuda} - - - - diff --git a/gym-java-client/src/main/java/org/deeplearning4j/gym/Client.java b/gym-java-client/src/main/java/org/deeplearning4j/gym/Client.java deleted file mode 100644 index c3f221c3f..000000000 --- a/gym-java-client/src/main/java/org/deeplearning4j/gym/Client.java +++ /dev/null @@ -1,200 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.gym; - - -import com.mashape.unirest.http.JsonNode; -import lombok.Value; -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.rl4j.space.GymObservationSpace; -import org.deeplearning4j.rl4j.space.ActionSpace; -import org.json.JSONObject; - -import java.util.Set; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16. - * - * A client represent an active connection to a specific instance of an environment on a rl4j-http-api server. - * for API specification - * - * @param Observation type - * @param Action type - * @param Action Space type - * @see https://github.com/openai/gym-http-api#api-specification - */ -@Slf4j -@Value -public class Client> { - - - public static String V1_ROOT = "/v1"; - public static String ENVS_ROOT = V1_ROOT + "/envs/"; - - public static String MONITOR_START = "/monitor/start/"; - public static String MONITOR_CLOSE = "/monitor/close/"; - public static String CLOSE = "/close/"; - public static String RESET = "/reset/"; - public static String SHUTDOWN = "/shutdown/"; - public static String UPLOAD = "/upload/"; - public static String STEP = "/step/"; - public static String OBSERVATION_SPACE = "/observation_space/"; - public static String ACTION_SPACE = "/action_space/"; - - - String url; - String envId; - String instanceId; - GymObservationSpace observationSpace; - AS actionSpace; - boolean render; - - - /** - * @param url url of the server - * @return set of all environments running on the server at the url - */ - public static Set listAll(String url) { - JSONObject reply = ClientUtils.get(url + ENVS_ROOT); - return reply.getJSONObject("envs").keySet(); - } - - /** - * Shutdown the server at the url - * - * @param url url of the server - */ - public static void serverShutdown(String url) { - ClientUtils.post(url + ENVS_ROOT + SHUTDOWN, new JSONObject()); - } - - /** - * @return set of all environments running on the same server than this client - */ - public Set listAll() { - return listAll(url); - } - - /** - * Step the environment by one action - * - * @param action action to step the environment with - * @return the StepReply containing the next observation, the reward, if it is a terminal state and optional information. - */ - public StepReply step(A action) { - JSONObject body = new JSONObject().put("action", getActionSpace().encode(action)).put("render", render); - - JSONObject reply = ClientUtils.post(url + ENVS_ROOT + instanceId + STEP, body).getObject(); - - O observation = observationSpace.getValue(reply, "observation"); - double reward = reply.getDouble("reward"); - boolean done = reply.getBoolean("done"); - JSONObject info = reply.getJSONObject("info"); - - return new StepReply(observation, reward, done, info); - } - - /** - * Reset the state of the environment and return an initial observation. - * - * @return initial observation - */ - public O reset() { - JsonNode resetRep = ClientUtils.post(url + ENVS_ROOT + instanceId + RESET, new JSONObject()); - return observationSpace.getValue(resetRep.getObject(), "observation"); - } - - /* - Present in the doc but not working currently server-side - public void monitorStart(String directory) { - - JSONObject json = new JSONObject() - .put("directory", directory); - - monitorStartPost(json); - } - */ - - /** - * Start monitoring. - * - * @param directory path to directory in which store the monitoring file - * @param force clear out existing training data from this directory (by deleting every file prefixed with "openaigym.") - * @param resume retain the training data already in this directory, which will be merged with our new data - */ - public void monitorStart(String directory, boolean force, boolean resume) { - JSONObject json = new JSONObject().put("directory", directory).put("force", force).put("resume", resume); - - monitorStartPost(json); - } - - private void monitorStartPost(JSONObject json) { - ClientUtils.post(url + ENVS_ROOT + instanceId + MONITOR_START, json); - } - - /** - * Flush all monitor data to disk - */ - public void monitorClose() { - ClientUtils.post(url + ENVS_ROOT + instanceId + MONITOR_CLOSE, new JSONObject()); - } - - /** - * Upload monitoring data to OpenAI servers. - * - * @param trainingDir directory that contains the monitoring data - * @param apiKey personal OpenAI API key - * @param algorithmId an arbitrary string indicating the paricular version of the algorithm (including choices of parameters) you are running. - **/ - public void upload(String trainingDir, String apiKey, String algorithmId) { - JSONObject json = new JSONObject().put("training_dir", trainingDir).put("api_key", apiKey).put("algorithm_id", - algorithmId); - - uploadPost(json); - } - - /** - * Upload monitoring data to OpenAI servers. - * - * @param trainingDir directory that contains the monitoring data - * @param apiKey personal OpenAI API key - */ - public void upload(String trainingDir, String apiKey) { - JSONObject json = new JSONObject().put("training_dir", trainingDir).put("api_key", apiKey); - - uploadPost(json); - } - - private void uploadPost(JSONObject json) { - try { - ClientUtils.post(url + V1_ROOT + UPLOAD, json); - } catch (RuntimeException e) { - log.error("Impossible to upload: Wrong API key?"); - } - } - - /** - * Shutdown the server at the same url than this client - */ - public void serverShutdown() { - serverShutdown(url); - } - - public ActionSpace getActionSpace(){ - return actionSpace; - } -} diff --git a/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientFactory.java b/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientFactory.java deleted file mode 100644 index 57002fb77..000000000 --- a/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientFactory.java +++ /dev/null @@ -1,90 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.gym; - -import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.GymObservationSpace; -import org.deeplearning4j.rl4j.space.HighLowDiscrete; -import org.json.JSONArray; -import org.json.JSONException; -import org.json.JSONObject; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. - * - * ClientFactory contains builder method to create a new {@link Client} - */ -public class ClientFactory { - - public static > Client build(String url, String envId, boolean render) { - - JSONObject body = new JSONObject().put("env_id", envId); - JSONObject reply = ClientUtils.post(url + Client.ENVS_ROOT, body).getObject(); - - String instanceId; - - try { - instanceId = reply.getString("instance_id"); - } catch (JSONException e) { - throw new RuntimeException("Environment id not found", e); - } - - GymObservationSpace observationSpace = fetchObservationSpace(url, instanceId); - AS actionSpace = fetchActionSpace(url, instanceId); - - return new Client(url, envId, instanceId, observationSpace, actionSpace, render); - - } - - public static > Client build(String envId, boolean render) { - return build("http://127.0.0.1:5000", envId, render); - } - - public static AS fetchActionSpace(String url, String instanceId) { - - JSONObject reply = ClientUtils.get(url + Client.ENVS_ROOT + instanceId + Client.ACTION_SPACE); - JSONObject info = reply.getJSONObject("info"); - String infoName = info.getString("name"); - - switch (infoName) { - case "Discrete": - return (AS) new DiscreteSpace(info.getInt("n")); - case "HighLow": - int numRows = info.getInt("num_rows"); - int size = 3 * numRows; - JSONArray matrixJson = info.getJSONArray("matrix"); - INDArray matrix = Nd4j.create(numRows, 3); - for (int i = 0; i < size; i++) { - matrix.putScalar(i, matrixJson.getDouble(i)); - } - matrix.reshape(numRows, 3); - return (AS) new HighLowDiscrete(matrix); - default: - throw new RuntimeException("Unknown action space " + infoName); - } - } - - public static GymObservationSpace fetchObservationSpace(String url, String instanceId) { - JSONObject reply = ClientUtils.get(url + Client.ENVS_ROOT + instanceId + Client.OBSERVATION_SPACE); - JSONObject info = reply.getJSONObject("info"); - return new GymObservationSpace(info); - } -} diff --git a/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientUtils.java b/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientUtils.java deleted file mode 100644 index ff61026ae..000000000 --- a/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientUtils.java +++ /dev/null @@ -1,75 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.gym; - -import com.mashape.unirest.http.HttpResponse; -import com.mashape.unirest.http.JsonNode; -import com.mashape.unirest.http.Unirest; -import com.mashape.unirest.http.exceptions.UnirestException; -import org.json.JSONException; -import org.json.JSONObject; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. - * - * ClientUtils contain the utility methods to post and get data from the server REST API through the library unirest. - */ -public class ClientUtils { - - static public JsonNode post(String url, JSONObject json) { - HttpResponse jsonResponse = null; - - try { - jsonResponse = Unirest.post(url).header("content-type", "application/json").body(json).asJson(); - } catch (UnirestException e) { - unirestCrash(e); - } - - return jsonResponse.getBody(); - } - - - static public JSONObject get(String url) { - HttpResponse jsonResponse = null; - - try { - jsonResponse = Unirest.get(url).header("content-type", "application/json").asJson(); - } catch (UnirestException e) { - unirestCrash(e); - } - - checkReply(jsonResponse, url); - - return jsonResponse.getBody().getObject(); - } - - - static public void checkReply(HttpResponse res, String url) { - if (res.getBody() == null) - throw new RuntimeException("Invalid reply at: " + url); - } - - static public void unirestCrash(UnirestException e) { - //if couldn't parse json - if (e.getCause().getCause().getCause() instanceof JSONException) - throw new RuntimeException("Couldn't parse json reply.", e); - else - throw new RuntimeException("Connection error", e); - } - - -} diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/GymObservationSpace.java b/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/GymObservationSpace.java deleted file mode 100644 index e8f9b2bbd..000000000 --- a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/GymObservationSpace.java +++ /dev/null @@ -1,82 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.space; - -import lombok.Value; -import org.json.JSONArray; -import org.json.JSONObject; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. - * - * Contain contextual information about the environment from which Observations are observed and must know how to build an Observation from json. - * - * @param the type of Observation - */ - -@Value -public class GymObservationSpace implements ObservationSpace { - - String name; - int[] shape; - INDArray low; - INDArray high; - - - public GymObservationSpace(JSONObject jsonObject) { - - name = jsonObject.getString("name"); - - JSONArray arr = jsonObject.getJSONArray("shape"); - int lg = arr.length(); - - shape = new int[lg]; - for (int i = 0; i < lg; i++) { - this.shape[i] = arr.getInt(i); - } - - low = Nd4j.create(shape); - high = Nd4j.create(shape); - - JSONArray lowJson = jsonObject.getJSONArray("low"); - JSONArray highJson = jsonObject.getJSONArray("high"); - - int size = shape[0]; - for (int i = 1; i < shape.length; i++) { - size *= shape[i]; - } - - for (int i = 0; i < size; i++) { - low.putScalar(i, lowJson.getDouble(i)); - high.putScalar(i, highJson.getDouble(i)); - } - - } - - public O getValue(JSONObject o, String key) { - switch (name) { - case "Box": - JSONArray arr = o.getJSONArray(key); - return (O) new Box(arr); - default: - throw new RuntimeException("Invalid environment name: " + name); - } - } - -} diff --git a/gym-java-client/src/test/java/org/deeplearning4j/gym/test/ClientTest.java b/gym-java-client/src/test/java/org/deeplearning4j/gym/test/ClientTest.java deleted file mode 100644 index ab31a488b..000000000 --- a/gym-java-client/src/test/java/org/deeplearning4j/gym/test/ClientTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.gym.test; - -import com.mashape.unirest.http.JsonNode; -import org.deeplearning4j.gym.Client; -import org.deeplearning4j.gym.ClientFactory; -import org.deeplearning4j.gym.ClientUtils; -import org.deeplearning4j.gym.StepReply; -import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Box; -import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.json.JSONObject; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; - -import static org.mockito.Matchers.eq; -import static org.powermock.api.mockito.PowerMockito.mockStatic; -import static org.powermock.api.mockito.PowerMockito.when; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/11/16. - */ - -@RunWith(PowerMockRunner.class) -@PrepareForTest(ClientUtils.class) -public class ClientTest { - - @Test - public void completeClientTest() { - - String url = "http://127.0.0.1:5000"; - String env = "Powermock-v0"; - String instanceID = "e15739cf"; - String testDir = "/tmp/testDir"; - boolean render = true; - String renderStr = render ? "True" : "False"; - - mockStatic(ClientUtils.class); - - //post mock - - JSONObject buildReq = new JSONObject("{\"env_id\":\"" + env + "\"}"); - JsonNode buildRep = new JsonNode("{\"instance_id\":\"" + instanceID + "\"}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT), JSONObjectMatcher.jsonEq(buildReq))).thenReturn(buildRep); - - JSONObject monStartReq = new JSONObject("{\"resume\":false,\"directory\":\"" + testDir + "\",\"force\":true}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.MONITOR_START), - JSONObjectMatcher.jsonEq(monStartReq))).thenReturn(null); - - JSONObject monStopReq = new JSONObject("{}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.MONITOR_CLOSE), - JSONObjectMatcher.jsonEq(monStopReq))).thenReturn(null); - - JSONObject resetReq = new JSONObject("{}"); - JsonNode resetRep = new JsonNode( - "{\"observation\":[0.021729452941849317,-0.04764548144956857,-0.024914502756611293,-0.04074903379512588]}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.RESET), - JSONObjectMatcher.jsonEq(resetReq))).thenReturn(resetRep); - - JSONObject stepReq = new JSONObject("{\"action\":0, \"render\":" + renderStr + "}"); - JsonNode stepRep = new JsonNode( - "{\"observation\":[0.020776543312857946,-0.24240146656155923,-0.02572948343251381,0.24397017400615437],\"reward\":1,\"done\":false,\"info\":{}}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.STEP), JSONObjectMatcher.jsonEq(stepReq))) - .thenReturn(stepRep); - - JSONObject stepReq2 = new JSONObject("{\"action\":1, \"render\":" + renderStr + "}"); - JsonNode stepRep2 = new JsonNode( - "{\"observation\":[0.020776543312857946,-0.24240146656155923,-0.02572948343251381,0.24397017400615437],\"reward\":1,\"done\":false,\"info\":{}}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.STEP), - JSONObjectMatcher.jsonEq(stepReq2))).thenReturn(stepRep2); - - //get mock - JSONObject obsSpace = new JSONObject( - "{\"info\":{\"name\":\"Box\",\"shape\":[4],\"high\":[4.8,3.4028234663852886E38,0.41887902047863906,3.4028234663852886E38],\"low\":[-4.8,-3.4028234663852886E38,-0.41887902047863906,-3.4028234663852886E38]}}"); - when(ClientUtils.get(eq(url + Client.ENVS_ROOT + instanceID + Client.OBSERVATION_SPACE))).thenReturn(obsSpace); - - JSONObject actionSpace = new JSONObject("{\"info\":{\"name\":\"Discrete\",\"n\":2}}"); - when(ClientUtils.get(eq(url + Client.ENVS_ROOT + instanceID + Client.ACTION_SPACE))).thenReturn(actionSpace); - - - //test - - Client client = ClientFactory.build(url, env, render); - client.monitorStart(testDir, true, false); - - int episodeCount = 1; - int maxSteps = 200; - int reward = 0; - - for (int i = 0; i < episodeCount; i++) { - client.reset(); - - for (int j = 0; j < maxSteps; j++) { - - Integer action = ((ActionSpace)client.getActionSpace()).randomAction(); - StepReply step = client.step(action); - reward += step.getReward(); - - //return a isDone true before i == maxSteps - if (j == maxSteps - 5) { - JSONObject stepReqLoc = new JSONObject("{\"action\":0}"); - JsonNode stepRepLoc = new JsonNode( - "{\"observation\":[0.020776543312857946,-0.24240146656155923,-0.02572948343251381,0.24397017400615437],\"reward\":1,\"done\":true,\"info\":{}}"); - - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.STEP), - JSONObjectMatcher.jsonEq(stepReqLoc))).thenReturn(stepRepLoc); - } - - if (step.isDone()) { - // System.out.println("break"); - break; - } - } - - } - - client.monitorClose(); - client.upload(testDir, "YOUR_OPENAI_GYM_API_KEY"); - - - } - -} diff --git a/gym-java-client/src/test/java/org/deeplearning4j/gym/test/JSONObjectMatcher.java b/gym-java-client/src/test/java/org/deeplearning4j/gym/test/JSONObjectMatcher.java deleted file mode 100644 index ff1430786..000000000 --- a/gym-java-client/src/test/java/org/deeplearning4j/gym/test/JSONObjectMatcher.java +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.gym.test; - -import org.json.JSONObject; -import org.mockito.ArgumentMatcher; - -import static org.mockito.Matchers.argThat; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/11/16. - */ - - -public class JSONObjectMatcher implements ArgumentMatcher { - private final JSONObject expected; - - public JSONObjectMatcher(JSONObject expected) { - this.expected = expected; - } - - public static JSONObject jsonEq(JSONObject expected) { - return argThat(new JSONObjectMatcher(expected)); - } - - - @Override - public boolean matches(JSONObject argument) { - if (expected == null) - return argument == null; - return expected.toString().equals(argument.toString()); } -} diff --git a/pom.xml b/pom.xml index ada833f12..3d7082524 100644 --- a/pom.xml +++ b/pom.xml @@ -132,7 +132,6 @@ deeplearning4j arbiter nd4s - gym-java-client rl4j scalnet jumpy @@ -288,22 +287,23 @@ ${javacpp.platform} - 1.5.2 - 1.5.2 - 1.5.2 + 1.5.3-SNAPSHOT + 1.5.3-SNAPSHOT + 1.5.3-SNAPSHOT - 3.7.5 + 3.7.6 ${python.version}-${javacpp-presets.version} - 1.17.3 + 1.18.0 ${numpy.version}-${javacpp-presets.version} 0.3.7 2019.5 - 4.1.2 - 4.2.1 - 1.78.0 - 1.10.5 - 0.6.0 + 4.2.0 + 4.2.2 + 1.79.0 + 1.10.6 + 0.6.1 + 0.15.4 1.15.0 ${tensorflow.version}-${javacpp-presets.version} diff --git a/rl4j/README.md b/rl4j/README.md index ef5797701..68d5b755e 100644 --- a/rl4j/README.md +++ b/rl4j/README.md @@ -32,10 +32,6 @@ Comments are welcome on our gitter channel: # Quickstart -** INSTALL rl4j-api before installing all (see below)!** - -* mvn install -pl rl4j-api -* [if you want rl4j-gym too] Download and mvn install: [gym-java-client](https://github.com/eclipse/deeplearning4j/tree/master/gym-java-client) * mvn install # Visualisation @@ -44,9 +40,7 @@ Comments are welcome on our gitter channel: # Quicktry cartpole: -* Install [gym-http-api](https://github.com/openai/gym-http-api). -* launch http api server. -* run with this [main](https://github.com/rubenfiszel/rl4j-examples/blob/master/src/main/java/org/deeplearning4j/rl4j/Cartpole.java) +* run with this [main](https://github.com/eclipse/deeplearning4j-examples/blob/master/rl4j-examples/src/main/java/org/deeplearning4j/examples/rl4j/Cartpole.java) # Doom @@ -83,4 +77,4 @@ Doom is not ready yet but you can make it work if you feel adventurous with some * Continuous control * Policy Gradient -* Update gym-java-client when gym-http-api gets compatible with pixels environments to play with Pong, Doom, etc .. +* Update rl4j-gym to make it compatible with pixels environments to play with Pong, Doom, etc .. diff --git a/rl4j/pom.xml b/rl4j/pom.xml index 318446629..c91dd7aa2 100644 --- a/rl4j/pom.xml +++ b/rl4j/pom.xml @@ -97,6 +97,30 @@ + + org.apache.maven.plugins + maven-enforcer-plugin + ${maven-enforcer-plugin.version} + + + test + enforce-choice-of-nd4j-test-backend + + enforce + + + ${skipBackendChoice} + + + test-nd4j-native,test-nd4j-cuda-10.2 + false + + + true + + + + maven-source-plugin ${maven-source-plugin.version} @@ -265,6 +289,32 @@ + + + test-nd4j-native + + + org.nd4j + nd4j-native + ${nd4j.version} + test + + + + + + test-nd4j-cuda-10.2 + + + org.nd4j + nd4j-cuda-10.2 + ${nd4j.version} + test + + + + + diff --git a/rl4j/rl4j-ale/pom.xml b/rl4j/rl4j-ale/pom.xml index 33211f480..360c3db1d 100644 --- a/rl4j/rl4j-ale/pom.xml +++ b/rl4j/rl4j-ale/pom.xml @@ -44,4 +44,13 @@ ${ale.version}-${javacpp-presets.version} + + + + test-nd4j-native + + + test-nd4j-cuda-10.2 + + diff --git a/rl4j/rl4j-api/pom.xml b/rl4j/rl4j-api/pom.xml index b89c875cc..629783e15 100644 --- a/rl4j/rl4j-api/pom.xml +++ b/rl4j/rl4j-api/pom.xml @@ -33,15 +33,19 @@ - - org.deeplearning4j - gym-java-client - ${dl4j.version} - org.nd4j nd4j-api ${nd4j.version} + + + + test-nd4j-native + + + test-nd4j-cuda-10.2 + + diff --git a/gym-java-client/src/main/java/org/deeplearning4j/gym/StepReply.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java similarity index 95% rename from gym-java-client/src/main/java/org/deeplearning4j/gym/StepReply.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java index 0da2971fe..ab054689a 100644 --- a/gym-java-client/src/main/java/org/deeplearning4j/gym/StepReply.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java @@ -17,7 +17,6 @@ package org.deeplearning4j.gym; import lombok.Value; -import org.json.JSONObject; /** * @param type of observation @@ -31,6 +30,6 @@ public class StepReply { T observation; double reward; boolean done; - JSONObject info; + Object info; } diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java similarity index 100% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java similarity index 100% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/Box.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java similarity index 83% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/Box.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java index c938747d5..e90601fda 100644 --- a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/Box.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java @@ -16,8 +16,6 @@ package org.deeplearning4j.rl4j.space; -import org.json.JSONArray; - /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. * @@ -29,14 +27,8 @@ public class Box implements Encodable { private final double[] array; - public Box(JSONArray arr) { - - int lg = arr.length(); - this.array = new double[lg]; - - for (int i = 0; i < lg; i++) { - this.array[i] = arr.getDouble(i); - } + public Box(double[] arr) { + this.array = arr; } public double[] toArray() { diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java similarity index 100% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java similarity index 100% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java similarity index 82% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java index d9c77f285..491f6aca1 100644 --- a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java @@ -17,10 +17,8 @@ package org.deeplearning4j.rl4j.space; import lombok.Value; -import org.json.JSONArray; import org.nd4j.linalg.api.ndarray.INDArray; - /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/26/16. */ @@ -37,12 +35,8 @@ public class HighLowDiscrete extends DiscreteSpace { @Override public Object encode(Integer a) { - JSONArray jsonArray = new JSONArray(); - for (int i = 0; i < size; i++) { - jsonArray.put(matrix.getDouble(i, 0)); - } - jsonArray.put(a - 1, matrix.getDouble(a - 1, 1)); - return jsonArray; + INDArray m = matrix.dup(); + m.put(a - 1, 0, matrix.getDouble(a - 1, 1)); + return m; } - } diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java similarity index 100% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml index ebdfec1a6..a78157603 100644 --- a/rl4j/rl4j-core/pom.xml +++ b/rl4j/rl4j-core/pom.xml @@ -44,11 +44,6 @@ ch.qos.logback logback-classic - - org.deeplearning4j - gym-java-client - ${dl4j.version} - org.bytedeco @@ -111,27 +106,10 @@ - nd4j-tests-cpu - - - org.nd4j - nd4j-native - ${project.version} - test - - + test-nd4j-native - - nd4j-tests-cuda - - - org.nd4j - nd4j-cuda-10.2 - ${project.version} - test - - + test-nd4j-cuda-10.2 diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java index a9708f550..2d8bc0402 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java @@ -23,7 +23,6 @@ import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.ArrayObservationSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.ObservationSpace; -import org.json.JSONObject; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -104,7 +103,7 @@ public class HardDeteministicToy implements MDP { public StepReply step(Integer a) { double reward = (simpleToyState.getStep() % 2 == 0) ? 1 - a : a; simpleToyState = new SimpleToyState(simpleToyState.getI() + 1, simpleToyState.getStep() + 1); - return new StepReply<>(simpleToyState, reward, isDone(), new JSONObject("{}")); + return new StepReply<>(simpleToyState, reward, isDone(), null); } public SimpleToy newInstance() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java index 4f7d6244e..b639efdaa 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java @@ -26,6 +26,7 @@ import org.deeplearning4j.rl4j.learning.NeuralNetFetchable; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.dqn.DQN; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.util.ModelSerializer; @@ -88,19 +89,38 @@ public class DataManager implements IDataManager { String json = new ObjectMapper().writeValueAsString(learning.getConfiguration()); writeEntry(new ByteArrayInputStream(json.getBytes()), zipfile); - ZipEntry dqn = new ZipEntry("dqn.bin"); - zipfile.putNextEntry(dqn); + try { + ZipEntry dqn = new ZipEntry("dqn.bin"); + zipfile.putNextEntry(dqn); - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - if(learning instanceof NeuralNetFetchable) { - ((NeuralNetFetchable)learning).getNeuralNet().save(bos); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + if(learning instanceof NeuralNetFetchable) { + ((NeuralNetFetchable)learning).getNeuralNet().save(bos); + } + bos.flush(); + bos.close(); + + InputStream inputStream = new ByteArrayInputStream(bos.toByteArray()); + writeEntry(inputStream, zipfile); + } catch (UnsupportedOperationException e) { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ByteArrayOutputStream bos2 = new ByteArrayOutputStream(); + ((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(bos, bos2); + + bos.flush(); + bos.close(); + InputStream inputStream = new ByteArrayInputStream(bos.toByteArray()); + ZipEntry value = new ZipEntry("value.bin"); + zipfile.putNextEntry(value); + writeEntry(inputStream, zipfile); + + bos2.flush(); + bos2.close(); + InputStream inputStream2 = new ByteArrayInputStream(bos2.toByteArray()); + ZipEntry policy = new ZipEntry("policy.bin"); + zipfile.putNextEntry(policy); + writeEntry(inputStream2, zipfile); } - bos.flush(); - bos.close(); - - InputStream inputStream = new ByteArrayInputStream(bos.toByteArray()); - writeEntry(inputStream, zipfile); - if (learning.getHistoryProcessor() != null) { ZipEntry hpconf = new ZipEntry("hpconf.bin"); @@ -268,7 +288,12 @@ public class DataManager implements IDataManager { save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning); if(learning instanceof NeuralNetFetchable) { - ((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model"); + try { + ((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model"); + } catch (UnsupportedOperationException e) { + String path = getModelDir() + "/" + learning.getStepCounter(); + ((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(path + "_value.model", path + "_policy.model"); + } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java index 221409040..dbcd38ddc 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java @@ -47,7 +47,10 @@ public class LegacyMDPWrapper> implements MDP${project.version} + + + + test-nd4j-native + + + test-nd4j-cuda-10.2 + + diff --git a/rl4j/rl4j-gym/pom.xml b/rl4j/rl4j-gym/pom.xml index d5be31032..76e2e39ff 100644 --- a/rl4j/rl4j-gym/pom.xml +++ b/rl4j/rl4j-gym/pom.xml @@ -39,9 +39,18 @@ ${project.version} - org.deeplearning4j - gym-java-client - ${project.version} + org.bytedeco + gym-platform + ${gym.version}-${javacpp-presets.version} + + + + test-nd4j-native + + + test-nd4j-cuda-10.2 + + diff --git a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java b/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java index 8f18cb29e..66be7698f 100644 --- a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java +++ b/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,39 +17,111 @@ package org.deeplearning4j.rl4j.mdp.gym; - -import org.deeplearning4j.gym.Client; -import org.deeplearning4j.gym.ClientFactory; +import java.io.IOException; +import lombok.Getter; +import lombok.Setter; +import lombok.Value; +import lombok.extern.slf4j.Slf4j; +import org.bytedeco.javacpp.DoublePointer; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.SizeTPointer; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.space.ArrayObservationSpace; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Box; +import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.HighLowDiscrete; import org.deeplearning4j.rl4j.space.ObservationSpace; +import org.bytedeco.cpython.*; +import org.bytedeco.numpy.*; +import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.numpy.global.numpy.*; + /** + * An MDP for OpenAI Gym: https://gym.openai.com/ + * * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16. - * - * Wrapper over the client of gym-java-client - * + * @author saudet */ +@Slf4j public class GymEnv> implements MDP { - final public static String GYM_MONITOR_DIR = "/tmp/gym-dqn"; + public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn"; - final private Client client; + private static void checkPythonError() { + if (PyErr_Occurred() != null) { + PyErr_Print(); + throw new RuntimeException("Python error occurred"); + } + } + + private static Pointer program; + private static PyObject globals; + static { + try { + Py_SetPath(org.bytedeco.gym.presets.gym.cachePackages()); + program = Py_DecodeLocale(GymEnv.class.getSimpleName(), null); + Py_SetProgramName(program); + Py_Initialize(); + PyEval_InitThreads(); + PySys_SetArgvEx(1, program, 0); + if (_import_array() < 0) { + PyErr_Print(); + throw new RuntimeException("numpy.core.multiarray failed to import"); + } + globals = PyModule_GetDict(PyImport_AddModule("__main__")); + PyEval_SaveThread(); // just to release the GIL + } catch (IOException e) { + PyMem_RawFree(program); + throw new RuntimeException(e); + } + } + private PyObject locals; + + final protected DiscreteSpace actionSpace; + final protected ObservationSpace observationSpace; + @Getter final private String envId; + @Getter final private boolean render; + @Getter final private boolean monitor; private ActionTransformer actionTransformer = null; private boolean done = false; public GymEnv(String envId, boolean render, boolean monitor) { - this.client = ClientFactory.build(envId, render); this.envId = envId; this.render = render; this.monitor = monitor; - if (monitor) - client.monitorStart(GYM_MONITOR_DIR, true, false); + + int gstate = PyGILState_Ensure(); + try { + locals = PyDict_New(); + + Py_DecRef(PyRun_StringFlags("import gym; env = gym.make('" + envId + "')", Py_single_input, globals, locals, null)); + checkPythonError(); + if (monitor) { + Py_DecRef(PyRun_StringFlags("env = gym.wrappers.Monitor(env, '" + GYM_MONITOR_DIR + "')", Py_single_input, globals, locals, null)); + checkPythonError(); + } + PyObject shapeTuple = PyRun_StringFlags("env.observation_space.shape", Py_eval_input, globals, locals, null); + int[] shape = new int[(int)PyTuple_Size(shapeTuple)]; + for (int i = 0; i < shape.length; i++) { + shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i)); + } + observationSpace = (ObservationSpace) new ArrayObservationSpace(shape); + Py_DecRef(shapeTuple); + + PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null); + actionSpace = new DiscreteSpace((int)PyLong_AsLong(n)); + Py_DecRef(n); + checkPythonError(); + } finally { + PyGILState_Release(gstate); + } } public GymEnv(String envId, boolean render, boolean monitor, int[] actions) { @@ -56,43 +129,87 @@ public class GymEnv> implements MDP { actionTransformer = new ActionTransformer((HighLowDiscrete) getActionSpace(), actions); } - + @Override public ObservationSpace getObservationSpace() { - return client.getObservationSpace(); + return observationSpace; } + @Override public AS getActionSpace() { if (actionTransformer == null) - return (AS) client.getActionSpace(); + return (AS) actionSpace; else return (AS) actionTransformer; } + @Override public StepReply step(A action) { - StepReply stepRep = client.step(action); - done = stepRep.isDone(); - return stepRep; + int gstate = PyGILState_Ensure(); + try { + if (render) { + Py_DecRef(PyRun_StringFlags("env.render()", Py_single_input, globals, locals, null)); + checkPythonError(); + } + Py_DecRef(PyRun_StringFlags("state, reward, done, info = env.step(" + (Integer)action +")", Py_single_input, globals, locals, null)); + checkPythonError(); + + PyArrayObject state = new PyArrayObject(PyDict_GetItemString(locals, "state")); + DoublePointer stateData = new DoublePointer(PyArray_BYTES(state)).capacity(PyArray_Size(state)); + SizeTPointer stateDims = PyArray_DIMS(state).capacity(PyArray_NDIM(state)); + + double reward = PyFloat_AsDouble(PyDict_GetItemString(locals, "reward")); + done = PyLong_AsLong(PyDict_GetItemString(locals, "done")) != 0; + checkPythonError(); + + double[] data = new double[(int)stateData.capacity()]; + stateData.get(data); + + return new StepReply(new Box(data), reward, done, null); + } finally { + PyGILState_Release(gstate); + } } + @Override public boolean isDone() { return done; } + @Override public O reset() { - done = false; - return client.reset(); - } - - - public void upload(String apiKey) { - client.upload(GYM_MONITOR_DIR, apiKey); + int gstate = PyGILState_Ensure(); + try { + Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null)); + checkPythonError(); + + PyArrayObject state = new PyArrayObject(PyDict_GetItemString(locals, "state")); + DoublePointer stateData = new DoublePointer(PyArray_BYTES(state)).capacity(PyArray_Size(state)); + SizeTPointer stateDims = PyArray_DIMS(state).capacity(PyArray_NDIM(state)); + checkPythonError(); + + done = false; + + double[] data = new double[(int)stateData.capacity()]; + stateData.get(data); + return (O) new Box(data); + } finally { + PyGILState_Release(gstate); + } } + @Override public void close() { - if (monitor) - client.monitorClose(); + int gstate = PyGILState_Ensure(); + try { + Py_DecRef(PyRun_StringFlags("env.close()", Py_single_input, globals, locals, null)); + checkPythonError(); + Py_DecRef(locals); + } finally { + PyGILState_Release(gstate); + } } + @Override public GymEnv newInstance() { return new GymEnv(envId, render, monitor); } diff --git a/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java b/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java new file mode 100644 index 000000000..2196d7b31 --- /dev/null +++ b/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.mdp.gym; + +import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.space.ArrayObservationSpace; +import org.deeplearning4j.rl4j.space.Box; +import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +/** + * + * @author saudet + */ +public class GymEnvTest { + + @Test + public void testCartpole() { + GymEnv mdp = new GymEnv("CartPole-v0", false, false); + assertArrayEquals(new int[] {4}, ((ArrayObservationSpace)mdp.getObservationSpace()).getShape()); + assertEquals(2, ((DiscreteSpace)mdp.getActionSpace()).getSize()); + assertEquals(false, mdp.isDone()); + Box o = (Box)mdp.reset(); + StepReply r = mdp.step(0); + assertEquals(4, o.toArray().length); + assertEquals(4, ((Box)r.getObservation()).toArray().length); + assertNotEquals(null, mdp.newInstance()); + mdp.close(); + } +} diff --git a/rl4j/rl4j-malmo/pom.xml b/rl4j/rl4j-malmo/pom.xml index 53ef7be13..3575da7a8 100644 --- a/rl4j/rl4j-malmo/pom.xml +++ b/rl4j/rl4j-malmo/pom.xml @@ -33,6 +33,11 @@ + + org.json + json + 20190722 + org.deeplearning4j rl4j-api @@ -44,4 +49,13 @@ 0.30.0 + + + + test-nd4j-native + + + test-nd4j-cuda-10.2 + +