diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/VoidParameterServer.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/VoidParameterServer.java index 751fa79c3..26e46949f 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/VoidParameterServer.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/VoidParameterServer.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed; import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.nd4j.config.ND4JEnvironmentVars; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.api.ndarray.INDArray; @@ -276,9 +277,12 @@ public class VoidParameterServer { processingThreads = new Thread[numThreads]; processingRunnables = new Runnable[numThreads]; + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + for (int x = 0; x < numThreads; x++) { processingThreads[x] = new Thread(() -> { runner.set(true); + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); while (runner.get()) { try { //VoidMessage message = transport.takeMessage(); @@ -296,11 +300,6 @@ public class VoidParameterServer { } }); - //executor.submit(processingRunnables[x); - - // TODO: maybe find the way to guarantee affinity in some other way, to make different devices usable as well? - Nd4j.getAffinityManager().attachThreadToDevice(processingThreads[x], - Nd4j.getAffinityManager().getDeviceForCurrentThread()); processingThreads[x].setDaemon(true); processingThreads[x].setName("VoidParameterServer messages handling thread"); processingThreads[x].start(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java index a1de47a15..dbd6545d6 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java @@ -25,6 +25,7 @@ import io.aeron.logbuffer.Header; import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.agrona.CloseHelper; import org.agrona.DirectBuffer; import org.agrona.concurrent.IdleStrategy; @@ -349,6 +350,8 @@ public abstract class BaseTransport implements Transport { * only because we want code to be obvious for people */ final AtomicBoolean localRunner = new AtomicBoolean(false); + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + if (nodeRole == NodeRole.NONE) { throw new ND4JIllegalStateException("No role is set for current node!"); } else if (nodeRole == NodeRole.SHARD || nodeRole == NodeRole.BACKUP || nodeRole == NodeRole.MASTER) { @@ -357,6 +360,7 @@ public abstract class BaseTransport implements Transport { // setting up thread for shard->client communication listener if (messageHandlerForShards != null) { threadB = new Thread(() -> { + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); while (runner.get()) idler.idle(subscriptionForShards.poll(messageHandlerForShards, 512)); @@ -368,13 +372,13 @@ public abstract class BaseTransport implements Transport { // setting up thread for inter-shard communication listener threadA = new Thread(() -> { localRunner.set(true); + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); while (runner.get()) idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512)); }); if (threadB != null) { - Nd4j.getAffinityManager().attachThreadToDevice(threadB, - Nd4j.getAffinityManager().getDeviceForCurrentThread()); + threadB.setDaemon(true); threadB.setName("VoidParamServer subscription threadB [" + nodeRole + "]"); threadB.start(); @@ -389,8 +393,7 @@ public abstract class BaseTransport implements Transport { } // all roles have threadA anyway - Nd4j.getAffinityManager().attachThreadToDevice(threadA, - Nd4j.getAffinityManager().getDeviceForCurrentThread()); + //Nd4j.getAffinityManager().attachThreadToDevice(threadA, Nd4j.getAffinityManager().getDeviceForCurrentThread()); threadA.setDaemon(true); threadA.setName("VoidParamServer subscription threadA [" + nodeRole + "]"); threadA.start(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java index b609d7378..a34720093 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java @@ -150,11 +150,19 @@ public class AeronUdpTransport extends BaseTransport implements AutoCloseable { // this executor service han protected ExecutorService messagesExecutorService = Executors.newFixedThreadPool(SENDER_THREADS + MESSAGE_THREADS + SUBSCRIPTION_THREADS, new ThreadFactory() { @Override - public Thread newThread(@NonNull Runnable r) { - val t = Executors.defaultThreadFactory().newThread(r); + public Thread newThread(@NonNull final Runnable r) { + val t = new Thread(new Runnable() { + + @Override + public void run() { + //TODO implement support for multi-GPU masters + Nd4j.getAffinityManager().unsafeSetDevice(0); //Associate thread with device 0 (no-op for CPU) + r.run(); + } + }); + t.setDaemon(true); - //TODO implement support for multi-GPU masters - Nd4j.getAffinityManager().attachThreadToDevice(t, 0); //Associate thread with device 0 (no-op for CPU) + t.setName("MessagesExecutorService thread"); return t; } }); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java index 07d8ef096..34bceb6c2 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java @@ -107,10 +107,16 @@ public abstract class BaseTransport implements Transport { protected final ThreadPoolExecutor executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(Math.max(2, Runtime.getRuntime().availableProcessors()), new ThreadFactory() { @Override - public Thread newThread(@NonNull Runnable r) { - val t = Executors.defaultThreadFactory().newThread(r); + public Thread newThread(@NonNull final Runnable r) { + val t = new Thread(new Runnable() { + @Override + public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(0); + r.run(); + } + }); + t.setDaemon(true); - Nd4j.getAffinityManager().attachThreadToDevice(t, 0); return t; } });