Small build fixes (#127)
* Small build fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fix RL4J Signed-off-by: Alex Black <blacka101@gmail.com> * Test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Another fix Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
e9b72e78ae
commit
95100ffd8c
|
@ -80,14 +80,9 @@ public class Word2VecPerformer implements VoidFunction<Pair<List<VocabWord>, Ato
|
|||
initExpTable();
|
||||
|
||||
if (negative > 0 && conf.contains(Word2VecVariables.TABLE)) {
|
||||
try {
|
||||
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(Word2VecVariables.TABLE).getBytes());
|
||||
DataInputStream dis = new DataInputStream(bis);
|
||||
table = Nd4j.read(dis);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(Word2VecVariables.TABLE).getBytes());
|
||||
DataInputStream dis = new DataInputStream(bis);
|
||||
table = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -95,16 +95,10 @@ public class Word2VecPerformerVoid implements VoidFunction<Pair<List<VocabWord>,
|
|||
initExpTable();
|
||||
|
||||
if (negative > 0 && conf.contains(TABLE)) {
|
||||
try {
|
||||
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes());
|
||||
DataInputStream dis = new DataInputStream(bis);
|
||||
table = Nd4j.read(dis);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes());
|
||||
DataInputStream dis = new DataInputStream(bis);
|
||||
table = Nd4j.read(dis);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ public class SharedTrainingWorker extends BaseTrainingWorker<SharedTrainingResul
|
|||
|
||||
// This method will be called ONLY once, in master thread
|
||||
//Before getting NetBroadcastTuple, to ensure it always gets mapped to device 0
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), 0);
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(0);
|
||||
|
||||
NetBroadcastTuple tuple = broadcastModel.getValue();
|
||||
if (tuple.getConfiguration() != null) {
|
||||
|
@ -109,7 +109,7 @@ public class SharedTrainingWorker extends BaseTrainingWorker<SharedTrainingResul
|
|||
@Override
|
||||
public ComputationGraph getInitialModelGraph() {
|
||||
//Before getting NetBroadcastTuple, to ensure it always gets mapped to device 0
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), 0);
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(0);
|
||||
NetBroadcastTuple tuple = broadcastModel.getValue();
|
||||
if (tuple.getGraphConfiguration() != null) {
|
||||
ComputationGraphConfiguration conf = tuple.getGraphConfiguration();
|
||||
|
|
|
@ -45,7 +45,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
|||
|
||||
public abstract AsyncConfiguration getConfiguration();
|
||||
|
||||
protected abstract AsyncThread newThread(int i);
|
||||
protected abstract AsyncThread newThread(int i, int deviceAffinity);
|
||||
|
||||
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
||||
|
||||
|
@ -60,9 +60,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
|||
public void launchThreads() {
|
||||
startGlobalThread();
|
||||
for (int i = 0; i < getConfiguration().getNumThread(); i++) {
|
||||
Thread t = newThread(i);
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(t,
|
||||
i % Nd4j.getAffinityManager().getNumberOfDevices());
|
||||
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
|
||||
t.start();
|
||||
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
|
|||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.Constants;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
|
@ -48,6 +49,8 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
extends Thread implements StepCountable {
|
||||
|
||||
private int threadNumber;
|
||||
@Getter
|
||||
protected final int deviceNum;
|
||||
@Getter @Setter
|
||||
private int stepCounter = 0;
|
||||
@Getter @Setter
|
||||
|
@ -57,8 +60,9 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
@Getter
|
||||
private int lastMonitor = -Constants.MONITOR_FREQ;
|
||||
|
||||
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber) {
|
||||
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
|
||||
this.threadNumber = threadNumber;
|
||||
this.deviceNum = deviceNum;
|
||||
}
|
||||
|
||||
public void setHistoryProcessor(IHistoryProcessor.Configuration conf) {
|
||||
|
@ -87,6 +91,7 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
|
|||
|
||||
@Override
|
||||
public void run() {
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
|
||||
|
||||
|
||||
try {
|
||||
|
|
|
@ -44,8 +44,8 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
|||
@Getter
|
||||
private NN current;
|
||||
|
||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber) {
|
||||
super(asyncGlobal, threadNumber);
|
||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, int threadNumber, int deviceNum) {
|
||||
super(asyncGlobal, threadNumber, deviceNum);
|
||||
synchronized (asyncGlobal) {
|
||||
current = (NN)asyncGlobal.getCurrent().clone();
|
||||
}
|
||||
|
|
|
@ -62,9 +62,9 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
|||
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||
}
|
||||
|
||||
|
||||
protected AsyncThread newThread(int i) {
|
||||
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), i, dataManager);
|
||||
@Override
|
||||
protected AsyncThread newThread(int i, int deviceNum) {
|
||||
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), i, dataManager, deviceNum);
|
||||
}
|
||||
|
||||
public IActorCritic getNeuralNet() {
|
||||
|
|
|
@ -63,8 +63,8 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
|||
}
|
||||
|
||||
@Override
|
||||
public AsyncThread newThread(int i) {
|
||||
AsyncThread at = super.newThread(i);
|
||||
public AsyncThread newThread(int i, int deviceNum) {
|
||||
AsyncThread at = super.newThread(i, deviceNum);
|
||||
at.setHistoryProcessor(hpconf);
|
||||
return at;
|
||||
}
|
||||
|
|
|
@ -57,8 +57,8 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
|||
final private Random random;
|
||||
|
||||
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
|
||||
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager) {
|
||||
super(asyncGlobal, threadNumber);
|
||||
A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager, int deviceNum) {
|
||||
super(asyncGlobal, threadNumber, deviceNum);
|
||||
this.conf = a3cc;
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.threadNumber = threadNumber;
|
||||
|
|
|
@ -55,9 +55,9 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
|||
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||
}
|
||||
|
||||
|
||||
public AsyncThread newThread(int i) {
|
||||
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager);
|
||||
@Override
|
||||
public AsyncThread newThread(int i, int deviceNum) {
|
||||
return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager, deviceNum);
|
||||
}
|
||||
|
||||
public IDQN getNeuralNet() {
|
||||
|
|
|
@ -53,8 +53,8 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
|
|||
}
|
||||
|
||||
@Override
|
||||
public AsyncThread newThread(int i) {
|
||||
AsyncThread at = super.newThread(i);
|
||||
public AsyncThread newThread(int i, int deviceNum) {
|
||||
AsyncThread at = super.newThread(i, deviceNum);
|
||||
at.setHistoryProcessor(hpconf);
|
||||
return at;
|
||||
}
|
||||
|
|
|
@ -56,8 +56,8 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
|||
|
||||
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
||||
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber,
|
||||
IDataManager dataManager) {
|
||||
super(asyncGlobal, threadNumber);
|
||||
IDataManager dataManager, int deviceNum) {
|
||||
super(asyncGlobal, threadNumber, deviceNum);
|
||||
this.conf = conf;
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.threadNumber = threadNumber;
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.sync;
|
||||
|
||||
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
|
||||
import it.unimi.dsi.fastutil.ints.IntSet;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
||||
|
||||
|
@ -50,7 +52,18 @@ public class ExpReplay<A> implements IExpReplay<A> {
|
|||
ArrayList<Transition<A>> batch = new ArrayList<>(size);
|
||||
int storageSize = storage.size();
|
||||
int actualBatchSize = Math.min(storageSize, size);
|
||||
int[] actualIndex = ThreadLocalRandom.current().ints(0, storageSize).distinct().limit(actualBatchSize).toArray();
|
||||
|
||||
int[] actualIndex = new int[actualBatchSize];
|
||||
ThreadLocalRandom r = ThreadLocalRandom.current();
|
||||
IntSet set = new IntOpenHashSet();
|
||||
for( int i=0; i<actualBatchSize; i++ ){
|
||||
int next = r.nextInt(storageSize);
|
||||
while(set.contains(next)){
|
||||
next = r.nextInt(storageSize);
|
||||
}
|
||||
set.add(next);
|
||||
actualIndex[i] = next;
|
||||
}
|
||||
|
||||
for (int i = 0; i < actualBatchSize; i ++) {
|
||||
Transition<A> trans = storage.get(actualIndex[i]);
|
||||
|
|
|
@ -50,7 +50,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
|||
// final private IExpReplay<A> expReplay;
|
||||
@Getter
|
||||
@Setter(AccessLevel.PACKAGE)
|
||||
private IExpReplay<A> expReplay;
|
||||
protected IExpReplay<A> expReplay;
|
||||
|
||||
public QLearning(QLConfiguration conf) {
|
||||
super(conf);
|
||||
|
|
|
@ -194,7 +194,7 @@ public class AsyncThreadTest {
|
|||
private final IDataManager dataManager;
|
||||
|
||||
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) {
|
||||
super(asyncGlobal, threadNumber);
|
||||
super(asyncGlobal, threadNumber, 0);
|
||||
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.neuralNet = neuralNet;
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
|
@ -138,5 +139,10 @@ public class QLearningDiscreteTest {
|
|||
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
|
||||
return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
|
||||
}
|
||||
|
||||
public void setExpReplay(IExpReplay<Integer> exp){
|
||||
this.expReplay = exp;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue