RL4J: Use Nd4j Random instead of java.util.Random (#8282)
* Changed to use Nd4j Random instead of java.util.Random Signed-off-by: unknown <aboulang2002@yahoo.com> * Changed to use Nd4j.getRandom() instead of the factory Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
This commit is contained in:
		
							parent
							
								
									2d750b69e5
								
							
						
					
					
						commit
						171ce51f46
					
				| @ -42,7 +42,7 @@ public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> ex | ||||
| 
 | ||||
|     interface LConfiguration { | ||||
| 
 | ||||
|         int getSeed(); | ||||
|         Integer getSeed(); | ||||
| 
 | ||||
|         int getMaxEpochStep(); | ||||
| 
 | ||||
|  | ||||
| @ -29,8 +29,6 @@ import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.Random; | ||||
| 
 | ||||
| /** | ||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/27/16. | ||||
|  * | ||||
| @ -43,8 +41,7 @@ import java.util.Random; | ||||
| @Slf4j | ||||
| public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> | ||||
|                 implements ILearning<O, A, AS>, NeuralNetFetchable<NN> { | ||||
|     @Getter | ||||
|     final private Random random; | ||||
| 
 | ||||
|     @Getter @Setter | ||||
|     private int stepCounter = 0; | ||||
|     @Getter @Setter | ||||
| @ -52,10 +49,6 @@ public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A> | ||||
|     @Getter @Setter | ||||
|     private IHistoryProcessor historyProcessor = null; | ||||
| 
 | ||||
|     public Learning(LConfiguration conf) { | ||||
|         random = new Random(conf.getSeed()); | ||||
|     } | ||||
| 
 | ||||
|     public static Integer getMaxAction(INDArray vector) { | ||||
|         return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0); | ||||
|     } | ||||
|  | ||||
| @ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.learning.ILearning; | ||||
|  */ | ||||
| public interface AsyncConfiguration extends ILearning.LConfiguration { | ||||
| 
 | ||||
|     int getSeed(); | ||||
|     Integer getSeed(); | ||||
| 
 | ||||
|     int getMaxEpochStep(); | ||||
| 
 | ||||
|  | ||||
| @ -42,10 +42,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa | ||||
|     @Getter(AccessLevel.PROTECTED) | ||||
|     private final TrainingListenerList listeners = new TrainingListenerList(); | ||||
| 
 | ||||
|     public AsyncLearning(AsyncConfiguration conf) { | ||||
|         super(conf); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Add a {@link TrainingListener} listener at the end of the listener list. | ||||
|      * | ||||
|  | ||||
| @ -26,6 +26,8 @@ import org.deeplearning4j.rl4j.network.ac.IActorCritic; | ||||
| import org.deeplearning4j.rl4j.policy.ACPolicy; | ||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| /** | ||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. | ||||
| @ -48,13 +50,19 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, | ||||
|     final private ACPolicy<O> policy; | ||||
| 
 | ||||
|     public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) { | ||||
|         super(conf); | ||||
|         this.iActorCritic = iActorCritic; | ||||
|         this.mdp = mdp; | ||||
|         this.configuration = conf; | ||||
|         policy = new ACPolicy<>(iActorCritic, getRandom()); | ||||
|         asyncGlobal = new AsyncGlobal<>(iActorCritic, conf); | ||||
|         mdp.getActionSpace().setSeed(conf.getSeed()); | ||||
| 
 | ||||
|         Integer seed = conf.getSeed(); | ||||
|         Random rnd = Nd4j.getRandom(); | ||||
|         if(seed != null) { | ||||
|             mdp.getActionSpace().setSeed(seed); | ||||
|             rnd.setSeed(seed); | ||||
|         } | ||||
| 
 | ||||
|         policy = new ACPolicy<>(iActorCritic, rnd); | ||||
|     } | ||||
| 
 | ||||
|     protected AsyncThread newThread(int i, int deviceNum) { | ||||
| @ -71,7 +79,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, | ||||
|     @EqualsAndHashCode(callSuper = false) | ||||
|     public static class A3CConfiguration implements AsyncConfiguration { | ||||
| 
 | ||||
|         int seed; | ||||
|         Integer seed; | ||||
|         int maxEpochStep; | ||||
|         int maxStep; | ||||
|         int numThread; | ||||
|  | ||||
| @ -32,8 +32,8 @@ import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.indexing.NDArrayIndex; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| 
 | ||||
| import java.util.Random; | ||||
| import java.util.Stack; | ||||
| 
 | ||||
| /** | ||||
| @ -50,7 +50,7 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete< | ||||
|     @Getter | ||||
|     final protected int threadNumber; | ||||
| 
 | ||||
|     final private Random random; | ||||
|     final private Random rnd; | ||||
| 
 | ||||
|     public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal, | ||||
|                              A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners, | ||||
| @ -59,13 +59,18 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete< | ||||
|         this.conf = a3cc; | ||||
|         this.asyncGlobal = asyncGlobal; | ||||
|         this.threadNumber = threadNumber; | ||||
|         mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber); | ||||
|         random = new Random(conf.getSeed() + threadNumber); | ||||
| 
 | ||||
|         Integer seed = conf.getSeed(); | ||||
|         rnd = Nd4j.getRandom(); | ||||
|         if(seed != null) { | ||||
|             mdp.getActionSpace().setSeed(seed + threadNumber); | ||||
|             rnd.setSeed(seed); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     protected Policy<O, Integer> getPolicy(IActorCritic net) { | ||||
|         return new ACPolicy(net, random); | ||||
|         return new ACPolicy(net, rnd); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|  | ||||
| @ -43,11 +43,13 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable> | ||||
| 
 | ||||
| 
 | ||||
|     public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { | ||||
|         super(conf); | ||||
|         this.mdp = mdp; | ||||
|         this.configuration = conf; | ||||
|         this.asyncGlobal = new AsyncGlobal<>(dqn, conf); | ||||
|         mdp.getActionSpace().setSeed(conf.getSeed()); | ||||
|         Integer seed = conf.getSeed(); | ||||
|         if(seed != null) { | ||||
|             mdp.getActionSpace().setSeed(seed); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
| @ -70,7 +72,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable> | ||||
|     @EqualsAndHashCode(callSuper = false) | ||||
|     public static class AsyncNStepQLConfiguration implements AsyncConfiguration { | ||||
| 
 | ||||
|         int seed; | ||||
|         Integer seed; | ||||
|         int maxEpochStep; | ||||
|         int maxStep; | ||||
|         int numThread; | ||||
|  | ||||
| @ -32,8 +32,8 @@ import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| 
 | ||||
| import java.util.Random; | ||||
| import java.util.Stack; | ||||
| 
 | ||||
| /** | ||||
| @ -48,7 +48,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn | ||||
|     @Getter | ||||
|     final protected int threadNumber; | ||||
| 
 | ||||
|     final private Random random; | ||||
|     final private Random rnd; | ||||
| 
 | ||||
|     public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal, | ||||
|                                              AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, | ||||
| @ -57,13 +57,18 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn | ||||
|         this.conf = conf; | ||||
|         this.asyncGlobal = asyncGlobal; | ||||
|         this.threadNumber = threadNumber; | ||||
|         mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber); | ||||
|         random = new Random(conf.getSeed() + threadNumber); | ||||
|         rnd = Nd4j.getRandom(); | ||||
| 
 | ||||
|         Integer seed = conf.getSeed(); | ||||
|         if(seed != null) { | ||||
|             mdp.getActionSpace().setSeed(seed + threadNumber); | ||||
|             rnd.setSeed(seed + threadNumber); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public Policy<O, Integer> getPolicy(IDQN nn) { | ||||
|         return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(), | ||||
|                         random, conf.getMinEpsilon(), this); | ||||
|                 rnd, conf.getMinEpsilon(), this); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -20,9 +20,9 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet; | ||||
| import it.unimi.dsi.fastutil.ints.IntSet; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.apache.commons.collections4.queue.CircularFifoQueue; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| 
 | ||||
| import java.util.*; | ||||
| import java.util.concurrent.ThreadLocalRandom; | ||||
| import java.util.ArrayList; | ||||
| 
 | ||||
| /** | ||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16. | ||||
| @ -36,30 +36,28 @@ import java.util.concurrent.ThreadLocalRandom; | ||||
| public class ExpReplay<A> implements IExpReplay<A> { | ||||
| 
 | ||||
|     final private int batchSize; | ||||
|     final private Random random; | ||||
|     final private Random rnd; | ||||
| 
 | ||||
|     //Implementing this as a circular buffer queue | ||||
|     private CircularFifoQueue<Transition<A>> storage; | ||||
| 
 | ||||
|     public ExpReplay(int maxSize, int batchSize, int seed) { | ||||
|     public ExpReplay(int maxSize, int batchSize, Random rnd) { | ||||
|         this.batchSize = batchSize; | ||||
|         this.random = new Random(seed); | ||||
|         this.rnd = rnd; | ||||
|         storage = new CircularFifoQueue<>(maxSize); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public ArrayList<Transition<A>> getBatch(int size) { | ||||
|         ArrayList<Transition<A>> batch = new ArrayList<>(size); | ||||
|         int storageSize = storage.size(); | ||||
|         int actualBatchSize = Math.min(storageSize, size); | ||||
| 
 | ||||
|         int[] actualIndex = new int[actualBatchSize]; | ||||
|         ThreadLocalRandom r = ThreadLocalRandom.current(); | ||||
|         IntSet set = new IntOpenHashSet(); | ||||
|         for( int i=0; i<actualBatchSize; i++ ){ | ||||
|             int next = r.nextInt(storageSize); | ||||
|             int next = rnd.nextInt(storageSize); | ||||
|             while(set.contains(next)){ | ||||
|                 next = r.nextInt(storageSize); | ||||
|                 next = rnd.nextInt(storageSize); | ||||
|             } | ||||
|             set.add(next); | ||||
|             actualIndex[i] = next; | ||||
|  | ||||
| @ -41,10 +41,6 @@ public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpac | ||||
| 
 | ||||
|     private final TrainingListenerList listeners = new TrainingListenerList(); | ||||
| 
 | ||||
|     public SyncLearning(LConfiguration conf) { | ||||
|         super(conf); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Add a listener at the end of the listener list. | ||||
|      * | ||||
|  | ||||
| @ -30,7 +30,8 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy; | ||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.deeplearning4j.rl4j.util.IDataManager.StatEntry; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| @ -53,8 +54,20 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A | ||||
|     protected IExpReplay<A> expReplay; | ||||
| 
 | ||||
|     public QLearning(QLConfiguration conf) { | ||||
|         super(conf); | ||||
|         expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), conf.getSeed()); | ||||
|         this(conf, getSeededRandom(conf.getSeed())); | ||||
|     } | ||||
| 
 | ||||
|     public QLearning(QLConfiguration conf, Random random) { | ||||
|         expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random); | ||||
|     } | ||||
| 
 | ||||
|     private static Random getSeededRandom(Integer seed) { | ||||
|         Random rnd = Nd4j.getRandom(); | ||||
|         if(seed != null) { | ||||
|             rnd.setSeed(seed); | ||||
|         } | ||||
| 
 | ||||
|         return rnd; | ||||
|     } | ||||
| 
 | ||||
|     protected abstract EpsGreedy<O, A, AS> getEgPolicy(); | ||||
| @ -160,7 +173,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A | ||||
|     @JsonDeserialize(builder = QLConfiguration.QLConfigurationBuilder.class) | ||||
|     public static class QLConfiguration implements LConfiguration { | ||||
| 
 | ||||
|         int seed; | ||||
|         Integer seed; | ||||
|         int maxEpochStep; | ||||
|         int maxStep; | ||||
|         int expRepMaxSize; | ||||
|  | ||||
| @ -31,7 +31,9 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy; | ||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| import org.nd4j.linalg.dataset.api.DataSet; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.util.ArrayUtil; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| @ -70,13 +72,18 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O | ||||
| 
 | ||||
|     public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf, | ||||
|                              int epsilonNbStep) { | ||||
|         this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed())); | ||||
|     } | ||||
| 
 | ||||
|     public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf, | ||||
|                              int epsilonNbStep, Random random) { | ||||
|         super(conf); | ||||
|         this.configuration = conf; | ||||
|         this.mdp = mdp; | ||||
|         qNetwork = dqn; | ||||
|         targetQNetwork = dqn.clone(); | ||||
|         policy = new DQNPolicy(getQNetwork()); | ||||
|         egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(), | ||||
|         egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(), | ||||
|                 this); | ||||
|         mdp.getActionSpace().setSeed(conf.getSeed()); | ||||
| 
 | ||||
|  | ||||
| @ -57,7 +57,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre | ||||
|     private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0; | ||||
|     private static final double xThreshold = 2.4; | ||||
| 
 | ||||
|     private final Random rnd = new Random(); | ||||
|     private final Random rnd; | ||||
| 
 | ||||
|     @Getter @Setter | ||||
|     private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler; | ||||
| @ -76,6 +76,14 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre | ||||
|     @Getter | ||||
|     private ObservationSpace<CartpoleNative.State> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES }); | ||||
| 
 | ||||
|     public CartpoleNative() { | ||||
|         rnd = new Random(); | ||||
|     } | ||||
| 
 | ||||
|     public CartpoleNative(int seed) { | ||||
|         rnd = new Random(seed); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public State reset() { | ||||
| 
 | ||||
|  | ||||
| @ -16,18 +16,16 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.rl4j.policy; | ||||
| 
 | ||||
| import org.deeplearning4j.nn.api.NeuralNetwork; | ||||
| import org.deeplearning4j.nn.graph.ComputationGraph; | ||||
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||||
| import org.deeplearning4j.rl4j.learning.Learning; | ||||
| import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph; | ||||
| import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate; | ||||
| import org.deeplearning4j.rl4j.network.ac.IActorCritic; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| import java.util.Random; | ||||
| 
 | ||||
| /** | ||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. | ||||
| @ -38,47 +36,41 @@ import java.util.Random; | ||||
|  */ | ||||
| public class ACPolicy<O extends Encodable> extends Policy<O, Integer> { | ||||
| 
 | ||||
|     final private IActorCritic IActorCritic; | ||||
|     Random rd; | ||||
|     final private IActorCritic actorCritic; | ||||
|     Random rnd; | ||||
| 
 | ||||
|     public ACPolicy(IActorCritic IActorCritic) { | ||||
|         this.IActorCritic = IActorCritic; | ||||
|         NeuralNetwork nn = IActorCritic.getNeuralNetworks()[0]; | ||||
|         if (nn instanceof ComputationGraph) { | ||||
|             rd = new Random(((ComputationGraph)nn).getConfiguration().getDefaultConfiguration().getSeed()); | ||||
|         } else if (nn instanceof MultiLayerNetwork) { | ||||
|             rd = new Random(((MultiLayerNetwork)nn).getDefaultConfiguration().getSeed()); | ||||
|         } | ||||
|     public ACPolicy(IActorCritic actorCritic) { | ||||
|         this(actorCritic, Nd4j.getRandom()); | ||||
|     } | ||||
|     public ACPolicy(IActorCritic IActorCritic, Random rd) { | ||||
|         this.IActorCritic = IActorCritic; | ||||
|         this.rd = rd; | ||||
|     public ACPolicy(IActorCritic actorCritic, Random rnd) { | ||||
|         this.actorCritic = actorCritic; | ||||
|         this.rnd = rnd; | ||||
|     } | ||||
| 
 | ||||
|     public static <O extends Encodable> ACPolicy<O> load(String path) throws IOException { | ||||
|         return new ACPolicy<O>(ActorCriticCompGraph.load(path)); | ||||
|     } | ||||
|     public static <O extends Encodable> ACPolicy<O> load(String path, Random rd) throws IOException { | ||||
|         return new ACPolicy<O>(ActorCriticCompGraph.load(path), rd); | ||||
|     public static <O extends Encodable> ACPolicy<O> load(String path, Random rnd) throws IOException { | ||||
|         return new ACPolicy<O>(ActorCriticCompGraph.load(path), rnd); | ||||
|     } | ||||
| 
 | ||||
|     public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy) throws IOException { | ||||
|         return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy)); | ||||
|     } | ||||
|     public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rd) throws IOException { | ||||
|         return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rd); | ||||
|     public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rnd) throws IOException { | ||||
|         return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd); | ||||
|     } | ||||
| 
 | ||||
|     public IActorCritic getNeuralNet() { | ||||
|         return IActorCritic; | ||||
|         return actorCritic; | ||||
|     } | ||||
| 
 | ||||
|     public Integer nextAction(INDArray input) { | ||||
|         INDArray output = IActorCritic.outputAll(input)[1]; | ||||
|         if (rd == null) { | ||||
|         INDArray output = actorCritic.outputAll(input)[1]; | ||||
|         if (rnd == null) { | ||||
|             return Learning.getMaxAction(output); | ||||
|         } | ||||
|         float rVal = rd.nextFloat(); | ||||
|         float rVal = rnd.nextFloat(); | ||||
|         for (int i = 0; i < output.length(); i++) { | ||||
|             //System.out.println(i + " " + rVal + " " + output.getFloat(i)); | ||||
|             if (rVal < output.getFloat(i)) { | ||||
| @ -91,11 +83,11 @@ public class ACPolicy<O extends Encodable> extends Policy<O, Integer> { | ||||
|     } | ||||
| 
 | ||||
|     public void save(String filename) throws IOException { | ||||
|         IActorCritic.save(filename); | ||||
|         actorCritic.save(filename); | ||||
|     } | ||||
| 
 | ||||
|     public void save(String filenameValue, String filenamePolicy) throws IOException { | ||||
|         IActorCritic.save(filenameValue, filenamePolicy); | ||||
|         actorCritic.save(filenameValue, filenamePolicy); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
|  | ||||
| @ -16,12 +16,11 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.rl4j.policy; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| import java.util.Random; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import static org.nd4j.linalg.ops.transforms.Transforms.exp; | ||||
| 
 | ||||
| @ -31,11 +30,15 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp; | ||||
|  * Boltzmann exploration is a stochastic policy wrt to the | ||||
|  * exponential Q-values as evaluated by the dqn model. | ||||
|  */ | ||||
| @AllArgsConstructor | ||||
| public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> { | ||||
| 
 | ||||
|     final private IDQN dqn; | ||||
|     final private Random rd = new Random(123); | ||||
|     final private Random rnd; | ||||
| 
 | ||||
|     public BoltzmannQ(IDQN dqn, Random random) { | ||||
|         this.dqn = dqn; | ||||
|         this.rnd = random; | ||||
|     } | ||||
| 
 | ||||
|     public IDQN getNeuralNet() { | ||||
|         return dqn; | ||||
| @ -47,7 +50,7 @@ public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> { | ||||
|         INDArray exp = exp(output); | ||||
| 
 | ||||
|         double sum = exp.sum(1).getDouble(0); | ||||
|         double picked = rd.nextDouble() * sum; | ||||
|         double picked = rnd.nextDouble() * sum; | ||||
|         for (int i = 0; i < exp.columns(); i++) { | ||||
|             if (picked < exp.getDouble(i)) | ||||
|                 return i; | ||||
|  | ||||
| @ -24,8 +24,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet; | ||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| import java.util.Random; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| 
 | ||||
| /** | ||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16. | ||||
| @ -45,7 +44,7 @@ public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extend | ||||
|     final private MDP<O, A, AS> mdp; | ||||
|     final private int updateStart; | ||||
|     final private int epsilonNbStep; | ||||
|     final private Random rd; | ||||
|     final private Random rnd; | ||||
|     final private float minEpsilon; | ||||
|     final private StepCountable learning; | ||||
| 
 | ||||
| @ -58,7 +57,7 @@ public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extend | ||||
|         float ep = getEpsilon(); | ||||
|         if (learning.getStepCounter() % 500 == 1) | ||||
|             log.info("EP: " + ep + " " + learning.getStepCounter()); | ||||
|         if (rd.nextFloat() > ep) | ||||
|         if (rnd.nextFloat() > ep) | ||||
|             return policy.nextAction(input); | ||||
|         else | ||||
|             return mdp.getActionSpace().randomAction(); | ||||
|  | ||||
| @ -87,7 +87,6 @@ public class AsyncLearningTest { | ||||
|         private final IPolicy<MockEncodable, Integer> policy; | ||||
| 
 | ||||
|         public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) { | ||||
|             super(conf); | ||||
|             this.conf = conf; | ||||
|             this.asyncGlobal = asyncGlobal; | ||||
|             this.policy = policy; | ||||
|  | ||||
| @ -0,0 +1,180 @@ | ||||
| package org.deeplearning4j.rl4j.learning.sync; | ||||
| 
 | ||||
| import org.deeplearning4j.rl4j.support.MockRandom; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| public class ExpReplayTest { | ||||
|     @Test | ||||
|     public void when_storingElementWithStorageNotFull_expect_elementStored() { | ||||
|         // Arrange | ||||
|         MockRandom randomMock = new MockRandom(null, new int[] { 0 }); | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 123, 234, false, Nd4j.create(1)); | ||||
|         sut.store(transition); | ||||
|         List<Transition<Integer>> results = sut.getBatch(1); | ||||
| 
 | ||||
|         // Assert | ||||
|         assertEquals(1, results.size()); | ||||
|         assertEquals(123, (int)results.get(0).getAction()); | ||||
|         assertEquals(234, (int)results.get(0).getReward()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void when_storingElementWithStorageFull_expect_oldestElementReplacedByStored() { | ||||
|         // Arrange | ||||
|         MockRandom randomMock = new MockRandom(null, new int[] { 0, 1 }); | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); | ||||
|         sut.store(transition1); | ||||
|         sut.store(transition2); | ||||
|         sut.store(transition3); | ||||
|         List<Transition<Integer>> results = sut.getBatch(2); | ||||
| 
 | ||||
|         // Assert | ||||
|         assertEquals(2, results.size()); | ||||
| 
 | ||||
|         assertEquals(3, (int)results.get(0).getAction()); | ||||
|         assertEquals(4, (int)results.get(0).getReward()); | ||||
| 
 | ||||
|         assertEquals(5, (int)results.get(1).getAction()); | ||||
|         assertEquals(6, (int)results.get(1).getReward()); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void when_askBatchSizeZeroAndStorageEmpty_expect_emptyBatch() { | ||||
|         // Arrange | ||||
|         MockRandom randomMock = new MockRandom(null, new int[] { 0 }); | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         List<Transition<Integer>> results = sut.getBatch(0); | ||||
| 
 | ||||
|         // Assert | ||||
|         assertEquals(0, results.size()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void when_askBatchSizeZeroAndStorageNotEmpty_expect_emptyBatch() { | ||||
|         // Arrange | ||||
|         MockRandom randomMock = new MockRandom(null, new int[] { 0 }); | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); | ||||
|         sut.store(transition1); | ||||
|         sut.store(transition2); | ||||
|         sut.store(transition3); | ||||
|         List<Transition<Integer>> results = sut.getBatch(0); | ||||
| 
 | ||||
|         // Assert | ||||
|         assertEquals(0, results.size()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void when_askBatchSizeGreaterThanStoredCount_expect_batchWithStoredCountElements() { | ||||
|         // Arrange | ||||
|         MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2 }); | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); | ||||
|         sut.store(transition1); | ||||
|         sut.store(transition2); | ||||
|         sut.store(transition3); | ||||
|         List<Transition<Integer>> results = sut.getBatch(10); | ||||
| 
 | ||||
|         // Assert | ||||
|         assertEquals(3, results.size()); | ||||
| 
 | ||||
|         assertEquals(1, (int)results.get(0).getAction()); | ||||
|         assertEquals(2, (int)results.get(0).getReward()); | ||||
| 
 | ||||
|         assertEquals(3, (int)results.get(1).getAction()); | ||||
|         assertEquals(4, (int)results.get(1).getReward()); | ||||
| 
 | ||||
|         assertEquals(5, (int)results.get(2).getAction()); | ||||
|         assertEquals(6, (int)results.get(2).getReward()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void when_askBatchSizeSmallerThanStoredCount_expect_batchWithAskedElements() { | ||||
|         // Arrange | ||||
|         MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2, 3, 4 }); | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1)); | ||||
|         sut.store(transition1); | ||||
|         sut.store(transition2); | ||||
|         sut.store(transition3); | ||||
|         sut.store(transition4); | ||||
|         sut.store(transition5); | ||||
|         List<Transition<Integer>> results = sut.getBatch(3); | ||||
| 
 | ||||
|         // Assert | ||||
|         assertEquals(3, results.size()); | ||||
| 
 | ||||
|         assertEquals(1, (int)results.get(0).getAction()); | ||||
|         assertEquals(2, (int)results.get(0).getReward()); | ||||
| 
 | ||||
|         assertEquals(3, (int)results.get(1).getAction()); | ||||
|         assertEquals(4, (int)results.get(1).getReward()); | ||||
| 
 | ||||
|         assertEquals(5, (int)results.get(2).getAction()); | ||||
|         assertEquals(6, (int)results.get(2).getReward()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void when_randomGivesDuplicates_expect_noDuplicatesInBatch() { | ||||
|         // Arrange | ||||
|         MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2, 1, 3, 1, 4 }); | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition1 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition2 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition3 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition4 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1)); | ||||
|         Transition<Integer> transition5 = new Transition<Integer>(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1)); | ||||
|         sut.store(transition1); | ||||
|         sut.store(transition2); | ||||
|         sut.store(transition3); | ||||
|         sut.store(transition4); | ||||
|         sut.store(transition5); | ||||
|         List<Transition<Integer>> results = sut.getBatch(3); | ||||
| 
 | ||||
|         // Assert | ||||
|         assertEquals(3, results.size()); | ||||
| 
 | ||||
|         assertEquals(1, (int)results.get(0).getAction()); | ||||
|         assertEquals(2, (int)results.get(0).getReward()); | ||||
| 
 | ||||
|         assertEquals(3, (int)results.get(1).getAction()); | ||||
|         assertEquals(4, (int)results.get(1).getReward()); | ||||
| 
 | ||||
|         assertEquals(5, (int)results.get(2).getAction()); | ||||
|         assertEquals(6, (int)results.get(2).getReward()); | ||||
| 
 | ||||
|     } | ||||
| } | ||||
| @ -89,7 +89,6 @@ public class SyncLearningTest { | ||||
|         private final LConfiguration conf; | ||||
| 
 | ||||
|         public MockSyncLearning(LConfiguration conf) { | ||||
|             super(conf); | ||||
|             this.conf = conf; | ||||
|         } | ||||
| 
 | ||||
|  | ||||
| @ -13,6 +13,7 @@ import org.deeplearning4j.rl4j.util.IDataManager; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.dataset.api.DataSet; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| @ -27,11 +28,12 @@ public class QLearningDiscreteTest { | ||||
|         MockObservationSpace observationSpace = new MockObservationSpace(); | ||||
|         MockMDP mdp = new MockMDP(observationSpace); | ||||
|         MockDQN dqn = new MockDQN(); | ||||
|         MockRandom random = new MockRandom(new double[] { 0.7309677600860596, 0.8314409852027893, 0.2405363917350769, 0.6063451766967773, 0.6374173760414124, 0.3090505599975586, 0.5504369735717773, 0.11700659990310669 }, null); | ||||
|         QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, | ||||
|                 0, 1.0, 0, 0, 0, 0, true); | ||||
|         MockDataManager dataManager = new MockDataManager(false); | ||||
|         MockExpReplay expReplay = new MockExpReplay(); | ||||
|         TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10); | ||||
|         TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random); | ||||
|         IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); | ||||
|         MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); | ||||
|         sut.setHistoryProcessor(hp); | ||||
| @ -130,10 +132,10 @@ public class QLearningDiscreteTest { | ||||
|     } | ||||
| 
 | ||||
|     public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> { | ||||
|         public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn, | ||||
|         public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn, | ||||
|                                      QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay, | ||||
|                                      int epsilonNbStep) { | ||||
|             super(mdp, dqn, conf, epsilonNbStep); | ||||
|                                      int epsilonNbStep, Random rnd) { | ||||
|             super(mdp, dqn, conf, epsilonNbStep, rnd); | ||||
|             addListener(new DataManagerTrainingListener(dataManager)); | ||||
|             setExpReplay(expReplay); | ||||
|         } | ||||
|  | ||||
| @ -127,10 +127,10 @@ public class PolicyTest { | ||||
|                 .layer(0, new OutputLayer.Builder().nOut(1).lossFunction(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).build()).build()); | ||||
| 
 | ||||
|         ACPolicy policy = new ACPolicy(new DummyAC(cg)); | ||||
|         assertNotNull(policy.rd); | ||||
|         assertNotNull(policy.rnd); | ||||
| 
 | ||||
|         policy = new ACPolicy(new DummyAC(mln)); | ||||
|         assertNotNull(policy.rd); | ||||
|         assertNotNull(policy.rnd); | ||||
| 
 | ||||
|         INDArray input = Nd4j.create(new double[] {1.0, 0.0}, new long[]{1,2}); | ||||
|         for (int i = 0; i < 100; i++) { | ||||
|  | ||||
| @ -14,7 +14,7 @@ public class MockAsyncConfiguration implements AsyncConfiguration { | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int getSeed() { | ||||
|     public Integer getSeed() { | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|  | ||||
| @ -0,0 +1,203 @@ | ||||
| package org.deeplearning4j.rl4j.support; | ||||
| 
 | ||||
| import org.bytedeco.javacpp.Pointer; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| public class MockRandom implements org.nd4j.linalg.api.rng.Random { | ||||
| 
 | ||||
|     private int randomDoubleValuesIdx = 0; | ||||
|     private final double[] randomDoubleValues; | ||||
| 
 | ||||
|     private int randomIntValuesIdx = 0; | ||||
|     private final int[] randomIntValues; | ||||
| 
 | ||||
|     public MockRandom(double[] randomDoubleValues, int[] randomIntValues) { | ||||
|         this.randomDoubleValues = randomDoubleValues; | ||||
|         this.randomIntValues = randomIntValues; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setSeed(int i) { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setSeed(int[] ints) { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setSeed(long l) { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public long getSeed() { | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void nextBytes(byte[] bytes) { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int nextInt() { | ||||
|         return randomIntValues[randomIntValuesIdx++]; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int nextInt(int i) { | ||||
|         return randomIntValues[randomIntValuesIdx++]; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int nextInt(int i, int i1) { | ||||
|         return randomIntValues[randomIntValuesIdx++]; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public long nextLong() { | ||||
|         return randomIntValues[randomIntValuesIdx++]; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean nextBoolean() { | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public float nextFloat() { | ||||
|         return (float)randomDoubleValues[randomDoubleValuesIdx++]; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public double nextDouble() { | ||||
|         return randomDoubleValues[randomDoubleValuesIdx++]; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public double nextGaussian() { | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextGaussian(int[] ints) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextGaussian(long[] longs) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextGaussian(char c, int[] ints) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextGaussian(char c, long[] longs) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextDouble(int[] ints) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextDouble(long[] longs) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextDouble(char c, int[] ints) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextDouble(char c, long[] longs) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextFloat(int[] ints) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextFloat(long[] longs) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextFloat(char c, int[] ints) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextFloat(char c, long[] longs) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextInt(int[] ints) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextInt(long[] longs) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextInt(int i, int[] ints) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray nextInt(int i, long[] longs) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Pointer getStatePointer() { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public long getPosition() { | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void reSeed() { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void reSeed(long l) { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public long rootState() { | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public long nodeState() { | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setStates(long l, long l1) { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void close() throws Exception { | ||||
| 
 | ||||
|     } | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user