parent
3ba616a161
commit
e9b72e78ae
|
@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
import org.nd4j.config.ND4JEnvironmentVars;
|
import org.nd4j.config.ND4JEnvironmentVars;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -276,9 +277,12 @@ public class VoidParameterServer {
|
||||||
processingThreads = new Thread[numThreads];
|
processingThreads = new Thread[numThreads];
|
||||||
processingRunnables = new Runnable[numThreads];
|
processingRunnables = new Runnable[numThreads];
|
||||||
|
|
||||||
|
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||||
|
|
||||||
for (int x = 0; x < numThreads; x++) {
|
for (int x = 0; x < numThreads; x++) {
|
||||||
processingThreads[x] = new Thread(() -> {
|
processingThreads[x] = new Thread(() -> {
|
||||||
runner.set(true);
|
runner.set(true);
|
||||||
|
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
|
||||||
while (runner.get()) {
|
while (runner.get()) {
|
||||||
try {
|
try {
|
||||||
//VoidMessage message = transport.takeMessage();
|
//VoidMessage message = transport.takeMessage();
|
||||||
|
@ -296,11 +300,6 @@ public class VoidParameterServer {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
//executor.submit(processingRunnables[x);
|
|
||||||
|
|
||||||
// TODO: maybe find the way to guarantee affinity in some other way, to make different devices usable as well?
|
|
||||||
Nd4j.getAffinityManager().attachThreadToDevice(processingThreads[x],
|
|
||||||
Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
|
||||||
processingThreads[x].setDaemon(true);
|
processingThreads[x].setDaemon(true);
|
||||||
processingThreads[x].setName("VoidParameterServer messages handling thread");
|
processingThreads[x].setName("VoidParameterServer messages handling thread");
|
||||||
processingThreads[x].start();
|
processingThreads[x].start();
|
||||||
|
|
|
@ -25,6 +25,7 @@ import io.aeron.logbuffer.Header;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
import org.agrona.CloseHelper;
|
import org.agrona.CloseHelper;
|
||||||
import org.agrona.DirectBuffer;
|
import org.agrona.DirectBuffer;
|
||||||
import org.agrona.concurrent.IdleStrategy;
|
import org.agrona.concurrent.IdleStrategy;
|
||||||
|
@ -349,6 +350,8 @@ public abstract class BaseTransport implements Transport {
|
||||||
* only because we want code to be obvious for people
|
* only because we want code to be obvious for people
|
||||||
*/
|
*/
|
||||||
final AtomicBoolean localRunner = new AtomicBoolean(false);
|
final AtomicBoolean localRunner = new AtomicBoolean(false);
|
||||||
|
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||||
|
|
||||||
if (nodeRole == NodeRole.NONE) {
|
if (nodeRole == NodeRole.NONE) {
|
||||||
throw new ND4JIllegalStateException("No role is set for current node!");
|
throw new ND4JIllegalStateException("No role is set for current node!");
|
||||||
} else if (nodeRole == NodeRole.SHARD || nodeRole == NodeRole.BACKUP || nodeRole == NodeRole.MASTER) {
|
} else if (nodeRole == NodeRole.SHARD || nodeRole == NodeRole.BACKUP || nodeRole == NodeRole.MASTER) {
|
||||||
|
@ -357,6 +360,7 @@ public abstract class BaseTransport implements Transport {
|
||||||
// setting up thread for shard->client communication listener
|
// setting up thread for shard->client communication listener
|
||||||
if (messageHandlerForShards != null) {
|
if (messageHandlerForShards != null) {
|
||||||
threadB = new Thread(() -> {
|
threadB = new Thread(() -> {
|
||||||
|
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
|
||||||
while (runner.get())
|
while (runner.get())
|
||||||
idler.idle(subscriptionForShards.poll(messageHandlerForShards, 512));
|
idler.idle(subscriptionForShards.poll(messageHandlerForShards, 512));
|
||||||
|
|
||||||
|
@ -368,13 +372,13 @@ public abstract class BaseTransport implements Transport {
|
||||||
// setting up thread for inter-shard communication listener
|
// setting up thread for inter-shard communication listener
|
||||||
threadA = new Thread(() -> {
|
threadA = new Thread(() -> {
|
||||||
localRunner.set(true);
|
localRunner.set(true);
|
||||||
|
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
|
||||||
while (runner.get())
|
while (runner.get())
|
||||||
idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512));
|
idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512));
|
||||||
});
|
});
|
||||||
|
|
||||||
if (threadB != null) {
|
if (threadB != null) {
|
||||||
Nd4j.getAffinityManager().attachThreadToDevice(threadB,
|
|
||||||
Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
|
||||||
threadB.setDaemon(true);
|
threadB.setDaemon(true);
|
||||||
threadB.setName("VoidParamServer subscription threadB [" + nodeRole + "]");
|
threadB.setName("VoidParamServer subscription threadB [" + nodeRole + "]");
|
||||||
threadB.start();
|
threadB.start();
|
||||||
|
@ -389,8 +393,7 @@ public abstract class BaseTransport implements Transport {
|
||||||
}
|
}
|
||||||
|
|
||||||
// all roles have threadA anyway
|
// all roles have threadA anyway
|
||||||
Nd4j.getAffinityManager().attachThreadToDevice(threadA,
|
//Nd4j.getAffinityManager().attachThreadToDevice(threadA, Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
||||||
Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
|
||||||
threadA.setDaemon(true);
|
threadA.setDaemon(true);
|
||||||
threadA.setName("VoidParamServer subscription threadA [" + nodeRole + "]");
|
threadA.setName("VoidParamServer subscription threadA [" + nodeRole + "]");
|
||||||
threadA.start();
|
threadA.start();
|
||||||
|
|
|
@ -150,11 +150,19 @@ public class AeronUdpTransport extends BaseTransport implements AutoCloseable {
|
||||||
// this executor service han
|
// this executor service han
|
||||||
protected ExecutorService messagesExecutorService = Executors.newFixedThreadPool(SENDER_THREADS + MESSAGE_THREADS + SUBSCRIPTION_THREADS, new ThreadFactory() {
|
protected ExecutorService messagesExecutorService = Executors.newFixedThreadPool(SENDER_THREADS + MESSAGE_THREADS + SUBSCRIPTION_THREADS, new ThreadFactory() {
|
||||||
@Override
|
@Override
|
||||||
public Thread newThread(@NonNull Runnable r) {
|
public Thread newThread(@NonNull final Runnable r) {
|
||||||
val t = Executors.defaultThreadFactory().newThread(r);
|
val t = new Thread(new Runnable() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
//TODO implement support for multi-GPU masters
|
||||||
|
Nd4j.getAffinityManager().unsafeSetDevice(0); //Associate thread with device 0 (no-op for CPU)
|
||||||
|
r.run();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
t.setDaemon(true);
|
t.setDaemon(true);
|
||||||
//TODO implement support for multi-GPU masters
|
t.setName("MessagesExecutorService thread");
|
||||||
Nd4j.getAffinityManager().attachThreadToDevice(t, 0); //Associate thread with device 0 (no-op for CPU)
|
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -107,10 +107,16 @@ public abstract class BaseTransport implements Transport {
|
||||||
|
|
||||||
protected final ThreadPoolExecutor executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(Math.max(2, Runtime.getRuntime().availableProcessors()), new ThreadFactory() {
|
protected final ThreadPoolExecutor executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(Math.max(2, Runtime.getRuntime().availableProcessors()), new ThreadFactory() {
|
||||||
@Override
|
@Override
|
||||||
public Thread newThread(@NonNull Runnable r) {
|
public Thread newThread(@NonNull final Runnable r) {
|
||||||
val t = Executors.defaultThreadFactory().newThread(r);
|
val t = new Thread(new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
Nd4j.getAffinityManager().unsafeSetDevice(0);
|
||||||
|
r.run();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
t.setDaemon(true);
|
t.setDaemon(true);
|
||||||
Nd4j.getAffinityManager().attachThreadToDevice(t, 0);
|
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue