RL4J: Use Nd4j Random instead of java.util.Random (#8282)

* Changed to use Nd4j Random instead of java.util.Random

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Changed to use Nd4j.getRandom() instead of the factory

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2019-10-15 21:56:24 -04:00 committed by Samuel Audet
parent 2d750b69e5
commit 171ce51f46
23 changed files with 504 additions and 96 deletions

View File

@ -42,7 +42,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> ex
interface LConfiguration {
int getSeed();
Integer getSeed();
int getMaxEpochStep();

View File

@ -29,8 +29,6 @@ import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Random;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/27/16.
*
@ -43,8 +41,7 @@ import java.util.Random;
@Slf4j
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
@Getter
final private Random random;
@Getter @Setter
private int stepCounter = 0;
@Getter @Setter
@ -52,10 +49,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>
@Getter @Setter
private IHistoryProcessor historyProcessor = null;
public Learning(LConfiguration conf) {
random = new Random(conf.getSeed());
}
public static Integer getMaxAction(INDArray vector) {
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
}

View File

@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.learning.ILearning;
*/
public interface AsyncConfiguration extends ILearning.LConfiguration {
int getSeed();
Integer getSeed();
int getMaxEpochStep();

View File

@ -42,10 +42,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
@Getter(AccessLevel.PROTECTED)
private final TrainingListenerList listeners = new TrainingListenerList();
public AsyncLearning(AsyncConfiguration conf) {
super(conf);
}
/**
* Add a {@link TrainingListener} listener at the end of the listener list.
*

View File

@ -26,6 +26,8 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
@ -48,13 +50,19 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
final private ACPolicy<O> policy;
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) {
super(conf);
this.iActorCritic = iActorCritic;
this.mdp = mdp;
this.configuration = conf;
policy = new ACPolicy<>(iActorCritic, getRandom());
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
mdp.getActionSpace().setSeed(conf.getSeed());
Integer seed = conf.getSeed();
Random rnd = Nd4j.getRandom();
if(seed != null) {
mdp.getActionSpace().setSeed(seed);
rnd.setSeed(seed);
}
policy = new ACPolicy<>(iActorCritic, rnd);
}
protected AsyncThread newThread(int i, int deviceNum) {
@ -71,7 +79,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
@EqualsAndHashCode(callSuper = false)
public static class A3CConfiguration implements AsyncConfiguration {
int seed;
Integer seed;
int maxEpochStep;
int maxStep;
int numThread;

View File

@ -32,8 +32,8 @@ import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.api.rng.Random;
import java.util.Random;
import java.util.Stack;
/**
@ -50,7 +50,7 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
@Getter
final protected int threadNumber;
final private Random random;
final private Random rnd;
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
@ -59,13 +59,18 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
this.conf = a3cc;
this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber;
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
random = new Random(conf.getSeed() + threadNumber);
Integer seed = conf.getSeed();
rnd = Nd4j.getRandom();
if(seed != null) {
mdp.getActionSpace().setSeed(seed + threadNumber);
rnd.setSeed(seed);
}
}
@Override
protected Policy<O, Integer> getPolicy(IActorCritic net) {
return new ACPolicy(net, random);
return new ACPolicy(net, rnd);
}
/**

View File

@ -43,11 +43,13 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
super(conf);
this.mdp = mdp;
this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
mdp.getActionSpace().setSeed(conf.getSeed());
Integer seed = conf.getSeed();
if(seed != null) {
mdp.getActionSpace().setSeed(seed);
}
}
@Override
@ -70,7 +72,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
@EqualsAndHashCode(callSuper = false)
public static class AsyncNStepQLConfiguration implements AsyncConfiguration {
int seed;
Integer seed;
int maxEpochStep;
int maxStep;
int numThread;

View File

@ -32,8 +32,8 @@ import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random;
import java.util.Random;
import java.util.Stack;
/**
@ -48,7 +48,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
@Getter
final protected int threadNumber;
final private Random random;
final private Random rnd;
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf,
@ -57,13 +57,18 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
this.conf = conf;
this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber;
mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber);
random = new Random(conf.getSeed() + threadNumber);
rnd = Nd4j.getRandom();
Integer seed = conf.getSeed();
if(seed != null) {
mdp.getActionSpace().setSeed(seed + threadNumber);
rnd.setSeed(seed + threadNumber);
}
}
public Policy<O, Integer> getPolicy(IDQN nn) {
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
random, conf.getMinEpsilon(), this);
rnd, conf.getMinEpsilon(), this);
}

View File

@ -20,9 +20,9 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.queue.CircularFifoQueue;
import org.nd4j.linalg.api.rng.Random;
import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.ArrayList;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16.
@ -36,30 +36,28 @@ import java.util.concurrent.ThreadLocalRandom;
public class ExpReplay<A> implements IExpReplay<A> {
final private int batchSize;
final private Random random;
final private Random rnd;
//Implementing this as a circular buffer queue
private CircularFifoQueue<Transition<A>> storage;
public ExpReplay(int maxSize, int batchSize, int seed) {
public ExpReplay(int maxSize, int batchSize, Random rnd) {
this.batchSize = batchSize;
this.random = new Random(seed);
this.rnd = rnd;
storage = new CircularFifoQueue<>(maxSize);
}
public ArrayList<Transition<A>> getBatch(int size) {
ArrayList<Transition<A>> batch = new ArrayList<>(size);
int storageSize = storage.size();
int actualBatchSize = Math.min(storageSize, size);
int[] actualIndex = new int[actualBatchSize];
ThreadLocalRandom r = ThreadLocalRandom.current();
IntSet set = new IntOpenHashSet();
for( int i=0; i<actualBatchSize; i++ ){
int next = r.nextInt(storageSize);
int next = rnd.nextInt(storageSize);
while(set.contains(next)){
next = r.nextInt(storageSize);
next = rnd.nextInt(storageSize);
}
set.add(next);
actualIndex[i] = next;

View File

@ -41,10 +41,6 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac
private final TrainingListenerList listeners = new TrainingListenerList();
public SyncLearning(LConfiguration conf) {
super(conf);
}
/**
* Add a listener at the end of the listener list.
*

View File

@ -30,7 +30,8 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random;
import java.util.ArrayList;
import java.util.List;
@ -53,8 +54,20 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
protected IExpReplay<A> expReplay;
public QLearning(QLConfiguration conf) {
super(conf);
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), conf.getSeed());
this(conf, getSeededRandom(conf.getSeed()));
}
public QLearning(QLConfiguration conf, Random random) {
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
}
private static Random getSeededRandom(Integer seed) {
Random rnd = Nd4j.getRandom();
if(seed != null) {
rnd.setSeed(seed);
}
return rnd;
}
protected abstract EpsGreedy<O, A, AS> getEgPolicy();
@ -160,7 +173,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
@JsonDeserialize(builder = QLConfiguration.QLConfigurationBuilder.class)
public static class QLConfiguration implements LConfiguration {
int seed;
Integer seed;
int maxEpochStep;
int maxStep;
int expRepMaxSize;

View File

@ -31,7 +31,9 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.ArrayList;
@ -70,13 +72,18 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
int epsilonNbStep) {
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
}
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
int epsilonNbStep, Random random) {
super(conf);
this.configuration = conf;
this.mdp = mdp;
qNetwork = dqn;
targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork());
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(),
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
this);
mdp.getActionSpace().setSeed(conf.getSeed());

View File

@ -57,7 +57,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0;
private static final double xThreshold = 2.4;
private final Random rnd = new Random();
private final Random rnd;
@Getter @Setter
private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler;
@ -76,6 +76,14 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
@Getter
private ObservationSpace<CartpoleNative.State> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES });
public CartpoleNative() {
rnd = new Random();
}
public CartpoleNative(int seed) {
rnd = new Random(seed);
}
@Override
public State reset() {

View File

@ -16,18 +16,16 @@
package org.deeplearning4j.rl4j.policy;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException;
import java.util.Random;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -38,47 +36,41 @@ import java.util.Random;
*/
public class ACPolicy<O extends Encodable> extends Policy<O, Integer> {
final private IActorCritic IActorCritic;
Random rd;
final private IActorCritic actorCritic;
Random rnd;
public ACPolicy(IActorCritic IActorCritic) {
this.IActorCritic = IActorCritic;
NeuralNetwork nn = IActorCritic.getNeuralNetworks()[0];
if (nn instanceof ComputationGraph) {
rd = new Random(((ComputationGraph)nn).getConfiguration().getDefaultConfiguration().getSeed());
} else if (nn instanceof MultiLayerNetwork) {
rd = new Random(((MultiLayerNetwork)nn).getDefaultConfiguration().getSeed());
}
public ACPolicy(IActorCritic actorCritic) {
this(actorCritic, Nd4j.getRandom());
}
public ACPolicy(IActorCritic IActorCritic, Random rd) {
this.IActorCritic = IActorCritic;
this.rd = rd;
public ACPolicy(IActorCritic actorCritic, Random rnd) {
this.actorCritic = actorCritic;
this.rnd = rnd;
}
public static <O extends Encodable> ACPolicy<O> load(String path) throws IOException {
return new ACPolicy<O>(ActorCriticCompGraph.load(path));
}
public static <O extends Encodable> ACPolicy<O> load(String path, Random rd) throws IOException {
return new ACPolicy<O>(ActorCriticCompGraph.load(path), rd);
public static <O extends Encodable> ACPolicy<O> load(String path, Random rnd) throws IOException {
return new ACPolicy<O>(ActorCriticCompGraph.load(path), rnd);
}
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy) throws IOException {
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy));
}
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rd) throws IOException {
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rd);
public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rnd) throws IOException {
return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd);
}
public IActorCritic getNeuralNet() {
return IActorCritic;
return actorCritic;
}
public Integer nextAction(INDArray input) {
INDArray output = IActorCritic.outputAll(input)[1];
if (rd == null) {
INDArray output = actorCritic.outputAll(input)[1];
if (rnd == null) {
return Learning.getMaxAction(output);
}
float rVal = rd.nextFloat();
float rVal = rnd.nextFloat();
for (int i = 0; i < output.length(); i++) {
//System.out.println(i + " " + rVal + " " + output.getFloat(i));
if (rVal < output.getFloat(i)) {
@ -91,11 +83,11 @@ public class ACPolicy<O extends Encodable> extends Policy<O, Integer> {
}
public void save(String filename) throws IOException {
IActorCritic.save(filename);
actorCritic.save(filename);
}
public void save(String filenameValue, String filenamePolicy) throws IOException {
IActorCritic.save(filenameValue, filenamePolicy);
actorCritic.save(filenameValue, filenamePolicy);
}
}

View File

@ -16,12 +16,11 @@
package org.deeplearning4j.rl4j.policy;
import lombok.AllArgsConstructor;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Random;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import static org.nd4j.linalg.ops.transforms.Transforms.exp;
@ -31,11 +30,15 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp;
* Boltzmann exploration is a stochastic policy wrt to the
* exponential Q-values as evaluated by the dqn model.
*/
@AllArgsConstructor
public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
final private IDQN dqn;
final private Random rd = new Random(123);
final private Random rnd;
public BoltzmannQ(IDQN dqn, Random random) {
this.dqn = dqn;
this.rnd = random;
}
public IDQN getNeuralNet() {
return dqn;
@ -47,7 +50,7 @@ public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
INDArray exp = exp(output);
double sum = exp.sum(1).getDouble(0);
double picked = rd.nextDouble() * sum;
double picked = rnd.nextDouble() * sum;
for (int i = 0; i < exp.columns(); i++) {
if (picked < exp.getDouble(i))
return i;

View File

@ -24,8 +24,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Random;
import org.nd4j.linalg.api.rng.Random;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16.
@ -45,7 +44,7 @@ public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extend
final private MDP<O, A, AS> mdp;
final private int updateStart;
final private int epsilonNbStep;
final private Random rd;
final private Random rnd;
final private float minEpsilon;
final private StepCountable learning;
@ -58,7 +57,7 @@ public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extend
float ep = getEpsilon();
if (learning.getStepCounter() % 500 == 1)
log.info("EP: " + ep + " " + learning.getStepCounter());
if (rd.nextFloat() > ep)
if (rnd.nextFloat() > ep)
return policy.nextAction(input);
else
return mdp.getActionSpace().randomAction();

View File

@ -87,7 +87,6 @@ public class AsyncLearningTest {
private final IPolicy<MockEncodable, Integer> policy;
public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
super(conf);
this.conf = conf;
this.asyncGlobal = asyncGlobal;
this.policy = policy;

View File

@ -0,0 +1,180 @@
package org.deeplearning4j.rl4j.learning.sync;
import org.deeplearning4j.rl4j.support.MockRandom;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
import static org.junit.Assert.assertEquals;
public class ExpReplayTest {
@Test
public void when_storingElementWithStorageNotFull_expect_elementStored() {
// Arrange
MockRandom randomMock = new MockRandom(null, new int[] { 0 });
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
// Act
Transition<Integer> transition = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 123, 234, false, Nd4j.create(1));
sut.store(transition);
List<Transition<Integer>> results = sut.getBatch(1);
// Assert
assertEquals(1, results.size());
assertEquals(123, (int)results.get(0).getAction());
assertEquals(234, (int)results.get(0).getReward());
}
@Test
public void when_storingElementWithStorageFull_expect_oldestElementReplacedByStored() {
// Arrange
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1 });
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
List<Transition<Integer>> results = sut.getBatch(2);
// Assert
assertEquals(2, results.size());
assertEquals(3, (int)results.get(0).getAction());
assertEquals(4, (int)results.get(0).getReward());
assertEquals(5, (int)results.get(1).getAction());
assertEquals(6, (int)results.get(1).getReward());
}
@Test
public void when_askBatchSizeZeroAndStorageEmpty_expect_emptyBatch() {
// Arrange
MockRandom randomMock = new MockRandom(null, new int[] { 0 });
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
List<Transition<Integer>> results = sut.getBatch(0);
// Assert
assertEquals(0, results.size());
}
@Test
public void when_askBatchSizeZeroAndStorageNotEmpty_expect_emptyBatch() {
// Arrange
MockRandom randomMock = new MockRandom(null, new int[] { 0 });
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
List<Transition<Integer>> results = sut.getBatch(0);
// Assert
assertEquals(0, results.size());
}
@Test
public void when_askBatchSizeGreaterThanStoredCount_expect_batchWithStoredCountElements() {
// Arrange
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2 });
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
List<Transition<Integer>> results = sut.getBatch(10);
// Assert
assertEquals(3, results.size());
assertEquals(1, (int)results.get(0).getAction());
assertEquals(2, (int)results.get(0).getReward());
assertEquals(3, (int)results.get(1).getAction());
assertEquals(4, (int)results.get(1).getReward());
assertEquals(5, (int)results.get(2).getAction());
assertEquals(6, (int)results.get(2).getReward());
}
@Test
public void when_askBatchSizeSmallerThanStoredCount_expect_batchWithAskedElements() {
// Arrange
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2, 3, 4 });
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1));
Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
sut.store(transition4);
sut.store(transition5);
List<Transition<Integer>> results = sut.getBatch(3);
// Assert
assertEquals(3, results.size());
assertEquals(1, (int)results.get(0).getAction());
assertEquals(2, (int)results.get(0).getReward());
assertEquals(3, (int)results.get(1).getAction());
assertEquals(4, (int)results.get(1).getReward());
assertEquals(5, (int)results.get(2).getAction());
assertEquals(6, (int)results.get(2).getReward());
}
@Test
public void when_randomGivesDuplicates_expect_noDuplicatesInBatch() {
// Arrange
MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2, 1, 3, 1, 4 });
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act
Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1));
Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1));
Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1));
Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1));
Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1));
sut.store(transition1);
sut.store(transition2);
sut.store(transition3);
sut.store(transition4);
sut.store(transition5);
List<Transition<Integer>> results = sut.getBatch(3);
// Assert
assertEquals(3, results.size());
assertEquals(1, (int)results.get(0).getAction());
assertEquals(2, (int)results.get(0).getReward());
assertEquals(3, (int)results.get(1).getAction());
assertEquals(4, (int)results.get(1).getReward());
assertEquals(5, (int)results.get(2).getAction());
assertEquals(6, (int)results.get(2).getReward());
}
}

View File

@ -89,7 +89,6 @@ public class SyncLearningTest {
private final LConfiguration conf;
public MockSyncLearning(LConfiguration conf) {
super(conf);
this.conf = conf;
}

View File

@ -13,6 +13,7 @@ import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
@ -27,11 +28,12 @@ public class QLearningDiscreteTest {
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockDQN dqn = new MockDQN();
MockRandom random = new MockRandom(new double[] { 0.7309677600860596, 0.8314409852027893, 0.2405363917350769, 0.6063451766967773, 0.6374173760414124, 0.3090505599975586, 0.5504369735717773, 0.11700659990310669 }, null);
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
0, 1.0, 0, 0, 0, 0, true);
MockDataManager dataManager = new MockDataManager(false);
MockExpReplay expReplay = new MockExpReplay();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10);
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hp);
@ -130,10 +132,10 @@ public class QLearningDiscreteTest {
}
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn,
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
int epsilonNbStep) {
super(mdp, dqn, conf, epsilonNbStep);
int epsilonNbStep, Random rnd) {
super(mdp, dqn, conf, epsilonNbStep, rnd);
addListener(new DataManagerTrainingListener(dataManager));
setExpReplay(expReplay);
}

View File

@ -127,10 +127,10 @@ public class PolicyTest {
.layer(0, new OutputLayer.Builder().nOut(1).lossFunction(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).build()).build());
ACPolicy policy = new ACPolicy(new DummyAC(cg));
assertNotNull(policy.rd);
assertNotNull(policy.rnd);
policy = new ACPolicy(new DummyAC(mln));
assertNotNull(policy.rd);
assertNotNull(policy.rnd);
INDArray input = Nd4j.create(new double[] {1.0, 0.0}, new long[]{1,2});
for (int i = 0; i < 100; i++) {

View File

@ -14,7 +14,7 @@ public class MockAsyncConfiguration implements AsyncConfiguration {
}
@Override
public int getSeed() {
public Integer getSeed() {
return 0;
}

View File

@ -0,0 +1,203 @@
package org.deeplearning4j.rl4j.support;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.ndarray.INDArray;
public class MockRandom implements org.nd4j.linalg.api.rng.Random {
private int randomDoubleValuesIdx = 0;
private final double[] randomDoubleValues;
private int randomIntValuesIdx = 0;
private final int[] randomIntValues;
public MockRandom(double[] randomDoubleValues, int[] randomIntValues) {
this.randomDoubleValues = randomDoubleValues;
this.randomIntValues = randomIntValues;
}
@Override
public void setSeed(int i) {
}
@Override
public void setSeed(int[] ints) {
}
@Override
public void setSeed(long l) {
}
@Override
public long getSeed() {
return 0;
}
@Override
public void nextBytes(byte[] bytes) {
}
@Override
public int nextInt() {
return randomIntValues[randomIntValuesIdx++];
}
@Override
public int nextInt(int i) {
return randomIntValues[randomIntValuesIdx++];
}
@Override
public int nextInt(int i, int i1) {
return randomIntValues[randomIntValuesIdx++];
}
@Override
public long nextLong() {
return randomIntValues[randomIntValuesIdx++];
}
@Override
public boolean nextBoolean() {
return false;
}
@Override
public float nextFloat() {
return (float)randomDoubleValues[randomDoubleValuesIdx++];
}
@Override
public double nextDouble() {
return randomDoubleValues[randomDoubleValuesIdx++];
}
@Override
public double nextGaussian() {
return 0;
}
@Override
public INDArray nextGaussian(int[] ints) {
return null;
}
@Override
public INDArray nextGaussian(long[] longs) {
return null;
}
@Override
public INDArray nextGaussian(char c, int[] ints) {
return null;
}
@Override
public INDArray nextGaussian(char c, long[] longs) {
return null;
}
@Override
public INDArray nextDouble(int[] ints) {
return null;
}
@Override
public INDArray nextDouble(long[] longs) {
return null;
}
@Override
public INDArray nextDouble(char c, int[] ints) {
return null;
}
@Override
public INDArray nextDouble(char c, long[] longs) {
return null;
}
@Override
public INDArray nextFloat(int[] ints) {
return null;
}
@Override
public INDArray nextFloat(long[] longs) {
return null;
}
@Override
public INDArray nextFloat(char c, int[] ints) {
return null;
}
@Override
public INDArray nextFloat(char c, long[] longs) {
return null;
}
@Override
public INDArray nextInt(int[] ints) {
return null;
}
@Override
public INDArray nextInt(long[] longs) {
return null;
}
@Override
public INDArray nextInt(int i, int[] ints) {
return null;
}
@Override
public INDArray nextInt(int i, long[] longs) {
return null;
}
@Override
public Pointer getStatePointer() {
return null;
}
@Override
public long getPosition() {
return 0;
}
@Override
public void reSeed() {
}
@Override
public void reSeed(long l) {
}
@Override
public long rootState() {
return 0;
}
@Override
public long nodeState() {
return 0;
}
@Override
public void setStates(long l, long l1) {
}
@Override
public void close() throws Exception {
}
}