one more build fix
Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									3ba616a161
								
							
						
					
					
						commit
						e9b72e78ae
					
				@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import lombok.val;
 | 
			
		||||
import org.nd4j.config.ND4JEnvironmentVars;
 | 
			
		||||
import org.nd4j.linalg.primitives.Pair;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
@ -276,9 +277,12 @@ public class VoidParameterServer {
 | 
			
		||||
                    processingThreads = new Thread[numThreads];
 | 
			
		||||
                    processingRunnables = new Runnable[numThreads];
 | 
			
		||||
 | 
			
		||||
                    val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
 | 
			
		||||
 | 
			
		||||
                    for (int x = 0; x < numThreads; x++) {
 | 
			
		||||
                        processingThreads[x] = new Thread(() -> {
 | 
			
		||||
                            runner.set(true);
 | 
			
		||||
                            Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
 | 
			
		||||
                            while (runner.get()) {
 | 
			
		||||
                                try {
 | 
			
		||||
                                    //VoidMessage message = transport.takeMessage();
 | 
			
		||||
@ -296,11 +300,6 @@ public class VoidParameterServer {
 | 
			
		||||
                            }
 | 
			
		||||
                        });
 | 
			
		||||
 | 
			
		||||
                        //executor.submit(processingRunnables[x);
 | 
			
		||||
 | 
			
		||||
                        // TODO: maybe find the way to guarantee affinity in some other way, to make different devices usable as well?
 | 
			
		||||
                        Nd4j.getAffinityManager().attachThreadToDevice(processingThreads[x],
 | 
			
		||||
                                        Nd4j.getAffinityManager().getDeviceForCurrentThread());
 | 
			
		||||
                        processingThreads[x].setDaemon(true);
 | 
			
		||||
                        processingThreads[x].setName("VoidParameterServer messages handling thread");
 | 
			
		||||
                        processingThreads[x].start();
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,7 @@ import io.aeron.logbuffer.Header;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import lombok.val;
 | 
			
		||||
import org.agrona.CloseHelper;
 | 
			
		||||
import org.agrona.DirectBuffer;
 | 
			
		||||
import org.agrona.concurrent.IdleStrategy;
 | 
			
		||||
@ -349,6 +350,8 @@ public abstract class BaseTransport implements Transport {
 | 
			
		||||
                 * only because we want code to be obvious for people
 | 
			
		||||
                 */
 | 
			
		||||
                final AtomicBoolean localRunner = new AtomicBoolean(false);
 | 
			
		||||
                val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
 | 
			
		||||
 | 
			
		||||
                if (nodeRole == NodeRole.NONE) {
 | 
			
		||||
                    throw new ND4JIllegalStateException("No role is set for current node!");
 | 
			
		||||
                } else if (nodeRole == NodeRole.SHARD || nodeRole == NodeRole.BACKUP || nodeRole == NodeRole.MASTER) {
 | 
			
		||||
@ -357,6 +360,7 @@ public abstract class BaseTransport implements Transport {
 | 
			
		||||
                    // setting up thread for shard->client communication listener
 | 
			
		||||
                    if (messageHandlerForShards != null) {
 | 
			
		||||
                        threadB = new Thread(() -> {
 | 
			
		||||
                            Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
 | 
			
		||||
                            while (runner.get())
 | 
			
		||||
                                idler.idle(subscriptionForShards.poll(messageHandlerForShards, 512));
 | 
			
		||||
 | 
			
		||||
@ -368,13 +372,13 @@ public abstract class BaseTransport implements Transport {
 | 
			
		||||
                    // setting up thread for inter-shard communication listener
 | 
			
		||||
                    threadA = new Thread(() -> {
 | 
			
		||||
                        localRunner.set(true);
 | 
			
		||||
                        Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
 | 
			
		||||
                        while (runner.get())
 | 
			
		||||
                            idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512));
 | 
			
		||||
                    });
 | 
			
		||||
 | 
			
		||||
                    if (threadB != null) {
 | 
			
		||||
                        Nd4j.getAffinityManager().attachThreadToDevice(threadB,
 | 
			
		||||
                                        Nd4j.getAffinityManager().getDeviceForCurrentThread());
 | 
			
		||||
 | 
			
		||||
                        threadB.setDaemon(true);
 | 
			
		||||
                        threadB.setName("VoidParamServer subscription threadB [" + nodeRole + "]");
 | 
			
		||||
                        threadB.start();
 | 
			
		||||
@ -389,8 +393,7 @@ public abstract class BaseTransport implements Transport {
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                // all roles have threadA anyway
 | 
			
		||||
                Nd4j.getAffinityManager().attachThreadToDevice(threadA,
 | 
			
		||||
                                Nd4j.getAffinityManager().getDeviceForCurrentThread());
 | 
			
		||||
                //Nd4j.getAffinityManager().attachThreadToDevice(threadA,                                Nd4j.getAffinityManager().getDeviceForCurrentThread());
 | 
			
		||||
                threadA.setDaemon(true);
 | 
			
		||||
                threadA.setName("VoidParamServer subscription threadA [" + nodeRole + "]");
 | 
			
		||||
                threadA.start();
 | 
			
		||||
 | 
			
		||||
@ -150,11 +150,19 @@ public class AeronUdpTransport extends BaseTransport implements AutoCloseable {
 | 
			
		||||
    // this executor service han
 | 
			
		||||
    protected ExecutorService messagesExecutorService = Executors.newFixedThreadPool(SENDER_THREADS + MESSAGE_THREADS + SUBSCRIPTION_THREADS, new ThreadFactory() {
 | 
			
		||||
        @Override
 | 
			
		||||
        public Thread newThread(@NonNull Runnable r) {
 | 
			
		||||
            val t = Executors.defaultThreadFactory().newThread(r);
 | 
			
		||||
        public Thread newThread(@NonNull final Runnable r) {
 | 
			
		||||
            val t = new Thread(new Runnable() {
 | 
			
		||||
 | 
			
		||||
                @Override
 | 
			
		||||
                public void run() {
 | 
			
		||||
                    //TODO implement support for multi-GPU masters
 | 
			
		||||
                    Nd4j.getAffinityManager().unsafeSetDevice(0);   //Associate thread with device 0 (no-op for CPU)
 | 
			
		||||
                    r.run();
 | 
			
		||||
                }
 | 
			
		||||
            });
 | 
			
		||||
 | 
			
		||||
            t.setDaemon(true);
 | 
			
		||||
            //TODO implement support for multi-GPU masters
 | 
			
		||||
            Nd4j.getAffinityManager().attachThreadToDevice(t, 0);   //Associate thread with device 0 (no-op for CPU)
 | 
			
		||||
            t.setName("MessagesExecutorService thread");
 | 
			
		||||
            return t;
 | 
			
		||||
        }
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
@ -107,10 +107,16 @@ public abstract  class BaseTransport  implements Transport {
 | 
			
		||||
 | 
			
		||||
    protected final ThreadPoolExecutor executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(Math.max(2, Runtime.getRuntime().availableProcessors()), new ThreadFactory() {
 | 
			
		||||
        @Override
 | 
			
		||||
        public Thread newThread(@NonNull Runnable r) {
 | 
			
		||||
            val t = Executors.defaultThreadFactory().newThread(r);
 | 
			
		||||
        public Thread newThread(@NonNull final Runnable r) {
 | 
			
		||||
            val t = new Thread(new Runnable() {
 | 
			
		||||
                @Override
 | 
			
		||||
                public void run() {
 | 
			
		||||
                    Nd4j.getAffinityManager().unsafeSetDevice(0);
 | 
			
		||||
                    r.run();
 | 
			
		||||
                }
 | 
			
		||||
            });
 | 
			
		||||
 | 
			
		||||
            t.setDaemon(true);
 | 
			
		||||
            Nd4j.getAffinityManager().attachThreadToDevice(t, 0);
 | 
			
		||||
            return t;
 | 
			
		||||
        }
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user