one more build fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-16 15:02:47 +03:00
parent 3ba616a161
commit e9b72e78ae
4 changed files with 32 additions and 16 deletions

View File

@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.config.ND4JEnvironmentVars;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -276,9 +277,12 @@ public class VoidParameterServer {
processingThreads = new Thread[numThreads];
processingRunnables = new Runnable[numThreads];
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
for (int x = 0; x < numThreads; x++) {
processingThreads[x] = new Thread(() -> {
runner.set(true);
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
while (runner.get()) {
try {
//VoidMessage message = transport.takeMessage();
@ -296,11 +300,6 @@ public class VoidParameterServer {
}
});
//executor.submit(processingRunnables[x);
// TODO: maybe find the way to guarantee affinity in some other way, to make different devices usable as well?
Nd4j.getAffinityManager().attachThreadToDevice(processingThreads[x],
Nd4j.getAffinityManager().getDeviceForCurrentThread());
processingThreads[x].setDaemon(true);
processingThreads[x].setName("VoidParameterServer messages handling thread");
processingThreads[x].start();

View File

@ -25,6 +25,7 @@ import io.aeron.logbuffer.Header;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.agrona.CloseHelper;
import org.agrona.DirectBuffer;
import org.agrona.concurrent.IdleStrategy;
@ -349,6 +350,8 @@ public abstract class BaseTransport implements Transport {
* only because we want code to be obvious for people
*/
final AtomicBoolean localRunner = new AtomicBoolean(false);
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
if (nodeRole == NodeRole.NONE) {
throw new ND4JIllegalStateException("No role is set for current node!");
} else if (nodeRole == NodeRole.SHARD || nodeRole == NodeRole.BACKUP || nodeRole == NodeRole.MASTER) {
@ -357,6 +360,7 @@ public abstract class BaseTransport implements Transport {
// setting up thread for shard->client communication listener
if (messageHandlerForShards != null) {
threadB = new Thread(() -> {
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
while (runner.get())
idler.idle(subscriptionForShards.poll(messageHandlerForShards, 512));
@ -368,13 +372,13 @@ public abstract class BaseTransport implements Transport {
// setting up thread for inter-shard communication listener
threadA = new Thread(() -> {
localRunner.set(true);
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
while (runner.get())
idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512));
});
if (threadB != null) {
Nd4j.getAffinityManager().attachThreadToDevice(threadB,
Nd4j.getAffinityManager().getDeviceForCurrentThread());
threadB.setDaemon(true);
threadB.setName("VoidParamServer subscription threadB [" + nodeRole + "]");
threadB.start();
@ -389,8 +393,7 @@ public abstract class BaseTransport implements Transport {
}
// all roles have threadA anyway
Nd4j.getAffinityManager().attachThreadToDevice(threadA,
Nd4j.getAffinityManager().getDeviceForCurrentThread());
//Nd4j.getAffinityManager().attachThreadToDevice(threadA, Nd4j.getAffinityManager().getDeviceForCurrentThread());
threadA.setDaemon(true);
threadA.setName("VoidParamServer subscription threadA [" + nodeRole + "]");
threadA.start();

View File

@ -150,11 +150,19 @@ public class AeronUdpTransport extends BaseTransport implements AutoCloseable {
// this executor service han
protected ExecutorService messagesExecutorService = Executors.newFixedThreadPool(SENDER_THREADS + MESSAGE_THREADS + SUBSCRIPTION_THREADS, new ThreadFactory() {
@Override
public Thread newThread(@NonNull Runnable r) {
val t = Executors.defaultThreadFactory().newThread(r);
t.setDaemon(true);
public Thread newThread(@NonNull final Runnable r) {
val t = new Thread(new Runnable() {
@Override
public void run() {
//TODO implement support for multi-GPU masters
Nd4j.getAffinityManager().attachThreadToDevice(t, 0); //Associate thread with device 0 (no-op for CPU)
Nd4j.getAffinityManager().unsafeSetDevice(0); //Associate thread with device 0 (no-op for CPU)
r.run();
}
});
t.setDaemon(true);
t.setName("MessagesExecutorService thread");
return t;
}
});

View File

@ -107,10 +107,16 @@ public abstract class BaseTransport implements Transport {
protected final ThreadPoolExecutor executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(Math.max(2, Runtime.getRuntime().availableProcessors()), new ThreadFactory() {
@Override
public Thread newThread(@NonNull Runnable r) {
val t = Executors.defaultThreadFactory().newThread(r);
public Thread newThread(@NonNull final Runnable r) {
val t = new Thread(new Runnable() {
@Override
public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(0);
r.run();
}
});
t.setDaemon(true);
Nd4j.getAffinityManager().attachThreadToDevice(t, 0);
return t;
}
});