Small build fixes (#127)

* Small build fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix RL4J

Signed-off-by: Alex Black <blacka101@gmail.com>

* Test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another fix

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-08-17 14:13:31 +10:00 committed by GitHub
parent e9b72e78ae
commit 95100ffd8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 54 additions and 43 deletions

View File

@ -80,14 +80,9 @@ public class Word2VecPerformer implements VoidFunction<Pair<List<VocabWord>, Ato
initExpTable();
if (negative > 0 && conf.contains(Word2VecVariables.TABLE)) {
try {
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(Word2VecVariables.TABLE).getBytes());
DataInputStream dis = new DataInputStream(bis);
table = Nd4j.read(dis);
} catch (IOException e) {
e.printStackTrace();
}
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(Word2VecVariables.TABLE).getBytes());
DataInputStream dis = new DataInputStream(bis);
table = Nd4j.read(dis);
}
}

View File

@ -95,16 +95,10 @@ public class Word2VecPerformerVoid implements VoidFunction<Pair<List<VocabWord>,
initExpTable();
if (negative > 0 && conf.contains(TABLE)) {
try {
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes());
DataInputStream dis = new DataInputStream(bis);
table = Nd4j.read(dis);
} catch (IOException e) {
e.printStackTrace();
}
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes());
DataInputStream dis = new DataInputStream(bis);
table = Nd4j.read(dis);
}
}

View File

@ -86,7 +86,7 @@ public class SharedTrainingWorker extends BaseTrainingWorker<SharedTrainingResul
// This method will be called ONLY once, in master thread
//Before getting NetBroadcastTuple, to ensure it always gets mapped to device 0
Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), 0);
Nd4j.getAffinityManager().unsafeSetDevice(0);
NetBroadcastTuple tuple = broadcastModel.getValue();
if (tuple.getConfiguration() != null) {
@ -109,7 +109,7 @@ public class SharedTrainingWorker extends BaseTrainingWorker<SharedTrainingResul
@Override
public ComputationGraph getInitialModelGraph() {
//Before getting NetBroadcastTuple, to ensure it always gets mapped to device 0
Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), 0);
Nd4j.getAffinityManager().unsafeSetDevice(0);
NetBroadcastTuple tuple = broadcastModel.getValue();
if (tuple.getGraphConfiguration() != null) {
ComputationGraphConfiguration conf = tuple.getGraphConfiguration();

View File

@ -45,7 +45,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
public abstract AsyncConfiguration getConfiguration();
protected abstract AsyncThread newThread(int i);
protected abstract AsyncThread newThread(int i, int deviceAffinity);
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
@ -60,9 +60,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
public void launchThreads() {
startGlobalThread();
for (int i = 0; i < getConfiguration().getNumThread(); i++) {
Thread t = newThread(i);
Nd4j.getAffinityManager().attachThreadToDevice(t,
i % Nd4j.getAffinityManager().getNumberOfDevices());
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
t.start();
}

View File

@ -32,6 +32,7 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.Constants;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -48,6 +49,8 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
extends Thread implements StepCountable {
private int threadNumber;
@Getter
protected final int deviceNum;
@Getter @Setter
private int stepCounter = 0;
@Getter @Setter
@ -57,8 +60,9 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
@Getter
private int lastMonitor = -Constants.MONITOR_FREQ;
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber) {
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
this.threadNumber = threadNumber;
this.deviceNum = deviceNum;
}
public void setHistoryProcessor(IHistoryProcessor.Configuration conf) {
@ -87,6 +91,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
@Override
public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
try {

View File

@ -44,8 +44,8 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
@Getter
private NN current;
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber) {
super(asyncGlobal, threadNumber);
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
super(asyncGlobal, threadNumber, deviceNum);
synchronized (asyncGlobal) {
current = (NN)asyncGlobal.getCurrent().clone();
}

View File

@ -62,9 +62,9 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
mdp.getActionSpace().setSeed(conf.getSeed());
}
protected AsyncThread newThread(int i) {
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), i, dataManager);
@Override
protected AsyncThread newThread(int i, int deviceNum) {
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), i, dataManager, deviceNum);
}
public IActorCritic getNeuralNet() {

View File

@ -63,8 +63,8 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
}
@Override
public AsyncThread newThread(int i) {
AsyncThread at = super.newThread(i);
public AsyncThread newThread(int i, int deviceNum) {
AsyncThread at = super.newThread(i, deviceNum);
at.setHistoryProcessor(hpconf);
return at;
}

View File

@ -57,8 +57,8 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
final private Random random;
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager) {
super(asyncGlobal, threadNumber);
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager, int deviceNum) {
super(asyncGlobal, threadNumber, deviceNum);
this.conf = a3cc;
this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber;

View File

@ -55,9 +55,9 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
mdp.getActionSpace().setSeed(conf.getSeed());
}
public AsyncThread newThread(int i) {
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager);
@Override
public AsyncThread newThread(int i, int deviceNum) {
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager, deviceNum);
}
public IDQN getNeuralNet() {

View File

@ -53,8 +53,8 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
}
@Override
public AsyncThread newThread(int i) {
AsyncThread at = super.newThread(i);
public AsyncThread newThread(int i, int deviceNum) {
AsyncThread at = super.newThread(i, deviceNum);
at.setHistoryProcessor(hpconf);
return at;
}

View File

@ -56,8 +56,8 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber,
IDataManager dataManager) {
super(asyncGlobal, threadNumber);
IDataManager dataManager, int deviceNum) {
super(asyncGlobal, threadNumber, deviceNum);
this.conf = conf;
this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber;

View File

@ -16,6 +16,8 @@
package org.deeplearning4j.rl4j.learning.sync;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.queue.CircularFifoQueue;
@ -50,7 +52,18 @@ public class ExpReplay<A> implements IExpReplay<A> {
ArrayList<Transition<A>> batch = new ArrayList<>(size);
int storageSize = storage.size();
int actualBatchSize = Math.min(storageSize, size);
int[] actualIndex = ThreadLocalRandom.current().ints(0, storageSize).distinct().limit(actualBatchSize).toArray();
int[] actualIndex = new int[actualBatchSize];
ThreadLocalRandom r = ThreadLocalRandom.current();
IntSet set = new IntOpenHashSet();
for( int i=0; i<actualBatchSize; i++ ){
int next = r.nextInt(storageSize);
while(set.contains(next)){
next = r.nextInt(storageSize);
}
set.add(next);
actualIndex[i] = next;
}
for (int i = 0; i < actualBatchSize; i ++) {
Transition<A> trans = storage.get(actualIndex[i]);

View File

@ -50,7 +50,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
// final private IExpReplay<A> expReplay;
@Getter
@Setter(AccessLevel.PACKAGE)
private IExpReplay<A> expReplay;
protected IExpReplay<A> expReplay;
public QLearning(QLConfiguration conf) {
super(conf);

View File

@ -194,7 +194,7 @@ public class AsyncThreadTest {
private final IDataManager dataManager;
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) {
super(asyncGlobal, threadNumber);
super(asyncGlobal, threadNumber, 0);
this.asyncGlobal = asyncGlobal;
this.neuralNet = neuralNet;

View File

@ -1,6 +1,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
@ -138,5 +139,10 @@ public class QLearningDiscreteTest {
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
}
public void setExpReplay(IExpReplay<Integer> exp){
this.expReplay = exp;
}
}
}