RL4J: Use Nd4j Random instead of java.util.Random (#8282)
* Changed to use Nd4j Random instead of java.util.Random Signed-off-by: unknown <aboulang2002@yahoo.com> * Changed to use Nd4j.getRandom() instead of the factory Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
2d750b69e5
commit
171ce51f46
|
@ -42,7 +42,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> ex
|
||||||
|
|
||||||
interface LConfiguration {
|
interface LConfiguration {
|
||||||
|
|
||||||
int getSeed();
|
Integer getSeed();
|
||||||
|
|
||||||
int getMaxEpochStep();
|
int getMaxEpochStep();
|
||||||
|
|
||||||
|
|
|
@ -29,8 +29,6 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/27/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/27/16.
|
||||||
*
|
*
|
||||||
|
@ -43,8 +41,7 @@ import java.util.Random;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
|
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
|
||||||
@Getter
|
|
||||||
final private Random random;
|
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private int stepCounter = 0;
|
private int stepCounter = 0;
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
|
@ -52,10 +49,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private IHistoryProcessor historyProcessor = null;
|
private IHistoryProcessor historyProcessor = null;
|
||||||
|
|
||||||
public Learning(LConfiguration conf) {
|
|
||||||
random = new Random(conf.getSeed());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Integer getMaxAction(INDArray vector) {
|
public static Integer getMaxAction(INDArray vector) {
|
||||||
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
|
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
*/
|
*/
|
||||||
public interface AsyncConfiguration extends ILearning.LConfiguration {
|
public interface AsyncConfiguration extends ILearning.LConfiguration {
|
||||||
|
|
||||||
int getSeed();
|
Integer getSeed();
|
||||||
|
|
||||||
int getMaxEpochStep();
|
int getMaxEpochStep();
|
||||||
|
|
||||||
|
|
|
@ -42,10 +42,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
@Getter(AccessLevel.PROTECTED)
|
@Getter(AccessLevel.PROTECTED)
|
||||||
private final TrainingListenerList listeners = new TrainingListenerList();
|
private final TrainingListenerList listeners = new TrainingListenerList();
|
||||||
|
|
||||||
public AsyncLearning(AsyncConfiguration conf) {
|
|
||||||
super(conf);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a {@link TrainingListener} listener at the end of the listener list.
|
* Add a {@link TrainingListener} listener at the end of the listener list.
|
||||||
*
|
*
|
||||||
|
|
|
@ -26,6 +26,8 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
||||||
|
@ -48,13 +50,19 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
||||||
final private ACPolicy<O> policy;
|
final private ACPolicy<O> policy;
|
||||||
|
|
||||||
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) {
|
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) {
|
||||||
super(conf);
|
|
||||||
this.iActorCritic = iActorCritic;
|
this.iActorCritic = iActorCritic;
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
policy = new ACPolicy<>(iActorCritic, getRandom());
|
|
||||||
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
|
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
|
||||||
|
Integer seed = conf.getSeed();
|
||||||
|
Random rnd = Nd4j.getRandom();
|
||||||
|
if(seed != null) {
|
||||||
|
mdp.getActionSpace().setSeed(seed);
|
||||||
|
rnd.setSeed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
policy = new ACPolicy<>(iActorCritic, rnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected AsyncThread newThread(int i, int deviceNum) {
|
protected AsyncThread newThread(int i, int deviceNum) {
|
||||||
|
@ -71,7 +79,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public static class A3CConfiguration implements AsyncConfiguration {
|
public static class A3CConfiguration implements AsyncConfiguration {
|
||||||
|
|
||||||
int seed;
|
Integer seed;
|
||||||
int maxEpochStep;
|
int maxEpochStep;
|
||||||
int maxStep;
|
int maxStep;
|
||||||
int numThread;
|
int numThread;
|
||||||
|
|
|
@ -32,8 +32,8 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
import java.util.Random;
|
|
||||||
import java.util.Stack;
|
import java.util.Stack;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -50,7 +50,7 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
@Getter
|
@Getter
|
||||||
final protected int threadNumber;
|
final protected int threadNumber;
|
||||||
|
|
||||||
final private Random random;
|
final private Random rnd;
|
||||||
|
|
||||||
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
|
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
|
||||||
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
|
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
|
||||||
|
@ -59,13 +59,18 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
this.conf = a3cc;
|
this.conf = a3cc;
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
|
|
||||||
random = new Random(conf.getSeed() + threadNumber);
|
Integer seed = conf.getSeed();
|
||||||
|
rnd = Nd4j.getRandom();
|
||||||
|
if(seed != null) {
|
||||||
|
mdp.getActionSpace().setSeed(seed + threadNumber);
|
||||||
|
rnd.setSeed(seed);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Policy<O, Integer> getPolicy(IActorCritic net) {
|
protected Policy<O, Integer> getPolicy(IActorCritic net) {
|
||||||
return new ACPolicy(net, random);
|
return new ACPolicy(net, rnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -43,11 +43,13 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
|
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
|
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
|
||||||
super(conf);
|
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
Integer seed = conf.getSeed();
|
||||||
|
if(seed != null) {
|
||||||
|
mdp.getActionSpace().setSeed(seed);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -70,7 +72,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public static class AsyncNStepQLConfiguration implements AsyncConfiguration {
|
public static class AsyncNStepQLConfiguration implements AsyncConfiguration {
|
||||||
|
|
||||||
int seed;
|
Integer seed;
|
||||||
int maxEpochStep;
|
int maxEpochStep;
|
||||||
int maxStep;
|
int maxStep;
|
||||||
int numThread;
|
int numThread;
|
||||||
|
|
|
@ -32,8 +32,8 @@ import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
import java.util.Random;
|
|
||||||
import java.util.Stack;
|
import java.util.Stack;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -48,7 +48,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
@Getter
|
@Getter
|
||||||
final protected int threadNumber;
|
final protected int threadNumber;
|
||||||
|
|
||||||
final private Random random;
|
final private Random rnd;
|
||||||
|
|
||||||
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
||||||
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf,
|
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf,
|
||||||
|
@ -57,13 +57,18 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
|
rnd = Nd4j.getRandom();
|
||||||
random = new Random(conf.getSeed() + threadNumber);
|
|
||||||
|
Integer seed = conf.getSeed();
|
||||||
|
if(seed != null) {
|
||||||
|
mdp.getActionSpace().setSeed(seed + threadNumber);
|
||||||
|
rnd.setSeed(seed + threadNumber);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Policy<O, Integer> getPolicy(IDQN nn) {
|
public Policy<O, Integer> getPolicy(IDQN nn) {
|
||||||
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
|
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
|
||||||
random, conf.getMinEpsilon(), this);
|
rnd, conf.getMinEpsilon(), this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,9 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
|
||||||
import it.unimi.dsi.fastutil.ints.IntSet;
|
import it.unimi.dsi.fastutil.ints.IntSet;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.ArrayList;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16.
|
||||||
|
@ -36,30 +36,28 @@ import java.util.concurrent.ThreadLocalRandom;
|
||||||
public class ExpReplay<A> implements IExpReplay<A> {
|
public class ExpReplay<A> implements IExpReplay<A> {
|
||||||
|
|
||||||
final private int batchSize;
|
final private int batchSize;
|
||||||
final private Random random;
|
final private Random rnd;
|
||||||
|
|
||||||
//Implementing this as a circular buffer queue
|
//Implementing this as a circular buffer queue
|
||||||
private CircularFifoQueue<Transition<A>> storage;
|
private CircularFifoQueue<Transition<A>> storage;
|
||||||
|
|
||||||
public ExpReplay(int maxSize, int batchSize, int seed) {
|
public ExpReplay(int maxSize, int batchSize, Random rnd) {
|
||||||
this.batchSize = batchSize;
|
this.batchSize = batchSize;
|
||||||
this.random = new Random(seed);
|
this.rnd = rnd;
|
||||||
storage = new CircularFifoQueue<>(maxSize);
|
storage = new CircularFifoQueue<>(maxSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public ArrayList<Transition<A>> getBatch(int size) {
|
public ArrayList<Transition<A>> getBatch(int size) {
|
||||||
ArrayList<Transition<A>> batch = new ArrayList<>(size);
|
ArrayList<Transition<A>> batch = new ArrayList<>(size);
|
||||||
int storageSize = storage.size();
|
int storageSize = storage.size();
|
||||||
int actualBatchSize = Math.min(storageSize, size);
|
int actualBatchSize = Math.min(storageSize, size);
|
||||||
|
|
||||||
int[] actualIndex = new int[actualBatchSize];
|
int[] actualIndex = new int[actualBatchSize];
|
||||||
ThreadLocalRandom r = ThreadLocalRandom.current();
|
|
||||||
IntSet set = new IntOpenHashSet();
|
IntSet set = new IntOpenHashSet();
|
||||||
for( int i=0; i<actualBatchSize; i++ ){
|
for( int i=0; i<actualBatchSize; i++ ){
|
||||||
int next = r.nextInt(storageSize);
|
int next = rnd.nextInt(storageSize);
|
||||||
while(set.contains(next)){
|
while(set.contains(next)){
|
||||||
next = r.nextInt(storageSize);
|
next = rnd.nextInt(storageSize);
|
||||||
}
|
}
|
||||||
set.add(next);
|
set.add(next);
|
||||||
actualIndex[i] = next;
|
actualIndex[i] = next;
|
||||||
|
|
|
@ -41,10 +41,6 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
|
||||||
|
|
||||||
private final TrainingListenerList listeners = new TrainingListenerList();
|
private final TrainingListenerList listeners = new TrainingListenerList();
|
||||||
|
|
||||||
public SyncLearning(LConfiguration conf) {
|
|
||||||
super(conf);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a listener at the end of the listener list.
|
* Add a listener at the end of the listener list.
|
||||||
*
|
*
|
||||||
|
|
|
@ -30,7 +30,8 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -53,8 +54,20 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
protected IExpReplay<A> expReplay;
|
protected IExpReplay<A> expReplay;
|
||||||
|
|
||||||
public QLearning(QLConfiguration conf) {
|
public QLearning(QLConfiguration conf) {
|
||||||
super(conf);
|
this(conf, getSeededRandom(conf.getSeed()));
|
||||||
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), conf.getSeed());
|
}
|
||||||
|
|
||||||
|
public QLearning(QLConfiguration conf, Random random) {
|
||||||
|
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Random getSeededRandom(Integer seed) {
|
||||||
|
Random rnd = Nd4j.getRandom();
|
||||||
|
if(seed != null) {
|
||||||
|
rnd.setSeed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
return rnd;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract EpsGreedy<O, A, AS> getEgPolicy();
|
protected abstract EpsGreedy<O, A, AS> getEgPolicy();
|
||||||
|
@ -160,7 +173,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
@JsonDeserialize(builder = QLConfiguration.QLConfigurationBuilder.class)
|
@JsonDeserialize(builder = QLConfiguration.QLConfigurationBuilder.class)
|
||||||
public static class QLConfiguration implements LConfiguration {
|
public static class QLConfiguration implements LConfiguration {
|
||||||
|
|
||||||
int seed;
|
Integer seed;
|
||||||
int maxEpochStep;
|
int maxEpochStep;
|
||||||
int maxStep;
|
int maxStep;
|
||||||
int expRepMaxSize;
|
int expRepMaxSize;
|
||||||
|
|
|
@ -31,7 +31,9 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -70,13 +72,18 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||||
int epsilonNbStep) {
|
int epsilonNbStep) {
|
||||||
|
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||||
|
int epsilonNbStep, Random random) {
|
||||||
super(conf);
|
super(conf);
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
qNetwork = dqn;
|
qNetwork = dqn;
|
||||||
targetQNetwork = dqn.clone();
|
targetQNetwork = dqn.clone();
|
||||||
policy = new DQNPolicy(getQNetwork());
|
policy = new DQNPolicy(getQNetwork());
|
||||||
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(),
|
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
|
||||||
this);
|
this);
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
||||||
private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0;
|
private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0;
|
||||||
private static final double xThreshold = 2.4;
|
private static final double xThreshold = 2.4;
|
||||||
|
|
||||||
private final Random rnd = new Random();
|
private final Random rnd;
|
||||||
|
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler;
|
private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler;
|
||||||
|
@ -76,6 +76,14 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
||||||
@Getter
|
@Getter
|
||||||
private ObservationSpace<CartpoleNative.State> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES });
|
private ObservationSpace<CartpoleNative.State> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES });
|
||||||
|
|
||||||
|
public CartpoleNative() {
|
||||||
|
rnd = new Random();
|
||||||
|
}
|
||||||
|
|
||||||
|
public CartpoleNative(int seed) {
|
||||||
|
rnd = new Random(seed);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public State reset() {
|
public State reset() {
|
||||||
|
|
||||||
|
|
|
@ -16,18 +16,16 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.policy;
|
package org.deeplearning4j.rl4j.policy;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
|
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
|
||||||
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
|
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
|
@ -38,47 +36,41 @@ import java.util.Random;
|
||||||
*/
|
*/
|
||||||
public class ACPolicy<O extends Encodable> extends Policy<O, Integer> {
|
public class ACPolicy<O extends Encodable> extends Policy<O, Integer> {
|
||||||
|
|
||||||
final private IActorCritic IActorCritic;
|
final private IActorCritic actorCritic;
|
||||||
Random rd;
|
Random rnd;
|
||||||
|
|
||||||
public ACPolicy(IActorCritic IActorCritic) {
|
public ACPolicy(IActorCritic actorCritic) {
|
||||||
this.IActorCritic = IActorCritic;
|
this(actorCritic, Nd4j.getRandom());
|
||||||
NeuralNetwork nn = IActorCritic.getNeuralNetworks()[0];
|
|
||||||
if (nn instanceof ComputationGraph) {
|
|
||||||
rd = new Random(((ComputationGraph)nn).getConfiguration().getDefaultConfiguration().getSeed());
|
|
||||||
} else if (nn instanceof MultiLayerNetwork) {
|
|
||||||
rd = new Random(((MultiLayerNetwork)nn).getDefaultConfiguration().getSeed());
|
|
||||||
}
|
}
|
||||||
}
|
public ACPolicy(IActorCritic actorCritic, Random rnd) {
|
||||||
public ACPolicy(IActorCritic IActorCritic, Random rd) {
|
this.actorCritic = actorCritic;
|
||||||
this.IActorCritic = IActorCritic;
|
this.rnd = rnd;
|
||||||
this.rd = rd;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <O extends Encodable> ACPolicy<O> load(String path) throws IOException {
|
public static <O extends Encodable> ACPolicy<O> load(String path) throws IOException {
|
||||||
return new ACPolicy<O>(ActorCriticCompGraph.load(path));
|
return new ACPolicy<O>(ActorCriticCompGraph.load(path));
|
||||||
}
|
}
|
||||||
public static <O extends Encodable> ACPolicy<O> load(String path, Random rd) throws IOException {
|
public static <O extends Encodable> ACPolicy<O> load(String path, Random rnd) throws IOException {
|
||||||
return new ACPolicy<O>(ActorCriticCompGraph.load(path), rd);
|
return new ACPolicy<O>(ActorCriticCompGraph.load(path), rnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy) throws IOException {
|
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy) throws IOException {
|
||||||
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy));
|
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy));
|
||||||
}
|
}
|
||||||
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rd) throws IOException {
|
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rnd) throws IOException {
|
||||||
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rd);
|
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
public IActorCritic getNeuralNet() {
|
public IActorCritic getNeuralNet() {
|
||||||
return IActorCritic;
|
return actorCritic;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Integer nextAction(INDArray input) {
|
public Integer nextAction(INDArray input) {
|
||||||
INDArray output = IActorCritic.outputAll(input)[1];
|
INDArray output = actorCritic.outputAll(input)[1];
|
||||||
if (rd == null) {
|
if (rnd == null) {
|
||||||
return Learning.getMaxAction(output);
|
return Learning.getMaxAction(output);
|
||||||
}
|
}
|
||||||
float rVal = rd.nextFloat();
|
float rVal = rnd.nextFloat();
|
||||||
for (int i = 0; i < output.length(); i++) {
|
for (int i = 0; i < output.length(); i++) {
|
||||||
//System.out.println(i + " " + rVal + " " + output.getFloat(i));
|
//System.out.println(i + " " + rVal + " " + output.getFloat(i));
|
||||||
if (rVal < output.getFloat(i)) {
|
if (rVal < output.getFloat(i)) {
|
||||||
|
@ -91,11 +83,11 @@ public class ACPolicy<O extends Encodable> extends Policy<O, Integer> {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void save(String filename) throws IOException {
|
public void save(String filename) throws IOException {
|
||||||
IActorCritic.save(filename);
|
actorCritic.save(filename);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void save(String filenameValue, String filenamePolicy) throws IOException {
|
public void save(String filenameValue, String filenamePolicy) throws IOException {
|
||||||
IActorCritic.save(filenameValue, filenamePolicy);
|
actorCritic.save(filenameValue, filenamePolicy);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,12 +16,11 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.policy;
|
package org.deeplearning4j.rl4j.policy;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import java.util.Random;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static org.nd4j.linalg.ops.transforms.Transforms.exp;
|
import static org.nd4j.linalg.ops.transforms.Transforms.exp;
|
||||||
|
|
||||||
|
@ -31,11 +30,15 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp;
|
||||||
* Boltzmann exploration is a stochastic policy wrt to the
|
* Boltzmann exploration is a stochastic policy wrt to the
|
||||||
* exponential Q-values as evaluated by the dqn model.
|
* exponential Q-values as evaluated by the dqn model.
|
||||||
*/
|
*/
|
||||||
@AllArgsConstructor
|
|
||||||
public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
|
public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
|
||||||
|
|
||||||
final private IDQN dqn;
|
final private IDQN dqn;
|
||||||
final private Random rd = new Random(123);
|
final private Random rnd;
|
||||||
|
|
||||||
|
public BoltzmannQ(IDQN dqn, Random random) {
|
||||||
|
this.dqn = dqn;
|
||||||
|
this.rnd = random;
|
||||||
|
}
|
||||||
|
|
||||||
public IDQN getNeuralNet() {
|
public IDQN getNeuralNet() {
|
||||||
return dqn;
|
return dqn;
|
||||||
|
@ -47,7 +50,7 @@ public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
|
||||||
INDArray exp = exp(output);
|
INDArray exp = exp(output);
|
||||||
|
|
||||||
double sum = exp.sum(1).getDouble(0);
|
double sum = exp.sum(1).getDouble(0);
|
||||||
double picked = rd.nextDouble() * sum;
|
double picked = rnd.nextDouble() * sum;
|
||||||
for (int i = 0; i < exp.columns(); i++) {
|
for (int i = 0; i < exp.columns(); i++) {
|
||||||
if (picked < exp.getDouble(i))
|
if (picked < exp.getDouble(i))
|
||||||
return i;
|
return i;
|
||||||
|
|
|
@ -24,8 +24,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16.
|
||||||
|
@ -45,7 +44,7 @@ public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extend
|
||||||
final private MDP<O, A, AS> mdp;
|
final private MDP<O, A, AS> mdp;
|
||||||
final private int updateStart;
|
final private int updateStart;
|
||||||
final private int epsilonNbStep;
|
final private int epsilonNbStep;
|
||||||
final private Random rd;
|
final private Random rnd;
|
||||||
final private float minEpsilon;
|
final private float minEpsilon;
|
||||||
final private StepCountable learning;
|
final private StepCountable learning;
|
||||||
|
|
||||||
|
@ -58,7 +57,7 @@ public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extend
|
||||||
float ep = getEpsilon();
|
float ep = getEpsilon();
|
||||||
if (learning.getStepCounter() % 500 == 1)
|
if (learning.getStepCounter() % 500 == 1)
|
||||||
log.info("EP: " + ep + " " + learning.getStepCounter());
|
log.info("EP: " + ep + " " + learning.getStepCounter());
|
||||||
if (rd.nextFloat() > ep)
|
if (rnd.nextFloat() > ep)
|
||||||
return policy.nextAction(input);
|
return policy.nextAction(input);
|
||||||
else
|
else
|
||||||
return mdp.getActionSpace().randomAction();
|
return mdp.getActionSpace().randomAction();
|
||||||
|
|
|
@ -87,7 +87,6 @@ public class AsyncLearningTest {
|
||||||
private final IPolicy<MockEncodable, Integer> policy;
|
private final IPolicy<MockEncodable, Integer> policy;
|
||||||
|
|
||||||
public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
|
public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
|
||||||
super(conf);
|
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
this.policy = policy;
|
this.policy = policy;
|
||||||
|
|
|
@ -0,0 +1,180 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.support.MockRandom;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class ExpReplayTest {
|
||||||
|
@Test
|
||||||
|
public void when_storingElementWithStorageNotFull_expect_elementStored() {
|
||||||
|
// Arrange
|
||||||
|
MockRandom randomMock = new MockRandom(null, new int[] { 0 });
|
||||||
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
Transition<Integer> transition = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 123, 234, false, Nd4j.create(1));
|
||||||
|
sut.store(transition);
|
||||||
|
List<Transition<Integer>> results = sut.getBatch(1);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, results.size());
|
||||||
|
assertEquals(123, (int)results.get(0).getAction());
|
||||||
|
assertEquals(234, (int)results.get(0).getReward());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_storingElementWithStorageFull_expect_oldestElementReplacedByStored() {
|
||||||
|
// Arrange
|
||||||
|
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1 });
|
||||||
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
||||||
|
sut.store(transition1);
|
||||||
|
sut.store(transition2);
|
||||||
|
sut.store(transition3);
|
||||||
|
List<Transition<Integer>> results = sut.getBatch(2);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(2, results.size());
|
||||||
|
|
||||||
|
assertEquals(3, (int)results.get(0).getAction());
|
||||||
|
assertEquals(4, (int)results.get(0).getReward());
|
||||||
|
|
||||||
|
assertEquals(5, (int)results.get(1).getAction());
|
||||||
|
assertEquals(6, (int)results.get(1).getReward());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_askBatchSizeZeroAndStorageEmpty_expect_emptyBatch() {
|
||||||
|
// Arrange
|
||||||
|
MockRandom randomMock = new MockRandom(null, new int[] { 0 });
|
||||||
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
List<Transition<Integer>> results = sut.getBatch(0);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(0, results.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_askBatchSizeZeroAndStorageNotEmpty_expect_emptyBatch() {
|
||||||
|
// Arrange
|
||||||
|
MockRandom randomMock = new MockRandom(null, new int[] { 0 });
|
||||||
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
||||||
|
sut.store(transition1);
|
||||||
|
sut.store(transition2);
|
||||||
|
sut.store(transition3);
|
||||||
|
List<Transition<Integer>> results = sut.getBatch(0);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(0, results.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_askBatchSizeGreaterThanStoredCount_expect_batchWithStoredCountElements() {
|
||||||
|
// Arrange
|
||||||
|
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2 });
|
||||||
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
||||||
|
sut.store(transition1);
|
||||||
|
sut.store(transition2);
|
||||||
|
sut.store(transition3);
|
||||||
|
List<Transition<Integer>> results = sut.getBatch(10);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(3, results.size());
|
||||||
|
|
||||||
|
assertEquals(1, (int)results.get(0).getAction());
|
||||||
|
assertEquals(2, (int)results.get(0).getReward());
|
||||||
|
|
||||||
|
assertEquals(3, (int)results.get(1).getAction());
|
||||||
|
assertEquals(4, (int)results.get(1).getReward());
|
||||||
|
|
||||||
|
assertEquals(5, (int)results.get(2).getAction());
|
||||||
|
assertEquals(6, (int)results.get(2).getReward());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_askBatchSizeSmallerThanStoredCount_expect_batchWithAskedElements() {
|
||||||
|
// Arrange
|
||||||
|
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2, 3, 4 });
|
||||||
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1));
|
||||||
|
sut.store(transition1);
|
||||||
|
sut.store(transition2);
|
||||||
|
sut.store(transition3);
|
||||||
|
sut.store(transition4);
|
||||||
|
sut.store(transition5);
|
||||||
|
List<Transition<Integer>> results = sut.getBatch(3);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(3, results.size());
|
||||||
|
|
||||||
|
assertEquals(1, (int)results.get(0).getAction());
|
||||||
|
assertEquals(2, (int)results.get(0).getReward());
|
||||||
|
|
||||||
|
assertEquals(3, (int)results.get(1).getAction());
|
||||||
|
assertEquals(4, (int)results.get(1).getReward());
|
||||||
|
|
||||||
|
assertEquals(5, (int)results.get(2).getAction());
|
||||||
|
assertEquals(6, (int)results.get(2).getReward());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_randomGivesDuplicates_expect_noDuplicatesInBatch() {
|
||||||
|
// Arrange
|
||||||
|
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2, 1, 3, 1, 4 });
|
||||||
|
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1));
|
||||||
|
Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1));
|
||||||
|
sut.store(transition1);
|
||||||
|
sut.store(transition2);
|
||||||
|
sut.store(transition3);
|
||||||
|
sut.store(transition4);
|
||||||
|
sut.store(transition5);
|
||||||
|
List<Transition<Integer>> results = sut.getBatch(3);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(3, results.size());
|
||||||
|
|
||||||
|
assertEquals(1, (int)results.get(0).getAction());
|
||||||
|
assertEquals(2, (int)results.get(0).getReward());
|
||||||
|
|
||||||
|
assertEquals(3, (int)results.get(1).getAction());
|
||||||
|
assertEquals(4, (int)results.get(1).getReward());
|
||||||
|
|
||||||
|
assertEquals(5, (int)results.get(2).getAction());
|
||||||
|
assertEquals(6, (int)results.get(2).getReward());
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -89,7 +89,6 @@ public class SyncLearningTest {
|
||||||
private final LConfiguration conf;
|
private final LConfiguration conf;
|
||||||
|
|
||||||
public MockSyncLearning(LConfiguration conf) {
|
public MockSyncLearning(LConfiguration conf) {
|
||||||
super(conf);
|
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -27,11 +28,12 @@ public class QLearningDiscreteTest {
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
MockMDP mdp = new MockMDP(observationSpace);
|
MockMDP mdp = new MockMDP(observationSpace);
|
||||||
MockDQN dqn = new MockDQN();
|
MockDQN dqn = new MockDQN();
|
||||||
|
MockRandom random = new MockRandom(new double[] { 0.7309677600860596, 0.8314409852027893, 0.2405363917350769, 0.6063451766967773, 0.6374173760414124, 0.3090505599975586, 0.5504369735717773, 0.11700659990310669 }, null);
|
||||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
||||||
0, 1.0, 0, 0, 0, 0, true);
|
0, 1.0, 0, 0, 0, 0, true);
|
||||||
MockDataManager dataManager = new MockDataManager(false);
|
MockDataManager dataManager = new MockDataManager(false);
|
||||||
MockExpReplay expReplay = new MockExpReplay();
|
MockExpReplay expReplay = new MockExpReplay();
|
||||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10);
|
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||||
sut.setHistoryProcessor(hp);
|
sut.setHistoryProcessor(hp);
|
||||||
|
@ -132,8 +134,8 @@ public class QLearningDiscreteTest {
|
||||||
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
||||||
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
|
QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
|
||||||
int epsilonNbStep) {
|
int epsilonNbStep, Random rnd) {
|
||||||
super(mdp, dqn, conf, epsilonNbStep);
|
super(mdp, dqn, conf, epsilonNbStep, rnd);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
setExpReplay(expReplay);
|
setExpReplay(expReplay);
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,10 +127,10 @@ public class PolicyTest {
|
||||||
.layer(0, new OutputLayer.Builder().nOut(1).lossFunction(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).build()).build());
|
.layer(0, new OutputLayer.Builder().nOut(1).lossFunction(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).build()).build());
|
||||||
|
|
||||||
ACPolicy policy = new ACPolicy(new DummyAC(cg));
|
ACPolicy policy = new ACPolicy(new DummyAC(cg));
|
||||||
assertNotNull(policy.rd);
|
assertNotNull(policy.rnd);
|
||||||
|
|
||||||
policy = new ACPolicy(new DummyAC(mln));
|
policy = new ACPolicy(new DummyAC(mln));
|
||||||
assertNotNull(policy.rd);
|
assertNotNull(policy.rnd);
|
||||||
|
|
||||||
INDArray input = Nd4j.create(new double[] {1.0, 0.0}, new long[]{1,2});
|
INDArray input = Nd4j.create(new double[] {1.0, 0.0}, new long[]{1,2});
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
|
|
|
@ -14,7 +14,7 @@ public class MockAsyncConfiguration implements AsyncConfiguration {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getSeed() {
|
public Integer getSeed() {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,203 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
public class MockRandom implements org.nd4j.linalg.api.rng.Random {
|
||||||
|
|
||||||
|
private int randomDoubleValuesIdx = 0;
|
||||||
|
private final double[] randomDoubleValues;
|
||||||
|
|
||||||
|
private int randomIntValuesIdx = 0;
|
||||||
|
private final int[] randomIntValues;
|
||||||
|
|
||||||
|
public MockRandom(double[] randomDoubleValues, int[] randomIntValues) {
|
||||||
|
this.randomDoubleValues = randomDoubleValues;
|
||||||
|
this.randomIntValues = randomIntValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setSeed(int i) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setSeed(int[] ints) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setSeed(long l) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getSeed() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void nextBytes(byte[] bytes) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextInt() {
|
||||||
|
return randomIntValues[randomIntValuesIdx++];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextInt(int i) {
|
||||||
|
return randomIntValues[randomIntValuesIdx++];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextInt(int i, int i1) {
|
||||||
|
return randomIntValues[randomIntValuesIdx++];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long nextLong() {
|
||||||
|
return randomIntValues[randomIntValuesIdx++];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean nextBoolean() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float nextFloat() {
|
||||||
|
return (float)randomDoubleValues[randomDoubleValuesIdx++];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double nextDouble() {
|
||||||
|
return randomDoubleValues[randomDoubleValuesIdx++];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double nextGaussian() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextGaussian(int[] ints) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextGaussian(long[] longs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextGaussian(char c, int[] ints) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextGaussian(char c, long[] longs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextDouble(int[] ints) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextDouble(long[] longs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextDouble(char c, int[] ints) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextDouble(char c, long[] longs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextFloat(int[] ints) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextFloat(long[] longs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextFloat(char c, int[] ints) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextFloat(char c, long[] longs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextInt(int[] ints) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextInt(long[] longs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextInt(int i, int[] ints) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray nextInt(int i, long[] longs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pointer getStatePointer() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getPosition() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reSeed() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reSeed(long l) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long rootState() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long nodeState() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setStates(long l, long l1) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() throws Exception {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue