parent
3ba616a161
commit
e9b72e78ae
|
@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed;
|
|||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.nd4j.config.ND4JEnvironmentVars;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -276,9 +277,12 @@ public class VoidParameterServer {
|
|||
processingThreads = new Thread[numThreads];
|
||||
processingRunnables = new Runnable[numThreads];
|
||||
|
||||
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||
|
||||
for (int x = 0; x < numThreads; x++) {
|
||||
processingThreads[x] = new Thread(() -> {
|
||||
runner.set(true);
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
|
||||
while (runner.get()) {
|
||||
try {
|
||||
//VoidMessage message = transport.takeMessage();
|
||||
|
@ -296,11 +300,6 @@ public class VoidParameterServer {
|
|||
}
|
||||
});
|
||||
|
||||
//executor.submit(processingRunnables[x);
|
||||
|
||||
// TODO: maybe find the way to guarantee affinity in some other way, to make different devices usable as well?
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(processingThreads[x],
|
||||
Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
||||
processingThreads[x].setDaemon(true);
|
||||
processingThreads[x].setName("VoidParameterServer messages handling thread");
|
||||
processingThreads[x].start();
|
||||
|
|
|
@ -25,6 +25,7 @@ import io.aeron.logbuffer.Header;
|
|||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.agrona.CloseHelper;
|
||||
import org.agrona.DirectBuffer;
|
||||
import org.agrona.concurrent.IdleStrategy;
|
||||
|
@ -349,6 +350,8 @@ public abstract class BaseTransport implements Transport {
|
|||
* only because we want code to be obvious for people
|
||||
*/
|
||||
final AtomicBoolean localRunner = new AtomicBoolean(false);
|
||||
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||
|
||||
if (nodeRole == NodeRole.NONE) {
|
||||
throw new ND4JIllegalStateException("No role is set for current node!");
|
||||
} else if (nodeRole == NodeRole.SHARD || nodeRole == NodeRole.BACKUP || nodeRole == NodeRole.MASTER) {
|
||||
|
@ -357,6 +360,7 @@ public abstract class BaseTransport implements Transport {
|
|||
// setting up thread for shard->client communication listener
|
||||
if (messageHandlerForShards != null) {
|
||||
threadB = new Thread(() -> {
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
|
||||
while (runner.get())
|
||||
idler.idle(subscriptionForShards.poll(messageHandlerForShards, 512));
|
||||
|
||||
|
@ -368,13 +372,13 @@ public abstract class BaseTransport implements Transport {
|
|||
// setting up thread for inter-shard communication listener
|
||||
threadA = new Thread(() -> {
|
||||
localRunner.set(true);
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
|
||||
while (runner.get())
|
||||
idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512));
|
||||
});
|
||||
|
||||
if (threadB != null) {
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(threadB,
|
||||
Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
||||
|
||||
threadB.setDaemon(true);
|
||||
threadB.setName("VoidParamServer subscription threadB [" + nodeRole + "]");
|
||||
threadB.start();
|
||||
|
@ -389,8 +393,7 @@ public abstract class BaseTransport implements Transport {
|
|||
}
|
||||
|
||||
// all roles have threadA anyway
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(threadA,
|
||||
Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
||||
//Nd4j.getAffinityManager().attachThreadToDevice(threadA, Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
||||
threadA.setDaemon(true);
|
||||
threadA.setName("VoidParamServer subscription threadA [" + nodeRole + "]");
|
||||
threadA.start();
|
||||
|
|
|
@ -150,11 +150,19 @@ public class AeronUdpTransport extends BaseTransport implements AutoCloseable {
|
|||
// this executor service han
|
||||
protected ExecutorService messagesExecutorService = Executors.newFixedThreadPool(SENDER_THREADS + MESSAGE_THREADS + SUBSCRIPTION_THREADS, new ThreadFactory() {
|
||||
@Override
|
||||
public Thread newThread(@NonNull Runnable r) {
|
||||
val t = Executors.defaultThreadFactory().newThread(r);
|
||||
public Thread newThread(@NonNull final Runnable r) {
|
||||
val t = new Thread(new Runnable() {
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
//TODO implement support for multi-GPU masters
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(0); //Associate thread with device 0 (no-op for CPU)
|
||||
r.run();
|
||||
}
|
||||
});
|
||||
|
||||
t.setDaemon(true);
|
||||
//TODO implement support for multi-GPU masters
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(t, 0); //Associate thread with device 0 (no-op for CPU)
|
||||
t.setName("MessagesExecutorService thread");
|
||||
return t;
|
||||
}
|
||||
});
|
||||
|
|
|
@ -107,10 +107,16 @@ public abstract class BaseTransport implements Transport {
|
|||
|
||||
protected final ThreadPoolExecutor executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(Math.max(2, Runtime.getRuntime().availableProcessors()), new ThreadFactory() {
|
||||
@Override
|
||||
public Thread newThread(@NonNull Runnable r) {
|
||||
val t = Executors.defaultThreadFactory().newThread(r);
|
||||
public Thread newThread(@NonNull final Runnable r) {
|
||||
val t = new Thread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(0);
|
||||
r.run();
|
||||
}
|
||||
});
|
||||
|
||||
t.setDaemon(true);
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(t, 0);
|
||||
return t;
|
||||
}
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue