diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java index 910e625f0..d3927e8c9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java @@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -125,7 +125,7 @@ public abstract class Learning return nshape; } - protected abstract DataManager getDataManager(); + protected abstract IDataManager getDataManager(); public abstract NN getNeuralNet(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java index 0b786d926..4ff461206 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java @@ -52,7 +52,7 @@ import java.util.concurrent.atomic.AtomicInteger; * */ @Slf4j -public class AsyncGlobal extends Thread { +public class AsyncGlobal extends Thread implements IAsyncGlobal { @Getter final private NN current; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java index 1815553b1..a8948b23e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java @@ -45,7 +45,7 @@ public abstract class AsyncLearning getAsyncGlobal(); + protected abstract IAsyncGlobal getAsyncGlobal(); protected void startGlobalThread() { getAsyncGlobal().start(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java index f957a3dba..5e2e2074f 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java @@ -31,7 +31,7 @@ import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.Constants; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. @@ -57,7 +57,7 @@ public abstract class AsyncThread asyncGlobal, int threadNumber) { + public AsyncThread(IAsyncGlobal asyncGlobal, int threadNumber) { this.threadNumber = threadNumber; } @@ -109,7 +109,7 @@ public abstract class AsyncThread= getConf().getMaxEpochStep() || getMdp().isDone()) { postEpoch(); - DataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score); + IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score); getDataManager().appendStat(statEntry); log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward()); @@ -136,13 +136,13 @@ public abstract class AsyncThread getAsyncGlobal(); + protected abstract IAsyncGlobal getAsyncGlobal(); protected abstract MDP getMdp(); protected abstract AsyncConfiguration getConf(); - protected abstract DataManager getDataManager(); + protected abstract IDataManager getDataManager(); protected abstract Policy getPolicy(NN net); @@ -159,7 +159,7 @@ public abstract class AsyncThread asyncGlobal, int threadNumber) { + public AsyncThreadDiscrete(IAsyncGlobal asyncGlobal, int threadNumber) { super(asyncGlobal, threadNumber); synchronized (asyncGlobal) { current = (NN)asyncGlobal.getCurrent().clone(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java new file mode 100644 index 000000000..138bff943 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java @@ -0,0 +1,33 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.async; + +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.network.NeuralNet; + +import java.util.concurrent.atomic.AtomicInteger; + +public interface IAsyncGlobal { + boolean isRunning(); + void setRunning(boolean value); + boolean isTrainingComplete(); + void start(); + AtomicInteger getT(); + NN getCurrent(); + NN getTarget(); + void enqueue(Gradient[] gradient, Integer nstep); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java index c962c6088..5777e2394 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java @@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. @@ -48,10 +48,10 @@ public abstract class A3CDiscrete extends AsyncLearning policy; @Getter - final private DataManager dataManager; + final private IDataManager dataManager; public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CConfiguration conf, - DataManager dataManager) { + IDataManager dataManager) { super(conf); this.iActorCritic = iActorCritic; this.mdp = mdp; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java index 764b9773e..6e74ba3b3 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java @@ -24,7 +24,7 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16. @@ -44,7 +44,7 @@ public class A3CDiscreteConv extends A3CDiscrete { final private HistoryProcessor.Configuration hpconf; public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, - HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) { + HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { super(mdp, IActorCritic, conf, dataManager); this.hpconf = hpconf; setHistoryProcessor(hpconf); @@ -52,13 +52,13 @@ public class A3CDiscreteConv extends A3CDiscrete { public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, - HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) { + HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, - HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) { + HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java index 1cb0c1119..c67659589 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java @@ -20,7 +20,7 @@ import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.*; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16. @@ -34,31 +34,31 @@ import org.deeplearning4j.rl4j.util.DataManager; public class A3CDiscreteDense extends A3CDiscrete { public A3CDiscreteDense(MDP mdp, IActorCritic IActorCritic, A3CConfiguration conf, - DataManager dataManager) { + IDataManager dataManager) { super(mdp, IActorCritic, conf, dataManager); } public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, - A3CConfiguration conf, DataManager dataManager) { + A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, - DataManager dataManager) { + IDataManager dataManager) { this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager); } public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, - A3CConfiguration conf, DataManager dataManager) { + A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf, - DataManager dataManager) { + IDataManager dataManager) { this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index bfc346c1f..4c5873b11 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -52,12 +52,12 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< @Getter final protected int threadNumber; @Getter - final protected DataManager dataManager; + final protected IDataManager dataManager; final private Random random; public A3CThreadDiscrete(MDP mdp, AsyncGlobal asyncGlobal, - A3CDiscrete.A3CConfiguration a3cc, int threadNumber, DataManager dataManager) { + A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager) { super(asyncGlobal, threadNumber); this.conf = a3cc; this.asyncGlobal = asyncGlobal; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index be3c57abf..f0d7a3349 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -27,7 +27,7 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. @@ -40,13 +40,13 @@ public abstract class AsyncNStepQLearningDiscrete @Getter final private MDP mdp; @Getter - final private DataManager dataManager; + final private IDataManager dataManager; @Getter final private AsyncGlobal asyncGlobal; public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf, - DataManager dataManager) { + IDataManager dataManager) { super(conf); this.mdp = mdp; this.dataManager = dataManager; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java index 84ec4f49a..4da14012e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java @@ -24,7 +24,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16. @@ -36,19 +36,19 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN final private HistoryProcessor.Configuration hpconf; public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) { + HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { super(mdp, dqn, conf, dataManager); this.hpconf = hpconf; setHistoryProcessor(hpconf); } public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) { + HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, - HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) { + HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java index e9dc906c9..837681981 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java @@ -22,7 +22,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16. @@ -30,18 +30,18 @@ import org.deeplearning4j.rl4j.util.DataManager; public class AsyncNStepQLearningDiscreteDense extends AsyncNStepQLearningDiscrete { public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, - AsyncNStepQLConfiguration conf, DataManager dataManager) { + AsyncNStepQLConfiguration conf, IDataManager dataManager) { super(mdp, dqn, conf, dataManager); } public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, - AsyncNStepQLConfiguration conf, DataManager dataManager) { + AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } public AsyncNStepQLearningDiscreteDense(MDP mdp, - DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, DataManager dataManager) { + DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index 670589235..4f6c3ad09 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -19,7 +19,7 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import lombok.Getter; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.learning.Learning; -import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; +import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.MiniTrans; import org.deeplearning4j.rl4j.mdp.MDP; @@ -29,7 +29,7 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -46,17 +46,17 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn @Getter final protected MDP mdp; @Getter - final protected AsyncGlobal asyncGlobal; + final protected IAsyncGlobal asyncGlobal; @Getter final protected int threadNumber; @Getter - final protected DataManager dataManager; + final protected IDataManager dataManager; final private Random random; - public AsyncNStepQLearningThreadDiscrete(MDP mdp, AsyncGlobal asyncGlobal, + public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber, - DataManager dataManager) { + IDataManager dataManager) { super(asyncGlobal, threadNumber); this.conf = conf; this.asyncGlobal = asyncGlobal; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java index d436b7b14..fb69f2fa9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java @@ -22,7 +22,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.Constants; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16. @@ -51,7 +51,7 @@ public abstract class SyncLearning extends QLearning mdp; @Getter @@ -72,7 +72,7 @@ public abstract class QLearningDiscrete extends QLearning mdp, IDQN dqn, QLConfiguration conf, - DataManager dataManager, int epsilonNbStep) { + IDataManager dataManager, int epsilonNbStep) { super(conf); this.configuration = conf; this.mdp = mdp; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java index 62fa09257..81d408ed5 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java @@ -23,7 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16. @@ -35,18 +35,18 @@ public class QLearningDiscreteConv extends QLearningDiscret public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, - QLConfiguration conf, DataManager dataManager) { + QLConfiguration conf, IDataManager dataManager) { super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); setHistoryProcessor(hpconf); } public QLearningDiscreteConv(MDP mdp, DQNFactory factory, - HistoryProcessor.Configuration hpconf, QLConfiguration conf, DataManager dataManager) { + HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, - HistoryProcessor.Configuration hpconf, QLConfiguration conf, DataManager dataManager) { + HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java index c5c4f1149..341031aec 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java @@ -23,7 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; -import org.deeplearning4j.rl4j.util.DataManager; +import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16. @@ -33,18 +33,18 @@ public class QLearningDiscreteDense extends QLearningDiscre public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf, - DataManager dataManager) { + IDataManager dataManager) { super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep()); } public QLearningDiscreteDense(MDP mdp, DQNFactory factory, - QLearning.QLConfiguration conf, DataManager dataManager) { + QLearning.QLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, - QLearning.QLConfiguration conf, DataManager dataManager) { + QLearning.QLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java index 5b88f9f0b..e9c243eea 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java @@ -43,7 +43,7 @@ import java.util.zip.ZipOutputStream; * the folder for every training and handle every path and model savings */ @Slf4j -public class DataManager { +public class DataManager implements IDataManager { final private String home = System.getProperty("user.home"); final private ObjectMapper mapper = new ObjectMapper(); @@ -266,16 +266,6 @@ public class DataManager { } - //In order for jackson to serialize StatEntry - //please use Lombok @Value (see QLStatEntry) - public interface StatEntry { - int getEpochCounter(); - - int getStepCounter(); - - double getReward(); - } - @AllArgsConstructor @Value @Builder diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java new file mode 100644 index 000000000..d265bca5a --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.util; + +import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.learning.Learning; + +import java.io.IOException; + +public interface IDataManager { + + boolean isSaveData(); + String getVideoDir(); + void appendStat(StatEntry statEntry) throws IOException; + void writeInfo(ILearning iLearning) throws IOException; + void save(Learning learning) throws IOException; + + //In order for jackson to serialize StatEntry + //please use Lombok @Value (see QLStatEntry) + interface StatEntry { + int getEpochCounter(); + + int getStepCounter(); + + double getReward(); + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java new file mode 100644 index 000000000..f8d1a219c --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -0,0 +1,458 @@ +package org.deeplearning4j.rl4j.learning.async; + +import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.nn.api.NeuralNetwork; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.policy.Policy; +import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.ObservationSpace; +import org.deeplearning4j.rl4j.support.MockDataManager; +import org.deeplearning4j.rl4j.support.MockHistoryProcessor; +import org.deeplearning4j.rl4j.util.IDataManager; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; + +public class AsyncThreadTest { + + @Test + public void refac_withoutHistoryProcessor_checkDataManagerCallsRemainTheSame() { + // Arrange + MockDataManager dataManager = new MockDataManager(false); + MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10); + MockNeuralNet neuralNet = new MockNeuralNet(); + MockObservationSpace observationSpace = new MockObservationSpace(); + MockMDP mdp = new MockMDP(observationSpace); + MockAsyncConfiguration config = new MockAsyncConfiguration(10, 2); + MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, dataManager); + + // Act + sut.run(); + + // Assert + assertEquals(4, dataManager.statEntries.size()); + + IDataManager.StatEntry entry = dataManager.statEntries.get(0); + assertEquals(2, entry.getStepCounter()); + assertEquals(0, entry.getEpochCounter()); + assertEquals(2.0, entry.getReward(), 0.0); + + entry = dataManager.statEntries.get(1); + assertEquals(4, entry.getStepCounter()); + assertEquals(1, entry.getEpochCounter()); + assertEquals(2.0, entry.getReward(), 0.0); + + entry = dataManager.statEntries.get(2); + assertEquals(6, entry.getStepCounter()); + assertEquals(2, entry.getEpochCounter()); + assertEquals(2.0, entry.getReward(), 0.0); + + entry = dataManager.statEntries.get(3); + assertEquals(8, entry.getStepCounter()); + assertEquals(3, entry.getEpochCounter()); + assertEquals(2.0, entry.getReward(), 0.0); + + assertEquals(0, dataManager.isSaveDataCallCount); + assertEquals(0, dataManager.getVideoDirCallCount); + } + + @Test + public void refac_withHistoryProcessor_isSaveFalse_checkDataManagerCallsRemainTheSame() { + // Arrange + MockDataManager dataManager = new MockDataManager(false); + MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10); + MockNeuralNet neuralNet = new MockNeuralNet(); + MockObservationSpace observationSpace = new MockObservationSpace(); + MockMDP mdp = new MockMDP(observationSpace); + MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2); + + IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder() + .build(); + MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig); + + + MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager); + sut.setHistoryProcessor(hp); + + // Act + sut.run(); + + // Assert + assertEquals(9, dataManager.statEntries.size()); + + for(int i = 0; i < 9; ++i) { + IDataManager.StatEntry entry = dataManager.statEntries.get(i); + assertEquals(i + 1, entry.getStepCounter()); + assertEquals(i, entry.getEpochCounter()); + assertEquals(1.0, entry.getReward(), 0.0); + } + + assertEquals(10, dataManager.isSaveDataCallCount); + assertEquals(0, dataManager.getVideoDirCallCount); + } + + @Test + public void refac_withHistoryProcessor_isSaveTrue_checkDataManagerCallsRemainTheSame() { + // Arrange + MockDataManager dataManager = new MockDataManager(true); + MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10); + MockNeuralNet neuralNet = new MockNeuralNet(); + MockObservationSpace observationSpace = new MockObservationSpace(); + MockMDP mdp = new MockMDP(observationSpace); + MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2); + + IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder() + .build(); + MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig); + + + MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager); + sut.setHistoryProcessor(hp); + + // Act + sut.run(); + + // Assert + assertEquals(9, dataManager.statEntries.size()); + + for(int i = 0; i < 9; ++i) { + IDataManager.StatEntry entry = dataManager.statEntries.get(i); + assertEquals(i + 1, entry.getStepCounter()); + assertEquals(i, entry.getEpochCounter()); + assertEquals(1.0, entry.getReward(), 0.0); + } + + assertEquals(1, dataManager.isSaveDataCallCount); + assertEquals(1, dataManager.getVideoDirCallCount); + } + + public static class MockAsyncGlobal implements IAsyncGlobal { + + private final int maxLoops; + private int currentLoop = 0; + + public MockAsyncGlobal(int maxLoops) { + + this.maxLoops = maxLoops; + } + + @Override + public boolean isRunning() { + return true; + } + + @Override + public void setRunning(boolean value) { + + } + + @Override + public boolean isTrainingComplete() { + return ++currentLoop >= maxLoops; + } + + @Override + public void start() { + + } + + @Override + public AtomicInteger getT() { + return null; + } + + @Override + public NeuralNet getCurrent() { + return null; + } + + @Override + public NeuralNet getTarget() { + return null; + } + + @Override + public void enqueue(Gradient[] gradient, Integer nstep) { + + } + } + + public static class MockAsyncThread extends AsyncThread { + + IAsyncGlobal asyncGlobal; + private final MockNeuralNet neuralNet; + private final MDP mdp; + private final AsyncConfiguration conf; + private final IDataManager dataManager; + + public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) { + super(asyncGlobal, threadNumber); + + this.asyncGlobal = asyncGlobal; + this.neuralNet = neuralNet; + this.mdp = mdp; + this.conf = conf; + this.dataManager = dataManager; + } + + @Override + protected NeuralNet getCurrent() { + return neuralNet; + } + + @Override + protected int getThreadNumber() { + return 0; + } + + @Override + protected IAsyncGlobal getAsyncGlobal() { + return asyncGlobal; + } + + @Override + protected MDP getMdp() { + return mdp; + } + + @Override + protected AsyncConfiguration getConf() { + return conf; + } + + @Override + protected IDataManager getDataManager() { + return dataManager; + } + + @Override + protected Policy getPolicy(NeuralNet net) { + return null; + } + + @Override + protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) { + return new SubEpochReturn(1, null, 1.0, 1.0); + } + } + + public static class MockNeuralNet implements NeuralNet { + + @Override + public NeuralNetwork[] getNeuralNetworks() { + return new NeuralNetwork[0]; + } + + @Override + public boolean isRecurrent() { + return false; + } + + @Override + public void reset() { + + } + + @Override + public INDArray[] outputAll(INDArray batch) { + return new INDArray[0]; + } + + @Override + public NeuralNet clone() { + return null; + } + + @Override + public void copy(NeuralNet from) { + + } + + @Override + public Gradient[] gradient(INDArray input, INDArray[] labels) { + return new Gradient[0]; + } + + @Override + public void fit(INDArray input, INDArray[] labels) { + + } + + @Override + public void applyGradient(Gradient[] gradients, int batchSize) { + + } + + @Override + public double getLatestScore() { + return 0; + } + + @Override + public void save(OutputStream os) throws IOException { + + } + + @Override + public void save(String filename) throws IOException { + + } + } + + public static class MockEncodable implements Encodable { + + private final int value; + + public MockEncodable(int value) { + + this.value = value; + } + + @Override + public double[] toArray() { + return new double[] { value }; + } + } + + public static class MockObservationSpace implements ObservationSpace { + + @Override + public String getName() { + return null; + } + + @Override + public int[] getShape() { + return new int[] { 1 }; + } + + @Override + public INDArray getLow() { + return null; + } + + @Override + public INDArray getHigh() { + return null; + } + } + + public static class MockMDP implements MDP { + + private final DiscreteSpace actionSpace; + private int currentObsValue = 0; + private final ObservationSpace observationSpace; + + public MockMDP(ObservationSpace observationSpace) { + actionSpace = new DiscreteSpace(5); + this.observationSpace = observationSpace; + } + + @Override + public ObservationSpace getObservationSpace() { + return observationSpace; + } + + @Override + public DiscreteSpace getActionSpace() { + return actionSpace; + } + + @Override + public MockEncodable reset() { + return new MockEncodable(++currentObsValue); + } + + @Override + public void close() { + + } + + @Override + public StepReply step(Integer obs) { + return new StepReply(new MockEncodable(obs), (double)obs, isDone(), null); + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public MDP newInstance() { + return null; + } + } + + public static class MockAsyncConfiguration implements AsyncConfiguration { + + private final int nStep; + private final int maxEpochStep; + + public MockAsyncConfiguration(int nStep, int maxEpochStep) { + this.nStep = nStep; + + this.maxEpochStep = maxEpochStep; + } + + @Override + public int getSeed() { + return 0; + } + + @Override + public int getMaxEpochStep() { + return maxEpochStep; + } + + @Override + public int getMaxStep() { + return 0; + } + + @Override + public int getNumThread() { + return 0; + } + + @Override + public int getNstep() { + return nStep; + } + + @Override + public int getTargetDqnUpdateFreq() { + return 0; + } + + @Override + public int getUpdateStart() { + return 0; + } + + @Override + public double getRewardFactor() { + return 0; + } + + @Override + public double getGamma() { + return 0; + } + + @Override + public double getErrorClamp() { + return 0; + } + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java new file mode 100644 index 000000000..cace83d71 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java @@ -0,0 +1,132 @@ +package org.deeplearning4j.rl4j.learning.sync; + +import lombok.AllArgsConstructor; +import lombok.Value; +import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.policy.Policy; +import org.deeplearning4j.rl4j.support.MockDataManager; +import org.deeplearning4j.rl4j.util.IDataManager; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class SyncLearningTest { + + @Test + public void refac_checkDataManagerCallsRemainTheSame() { + // Arrange + MockLConfiguration lconfig = new MockLConfiguration(10); + MockDataManager dataManager = new MockDataManager(false); + MockSyncLearning sut = new MockSyncLearning(lconfig, dataManager, 2); + + // Act + sut.train(); + + assertEquals(10, dataManager.statEntries.size()); + for(int i = 0; i < 10; ++i) { + IDataManager.StatEntry entry = dataManager.statEntries.get(i); + assertEquals(2, entry.getEpochCounter()); + assertEquals(i+1, entry.getStepCounter()); + assertEquals(1.0, entry.getReward(), 0.0); + + } + assertEquals(0, dataManager.isSaveDataCallCount); + assertEquals(0, dataManager.getVideoDirCallCount); + assertEquals(11, dataManager.writeInfoCallCount); + assertEquals(1, dataManager.saveCallCount); + } + + public static class MockSyncLearning extends SyncLearning { + + private final IDataManager dataManager; + private LConfiguration conf; + private final int epochSteps; + + public MockSyncLearning(LConfiguration conf, IDataManager dataManager, int epochSteps) { + super(conf); + this.dataManager = dataManager; + this.conf = conf; + this.epochSteps = epochSteps; + } + + @Override + protected void preEpoch() { + + } + + @Override + protected void postEpoch() { + + } + + @Override + protected IDataManager.StatEntry trainEpoch() { + setStepCounter(getStepCounter() + 1); + return new MockStatEntry(epochSteps, getStepCounter(), 1.0); + } + + @Override + protected IDataManager getDataManager() { + return dataManager; + } + + @Override + public NeuralNet getNeuralNet() { + return null; + } + + @Override + public Policy getPolicy() { + return null; + } + + @Override + public LConfiguration getConfiguration() { + return conf; + } + + @Override + public MDP getMdp() { + return null; + } + } + + public static class MockLConfiguration implements ILearning.LConfiguration { + + private final int maxStep; + + public MockLConfiguration(int maxStep) { + this.maxStep = maxStep; + } + + @Override + public int getSeed() { + return 0; + } + + @Override + public int getMaxEpochStep() { + return 0; + } + + @Override + public int getMaxStep() { + return maxStep; + } + + @Override + public double getGamma() { + return 0; + } + } + + @AllArgsConstructor + @Value + public static class MockStatEntry implements IDataManager.StatEntry { + int epochCounter; + int stepCounter; + double reward; + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java new file mode 100644 index 000000000..659198f79 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java @@ -0,0 +1,50 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.rl4j.learning.ILearning; +import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.util.IDataManager; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class MockDataManager implements IDataManager { + + private final boolean isSaveData; + public List statEntries = new ArrayList<>(); + public int isSaveDataCallCount = 0; + public int getVideoDirCallCount = 0; + public int writeInfoCallCount = 0; + public int saveCallCount = 0; + + public MockDataManager(boolean isSaveData) { + this.isSaveData = isSaveData; + } + + @Override + public boolean isSaveData() { + ++isSaveDataCallCount; + return isSaveData; + } + + @Override + public String getVideoDir() { + ++getVideoDirCallCount; + return null; + } + + @Override + public void appendStat(StatEntry statEntry) throws IOException { + statEntries.add(statEntry); + } + + @Override + public void writeInfo(ILearning iLearning) throws IOException { + ++writeInfoCallCount; + } + + @Override + public void save(Learning learning) throws IOException { + ++saveCallCount; + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java new file mode 100644 index 000000000..9d24161b4 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java @@ -0,0 +1,54 @@ +package org.deeplearning4j.rl4j.support; + +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.nd4j.linalg.api.ndarray.INDArray; + +public class MockHistoryProcessor implements IHistoryProcessor { + + private final Configuration config; + + public MockHistoryProcessor(Configuration config) { + + this.config = config; + } + + @Override + public Configuration getConf() { + return config; + } + + @Override + public INDArray[] getHistory() { + return new INDArray[0]; + } + + @Override + public void record(INDArray image) { + + } + + @Override + public void add(INDArray image) { + + } + + @Override + public void startMonitor(String filename, int[] shape) { + + } + + @Override + public void stopMonitor() { + + } + + @Override + public boolean isMonitoring() { + return false; + } + + @Override + public double getScale() { + return 0; + } +}