RL4J: Use Nd4j Random instead of java.util.Random (#8282)
* Changed to use Nd4j Random instead of java.util.Random Signed-off-by: unknown <aboulang2002@yahoo.com> * Changed to use Nd4j.getRandom() instead of the factory Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
2d750b69e5
commit
171ce51f46
|
@ -42,7 +42,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> ex
|
|||
|
||||
interface LConfiguration {
|
||||
|
||||
int getSeed();
|
||||
Integer getSeed();
|
||||
|
||||
int getMaxEpochStep();
|
||||
|
||||
|
|
|
@ -29,8 +29,6 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/27/16.
|
||||
*
|
||||
|
@ -43,8 +41,7 @@ import java.util.Random;
|
|||
@Slf4j
|
||||
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
|
||||
@Getter
|
||||
final private Random random;
|
||||
|
||||
@Getter @Setter
|
||||
private int stepCounter = 0;
|
||||
@Getter @Setter
|
||||
|
@ -52,10 +49,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
|
|||
@Getter @Setter
|
||||
private IHistoryProcessor historyProcessor = null;
|
||||
|
||||
public Learning(LConfiguration conf) {
|
||||
random = new Random(conf.getSeed());
|
||||
}
|
||||
|
||||
public static Integer getMaxAction(INDArray vector) {
|
||||
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.learning.ILearning;
|
|||
*/
|
||||
public interface AsyncConfiguration extends ILearning.LConfiguration {
|
||||
|
||||
int getSeed();
|
||||
Integer getSeed();
|
||||
|
||||
int getMaxEpochStep();
|
||||
|
||||
|
|
|
@ -42,10 +42,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
|||
@Getter(AccessLevel.PROTECTED)
|
||||
private final TrainingListenerList listeners = new TrainingListenerList();
|
||||
|
||||
public AsyncLearning(AsyncConfiguration conf) {
|
||||
super(conf);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a {@link TrainingListener} listener at the end of the listener list.
|
||||
*
|
||||
|
|
|
@ -26,6 +26,8 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
|||
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
||||
|
@ -48,13 +50,19 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
|||
final private ACPolicy<O> policy;
|
||||
|
||||
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) {
|
||||
super(conf);
|
||||
this.iActorCritic = iActorCritic;
|
||||
this.mdp = mdp;
|
||||
this.configuration = conf;
|
||||
policy = new ACPolicy<>(iActorCritic, getRandom());
|
||||
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
|
||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||
|
||||
Integer seed = conf.getSeed();
|
||||
Random rnd = Nd4j.getRandom();
|
||||
if(seed != null) {
|
||||
mdp.getActionSpace().setSeed(seed);
|
||||
rnd.setSeed(seed);
|
||||
}
|
||||
|
||||
policy = new ACPolicy<>(iActorCritic, rnd);
|
||||
}
|
||||
|
||||
protected AsyncThread newThread(int i, int deviceNum) {
|
||||
|
@ -71,7 +79,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
|||
@EqualsAndHashCode(callSuper = false)
|
||||
public static class A3CConfiguration implements AsyncConfiguration {
|
||||
|
||||
int seed;
|
||||
Integer seed;
|
||||
int maxEpochStep;
|
||||
int maxStep;
|
||||
int numThread;
|
||||
|
|
|
@ -32,8 +32,8 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
||||
import java.util.Random;
|
||||
import java.util.Stack;
|
||||
|
||||
/**
|
||||
|
@ -50,7 +50,7 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
|||
@Getter
|
||||
final protected int threadNumber;
|
||||
|
||||
final private Random random;
|
||||
final private Random rnd;
|
||||
|
||||
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
|
||||
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
|
||||
|
@ -59,13 +59,18 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
|||
this.conf = a3cc;
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.threadNumber = threadNumber;
|
||||
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
|
||||
random = new Random(conf.getSeed() + threadNumber);
|
||||
|
||||
Integer seed = conf.getSeed();
|
||||
rnd = Nd4j.getRandom();
|
||||
if(seed != null) {
|
||||
mdp.getActionSpace().setSeed(seed + threadNumber);
|
||||
rnd.setSeed(seed);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Policy<O, Integer> getPolicy(IActorCritic net) {
|
||||
return new ACPolicy(net, random);
|
||||
return new ACPolicy(net, rnd);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -43,11 +43,13 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
|||
|
||||
|
||||
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
|
||||
super(conf);
|
||||
this.mdp = mdp;
|
||||
this.configuration = conf;
|
||||
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||
Integer seed = conf.getSeed();
|
||||
if(seed != null) {
|
||||
mdp.getActionSpace().setSeed(seed);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -70,7 +72,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
|||
@EqualsAndHashCode(callSuper = false)
|
||||
public static class AsyncNStepQLConfiguration implements AsyncConfiguration {
|
||||
|
||||
int seed;
|
||||
Integer seed;
|
||||
int maxEpochStep;
|
||||
int maxStep;
|
||||
int numThread;
|
||||
|
|
|
@ -32,8 +32,8 @@ import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
||||
import java.util.Random;
|
||||
import java.util.Stack;
|
||||
|
||||
/**
|
||||
|
@ -48,7 +48,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
|||
@Getter
|
||||
final protected int threadNumber;
|
||||
|
||||
final private Random random;
|
||||
final private Random rnd;
|
||||
|
||||
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
||||
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf,
|
||||
|
@ -57,13 +57,18 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
|||
this.conf = conf;
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.threadNumber = threadNumber;
|
||||
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
|
||||
random = new Random(conf.getSeed() + threadNumber);
|
||||
rnd = Nd4j.getRandom();
|
||||
|
||||
Integer seed = conf.getSeed();
|
||||
if(seed != null) {
|
||||
mdp.getActionSpace().setSeed(seed + threadNumber);
|
||||
rnd.setSeed(seed + threadNumber);
|
||||
}
|
||||
}
|
||||
|
||||
public Policy<O, Integer> getPolicy(IDQN nn) {
|
||||
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
|
||||
random, conf.getMinEpsilon(), this);
|
||||
rnd, conf.getMinEpsilon(), this);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
|
|||
import it.unimi.dsi.fastutil.ints.IntSet;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.ArrayList;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16.
|
||||
|
@ -36,30 +36,28 @@ import java.util.concurrent.ThreadLocalRandom;
|
|||
public class ExpReplay<A> implements IExpReplay<A> {
|
||||
|
||||
final private int batchSize;
|
||||
final private Random random;
|
||||
final private Random rnd;
|
||||
|
||||
//Implementing this as a circular buffer queue
|
||||
private CircularFifoQueue<Transition<A>> storage;
|
||||
|
||||
public ExpReplay(int maxSize, int batchSize, int seed) {
|
||||
public ExpReplay(int maxSize, int batchSize, Random rnd) {
|
||||
this.batchSize = batchSize;
|
||||
this.random = new Random(seed);
|
||||
this.rnd = rnd;
|
||||
storage = new CircularFifoQueue<>(maxSize);
|
||||
}
|
||||
|
||||
|
||||
public ArrayList<Transition<A>> getBatch(int size) {
|
||||
ArrayList<Transition<A>> batch = new ArrayList<>(size);
|
||||
int storageSize = storage.size();
|
||||
int actualBatchSize = Math.min(storageSize, size);
|
||||
|
||||
int[] actualIndex = new int[actualBatchSize];
|
||||
ThreadLocalRandom r = ThreadLocalRandom.current();
|
||||
IntSet set = new IntOpenHashSet();
|
||||
for( int i=0; i<actualBatchSize; i++ ){
|
||||
int next = r.nextInt(storageSize);
|
||||
int next = rnd.nextInt(storageSize);
|
||||
while(set.contains(next)){
|
||||
next = r.nextInt(storageSize);
|
||||
next = rnd.nextInt(storageSize);
|
||||
}
|
||||
set.add(next);
|
||||
actualIndex[i] = next;
|
||||
|
|
|
@ -41,10 +41,6 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
|
|||
|
||||
private final TrainingListenerList listeners = new TrainingListenerList();
|
||||
|
||||
public SyncLearning(LConfiguration conf) {
|
||||
super(conf);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a listener at the end of the listener list.
|
||||
*
|
||||
|
|
|
@ -30,7 +30,8 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
|||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -53,8 +54,20 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
|||
protected IExpReplay<A> expReplay;
|
||||
|
||||
public QLearning(QLConfiguration conf) {
|
||||
super(conf);
|
||||
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), conf.getSeed());
|
||||
this(conf, getSeededRandom(conf.getSeed()));
|
||||
}
|
||||
|
||||
public QLearning(QLConfiguration conf, Random random) {
|
||||
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
||||
}
|
||||
|
||||
private static Random getSeededRandom(Integer seed) {
|
||||
Random rnd = Nd4j.getRandom();
|
||||
if(seed != null) {
|
||||
rnd.setSeed(seed);
|
||||
}
|
||||
|
||||
return rnd;
|
||||
}
|
||||
|
||||
protected abstract EpsGreedy<O, A, AS> getEgPolicy();
|
||||
|
@ -160,7 +173,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
|||
@JsonDeserialize(builder = QLConfiguration.QLConfigurationBuilder.class)
|
||||
public static class QLConfiguration implements LConfiguration {
|
||||
|
||||
int seed;
|
||||
Integer seed;
|
||||
int maxEpochStep;
|
||||
int maxStep;
|
||||
int expRepMaxSize;
|
||||
|
|
|
@ -31,7 +31,9 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
|||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -70,13 +72,18 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
|
||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||
int epsilonNbStep) {
|
||||
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
|
||||
}
|
||||
|
||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||
int epsilonNbStep, Random random) {
|
||||
super(conf);
|
||||
this.configuration = conf;
|
||||
this.mdp = mdp;
|
||||
qNetwork = dqn;
|
||||
targetQNetwork = dqn.clone();
|
||||
policy = new DQNPolicy(getQNetwork());
|
||||
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(),
|
||||
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
|
||||
this);
|
||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
|||
private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0;
|
||||
private static final double xThreshold = 2.4;
|
||||
|
||||
private final Random rnd = new Random();
|
||||
private final Random rnd;
|
||||
|
||||
@Getter @Setter
|
||||
private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler;
|
||||
|
@ -76,6 +76,14 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
|||
@Getter
|
||||
private ObservationSpace<CartpoleNative.State> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES });
|
||||
|
||||
public CartpoleNative() {
|
||||
rnd = new Random();
|
||||
}
|
||||
|
||||
public CartpoleNative(int seed) {
|
||||
rnd = new Random(seed);
|
||||
}
|
||||
|
||||
@Override
|
||||
public State reset() {
|
||||
|
||||
|
|
|
@ -16,18 +16,16 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.policy;
|
||||
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
|
||||
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
|
||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
|
@ -38,47 +36,41 @@ import java.util.Random;
|
|||
*/
|
||||
public class ACPolicy<O extends Encodable> extends Policy<O, Integer> {
|
||||
|
||||
final private IActorCritic IActorCritic;
|
||||
Random rd;
|
||||
final private IActorCritic actorCritic;
|
||||
Random rnd;
|
||||
|
||||
public ACPolicy(IActorCritic IActorCritic) {
|
||||
this.IActorCritic = IActorCritic;
|
||||
NeuralNetwork nn = IActorCritic.getNeuralNetworks()[0];
|
||||
if (nn instanceof ComputationGraph) {
|
||||
rd = new Random(((ComputationGraph)nn).getConfiguration().getDefaultConfiguration().getSeed());
|
||||
} else if (nn instanceof MultiLayerNetwork) {
|
||||
rd = new Random(((MultiLayerNetwork)nn).getDefaultConfiguration().getSeed());
|
||||
}
|
||||
public ACPolicy(IActorCritic actorCritic) {
|
||||
this(actorCritic, Nd4j.getRandom());
|
||||
}
|
||||
public ACPolicy(IActorCritic IActorCritic, Random rd) {
|
||||
this.IActorCritic = IActorCritic;
|
||||
this.rd = rd;
|
||||
public ACPolicy(IActorCritic actorCritic, Random rnd) {
|
||||
this.actorCritic = actorCritic;
|
||||
this.rnd = rnd;
|
||||
}
|
||||
|
||||
public static <O extends Encodable> ACPolicy<O> load(String path) throws IOException {
|
||||
return new ACPolicy<O>(ActorCriticCompGraph.load(path));
|
||||
}
|
||||
public static <O extends Encodable> ACPolicy<O> load(String path, Random rd) throws IOException {
|
||||
return new ACPolicy<O>(ActorCriticCompGraph.load(path), rd);
|
||||
public static <O extends Encodable> ACPolicy<O> load(String path, Random rnd) throws IOException {
|
||||
return new ACPolicy<O>(ActorCriticCompGraph.load(path), rnd);
|
||||
}
|
||||
|
||||
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy) throws IOException {
|
||||
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy));
|
||||
}
|
||||
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rd) throws IOException {
|
||||
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rd);
|
||||
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rnd) throws IOException {
|
||||
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd);
|
||||
}
|
||||
|
||||
public IActorCritic getNeuralNet() {
|
||||
return IActorCritic;
|
||||
return actorCritic;
|
||||
}
|
||||
|
||||
public Integer nextAction(INDArray input) {
|
||||
INDArray output = IActorCritic.outputAll(input)[1];
|
||||
if (rd == null) {
|
||||
INDArray output = actorCritic.outputAll(input)[1];
|
||||
if (rnd == null) {
|
||||
return Learning.getMaxAction(output);
|
||||
}
|
||||
float rVal = rd.nextFloat();
|
||||
float rVal = rnd.nextFloat();
|
||||
for (int i = 0; i < output.length(); i++) {
|
||||
//System.out.println(i + " " + rVal + " " + output.getFloat(i));
|
||||
if (rVal < output.getFloat(i)) {
|
||||
|
@ -91,11 +83,11 @@ public class ACPolicy<O extends Encodable> extends Policy<O, Integer> {
|
|||
}
|
||||
|
||||
public void save(String filename) throws IOException {
|
||||
IActorCritic.save(filename);
|
||||
actorCritic.save(filename);
|
||||
}
|
||||
|
||||
public void save(String filenameValue, String filenamePolicy) throws IOException {
|
||||
IActorCritic.save(filenameValue, filenamePolicy);
|
||||
actorCritic.save(filenameValue, filenamePolicy);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -16,12 +16,11 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.policy;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Random;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.nd4j.linalg.ops.transforms.Transforms.exp;
|
||||
|
||||
|
@ -31,11 +30,15 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp;
|
|||
* Boltzmann exploration is a stochastic policy wrt to the
|
||||
* exponential Q-values as evaluated by the dqn model.
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
|
||||
|
||||
final private IDQN dqn;
|
||||
final private Random rd = new Random(123);
|
||||
final private Random rnd;
|
||||
|
||||
public BoltzmannQ(IDQN dqn, Random random) {
|
||||
this.dqn = dqn;
|
||||
this.rnd = random;
|
||||
}
|
||||
|
||||
public IDQN getNeuralNet() {
|
||||
return dqn;
|
||||
|
@ -47,7 +50,7 @@ public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
|
|||
INDArray exp = exp(output);
|
||||
|
||||
double sum = exp.sum(1).getDouble(0);
|
||||
double picked = rd.nextDouble() * sum;
|
||||
double picked = rnd.nextDouble() * sum;
|
||||
for (int i = 0; i < exp.columns(); i++) {
|
||||
if (picked < exp.getDouble(i))
|
||||
return i;
|
||||
|
|
|
@ -24,8 +24,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
|||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Random;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16.
|
||||
|
@ -45,7 +44,7 @@ public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extend
|
|||
final private MDP<O, A, AS> mdp;
|
||||
final private int updateStart;
|
||||
final private int epsilonNbStep;
|
||||
final private Random rd;
|
||||
final private Random rnd;
|
||||
final private float minEpsilon;
|
||||
final private StepCountable learning;
|
||||
|
||||
|
@ -58,7 +57,7 @@ public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extend
|
|||
float ep = getEpsilon();
|
||||
if (learning.getStepCounter() % 500 == 1)
|
||||
log.info("EP: " + ep + " " + learning.getStepCounter());
|
||||
if (rd.nextFloat() > ep)
|
||||
if (rnd.nextFloat() > ep)
|
||||
return policy.nextAction(input);
|
||||
else
|
||||
return mdp.getActionSpace().randomAction();
|
||||
|
|
|
@ -87,7 +87,6 @@ public class AsyncLearningTest {
|
|||
private final IPolicy<MockEncodable, Integer> policy;
|
||||
|
||||
public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
|
||||
super(conf);
|
||||
this.conf = conf;
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.policy = policy;
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync;
|
||||
|
||||
import org.deeplearning4j.rl4j.support.MockRandom;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class ExpReplayTest {
|
||||
@Test
|
||||
public void when_storingElementWithStorageNotFull_expect_elementStored() {
|
||||
// Arrange
|
||||
MockRandom randomMock = new MockRandom(null, new int[] { 0 });
|
||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 123, 234, false, Nd4j.create(1));
|
||||
sut.store(transition);
|
||||
List<Transition<Integer>> results = sut.getBatch(1);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, results.size());
|
||||
assertEquals(123, (int)results.get(0).getAction());
|
||||
assertEquals(234, (int)results.get(0).getReward());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_storingElementWithStorageFull_expect_oldestElementReplacedByStored() {
|
||||
// Arrange
|
||||
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1 });
|
||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
||||
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
||||
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
sut.store(transition3);
|
||||
List<Transition<Integer>> results = sut.getBatch(2);
|
||||
|
||||
// Assert
|
||||
assertEquals(2, results.size());
|
||||
|
||||
assertEquals(3, (int)results.get(0).getAction());
|
||||
assertEquals(4, (int)results.get(0).getReward());
|
||||
|
||||
assertEquals(5, (int)results.get(1).getAction());
|
||||
assertEquals(6, (int)results.get(1).getReward());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void when_askBatchSizeZeroAndStorageEmpty_expect_emptyBatch() {
|
||||
// Arrange
|
||||
MockRandom randomMock = new MockRandom(null, new int[] { 0 });
|
||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
List<Transition<Integer>> results = sut.getBatch(0);
|
||||
|
||||
// Assert
|
||||
assertEquals(0, results.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_askBatchSizeZeroAndStorageNotEmpty_expect_emptyBatch() {
|
||||
// Arrange
|
||||
MockRandom randomMock = new MockRandom(null, new int[] { 0 });
|
||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
||||
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
||||
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
sut.store(transition3);
|
||||
List<Transition<Integer>> results = sut.getBatch(0);
|
||||
|
||||
// Assert
|
||||
assertEquals(0, results.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_askBatchSizeGreaterThanStoredCount_expect_batchWithStoredCountElements() {
|
||||
// Arrange
|
||||
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2 });
|
||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
||||
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
||||
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
sut.store(transition3);
|
||||
List<Transition<Integer>> results = sut.getBatch(10);
|
||||
|
||||
// Assert
|
||||
assertEquals(3, results.size());
|
||||
|
||||
assertEquals(1, (int)results.get(0).getAction());
|
||||
assertEquals(2, (int)results.get(0).getReward());
|
||||
|
||||
assertEquals(3, (int)results.get(1).getAction());
|
||||
assertEquals(4, (int)results.get(1).getReward());
|
||||
|
||||
assertEquals(5, (int)results.get(2).getAction());
|
||||
assertEquals(6, (int)results.get(2).getReward());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_askBatchSizeSmallerThanStoredCount_expect_batchWithAskedElements() {
|
||||
// Arrange
|
||||
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2, 3, 4 });
|
||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
||||
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
||||
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
||||
Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1));
|
||||
Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
sut.store(transition3);
|
||||
sut.store(transition4);
|
||||
sut.store(transition5);
|
||||
List<Transition<Integer>> results = sut.getBatch(3);
|
||||
|
||||
// Assert
|
||||
assertEquals(3, results.size());
|
||||
|
||||
assertEquals(1, (int)results.get(0).getAction());
|
||||
assertEquals(2, (int)results.get(0).getReward());
|
||||
|
||||
assertEquals(3, (int)results.get(1).getAction());
|
||||
assertEquals(4, (int)results.get(1).getReward());
|
||||
|
||||
assertEquals(5, (int)results.get(2).getAction());
|
||||
assertEquals(6, (int)results.get(2).getReward());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_randomGivesDuplicates_expect_noDuplicatesInBatch() {
|
||||
// Arrange
|
||||
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2, 1, 3, 1, 4 });
|
||||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
|
||||
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
|
||||
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
|
||||
Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1));
|
||||
Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
sut.store(transition3);
|
||||
sut.store(transition4);
|
||||
sut.store(transition5);
|
||||
List<Transition<Integer>> results = sut.getBatch(3);
|
||||
|
||||
// Assert
|
||||
assertEquals(3, results.size());
|
||||
|
||||
assertEquals(1, (int)results.get(0).getAction());
|
||||
assertEquals(2, (int)results.get(0).getReward());
|
||||
|
||||
assertEquals(3, (int)results.get(1).getAction());
|
||||
assertEquals(4, (int)results.get(1).getReward());
|
||||
|
||||
assertEquals(5, (int)results.get(2).getAction());
|
||||
assertEquals(6, (int)results.get(2).getReward());
|
||||
|
||||
}
|
||||
}
|
|
@ -89,7 +89,6 @@ public class SyncLearningTest {
|
|||
private final LConfiguration conf;
|
||||
|
||||
public MockSyncLearning(LConfiguration conf) {
|
||||
super(conf);
|
||||
this.conf = conf;
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
|||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -27,11 +28,12 @@ public class QLearningDiscreteTest {
|
|||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdp = new MockMDP(observationSpace);
|
||||
MockDQN dqn = new MockDQN();
|
||||
MockRandom random = new MockRandom(new double[] { 0.7309677600860596, 0.8314409852027893, 0.2405363917350769, 0.6063451766967773, 0.6374173760414124, 0.3090505599975586, 0.5504369735717773, 0.11700659990310669 }, null);
|
||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
||||
0, 1.0, 0, 0, 0, 0, true);
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
MockExpReplay expReplay = new MockExpReplay();
|
||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10);
|
||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
|
||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||
sut.setHistoryProcessor(hp);
|
||||
|
@ -130,10 +132,10 @@ public class QLearningDiscreteTest {
|
|||
}
|
||||
|
||||
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
||||
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn,
|
||||
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||
QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
|
||||
int epsilonNbStep) {
|
||||
super(mdp, dqn, conf, epsilonNbStep);
|
||||
int epsilonNbStep, Random rnd) {
|
||||
super(mdp, dqn, conf, epsilonNbStep, rnd);
|
||||
addListener(new DataManagerTrainingListener(dataManager));
|
||||
setExpReplay(expReplay);
|
||||
}
|
||||
|
|
|
@ -127,10 +127,10 @@ public class PolicyTest {
|
|||
.layer(0, new OutputLayer.Builder().nOut(1).lossFunction(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).build()).build());
|
||||
|
||||
ACPolicy policy = new ACPolicy(new DummyAC(cg));
|
||||
assertNotNull(policy.rd);
|
||||
assertNotNull(policy.rnd);
|
||||
|
||||
policy = new ACPolicy(new DummyAC(mln));
|
||||
assertNotNull(policy.rd);
|
||||
assertNotNull(policy.rnd);
|
||||
|
||||
INDArray input = Nd4j.create(new double[] {1.0, 0.0}, new long[]{1,2});
|
||||
for (int i = 0; i < 100; i++) {
|
||||
|
|
|
@ -14,7 +14,7 @@ public class MockAsyncConfiguration implements AsyncConfiguration {
|
|||
}
|
||||
|
||||
@Override
|
||||
public int getSeed() {
|
||||
public Integer getSeed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
public class MockRandom implements org.nd4j.linalg.api.rng.Random {
|
||||
|
||||
private int randomDoubleValuesIdx = 0;
|
||||
private final double[] randomDoubleValues;
|
||||
|
||||
private int randomIntValuesIdx = 0;
|
||||
private final int[] randomIntValues;
|
||||
|
||||
public MockRandom(double[] randomDoubleValues, int[] randomIntValues) {
|
||||
this.randomDoubleValues = randomDoubleValues;
|
||||
this.randomIntValues = randomIntValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setSeed(int i) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setSeed(int[] ints) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setSeed(long l) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getSeed() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void nextBytes(byte[] bytes) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextInt() {
|
||||
return randomIntValues[randomIntValuesIdx++];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextInt(int i) {
|
||||
return randomIntValues[randomIntValuesIdx++];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextInt(int i, int i1) {
|
||||
return randomIntValues[randomIntValuesIdx++];
|
||||
}
|
||||
|
||||
@Override
|
||||
public long nextLong() {
|
||||
return randomIntValues[randomIntValuesIdx++];
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nextBoolean() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float nextFloat() {
|
||||
return (float)randomDoubleValues[randomDoubleValuesIdx++];
|
||||
}
|
||||
|
||||
@Override
|
||||
public double nextDouble() {
|
||||
return randomDoubleValues[randomDoubleValuesIdx++];
|
||||
}
|
||||
|
||||
@Override
|
||||
public double nextGaussian() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextGaussian(int[] ints) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextGaussian(long[] longs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextGaussian(char c, int[] ints) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextGaussian(char c, long[] longs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextDouble(int[] ints) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextDouble(long[] longs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextDouble(char c, int[] ints) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextDouble(char c, long[] longs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextFloat(int[] ints) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextFloat(long[] longs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextFloat(char c, int[] ints) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextFloat(char c, long[] longs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextInt(int[] ints) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextInt(long[] longs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextInt(int i, int[] ints) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray nextInt(int i, long[] longs) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pointer getStatePointer() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getPosition() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reSeed() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reSeed(long l) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public long rootState() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long nodeState() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setStates(long l, long l1) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {
|
||||
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue