Added interface IDataManager (#8034)
Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
22993f853f
commit
87d2b2cd3d
|
@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
||||||
return nshape;
|
return nshape;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract DataManager getDataManager();
|
protected abstract IDataManager getDataManager();
|
||||||
|
|
||||||
public abstract NN getNeuralNet();
|
public abstract NN getNeuralNet();
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class AsyncGlobal<NN extends NeuralNet> extends Thread {
|
public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final private NN current;
|
final private NN current;
|
||||||
|
|
|
@ -45,7 +45,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
|
|
||||||
protected abstract AsyncThread newThread(int i);
|
protected abstract AsyncThread newThread(int i);
|
||||||
|
|
||||||
protected abstract AsyncGlobal<NN> getAsyncGlobal();
|
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
||||||
|
|
||||||
protected void startGlobalThread() {
|
protected void startGlobalThread() {
|
||||||
getAsyncGlobal().start();
|
getAsyncGlobal().start();
|
||||||
|
|
|
@ -31,7 +31,7 @@ import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.Constants;
|
import org.deeplearning4j.rl4j.util.Constants;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
|
@ -57,7 +57,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
@Getter
|
@Getter
|
||||||
private int lastMonitor = -Constants.MONITOR_FREQ;
|
private int lastMonitor = -Constants.MONITOR_FREQ;
|
||||||
|
|
||||||
public AsyncThread(AsyncGlobal<NN> asyncGlobal, int threadNumber) {
|
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber) {
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
if (length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
if (length >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||||
postEpoch();
|
postEpoch();
|
||||||
|
|
||||||
DataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score);
|
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, rewards, length, score);
|
||||||
getDataManager().appendStat(statEntry);
|
getDataManager().appendStat(statEntry);
|
||||||
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
||||||
|
|
||||||
|
@ -136,13 +136,13 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
|
|
||||||
protected abstract int getThreadNumber();
|
protected abstract int getThreadNumber();
|
||||||
|
|
||||||
protected abstract AsyncGlobal<NN> getAsyncGlobal();
|
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
||||||
|
|
||||||
protected abstract MDP<O, A, AS> getMdp();
|
protected abstract MDP<O, A, AS> getMdp();
|
||||||
|
|
||||||
protected abstract AsyncConfiguration getConf();
|
protected abstract AsyncConfiguration getConf();
|
||||||
|
|
||||||
protected abstract DataManager getDataManager();
|
protected abstract IDataManager getDataManager();
|
||||||
|
|
||||||
protected abstract Policy<O, A> getPolicy(NN net);
|
protected abstract Policy<O, A> getPolicy(NN net);
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Value
|
@Value
|
||||||
public static class AsyncStatEntry implements DataManager.StatEntry {
|
public static class AsyncStatEntry implements IDataManager.StatEntry {
|
||||||
int stepCounter;
|
int stepCounter;
|
||||||
int epochCounter;
|
int epochCounter;
|
||||||
double reward;
|
double reward;
|
||||||
|
|
|
@ -44,7 +44,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
||||||
@Getter
|
@Getter
|
||||||
private NN current;
|
private NN current;
|
||||||
|
|
||||||
public AsyncThreadDiscrete(AsyncGlobal<NN> asyncGlobal, int threadNumber) {
|
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber) {
|
||||||
super(asyncGlobal, threadNumber);
|
super(asyncGlobal, threadNumber);
|
||||||
synchronized (asyncGlobal) {
|
synchronized (asyncGlobal) {
|
||||||
current = (NN)asyncGlobal.getCurrent().clone();
|
current = (NN)asyncGlobal.getCurrent().clone();
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
public interface IAsyncGlobal<NN extends NeuralNet> {
|
||||||
|
boolean isRunning();
|
||||||
|
void setRunning(boolean value);
|
||||||
|
boolean isTrainingComplete();
|
||||||
|
void start();
|
||||||
|
AtomicInteger getT();
|
||||||
|
NN getCurrent();
|
||||||
|
NN getTarget();
|
||||||
|
void enqueue(Gradient[] gradient, Integer nstep);
|
||||||
|
}
|
|
@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
||||||
|
@ -48,10 +48,10 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
||||||
@Getter
|
@Getter
|
||||||
final private ACPolicy<O> policy;
|
final private ACPolicy<O> policy;
|
||||||
@Getter
|
@Getter
|
||||||
final private DataManager dataManager;
|
final private IDataManager dataManager;
|
||||||
|
|
||||||
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf,
|
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf,
|
||||||
DataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
super(conf);
|
super(conf);
|
||||||
this.iActorCritic = iActorCritic;
|
this.iActorCritic = iActorCritic;
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
|
||||||
|
@ -44,7 +44,7 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
||||||
final private HistoryProcessor.Configuration hpconf;
|
final private HistoryProcessor.Configuration hpconf;
|
||||||
|
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
super(mdp, IActorCritic, conf, dataManager);
|
super(mdp, IActorCritic, conf, dataManager);
|
||||||
this.hpconf = hpconf;
|
this.hpconf = hpconf;
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
|
@ -52,13 +52,13 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
||||||
|
|
||||||
|
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf,
|
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, DataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
|
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.ac.*;
|
import org.deeplearning4j.rl4j.network.ac.*;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
|
||||||
|
@ -34,31 +34,31 @@ import org.deeplearning4j.rl4j.util.DataManager;
|
||||||
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
|
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
|
||||||
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
|
||||||
DataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
super(mdp, IActorCritic, conf, dataManager);
|
super(mdp, IActorCritic, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||||
A3CConfiguration conf, DataManager dataManager) {
|
A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
|
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
|
||||||
DataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
|
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
A3CConfiguration conf, DataManager dataManager) {
|
A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
|
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
|
||||||
DataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
|
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
@ -52,12 +52,12 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
@Getter
|
@Getter
|
||||||
final protected int threadNumber;
|
final protected int threadNumber;
|
||||||
@Getter
|
@Getter
|
||||||
final protected DataManager dataManager;
|
final protected IDataManager dataManager;
|
||||||
|
|
||||||
final private Random random;
|
final private Random random;
|
||||||
|
|
||||||
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
|
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
|
||||||
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, DataManager dataManager) {
|
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager) {
|
||||||
super(asyncGlobal, threadNumber);
|
super(asyncGlobal, threadNumber);
|
||||||
this.conf = a3cc;
|
this.conf = a3cc;
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
|
@ -40,13 +40,13 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
@Getter
|
@Getter
|
||||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
final private MDP<O, Integer, DiscreteSpace> mdp;
|
||||||
@Getter
|
@Getter
|
||||||
final private DataManager dataManager;
|
final private IDataManager dataManager;
|
||||||
@Getter
|
@Getter
|
||||||
final private AsyncGlobal<IDQN> asyncGlobal;
|
final private AsyncGlobal<IDQN> asyncGlobal;
|
||||||
|
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf,
|
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf,
|
||||||
DataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
super(conf);
|
super(conf);
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.dataManager = dataManager;
|
this.dataManager = dataManager;
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
|
||||||
|
@ -36,19 +36,19 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
|
||||||
final private HistoryProcessor.Configuration hpconf;
|
final private HistoryProcessor.Configuration hpconf;
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
super(mdp, dqn, conf, dataManager);
|
super(mdp, dqn, conf, dataManager);
|
||||||
this.hpconf = hpconf;
|
this.hpconf = hpconf;
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16.
|
||||||
|
@ -30,18 +30,18 @@ import org.deeplearning4j.rl4j.util.DataManager;
|
||||||
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
|
public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends AsyncNStepQLearningDiscrete<O> {
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
super(mdp, dqn, conf, dataManager);
|
super(mdp, dqn, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, DataManager dataManager) {
|
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
@ -29,7 +29,7 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -46,17 +46,17 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
@Getter
|
@Getter
|
||||||
final protected MDP<O, Integer, DiscreteSpace> mdp;
|
final protected MDP<O, Integer, DiscreteSpace> mdp;
|
||||||
@Getter
|
@Getter
|
||||||
final protected AsyncGlobal<IDQN> asyncGlobal;
|
final protected IAsyncGlobal<IDQN> asyncGlobal;
|
||||||
@Getter
|
@Getter
|
||||||
final protected int threadNumber;
|
final protected int threadNumber;
|
||||||
@Getter
|
@Getter
|
||||||
final protected DataManager dataManager;
|
final protected IDataManager dataManager;
|
||||||
|
|
||||||
final private Random random;
|
final private Random random;
|
||||||
|
|
||||||
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IDQN> asyncGlobal,
|
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
||||||
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber,
|
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber,
|
||||||
DataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
super(asyncGlobal, threadNumber);
|
super(asyncGlobal, threadNumber);
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.Constants;
|
import org.deeplearning4j.rl4j.util.Constants;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/3/16.
|
||||||
|
@ -51,7 +51,7 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
|
||||||
|
|
||||||
while (getStepCounter() < getConfiguration().getMaxStep()) {
|
while (getStepCounter() < getConfiguration().getMaxStep()) {
|
||||||
preEpoch();
|
preEpoch();
|
||||||
DataManager.StatEntry statEntry = trainEpoch();
|
IDataManager.StatEntry statEntry = trainEpoch();
|
||||||
postEpoch();
|
postEpoch();
|
||||||
|
|
||||||
incrementEpoch();
|
incrementEpoch();
|
||||||
|
@ -79,6 +79,6 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
|
||||||
|
|
||||||
protected abstract void postEpoch();
|
protected abstract void postEpoch();
|
||||||
|
|
||||||
protected abstract DataManager.StatEntry trainEpoch();
|
protected abstract IDataManager.StatEntry trainEpoch();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager.StatEntry;
|
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
|
|
@ -30,7 +30,7 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.Constants;
|
import org.deeplearning4j.rl4j.util.Constants;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
|
@ -53,7 +53,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
@Getter
|
@Getter
|
||||||
final private QLConfiguration configuration;
|
final private QLConfiguration configuration;
|
||||||
@Getter
|
@Getter
|
||||||
final private DataManager dataManager;
|
final private IDataManager dataManager;
|
||||||
@Getter
|
@Getter
|
||||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
final private MDP<O, Integer, DiscreteSpace> mdp;
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -72,7 +72,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
|
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||||
DataManager dataManager, int epsilonNbStep) {
|
IDataManager dataManager, int epsilonNbStep) {
|
||||||
super(conf);
|
super(conf);
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
|
|
|
@ -23,7 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
|
||||||
|
@ -35,18 +35,18 @@ public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscret
|
||||||
|
|
||||||
|
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||||
QLConfiguration conf, DataManager dataManager) {
|
QLConfiguration conf, IDataManager dataManager) {
|
||||||
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, DataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, DataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16.
|
||||||
|
@ -33,18 +33,18 @@ public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscre
|
||||||
|
|
||||||
|
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf,
|
||||||
DataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep());
|
super(mdp, dqn, conf, dataManager, conf.getEpsilonNbStep());
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
QLearning.QLConfiguration conf, DataManager dataManager) {
|
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
||||||
QLearning.QLConfiguration conf, DataManager dataManager) {
|
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ import java.util.zip.ZipOutputStream;
|
||||||
* the folder for every training and handle every path and model savings
|
* the folder for every training and handle every path and model savings
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class DataManager {
|
public class DataManager implements IDataManager {
|
||||||
|
|
||||||
final private String home = System.getProperty("user.home");
|
final private String home = System.getProperty("user.home");
|
||||||
final private ObjectMapper mapper = new ObjectMapper();
|
final private ObjectMapper mapper = new ObjectMapper();
|
||||||
|
@ -266,16 +266,6 @@ public class DataManager {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//In order for jackson to serialize StatEntry
|
|
||||||
//please use Lombok @Value (see QLStatEntry)
|
|
||||||
public interface StatEntry {
|
|
||||||
int getEpochCounter();
|
|
||||||
|
|
||||||
int getStepCounter();
|
|
||||||
|
|
||||||
double getReward();
|
|
||||||
}
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Value
|
@Value
|
||||||
@Builder
|
@Builder
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.util;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public interface IDataManager {
|
||||||
|
|
||||||
|
boolean isSaveData();
|
||||||
|
String getVideoDir();
|
||||||
|
void appendStat(StatEntry statEntry) throws IOException;
|
||||||
|
void writeInfo(ILearning iLearning) throws IOException;
|
||||||
|
void save(Learning learning) throws IOException;
|
||||||
|
|
||||||
|
//In order for jackson to serialize StatEntry
|
||||||
|
//please use Lombok @Value (see QLStatEntry)
|
||||||
|
interface StatEntry {
|
||||||
|
int getEpochCounter();
|
||||||
|
|
||||||
|
int getStepCounter();
|
||||||
|
|
||||||
|
double getReward();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,458 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
|
import org.deeplearning4j.gym.StepReply;
|
||||||
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class AsyncThreadTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void refac_withoutHistoryProcessor_checkDataManagerCallsRemainTheSame() {
|
||||||
|
// Arrange
|
||||||
|
MockDataManager dataManager = new MockDataManager(false);
|
||||||
|
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
||||||
|
MockNeuralNet neuralNet = new MockNeuralNet();
|
||||||
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
|
MockMDP mdp = new MockMDP(observationSpace);
|
||||||
|
MockAsyncConfiguration config = new MockAsyncConfiguration(10, 2);
|
||||||
|
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, dataManager);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(4, dataManager.statEntries.size());
|
||||||
|
|
||||||
|
IDataManager.StatEntry entry = dataManager.statEntries.get(0);
|
||||||
|
assertEquals(2, entry.getStepCounter());
|
||||||
|
assertEquals(0, entry.getEpochCounter());
|
||||||
|
assertEquals(2.0, entry.getReward(), 0.0);
|
||||||
|
|
||||||
|
entry = dataManager.statEntries.get(1);
|
||||||
|
assertEquals(4, entry.getStepCounter());
|
||||||
|
assertEquals(1, entry.getEpochCounter());
|
||||||
|
assertEquals(2.0, entry.getReward(), 0.0);
|
||||||
|
|
||||||
|
entry = dataManager.statEntries.get(2);
|
||||||
|
assertEquals(6, entry.getStepCounter());
|
||||||
|
assertEquals(2, entry.getEpochCounter());
|
||||||
|
assertEquals(2.0, entry.getReward(), 0.0);
|
||||||
|
|
||||||
|
entry = dataManager.statEntries.get(3);
|
||||||
|
assertEquals(8, entry.getStepCounter());
|
||||||
|
assertEquals(3, entry.getEpochCounter());
|
||||||
|
assertEquals(2.0, entry.getReward(), 0.0);
|
||||||
|
|
||||||
|
assertEquals(0, dataManager.isSaveDataCallCount);
|
||||||
|
assertEquals(0, dataManager.getVideoDirCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void refac_withHistoryProcessor_isSaveFalse_checkDataManagerCallsRemainTheSame() {
|
||||||
|
// Arrange
|
||||||
|
MockDataManager dataManager = new MockDataManager(false);
|
||||||
|
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
||||||
|
MockNeuralNet neuralNet = new MockNeuralNet();
|
||||||
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
|
MockMDP mdp = new MockMDP(observationSpace);
|
||||||
|
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
|
||||||
|
|
||||||
|
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
||||||
|
.build();
|
||||||
|
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
|
||||||
|
|
||||||
|
|
||||||
|
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
|
||||||
|
sut.setHistoryProcessor(hp);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(9, dataManager.statEntries.size());
|
||||||
|
|
||||||
|
for(int i = 0; i < 9; ++i) {
|
||||||
|
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||||
|
assertEquals(i + 1, entry.getStepCounter());
|
||||||
|
assertEquals(i, entry.getEpochCounter());
|
||||||
|
assertEquals(1.0, entry.getReward(), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(10, dataManager.isSaveDataCallCount);
|
||||||
|
assertEquals(0, dataManager.getVideoDirCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void refac_withHistoryProcessor_isSaveTrue_checkDataManagerCallsRemainTheSame() {
|
||||||
|
// Arrange
|
||||||
|
MockDataManager dataManager = new MockDataManager(true);
|
||||||
|
MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(10);
|
||||||
|
MockNeuralNet neuralNet = new MockNeuralNet();
|
||||||
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
|
MockMDP mdp = new MockMDP(observationSpace);
|
||||||
|
MockAsyncConfiguration asyncConfig = new MockAsyncConfiguration(10, 2);
|
||||||
|
|
||||||
|
IHistoryProcessor.Configuration hpConfig = IHistoryProcessor.Configuration.builder()
|
||||||
|
.build();
|
||||||
|
MockHistoryProcessor hp = new MockHistoryProcessor(hpConfig);
|
||||||
|
|
||||||
|
|
||||||
|
MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, asyncConfig, dataManager);
|
||||||
|
sut.setHistoryProcessor(hp);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(9, dataManager.statEntries.size());
|
||||||
|
|
||||||
|
for(int i = 0; i < 9; ++i) {
|
||||||
|
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||||
|
assertEquals(i + 1, entry.getStepCounter());
|
||||||
|
assertEquals(i, entry.getEpochCounter());
|
||||||
|
assertEquals(1.0, entry.getReward(), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(1, dataManager.isSaveDataCallCount);
|
||||||
|
assertEquals(1, dataManager.getVideoDirCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MockAsyncGlobal implements IAsyncGlobal {
|
||||||
|
|
||||||
|
private final int maxLoops;
|
||||||
|
private int currentLoop = 0;
|
||||||
|
|
||||||
|
public MockAsyncGlobal(int maxLoops) {
|
||||||
|
|
||||||
|
this.maxLoops = maxLoops;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isRunning() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setRunning(boolean value) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isTrainingComplete() {
|
||||||
|
return ++currentLoop >= maxLoops;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void start() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AtomicInteger getT() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNet getCurrent() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNet getTarget() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void enqueue(Gradient[] gradient, Integer nstep) {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MockAsyncThread extends AsyncThread {
|
||||||
|
|
||||||
|
IAsyncGlobal asyncGlobal;
|
||||||
|
private final MockNeuralNet neuralNet;
|
||||||
|
private final MDP mdp;
|
||||||
|
private final AsyncConfiguration conf;
|
||||||
|
private final IDataManager dataManager;
|
||||||
|
|
||||||
|
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) {
|
||||||
|
super(asyncGlobal, threadNumber);
|
||||||
|
|
||||||
|
this.asyncGlobal = asyncGlobal;
|
||||||
|
this.neuralNet = neuralNet;
|
||||||
|
this.mdp = mdp;
|
||||||
|
this.conf = conf;
|
||||||
|
this.dataManager = dataManager;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NeuralNet getCurrent() {
|
||||||
|
return neuralNet;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected int getThreadNumber() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected IAsyncGlobal getAsyncGlobal() {
|
||||||
|
return asyncGlobal;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected MDP getMdp() {
|
||||||
|
return mdp;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected AsyncConfiguration getConf() {
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected IDataManager getDataManager() {
|
||||||
|
return dataManager;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Policy getPolicy(NeuralNet net) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) {
|
||||||
|
return new SubEpochReturn(1, null, 1.0, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MockNeuralNet implements NeuralNet {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNetwork[] getNeuralNetworks() {
|
||||||
|
return new NeuralNetwork[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isRecurrent() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
|
return new INDArray[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNet clone() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(NeuralNet from) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
||||||
|
return new Gradient[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(INDArray input, INDArray[] labels) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void applyGradient(Gradient[] gradients, int batchSize) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getLatestScore() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(OutputStream os) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(String filename) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MockEncodable implements Encodable {
|
||||||
|
|
||||||
|
private final int value;
|
||||||
|
|
||||||
|
public MockEncodable(int value) {
|
||||||
|
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double[] toArray() {
|
||||||
|
return new double[] { value };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MockObservationSpace implements ObservationSpace {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int[] getShape() {
|
||||||
|
return new int[] { 1 };
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getLow() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getHigh() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||||
|
|
||||||
|
private final DiscreteSpace actionSpace;
|
||||||
|
private int currentObsValue = 0;
|
||||||
|
private final ObservationSpace observationSpace;
|
||||||
|
|
||||||
|
public MockMDP(ObservationSpace observationSpace) {
|
||||||
|
actionSpace = new DiscreteSpace(5);
|
||||||
|
this.observationSpace = observationSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ObservationSpace getObservationSpace() {
|
||||||
|
return observationSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DiscreteSpace getActionSpace() {
|
||||||
|
return actionSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MockEncodable reset() {
|
||||||
|
return new MockEncodable(++currentObsValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public StepReply<MockEncodable> step(Integer obs) {
|
||||||
|
return new StepReply<MockEncodable>(new MockEncodable(obs), (double)obs, isDone(), null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isDone() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MDP newInstance() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MockAsyncConfiguration implements AsyncConfiguration {
|
||||||
|
|
||||||
|
private final int nStep;
|
||||||
|
private final int maxEpochStep;
|
||||||
|
|
||||||
|
public MockAsyncConfiguration(int nStep, int maxEpochStep) {
|
||||||
|
this.nStep = nStep;
|
||||||
|
|
||||||
|
this.maxEpochStep = maxEpochStep;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getSeed() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getMaxEpochStep() {
|
||||||
|
return maxEpochStep;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getMaxStep() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumThread() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNstep() {
|
||||||
|
return nStep;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getTargetDqnUpdateFreq() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getUpdateStart() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getRewardFactor() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getGamma() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getErrorClamp() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,132 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Value;
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class SyncLearningTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void refac_checkDataManagerCallsRemainTheSame() {
|
||||||
|
// Arrange
|
||||||
|
MockLConfiguration lconfig = new MockLConfiguration(10);
|
||||||
|
MockDataManager dataManager = new MockDataManager(false);
|
||||||
|
MockSyncLearning sut = new MockSyncLearning(lconfig, dataManager, 2);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.train();
|
||||||
|
|
||||||
|
assertEquals(10, dataManager.statEntries.size());
|
||||||
|
for(int i = 0; i < 10; ++i) {
|
||||||
|
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||||
|
assertEquals(2, entry.getEpochCounter());
|
||||||
|
assertEquals(i+1, entry.getStepCounter());
|
||||||
|
assertEquals(1.0, entry.getReward(), 0.0);
|
||||||
|
|
||||||
|
}
|
||||||
|
assertEquals(0, dataManager.isSaveDataCallCount);
|
||||||
|
assertEquals(0, dataManager.getVideoDirCallCount);
|
||||||
|
assertEquals(11, dataManager.writeInfoCallCount);
|
||||||
|
assertEquals(1, dataManager.saveCallCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MockSyncLearning extends SyncLearning {
|
||||||
|
|
||||||
|
private final IDataManager dataManager;
|
||||||
|
private LConfiguration conf;
|
||||||
|
private final int epochSteps;
|
||||||
|
|
||||||
|
public MockSyncLearning(LConfiguration conf, IDataManager dataManager, int epochSteps) {
|
||||||
|
super(conf);
|
||||||
|
this.dataManager = dataManager;
|
||||||
|
this.conf = conf;
|
||||||
|
this.epochSteps = epochSteps;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void preEpoch() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void postEpoch() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected IDataManager.StatEntry trainEpoch() {
|
||||||
|
setStepCounter(getStepCounter() + 1);
|
||||||
|
return new MockStatEntry(epochSteps, getStepCounter(), 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected IDataManager getDataManager() {
|
||||||
|
return dataManager;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNet getNeuralNet() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Policy getPolicy() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LConfiguration getConfiguration() {
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MDP getMdp() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MockLConfiguration implements ILearning.LConfiguration {
|
||||||
|
|
||||||
|
private final int maxStep;
|
||||||
|
|
||||||
|
public MockLConfiguration(int maxStep) {
|
||||||
|
this.maxStep = maxStep;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getSeed() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getMaxEpochStep() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getMaxStep() {
|
||||||
|
return maxStep;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getGamma() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Value
|
||||||
|
public static class MockStatEntry implements IDataManager.StatEntry {
|
||||||
|
int epochCounter;
|
||||||
|
int stepCounter;
|
||||||
|
double reward;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class MockDataManager implements IDataManager {
|
||||||
|
|
||||||
|
private final boolean isSaveData;
|
||||||
|
public List<StatEntry> statEntries = new ArrayList<>();
|
||||||
|
public int isSaveDataCallCount = 0;
|
||||||
|
public int getVideoDirCallCount = 0;
|
||||||
|
public int writeInfoCallCount = 0;
|
||||||
|
public int saveCallCount = 0;
|
||||||
|
|
||||||
|
public MockDataManager(boolean isSaveData) {
|
||||||
|
this.isSaveData = isSaveData;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSaveData() {
|
||||||
|
++isSaveDataCallCount;
|
||||||
|
return isSaveData;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getVideoDir() {
|
||||||
|
++getVideoDirCallCount;
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void appendStat(StatEntry statEntry) throws IOException {
|
||||||
|
statEntries.add(statEntry);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeInfo(ILearning iLearning) throws IOException {
|
||||||
|
++writeInfoCallCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(Learning learning) throws IOException {
|
||||||
|
++saveCallCount;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,54 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
public class MockHistoryProcessor implements IHistoryProcessor {
|
||||||
|
|
||||||
|
private final Configuration config;
|
||||||
|
|
||||||
|
public MockHistoryProcessor(Configuration config) {
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Configuration getConf() {
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray[] getHistory() {
|
||||||
|
return new INDArray[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void record(INDArray image) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(INDArray image) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void startMonitor(String filename, int[] shape) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void stopMonitor() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isMonitoring() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getScale() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue