Added interface IDataManager (#8034)
Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
22993f853f
commit
87d2b2cd3d
|
@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
|||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -125,7 +125,7 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
|||
return nshape;
|
||||
}
|
||||
|
||||
protected abstract DataManager getDataManager();
|
||||
protected abstract IDataManager getDataManager();
|
||||
|
||||
public abstract NN getNeuralNet();
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
|||
*
|
||||
*/
|
||||
@Slf4j
|
||||
public class AsyncGlobal<NN extends NeuralNet> extends Thread {
|
||||
public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> {
|
||||
|
||||
@Getter
|
||||
final private NN current;
|
||||
|
|
|
@ -45,7 +45,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
|||
|
||||
protected abstract AsyncThread newThread(int i);
|
||||
|
||||
protected abstract AsyncGlobal<NN> getAsyncGlobal();
|
||||
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
||||
|
||||
protected void startGlobalThread() {
|
||||
getAsyncGlobal().start();
|
||||
|
|
|
@ -31,7 +31,7 @@ import org.deeplearning4j.rl4j.policy.Policy;
|
|||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.Constants;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
|
@ -57,7 +57,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
@Getter
|
||||
private int lastMonitor = -Constants.MONITOR_FREQ;
|
||||
|
||||
public AsyncThread(AsyncGlobal<NN> asyncGlobal, int threadNumber) {
|
||||
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber) {
|
||||
this.threadNumber = threadNumber;
|
||||
}
|
||||
|
||||
|
@ -109,7 +109,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
if (length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||
postEpoch();
|
||||
|
||||
DataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score);
|
||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score);
|
||||
getDataManager().appendStat(statEntry);
|
||||
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
||||
|
||||
|
@ -136,13 +136,13 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
|
||||
protected abstract int getThreadNumber();
|
||||
|
||||
protected abstract AsyncGlobal<NN> getAsyncGlobal();
|
||||
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
||||
|
||||
protected abstract MDP<O, A, AS> getMdp();
|
||||
|
||||
protected abstract AsyncConfiguration getConf();
|
||||
|
||||
protected abstract DataManager getDataManager();
|
||||
protected abstract IDataManager getDataManager();
|
||||
|
||||
protected abstract Policy<O, A> getPolicy(NN net);
|
||||
|
||||
|
@ -159,7 +159,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
|
||||
@AllArgsConstructor
|
||||
@Value
|
||||
public static class AsyncStatEntry implements DataManager.StatEntry {
|
||||
public static class AsyncStatEntry implements IDataManager.StatEntry {
|
||||
int stepCounter;
|
||||
int epochCounter;
|
||||
double reward;
|
||||
|
|
|
@ -44,7 +44,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
|||
@Getter
|
||||
private NN current;
|
||||
|
||||
public AsyncThreadDiscrete(AsyncGlobal<NN> asyncGlobal, int threadNumber) {
|
||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber) {
|
||||
super(asyncGlobal, threadNumber);
|
||||
synchronized (asyncGlobal) {
|
||||
current = (NN)asyncGlobal.getCurrent().clone();
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public interface IAsyncGlobal<NN extends NeuralNet> {
|
||||
boolean isRunning();
|
||||
void setRunning(boolean value);
|
||||
boolean isTrainingComplete();
|
||||
void start();
|
||||
AtomicInteger getT();
|
||||
NN getCurrent();
|
||||
NN getTarget();
|
||||
void enqueue(Gradient[] gradient, Integer nstep);
|
||||
}
|
|
@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
|||
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
||||
|
@ -48,10 +48,10 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
|||
@Getter
|
||||
final private ACPolicy<O> policy;
|
||||
@Getter
|
||||
final private DataManager dataManager;
|
||||
final private IDataManager dataManager;
|
||||
|
||||
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf,
|
||||
DataManager dataManager) {
|
||||
IDataManager dataManager) {
|
||||
super(conf);
|
||||
this.iActorCritic = iActorCritic;
|
||||
this.mdp = mdp;
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
|
|||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
|
||||
|
@ -44,7 +44,7 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
|||
final private HistoryProcessor.Configuration hpconf;
|
||||
|
||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) {
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||
super(mdp, IActorCritic, conf, dataManager);
|
||||
this.hpconf = hpconf;
|
||||
setHistoryProcessor(hpconf);
|
||||
|
@ -52,13 +52,13 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
|||
|
||||
|
||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) {
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf,
|
||||
dataManager);
|
||||
}
|
||||
|
||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) {
|
||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
|||
import org.deeplearning4j.rl4j.network.ac.*;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
|
||||
|
@ -34,31 +34,31 @@ import org.deeplearning4j.rl4j.util.DataManager;
|
|||
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
|
||||
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
|
||||
DataManager dataManager) {
|
||||
IDataManager dataManager) {
|
||||
super(mdp, IActorCritic, conf, dataManager);
|
||||
}
|
||||
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||
A3CConfiguration conf, DataManager dataManager) {
|
||||
A3CConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||
dataManager);
|
||||
}
|
||||
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
|
||||
DataManager dataManager) {
|
||||
IDataManager dataManager) {
|
||||
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
|
||||
}
|
||||
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||
A3CConfiguration conf, DataManager dataManager) {
|
||||
A3CConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||
dataManager);
|
||||
}
|
||||
|
||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
|
||||
DataManager dataManager) {
|
||||
IDataManager dataManager) {
|
||||
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
|
||||
}
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.policy.ACPolicy;
|
|||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
@ -52,12 +52,12 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
|||
@Getter
|
||||
final protected int threadNumber;
|
||||
@Getter
|
||||
final protected DataManager dataManager;
|
||||
final protected IDataManager dataManager;
|
||||
|
||||
final private Random random;
|
||||
|
||||
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
|
||||
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, DataManager dataManager) {
|
||||
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager) {
|
||||
super(asyncGlobal, threadNumber);
|
||||
this.conf = a3cc;
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
|||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
|
@ -40,13 +40,13 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
|||
@Getter
|
||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
||||
@Getter
|
||||
final private DataManager dataManager;
|
||||
final private IDataManager dataManager;
|
||||
@Getter
|
||||
final private AsyncGlobal<IDQN> asyncGlobal;
|
||||
|
||||
|
||||
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf,
|
||||
DataManager dataManager) {
|
||||
IDataManager dataManager) {
|
||||
super(conf);
|
||||
this.mdp = mdp;
|
||||
this.dataManager = dataManager;
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
|||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
|
||||
|
@ -36,19 +36,19 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
|
|||
final private HistoryProcessor.Configuration hpconf;
|
||||
|
||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
super(mdp, dqn, conf, dataManager);
|
||||
this.hpconf = hpconf;
|
||||
setHistoryProcessor(hpconf);
|
||||
}
|
||||
|
||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||
}
|
||||
|
||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
|||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
|
||||
|
@ -30,18 +30,18 @@ import org.deeplearning4j.rl4j.util.DataManager;
|
|||
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
|
||||
|
||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||
AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
||||
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
super(mdp, dqn, conf, dataManager);
|
||||
}
|
||||
|
||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
||||
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||
dataManager);
|
||||
}
|
||||
|
||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
||||
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
|||
import lombok.Getter;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
|
@ -29,7 +29,7 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
|||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -46,17 +46,17 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
|||
@Getter
|
||||
final protected MDP<O, Integer, DiscreteSpace> mdp;
|
||||
@Getter
|
||||
final protected AsyncGlobal<IDQN> asyncGlobal;
|
||||
final protected IAsyncGlobal<IDQN> asyncGlobal;
|
||||
@Getter
|
||||
final protected int threadNumber;
|
||||
@Getter
|
||||
final protected DataManager dataManager;
|
||||
final protected IDataManager dataManager;
|
||||
|
||||
final private Random random;
|
||||
|
||||
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IDQN> asyncGlobal,
|
||||
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
||||
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber,
|
||||
DataManager dataManager) {
|
||||
IDataManager dataManager) {
|
||||
super(asyncGlobal, threadNumber);
|
||||
this.conf = conf;
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
|||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.Constants;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16.
|
||||
|
@ -51,7 +51,7 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
|
|||
|
||||
while (getStepCounter() < getConfiguration().getMaxStep()) {
|
||||
preEpoch();
|
||||
DataManager.StatEntry statEntry = trainEpoch();
|
||||
IDataManager.StatEntry statEntry = trainEpoch();
|
||||
postEpoch();
|
||||
|
||||
incrementEpoch();
|
||||
|
@ -79,6 +79,6 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
|
|||
|
||||
protected abstract void postEpoch();
|
||||
|
||||
protected abstract DataManager.StatEntry trainEpoch();
|
||||
protected abstract IDataManager.StatEntry trainEpoch();
|
||||
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
|||
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager.StatEntry;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
|
|
@ -30,7 +30,7 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
|||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.Constants;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||
|
@ -53,7 +53,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
@Getter
|
||||
final private QLConfiguration configuration;
|
||||
@Getter
|
||||
final private DataManager dataManager;
|
||||
final private IDataManager dataManager;
|
||||
@Getter
|
||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
||||
@Getter
|
||||
|
@ -72,7 +72,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
|
||||
|
||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||
DataManager dataManager, int epsilonNbStep) {
|
||||
IDataManager dataManager, int epsilonNbStep) {
|
||||
super(conf);
|
||||
this.configuration = conf;
|
||||
this.mdp = mdp;
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
|||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
|
||||
|
@ -35,18 +35,18 @@ public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscret
|
|||
|
||||
|
||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||
QLConfiguration conf, DataManager dataManager) {
|
||||
QLConfiguration conf, IDataManager dataManager) {
|
||||
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
||||
setHistoryProcessor(hpconf);
|
||||
}
|
||||
|
||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, DataManager dataManager) {
|
||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||
}
|
||||
|
||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, DataManager dataManager) {
|
||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
|||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.DataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
|
||||
|
@ -33,18 +33,18 @@ public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscre
|
|||
|
||||
|
||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
|
||||
DataManager dataManager) {
|
||||
IDataManager dataManager) {
|
||||
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep());
|
||||
}
|
||||
|
||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||
QLearning.QLConfiguration conf, DataManager dataManager) {
|
||||
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||
dataManager);
|
||||
}
|
||||
|
||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
||||
QLearning.QLConfiguration conf, DataManager dataManager) {
|
||||
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
||||
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
||||
}
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ import java.util.zip.ZipOutputStream;
|
|||
* the folder for every training and handle every path and model savings
|
||||
*/
|
||||
@Slf4j
|
||||
public class DataManager {
|
||||
public class DataManager implements IDataManager {
|
||||
|
||||
final private String home = System.getProperty("user.home");
|
||||
final private ObjectMapper mapper = new ObjectMapper();
|
||||
|
@ -266,16 +266,6 @@ public class DataManager {
|
|||
|
||||
}
|
||||
|
||||
//In order for jackson to serialize StatEntry
|
||||
//please use Lombok @Value (see QLStatEntry)
|
||||
public interface StatEntry {
|
||||
int getEpochCounter();
|
||||
|
||||
int getStepCounter();
|
||||
|
||||
double getReward();
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
@Value
|
||||
@Builder
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.util;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public interface IDataManager {
|
||||
|
||||
boolean isSaveData();
|
||||
String getVideoDir();
|
||||
void appendStat(StatEntry statEntry) throws IOException;
|
||||
void writeInfo(ILearning iLearning) throws IOException;
|
||||
void save(Learning learning) throws IOException;
|
||||
|
||||
//In order for jackson to serialize StatEntry
|
||||
//please use Lombok @Value (see QLStatEntry)
|
||||
interface StatEntry {
|
||||
int getEpochCounter();
|
||||
|
||||
int getStepCounter();
|
||||
|
||||
double getReward();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,458 @@
|
|||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class AsyncThreadTest {
|
||||
|
||||
@Test
|
||||
public void refac_withoutHistoryProcessor_checkDataManagerCallsRemainTheSame() {
|
||||
// Arrange
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
||||
MockNeuralNet neuralNet = new MockNeuralNet();
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdp = new MockMDP(observationSpace);
|
||||
MockAsyncConfiguration config = new MockAsyncConfiguration(10, 2);
|
||||
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, dataManager);
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(4, dataManager.statEntries.size());
|
||||
|
||||
IDataManager.StatEntry entry = dataManager.statEntries.get(0);
|
||||
assertEquals(2, entry.getStepCounter());
|
||||
assertEquals(0, entry.getEpochCounter());
|
||||
assertEquals(2.0, entry.getReward(), 0.0);
|
||||
|
||||
entry = dataManager.statEntries.get(1);
|
||||
assertEquals(4, entry.getStepCounter());
|
||||
assertEquals(1, entry.getEpochCounter());
|
||||
assertEquals(2.0, entry.getReward(), 0.0);
|
||||
|
||||
entry = dataManager.statEntries.get(2);
|
||||
assertEquals(6, entry.getStepCounter());
|
||||
assertEquals(2, entry.getEpochCounter());
|
||||
assertEquals(2.0, entry.getReward(), 0.0);
|
||||
|
||||
entry = dataManager.statEntries.get(3);
|
||||
assertEquals(8, entry.getStepCounter());
|
||||
assertEquals(3, entry.getEpochCounter());
|
||||
assertEquals(2.0, entry.getReward(), 0.0);
|
||||
|
||||
assertEquals(0, dataManager.isSaveDataCallCount);
|
||||
assertEquals(0, dataManager.getVideoDirCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void refac_withHistoryProcessor_isSaveFalse_checkDataManagerCallsRemainTheSame() {
|
||||
// Arrange
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
||||
MockNeuralNet neuralNet = new MockNeuralNet();
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdp = new MockMDP(observationSpace);
|
||||
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
|
||||
|
||||
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
||||
.build();
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
|
||||
|
||||
|
||||
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
|
||||
sut.setHistoryProcessor(hp);
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(9, dataManager.statEntries.size());
|
||||
|
||||
for(int i = 0; i < 9; ++i) {
|
||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||
assertEquals(i + 1, entry.getStepCounter());
|
||||
assertEquals(i, entry.getEpochCounter());
|
||||
assertEquals(1.0, entry.getReward(), 0.0);
|
||||
}
|
||||
|
||||
assertEquals(10, dataManager.isSaveDataCallCount);
|
||||
assertEquals(0, dataManager.getVideoDirCallCount);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void refac_withHistoryProcessor_isSaveTrue_checkDataManagerCallsRemainTheSame() {
|
||||
// Arrange
|
||||
MockDataManager dataManager = new MockDataManager(true);
|
||||
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
||||
MockNeuralNet neuralNet = new MockNeuralNet();
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdp = new MockMDP(observationSpace);
|
||||
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
|
||||
|
||||
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
||||
.build();
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
|
||||
|
||||
|
||||
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
|
||||
sut.setHistoryProcessor(hp);
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(9, dataManager.statEntries.size());
|
||||
|
||||
for(int i = 0; i < 9; ++i) {
|
||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||
assertEquals(i + 1, entry.getStepCounter());
|
||||
assertEquals(i, entry.getEpochCounter());
|
||||
assertEquals(1.0, entry.getReward(), 0.0);
|
||||
}
|
||||
|
||||
assertEquals(1, dataManager.isSaveDataCallCount);
|
||||
assertEquals(1, dataManager.getVideoDirCallCount);
|
||||
}
|
||||
|
||||
public static class MockAsyncGlobal implements IAsyncGlobal {
|
||||
|
||||
private final int maxLoops;
|
||||
private int currentLoop = 0;
|
||||
|
||||
public MockAsyncGlobal(int maxLoops) {
|
||||
|
||||
this.maxLoops = maxLoops;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRunning() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setRunning(boolean value) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTrainingComplete() {
|
||||
return ++currentLoop >= maxLoops;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public AtomicInteger getT() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet getCurrent() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet getTarget() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockAsyncThread extends AsyncThread {
|
||||
|
||||
IAsyncGlobal asyncGlobal;
|
||||
private final MockNeuralNet neuralNet;
|
||||
private final MDP mdp;
|
||||
private final AsyncConfiguration conf;
|
||||
private final IDataManager dataManager;
|
||||
|
||||
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) {
|
||||
super(asyncGlobal, threadNumber);
|
||||
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.neuralNet = neuralNet;
|
||||
this.mdp = mdp;
|
||||
this.conf = conf;
|
||||
this.dataManager = dataManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NeuralNet getCurrent() {
|
||||
return neuralNet;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int getThreadNumber() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IAsyncGlobal getAsyncGlobal() {
|
||||
return asyncGlobal;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MDP getMdp() {
|
||||
return mdp;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AsyncConfiguration getConf() {
|
||||
return conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IDataManager getDataManager() {
|
||||
return dataManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Policy getPolicy(NeuralNet net) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) {
|
||||
return new SubEpochReturn(1, null, 1.0, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockNeuralNet implements NeuralNet {
|
||||
|
||||
@Override
|
||||
public NeuralNetwork[] getNeuralNetworks() {
|
||||
return new NeuralNetwork[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRecurrent() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] outputAll(INDArray batch) {
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet clone() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NeuralNet from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
||||
return new Gradient[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(INDArray input, INDArray[] labels) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void applyGradient(Gradient[] gradients, int batchSize) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getLatestScore() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(OutputStream os) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(String filename) throws IOException {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockEncodable implements Encodable {
|
||||
|
||||
private final int value;
|
||||
|
||||
public MockEncodable(int value) {
|
||||
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] toArray() {
|
||||
return new double[] { value };
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockObservationSpace implements ObservationSpace {
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int[] getShape() {
|
||||
return new int[] { 1 };
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getLow() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getHigh() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||
|
||||
private final DiscreteSpace actionSpace;
|
||||
private int currentObsValue = 0;
|
||||
private final ObservationSpace observationSpace;
|
||||
|
||||
public MockMDP(ObservationSpace observationSpace) {
|
||||
actionSpace = new DiscreteSpace(5);
|
||||
this.observationSpace = observationSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObservationSpace getObservationSpace() {
|
||||
return observationSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DiscreteSpace getActionSpace() {
|
||||
return actionSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MockEncodable reset() {
|
||||
return new MockEncodable(++currentObsValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public StepReply<MockEncodable> step(Integer obs) {
|
||||
return new StepReply<MockEncodable>(new MockEncodable(obs), (double)obs, isDone(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isDone() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MDP newInstance() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockAsyncConfiguration implements AsyncConfiguration {
|
||||
|
||||
private final int nStep;
|
||||
private final int maxEpochStep;
|
||||
|
||||
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
|
||||
this.nStep = nStep;
|
||||
|
||||
this.maxEpochStep = maxEpochStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSeed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxEpochStep() {
|
||||
return maxEpochStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxStep() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumThread() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNstep() {
|
||||
return nStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getTargetDqnUpdateFreq() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getUpdateStart() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getRewardFactor() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getGamma() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getErrorClamp() {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Value;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class SyncLearningTest {
|
||||
|
||||
@Test
|
||||
public void refac_checkDataManagerCallsRemainTheSame() {
|
||||
// Arrange
|
||||
MockLConfiguration lconfig = new MockLConfiguration(10);
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
MockSyncLearning sut = new MockSyncLearning(lconfig, dataManager, 2);
|
||||
|
||||
// Act
|
||||
sut.train();
|
||||
|
||||
assertEquals(10, dataManager.statEntries.size());
|
||||
for(int i = 0; i < 10; ++i) {
|
||||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||
assertEquals(2, entry.getEpochCounter());
|
||||
assertEquals(i+1, entry.getStepCounter());
|
||||
assertEquals(1.0, entry.getReward(), 0.0);
|
||||
|
||||
}
|
||||
assertEquals(0, dataManager.isSaveDataCallCount);
|
||||
assertEquals(0, dataManager.getVideoDirCallCount);
|
||||
assertEquals(11, dataManager.writeInfoCallCount);
|
||||
assertEquals(1, dataManager.saveCallCount);
|
||||
}
|
||||
|
||||
public static class MockSyncLearning extends SyncLearning {
|
||||
|
||||
private final IDataManager dataManager;
|
||||
private LConfiguration conf;
|
||||
private final int epochSteps;
|
||||
|
||||
public MockSyncLearning(LConfiguration conf, IDataManager dataManager, int epochSteps) {
|
||||
super(conf);
|
||||
this.dataManager = dataManager;
|
||||
this.conf = conf;
|
||||
this.epochSteps = epochSteps;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void preEpoch() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void postEpoch() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IDataManager.StatEntry trainEpoch() {
|
||||
setStepCounter(getStepCounter() + 1);
|
||||
return new MockStatEntry(epochSteps, getStepCounter(), 1.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IDataManager getDataManager() {
|
||||
return dataManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet getNeuralNet() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Policy getPolicy() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public LConfiguration getConfiguration() {
|
||||
return conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MDP getMdp() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockLConfiguration implements ILearning.LConfiguration {
|
||||
|
||||
private final int maxStep;
|
||||
|
||||
public MockLConfiguration(int maxStep) {
|
||||
this.maxStep = maxStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSeed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxEpochStep() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxStep() {
|
||||
return maxStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getGamma() {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
@Value
|
||||
public static class MockStatEntry implements IDataManager.StatEntry {
|
||||
int epochCounter;
|
||||
int stepCounter;
|
||||
double reward;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MockDataManager implements IDataManager {
|
||||
|
||||
private final boolean isSaveData;
|
||||
public List<StatEntry> statEntries = new ArrayList<>();
|
||||
public int isSaveDataCallCount = 0;
|
||||
public int getVideoDirCallCount = 0;
|
||||
public int writeInfoCallCount = 0;
|
||||
public int saveCallCount = 0;
|
||||
|
||||
public MockDataManager(boolean isSaveData) {
|
||||
this.isSaveData = isSaveData;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSaveData() {
|
||||
++isSaveDataCallCount;
|
||||
return isSaveData;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getVideoDir() {
|
||||
++getVideoDirCallCount;
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void appendStat(StatEntry statEntry) throws IOException {
|
||||
statEntries.add(statEntry);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeInfo(ILearning iLearning) throws IOException {
|
||||
++writeInfoCallCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(Learning learning) throws IOException {
|
||||
++saveCallCount;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
public class MockHistoryProcessor implements IHistoryProcessor {
|
||||
|
||||
private final Configuration config;
|
||||
|
||||
public MockHistoryProcessor(Configuration config) {
|
||||
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Configuration getConf() {
|
||||
return config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] getHistory() {
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void record(INDArray image) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(INDArray image) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void startMonitor(String filename, int[] shape) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stopMonitor() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isMonitoring() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getScale() {
|
||||
return 0;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue