Added interface IDataManager (#8034)

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2019-07-25 07:34:54 -04:00 committed by Alex Black
parent 22993f853f
commit 87d2b2cd3d
25 changed files with 828 additions and 70 deletions

View File

@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -125,7 +125,7 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
return nshape; return nshape;
} }
protected abstract DataManager getDataManager(); protected abstract IDataManager getDataManager();
public abstract NN getNeuralNet(); public abstract NN getNeuralNet();

View File

@ -52,7 +52,7 @@ import java.util.concurrent.atomic.AtomicInteger;
* *
*/ */
@Slf4j @Slf4j
public class AsyncGlobal<NN extends NeuralNet> extends Thread { public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> {
@Getter @Getter
final private NN current; final private NN current;

View File

@ -45,7 +45,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
protected abstract AsyncThread newThread(int i); protected abstract AsyncThread newThread(int i);
protected abstract AsyncGlobal<NN> getAsyncGlobal(); protected abstract IAsyncGlobal<NN> getAsyncGlobal();
protected void startGlobalThread() { protected void startGlobalThread() {
getAsyncGlobal().start(); getAsyncGlobal().start();

View File

@ -31,7 +31,7 @@ import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.Constants; import org.deeplearning4j.rl4j.util.Constants;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -57,7 +57,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
@Getter @Getter
private int lastMonitor = -Constants.MONITOR_FREQ; private int lastMonitor = -Constants.MONITOR_FREQ;
public AsyncThread(AsyncGlobal<NN> asyncGlobal, int threadNumber) { public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber) {
this.threadNumber = threadNumber; this.threadNumber = threadNumber;
} }
@ -109,7 +109,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
if (length >= getConf().getMaxEpochStep() || getMdp().isDone()) { if (length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
postEpoch(); postEpoch();
DataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score); IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score);
getDataManager().appendStat(statEntry); getDataManager().appendStat(statEntry);
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward()); log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
@ -136,13 +136,13 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
protected abstract int getThreadNumber(); protected abstract int getThreadNumber();
protected abstract AsyncGlobal<NN> getAsyncGlobal(); protected abstract IAsyncGlobal<NN> getAsyncGlobal();
protected abstract MDP<O, A, AS> getMdp(); protected abstract MDP<O, A, AS> getMdp();
protected abstract AsyncConfiguration getConf(); protected abstract AsyncConfiguration getConf();
protected abstract DataManager getDataManager(); protected abstract IDataManager getDataManager();
protected abstract Policy<O, A> getPolicy(NN net); protected abstract Policy<O, A> getPolicy(NN net);
@ -159,7 +159,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
@AllArgsConstructor @AllArgsConstructor
@Value @Value
public static class AsyncStatEntry implements DataManager.StatEntry { public static class AsyncStatEntry implements IDataManager.StatEntry {
int stepCounter; int stepCounter;
int epochCounter; int epochCounter;
double reward; double reward;

View File

@ -44,7 +44,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
@Getter @Getter
private NN current; private NN current;
public AsyncThreadDiscrete(AsyncGlobal<NN> asyncGlobal, int threadNumber) { public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber) {
super(asyncGlobal, threadNumber); super(asyncGlobal, threadNumber);
synchronized (asyncGlobal) { synchronized (asyncGlobal) {
current = (NN)asyncGlobal.getCurrent().clone(); current = (NN)asyncGlobal.getCurrent().clone();

View File

@ -0,0 +1,33 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import java.util.concurrent.atomic.AtomicInteger;
public interface IAsyncGlobal<NN extends NeuralNet> {
boolean isRunning();
void setRunning(boolean value);
boolean isTrainingComplete();
void start();
AtomicInteger getT();
NN getCurrent();
NN getTarget();
void enqueue(Gradient[] gradient, Integer nstep);
}

View File

@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
@ -48,10 +48,10 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
@Getter @Getter
final private ACPolicy<O> policy; final private ACPolicy<O> policy;
@Getter @Getter
final private DataManager dataManager; final private IDataManager dataManager;
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf, public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf,
DataManager dataManager) { IDataManager dataManager) {
super(conf); super(conf);
this.iActorCritic = iActorCritic; this.iActorCritic = iActorCritic;
this.mdp = mdp; this.mdp = mdp;

View File

@ -24,7 +24,7 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
@ -44,7 +44,7 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
final private HistoryProcessor.Configuration hpconf; final private HistoryProcessor.Configuration hpconf;
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
super(mdp, IActorCritic, conf, dataManager); super(mdp, IActorCritic, conf, dataManager);
this.hpconf = hpconf; this.hpconf = hpconf;
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
@ -52,13 +52,13 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf,
dataManager); dataManager);
} }
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager); this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
} }

View File

@ -20,7 +20,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.*; import org.deeplearning4j.rl4j.network.ac.*;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
@ -34,31 +34,31 @@ import org.deeplearning4j.rl4j.util.DataManager;
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> { public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
DataManager dataManager) { IDataManager dataManager) {
super(mdp, IActorCritic, conf, dataManager); super(mdp, IActorCritic, conf, dataManager);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf, DataManager dataManager) { A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
DataManager dataManager) { IDataManager dataManager) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager); this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf, DataManager dataManager) { A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf, ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
DataManager dataManager) { IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager); this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
} }

View File

@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -52,12 +52,12 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
@Getter @Getter
final protected int threadNumber; final protected int threadNumber;
@Getter @Getter
final protected DataManager dataManager; final protected IDataManager dataManager;
final private Random random; final private Random random;
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal, public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, DataManager dataManager) { A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager) {
super(asyncGlobal, threadNumber); super(asyncGlobal, threadNumber);
this.conf = a3cc; this.conf = a3cc;
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;

View File

@ -27,7 +27,7 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -40,13 +40,13 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
@Getter @Getter
final private MDP<O, Integer, DiscreteSpace> mdp; final private MDP<O, Integer, DiscreteSpace> mdp;
@Getter @Getter
final private DataManager dataManager; final private IDataManager dataManager;
@Getter @Getter
final private AsyncGlobal<IDQN> asyncGlobal; final private AsyncGlobal<IDQN> asyncGlobal;
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf, public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf,
DataManager dataManager) { IDataManager dataManager) {
super(conf); super(conf);
this.mdp = mdp; this.mdp = mdp;
this.dataManager = dataManager; this.dataManager = dataManager;

View File

@ -24,7 +24,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
@ -36,19 +36,19 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
final private HistoryProcessor.Configuration hpconf; final private HistoryProcessor.Configuration hpconf;
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf, dataManager); super(mdp, dqn, conf, dataManager);
this.hpconf = hpconf; this.hpconf = hpconf;
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
} }
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
} }
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
} }

View File

@ -22,7 +22,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
@ -30,18 +30,18 @@ import org.deeplearning4j.rl4j.util.DataManager;
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> { public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf, DataManager dataManager) { AsyncNStepQLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf, dataManager); super(mdp, dqn, conf, dataManager);
} }
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf, DataManager dataManager) { AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, DataManager dataManager) { DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
} }
} }

View File

@ -19,7 +19,7 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.MiniTrans; import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
@ -29,7 +29,7 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -46,17 +46,17 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
@Getter @Getter
final protected MDP<O, Integer, DiscreteSpace> mdp; final protected MDP<O, Integer, DiscreteSpace> mdp;
@Getter @Getter
final protected AsyncGlobal<IDQN> asyncGlobal; final protected IAsyncGlobal<IDQN> asyncGlobal;
@Getter @Getter
final protected int threadNumber; final protected int threadNumber;
@Getter @Getter
final protected DataManager dataManager; final protected IDataManager dataManager;
final private Random random; final private Random random;
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IDQN> asyncGlobal, public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber, AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber,
DataManager dataManager) { IDataManager dataManager) {
super(asyncGlobal, threadNumber); super(asyncGlobal, threadNumber);
this.conf = conf; this.conf = conf;
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;

View File

@ -22,7 +22,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.Constants; import org.deeplearning4j.rl4j.util.Constants;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16.
@ -51,7 +51,7 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
while (getStepCounter() < getConfiguration().getMaxStep()) { while (getStepCounter() < getConfiguration().getMaxStep()) {
preEpoch(); preEpoch();
DataManager.StatEntry statEntry = trainEpoch(); IDataManager.StatEntry statEntry = trainEpoch();
postEpoch(); postEpoch();
incrementEpoch(); incrementEpoch();
@ -79,6 +79,6 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
protected abstract void postEpoch(); protected abstract void postEpoch();
protected abstract DataManager.StatEntry trainEpoch(); protected abstract IDataManager.StatEntry trainEpoch();
} }

View File

@ -29,7 +29,7 @@ import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager.StatEntry; import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList; import java.util.ArrayList;

View File

@ -30,7 +30,7 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.Constants; import org.deeplearning4j.rl4j.util.Constants;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.INDArrayIndex;
@ -53,7 +53,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Getter @Getter
final private QLConfiguration configuration; final private QLConfiguration configuration;
@Getter @Getter
final private DataManager dataManager; final private IDataManager dataManager;
@Getter @Getter
final private MDP<O, Integer, DiscreteSpace> mdp; final private MDP<O, Integer, DiscreteSpace> mdp;
@Getter @Getter
@ -72,7 +72,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf, public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
DataManager dataManager, int epsilonNbStep) { IDataManager dataManager, int epsilonNbStep) {
super(conf); super(conf);
this.configuration = conf; this.configuration = conf;
this.mdp = mdp; this.mdp = mdp;

View File

@ -23,7 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
@ -35,18 +35,18 @@ public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscret
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf, DataManager dataManager) { QLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
} }
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, DataManager dataManager) { HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
} }
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, DataManager dataManager) { HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
} }
} }

View File

@ -23,7 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager; import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
@ -33,18 +33,18 @@ public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscre
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf, public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
DataManager dataManager) { IDataManager dataManager) {
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep()); super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep());
} }
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearning.QLConfiguration conf, DataManager dataManager) { QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf, public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf, DataManager dataManager) { QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
} }

View File

@ -43,7 +43,7 @@ import java.util.zip.ZipOutputStream;
* the folder for every training and handle every path and model savings * the folder for every training and handle every path and model savings
*/ */
@Slf4j @Slf4j
public class DataManager { public class DataManager implements IDataManager {
final private String home = System.getProperty("user.home"); final private String home = System.getProperty("user.home");
final private ObjectMapper mapper = new ObjectMapper(); final private ObjectMapper mapper = new ObjectMapper();
@ -266,16 +266,6 @@ public class DataManager {
} }
//In order for jackson to serialize StatEntry
//please use Lombok @Value (see QLStatEntry)
public interface StatEntry {
int getEpochCounter();
int getStepCounter();
double getReward();
}
@AllArgsConstructor @AllArgsConstructor
@Value @Value
@Builder @Builder

View File

@ -0,0 +1,41 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.util;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning;
import java.io.IOException;
public interface IDataManager {
boolean isSaveData();
String getVideoDir();
void appendStat(StatEntry statEntry) throws IOException;
void writeInfo(ILearning iLearning) throws IOException;
void save(Learning learning) throws IOException;
//In order for jackson to serialize StatEntry
//please use Lombok @Value (see QLStatEntry)
interface StatEntry {
int getEpochCounter();
int getStepCounter();
double getReward();
}
}

View File

@ -0,0 +1,458 @@
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.deeplearning4j.rl4j.support.MockDataManager;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
public class AsyncThreadTest {
@Test
public void refac_withoutHistoryProcessor_checkDataManagerCallsRemainTheSame() {
// Arrange
MockDataManager dataManager = new MockDataManager(false);
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
MockNeuralNet neuralNet = new MockNeuralNet();
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockAsyncConfiguration config = new MockAsyncConfiguration(10, 2);
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, dataManager);
// Act
sut.run();
// Assert
assertEquals(4, dataManager.statEntries.size());
IDataManager.StatEntry entry = dataManager.statEntries.get(0);
assertEquals(2, entry.getStepCounter());
assertEquals(0, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
entry = dataManager.statEntries.get(1);
assertEquals(4, entry.getStepCounter());
assertEquals(1, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
entry = dataManager.statEntries.get(2);
assertEquals(6, entry.getStepCounter());
assertEquals(2, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
entry = dataManager.statEntries.get(3);
assertEquals(8, entry.getStepCounter());
assertEquals(3, entry.getEpochCounter());
assertEquals(2.0, entry.getReward(), 0.0);
assertEquals(0, dataManager.isSaveDataCallCount);
assertEquals(0, dataManager.getVideoDirCallCount);
}
@Test
public void refac_withHistoryProcessor_isSaveFalse_checkDataManagerCallsRemainTheSame() {
// Arrange
MockDataManager dataManager = new MockDataManager(false);
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
MockNeuralNet neuralNet = new MockNeuralNet();
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
.build();
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
sut.setHistoryProcessor(hp);
// Act
sut.run();
// Assert
assertEquals(9, dataManager.statEntries.size());
for(int i = 0; i < 9; ++i) {
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(i + 1, entry.getStepCounter());
assertEquals(i, entry.getEpochCounter());
assertEquals(1.0, entry.getReward(), 0.0);
}
assertEquals(10, dataManager.isSaveDataCallCount);
assertEquals(0, dataManager.getVideoDirCallCount);
}
@Test
public void refac_withHistoryProcessor_isSaveTrue_checkDataManagerCallsRemainTheSame() {
// Arrange
MockDataManager dataManager = new MockDataManager(true);
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
MockNeuralNet neuralNet = new MockNeuralNet();
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
.build();
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
sut.setHistoryProcessor(hp);
// Act
sut.run();
// Assert
assertEquals(9, dataManager.statEntries.size());
for(int i = 0; i < 9; ++i) {
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(i + 1, entry.getStepCounter());
assertEquals(i, entry.getEpochCounter());
assertEquals(1.0, entry.getReward(), 0.0);
}
assertEquals(1, dataManager.isSaveDataCallCount);
assertEquals(1, dataManager.getVideoDirCallCount);
}
public static class MockAsyncGlobal implements IAsyncGlobal {
private final int maxLoops;
private int currentLoop = 0;
public MockAsyncGlobal(int maxLoops) {
this.maxLoops = maxLoops;
}
@Override
public boolean isRunning() {
return true;
}
@Override
public void setRunning(boolean value) {
}
@Override
public boolean isTrainingComplete() {
return ++currentLoop >= maxLoops;
}
@Override
public void start() {
}
@Override
public AtomicInteger getT() {
return null;
}
@Override
public NeuralNet getCurrent() {
return null;
}
@Override
public NeuralNet getTarget() {
return null;
}
@Override
public void enqueue(Gradient[] gradient, Integer nstep) {
}
}
public static class MockAsyncThread extends AsyncThread {
IAsyncGlobal asyncGlobal;
private final MockNeuralNet neuralNet;
private final MDP mdp;
private final AsyncConfiguration conf;
private final IDataManager dataManager;
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) {
super(asyncGlobal, threadNumber);
this.asyncGlobal = asyncGlobal;
this.neuralNet = neuralNet;
this.mdp = mdp;
this.conf = conf;
this.dataManager = dataManager;
}
@Override
protected NeuralNet getCurrent() {
return neuralNet;
}
@Override
protected int getThreadNumber() {
return 0;
}
@Override
protected IAsyncGlobal getAsyncGlobal() {
return asyncGlobal;
}
@Override
protected MDP getMdp() {
return mdp;
}
@Override
protected AsyncConfiguration getConf() {
return conf;
}
@Override
protected IDataManager getDataManager() {
return dataManager;
}
@Override
protected Policy getPolicy(NeuralNet net) {
return null;
}
@Override
protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) {
return new SubEpochReturn(1, null, 1.0, 1.0);
}
}
public static class MockNeuralNet implements NeuralNet {
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}
@Override
public boolean isRecurrent() {
return false;
}
@Override
public void reset() {
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];
}
@Override
public NeuralNet clone() {
return null;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] labels) {
return new Gradient[0];
}
@Override
public void fit(INDArray input, INDArray[] labels) {
}
@Override
public void applyGradient(Gradient[] gradients, int batchSize) {
}
@Override
public double getLatestScore() {
return 0;
}
@Override
public void save(OutputStream os) throws IOException {
}
@Override
public void save(String filename) throws IOException {
}
}
public static class MockEncodable implements Encodable {
private final int value;
public MockEncodable(int value) {
this.value = value;
}
@Override
public double[] toArray() {
return new double[] { value };
}
}
public static class MockObservationSpace implements ObservationSpace {
@Override
public String getName() {
return null;
}
@Override
public int[] getShape() {
return new int[] { 1 };
}
@Override
public INDArray getLow() {
return null;
}
@Override
public INDArray getHigh() {
return null;
}
}
public static class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
private final DiscreteSpace actionSpace;
private int currentObsValue = 0;
private final ObservationSpace observationSpace;
public MockMDP(ObservationSpace observationSpace) {
actionSpace = new DiscreteSpace(5);
this.observationSpace = observationSpace;
}
@Override
public ObservationSpace getObservationSpace() {
return observationSpace;
}
@Override
public DiscreteSpace getActionSpace() {
return actionSpace;
}
@Override
public MockEncodable reset() {
return new MockEncodable(++currentObsValue);
}
@Override
public void close() {
}
@Override
public StepReply<MockEncodable> step(Integer obs) {
return new StepReply<MockEncodable>(new MockEncodable(obs), (double)obs, isDone(), null);
}
@Override
public boolean isDone() {
return false;
}
@Override
public MDP newInstance() {
return null;
}
}
public static class MockAsyncConfiguration implements AsyncConfiguration {
private final int nStep;
private final int maxEpochStep;
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
this.nStep = nStep;
this.maxEpochStep = maxEpochStep;
}
@Override
public int getSeed() {
return 0;
}
@Override
public int getMaxEpochStep() {
return maxEpochStep;
}
@Override
public int getMaxStep() {
return 0;
}
@Override
public int getNumThread() {
return 0;
}
@Override
public int getNstep() {
return nStep;
}
@Override
public int getTargetDqnUpdateFreq() {
return 0;
}
@Override
public int getUpdateStart() {
return 0;
}
@Override
public double getRewardFactor() {
return 0;
}
@Override
public double getGamma() {
return 0;
}
@Override
public double getErrorClamp() {
return 0;
}
}
}

View File

@ -0,0 +1,132 @@
package org.deeplearning4j.rl4j.learning.sync;
import lombok.AllArgsConstructor;
import lombok.Value;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.support.MockDataManager;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
public class SyncLearningTest {
@Test
public void refac_checkDataManagerCallsRemainTheSame() {
// Arrange
MockLConfiguration lconfig = new MockLConfiguration(10);
MockDataManager dataManager = new MockDataManager(false);
MockSyncLearning sut = new MockSyncLearning(lconfig, dataManager, 2);
// Act
sut.train();
assertEquals(10, dataManager.statEntries.size());
for(int i = 0; i < 10; ++i) {
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(2, entry.getEpochCounter());
assertEquals(i+1, entry.getStepCounter());
assertEquals(1.0, entry.getReward(), 0.0);
}
assertEquals(0, dataManager.isSaveDataCallCount);
assertEquals(0, dataManager.getVideoDirCallCount);
assertEquals(11, dataManager.writeInfoCallCount);
assertEquals(1, dataManager.saveCallCount);
}
public static class MockSyncLearning extends SyncLearning {
private final IDataManager dataManager;
private LConfiguration conf;
private final int epochSteps;
public MockSyncLearning(LConfiguration conf, IDataManager dataManager, int epochSteps) {
super(conf);
this.dataManager = dataManager;
this.conf = conf;
this.epochSteps = epochSteps;
}
@Override
protected void preEpoch() {
}
@Override
protected void postEpoch() {
}
@Override
protected IDataManager.StatEntry trainEpoch() {
setStepCounter(getStepCounter() + 1);
return new MockStatEntry(epochSteps, getStepCounter(), 1.0);
}
@Override
protected IDataManager getDataManager() {
return dataManager;
}
@Override
public NeuralNet getNeuralNet() {
return null;
}
@Override
public Policy getPolicy() {
return null;
}
@Override
public LConfiguration getConfiguration() {
return conf;
}
@Override
public MDP getMdp() {
return null;
}
}
public static class MockLConfiguration implements ILearning.LConfiguration {
private final int maxStep;
public MockLConfiguration(int maxStep) {
this.maxStep = maxStep;
}
@Override
public int getSeed() {
return 0;
}
@Override
public int getMaxEpochStep() {
return 0;
}
@Override
public int getMaxStep() {
return maxStep;
}
@Override
public double getGamma() {
return 0;
}
}
@AllArgsConstructor
@Value
public static class MockStatEntry implements IDataManager.StatEntry {
int epochCounter;
int stepCounter;
double reward;
}
}

View File

@ -0,0 +1,50 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.util.IDataManager;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class MockDataManager implements IDataManager {
private final boolean isSaveData;
public List<StatEntry> statEntries = new ArrayList<>();
public int isSaveDataCallCount = 0;
public int getVideoDirCallCount = 0;
public int writeInfoCallCount = 0;
public int saveCallCount = 0;
public MockDataManager(boolean isSaveData) {
this.isSaveData = isSaveData;
}
@Override
public boolean isSaveData() {
++isSaveDataCallCount;
return isSaveData;
}
@Override
public String getVideoDir() {
++getVideoDirCallCount;
return null;
}
@Override
public void appendStat(StatEntry statEntry) throws IOException {
statEntries.add(statEntry);
}
@Override
public void writeInfo(ILearning iLearning) throws IOException {
++writeInfoCallCount;
}
@Override
public void save(Learning learning) throws IOException {
++saveCallCount;
}
}

View File

@ -0,0 +1,54 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
public class MockHistoryProcessor implements IHistoryProcessor {
private final Configuration config;
public MockHistoryProcessor(Configuration config) {
this.config = config;
}
@Override
public Configuration getConf() {
return config;
}
@Override
public INDArray[] getHistory() {
return new INDArray[0];
}
@Override
public void record(INDArray image) {
}
@Override
public void add(INDArray image) {
}
@Override
public void startMonitor(String filename, int[] shape) {
}
@Override
public void stopMonitor() {
}
@Override
public boolean isMonitoring() {
return false;
}
@Override
public double getScale() {
return 0;
}
}