Small build fixes (#127)

* Small build fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix RL4J

Signed-off-by: Alex Black <blacka101@gmail.com>

* Test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another fix

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-08-17 14:13:31 +10:00 committed by GitHub
parent e9b72e78ae
commit 95100ffd8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 54 additions and 43 deletions

View File

@ -80,14 +80,9 @@ public class Word2VecPerformer implements VoidFunction<Pair<List<VocabWord>, Ato
initExpTable(); initExpTable();
if (negative > 0 && conf.contains(Word2VecVariables.TABLE)) { if (negative > 0 && conf.contains(Word2VecVariables.TABLE)) {
try {
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(Word2VecVariables.TABLE).getBytes()); ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(Word2VecVariables.TABLE).getBytes());
DataInputStream dis = new DataInputStream(bis); DataInputStream dis = new DataInputStream(bis);
table = Nd4j.read(dis); table = Nd4j.read(dis);
} catch (IOException e) {
e.printStackTrace();
}
} }
} }

View File

@ -95,16 +95,10 @@ public class Word2VecPerformerVoid implements VoidFunction<Pair<List<VocabWord>,
initExpTable(); initExpTable();
if (negative > 0 && conf.contains(TABLE)) { if (negative > 0 && conf.contains(TABLE)) {
try {
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes()); ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes());
DataInputStream dis = new DataInputStream(bis); DataInputStream dis = new DataInputStream(bis);
table = Nd4j.read(dis); table = Nd4j.read(dis);
} catch (IOException e) {
e.printStackTrace();
} }
}
} }

View File

@ -86,7 +86,7 @@ public class SharedTrainingWorker extends BaseTrainingWorker<SharedTrainingResul
// This method will be called ONLY once, in master thread // This method will be called ONLY once, in master thread
//Before getting NetBroadcastTuple, to ensure it always gets mapped to device 0 //Before getting NetBroadcastTuple, to ensure it always gets mapped to device 0
Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), 0); Nd4j.getAffinityManager().unsafeSetDevice(0);
NetBroadcastTuple tuple = broadcastModel.getValue(); NetBroadcastTuple tuple = broadcastModel.getValue();
if (tuple.getConfiguration() != null) { if (tuple.getConfiguration() != null) {
@ -109,7 +109,7 @@ public class SharedTrainingWorker extends BaseTrainingWorker<SharedTrainingResul
@Override @Override
public ComputationGraph getInitialModelGraph() { public ComputationGraph getInitialModelGraph() {
//Before getting NetBroadcastTuple, to ensure it always gets mapped to device 0 //Before getting NetBroadcastTuple, to ensure it always gets mapped to device 0
Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), 0); Nd4j.getAffinityManager().unsafeSetDevice(0);
NetBroadcastTuple tuple = broadcastModel.getValue(); NetBroadcastTuple tuple = broadcastModel.getValue();
if (tuple.getGraphConfiguration() != null) { if (tuple.getGraphConfiguration() != null) {
ComputationGraphConfiguration conf = tuple.getGraphConfiguration(); ComputationGraphConfiguration conf = tuple.getGraphConfiguration();

View File

@ -45,7 +45,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
public abstract AsyncConfiguration getConfiguration(); public abstract AsyncConfiguration getConfiguration();
protected abstract AsyncThread newThread(int i); protected abstract AsyncThread newThread(int i, int deviceAffinity);
protected abstract IAsyncGlobal<NN> getAsyncGlobal(); protected abstract IAsyncGlobal<NN> getAsyncGlobal();
@ -60,9 +60,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
public void launchThreads() { public void launchThreads() {
startGlobalThread(); startGlobalThread();
for (int i = 0; i < getConfiguration().getNumThread(); i++) { for (int i = 0; i < getConfiguration().getNumThread(); i++) {
Thread t = newThread(i); Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
Nd4j.getAffinityManager().attachThreadToDevice(t,
i % Nd4j.getAffinityManager().getNumberOfDevices());
t.start(); t.start();
} }

View File

@ -32,6 +32,7 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.Constants; import org.deeplearning4j.rl4j.util.Constants;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.factory.Nd4j;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -48,6 +49,8 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
extends Thread implements StepCountable { extends Thread implements StepCountable {
private int threadNumber; private int threadNumber;
@Getter
protected final int deviceNum;
@Getter @Setter @Getter @Setter
private int stepCounter = 0; private int stepCounter = 0;
@Getter @Setter @Getter @Setter
@ -57,8 +60,9 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
@Getter @Getter
private int lastMonitor = -Constants.MONITOR_FREQ; private int lastMonitor = -Constants.MONITOR_FREQ;
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber) { public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
this.threadNumber = threadNumber; this.threadNumber = threadNumber;
this.deviceNum = deviceNum;
} }
public void setHistoryProcessor(IHistoryProcessor.Configuration conf) { public void setHistoryProcessor(IHistoryProcessor.Configuration conf) {
@ -87,6 +91,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
@Override @Override
public void run() { public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
try { try {

View File

@ -44,8 +44,8 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
@Getter @Getter
private NN current; private NN current;
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber) { public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
super(asyncGlobal, threadNumber); super(asyncGlobal, threadNumber, deviceNum);
synchronized (asyncGlobal) { synchronized (asyncGlobal) {
current = (NN)asyncGlobal.getCurrent().clone(); current = (NN)asyncGlobal.getCurrent().clone();
} }

View File

@ -62,9 +62,9 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
mdp.getActionSpace().setSeed(conf.getSeed()); mdp.getActionSpace().setSeed(conf.getSeed());
} }
@Override
protected AsyncThread newThread(int i) { protected AsyncThread newThread(int i, int deviceNum) {
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), i, dataManager); return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), i, dataManager, deviceNum);
} }
public IActorCritic getNeuralNet() { public IActorCritic getNeuralNet() {

View File

@ -63,8 +63,8 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
} }
@Override @Override
public AsyncThread newThread(int i) { public AsyncThread newThread(int i, int deviceNum) {
AsyncThread at = super.newThread(i); AsyncThread at = super.newThread(i, deviceNum);
at.setHistoryProcessor(hpconf); at.setHistoryProcessor(hpconf);
return at; return at;
} }

View File

@ -57,8 +57,8 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
final private Random random; final private Random random;
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal, public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager) { A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager, int deviceNum) {
super(asyncGlobal, threadNumber); super(asyncGlobal, threadNumber, deviceNum);
this.conf = a3cc; this.conf = a3cc;
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber; this.threadNumber = threadNumber;

View File

@ -55,9 +55,9 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
mdp.getActionSpace().setSeed(conf.getSeed()); mdp.getActionSpace().setSeed(conf.getSeed());
} }
@Override
public AsyncThread newThread(int i) { public AsyncThread newThread(int i, int deviceNum) {
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager); return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager, deviceNum);
} }
public IDQN getNeuralNet() { public IDQN getNeuralNet() {

View File

@ -53,8 +53,8 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
} }
@Override @Override
public AsyncThread newThread(int i) { public AsyncThread newThread(int i, int deviceNum) {
AsyncThread at = super.newThread(i); AsyncThread at = super.newThread(i, deviceNum);
at.setHistoryProcessor(hpconf); at.setHistoryProcessor(hpconf);
return at; return at;
} }

View File

@ -56,8 +56,8 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal, public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber, AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber,
IDataManager dataManager) { IDataManager dataManager, int deviceNum) {
super(asyncGlobal, threadNumber); super(asyncGlobal, threadNumber, deviceNum);
this.conf = conf; this.conf = conf;
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber; this.threadNumber = threadNumber;

View File

@ -16,6 +16,8 @@
package org.deeplearning4j.rl4j.learning.sync; package org.deeplearning4j.rl4j.learning.sync;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.queue.CircularFifoQueue; import org.apache.commons.collections4.queue.CircularFifoQueue;
@ -50,7 +52,18 @@ public class ExpReplay<A> implements IExpReplay<A> {
ArrayList<Transition<A>> batch = new ArrayList<>(size); ArrayList<Transition<A>> batch = new ArrayList<>(size);
int storageSize = storage.size(); int storageSize = storage.size();
int actualBatchSize = Math.min(storageSize, size); int actualBatchSize = Math.min(storageSize, size);
int[] actualIndex = ThreadLocalRandom.current().ints(0, storageSize).distinct().limit(actualBatchSize).toArray();
int[] actualIndex = new int[actualBatchSize];
ThreadLocalRandom r = ThreadLocalRandom.current();
IntSet set = new IntOpenHashSet();
for( int i=0; i<actualBatchSize; i++ ){
int next = r.nextInt(storageSize);
while(set.contains(next)){
next = r.nextInt(storageSize);
}
set.add(next);
actualIndex[i] = next;
}
for (int i = 0; i < actualBatchSize; i ++) { for (int i = 0; i < actualBatchSize; i ++) {
Transition<A> trans = storage.get(actualIndex[i]); Transition<A> trans = storage.get(actualIndex[i]);

View File

@ -50,7 +50,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
// final private IExpReplay<A> expReplay; // final private IExpReplay<A> expReplay;
@Getter @Getter
@Setter(AccessLevel.PACKAGE) @Setter(AccessLevel.PACKAGE)
private IExpReplay<A> expReplay; protected IExpReplay<A> expReplay;
public QLearning(QLConfiguration conf) { public QLearning(QLConfiguration conf) {
super(conf); super(conf);

View File

@ -194,7 +194,7 @@ public class AsyncThreadTest {
private final IDataManager dataManager; private final IDataManager dataManager;
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) { public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) {
super(asyncGlobal, threadNumber); super(asyncGlobal, threadNumber, 0);
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;
this.neuralNet = neuralNet; this.neuralNet = neuralNet;

View File

@ -1,6 +1,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
@ -138,5 +139,10 @@ public class QLearningDiscreteTest {
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) { protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
} }
public void setExpReplay(IExpReplay<Integer> exp){
this.expReplay = exp;
}
} }
} }