[WIP] build fix (#124)

* AffinityManager changes

Signed-off-by: raver119 <raver119@gmail.com>

* build fixes

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-16 08:14:18 +03:00 committed by GitHub
parent 65ff18383a
commit 2f3d7330ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 162 additions and 77 deletions

View File

@ -380,19 +380,23 @@ public class VPTree implements Serializable {
private Node buildFromPoints(INDArray items) { private Node buildFromPoints(INDArray items) {
if (executorService == null && items == this.items && workers > 1) { if (executorService == null && items == this.items && workers > 1) {
final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() { executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() {
@Override @Override
public Thread newThread(Runnable r) { public Thread newThread(final Runnable r) {
Thread t = Executors.defaultThreadFactory().newThread(r); Thread t = new Thread(new Runnable() {
@Override
public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
r.run();
}
});
t.setDaemon(true); t.setDaemon(true);
t.setName("VPTree thread"); t.setName("VPTree thread");
// we don't want threads to be working on different devices
Nd4j.getAffinityManager().attachThreadToDevice(t,
Nd4j.getAffinityManager().getDeviceForCurrentThread());
return t; return t;
} }
}); });

View File

@ -132,9 +132,8 @@ public class ParallelInference {
boolean cRoot = !assignedRoot.get() && cDevice == currentDevice; boolean cRoot = !assignedRoot.get() && cDevice == currentDevice;
assignedRoot.compareAndSet(false, cRoot); assignedRoot.compareAndSet(false, cRoot);
zoo[i] = new InferenceWorker(i, model, observables, cRoot); zoo[i] = new InferenceWorker(i, model, observables, cRoot, cDevice);
Nd4j.getAffinityManager().attachThreadToDevice(zoo[i], cDevice);
zoo[i].setDaemon(true); zoo[i].setDaemon(true);
zoo[i].start(); zoo[i].start();
} }
@ -425,13 +424,15 @@ public class ParallelInference {
private Model replicatedModel; private Model replicatedModel;
private AtomicLong counter = new AtomicLong(0); private AtomicLong counter = new AtomicLong(0);
private boolean rootDevice; private boolean rootDevice;
private int deviceId;
private ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock(); private ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock();
private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice) { private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice, int deviceId) {
this.inputQueue = inputQueue; this.inputQueue = inputQueue;
this.protoModel = model; this.protoModel = model;
this.rootDevice = rootDevice; this.rootDevice = rootDevice;
this.deviceId = deviceId;
this.setDaemon(true); this.setDaemon(true);
this.setName("InferenceThread-" + id); this.setName("InferenceThread-" + id);
@ -491,6 +492,7 @@ public class ParallelInference {
@Override @Override
public void run() { public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
try { try {
// model should be replicated & initialized here // model should be replicated & initialized here
initializeReplicaModel(); initializeReplicaModel();

View File

@ -151,18 +151,21 @@ public class ParallelWrapper implements AutoCloseable {
workerCounter.set(0); workerCounter.set(0);
this.executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(workers, new ThreadFactory() { this.executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(workers, new ThreadFactory() {
@Override @Override
public Thread newThread(@NonNull Runnable r) { public Thread newThread(@NonNull final Runnable r) {
Thread t = Executors.defaultThreadFactory().newThread(r); final int cThread = workerCounter.getAndIncrement();
int cThread = workerCounter.getAndIncrement(); Thread t = new Thread(new Runnable() {
@Override
public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(cThread % Nd4j.getAffinityManager().getNumberOfDevices());
r.run();
}
});
t.setName("ParallelWrapper training thread " + cThread); t.setName("ParallelWrapper training thread " + cThread);
t.setDaemon(true); t.setDaemon(true);
t.setUncaughtExceptionHandler(handler); t.setUncaughtExceptionHandler(handler);
Nd4j.getAffinityManager().attachThreadToDevice(t,
cThread % Nd4j.getAffinityManager().getNumberOfDevices());
return t; return t;
} }
}); });

View File

@ -108,7 +108,7 @@ public class EvaluationRunner {
INDArray p; INDArray p;
try{ try{
p = Nd4j.read(new ByteArrayInputStream(pBytes)); p = Nd4j.read(new ByteArrayInputStream(pBytes));
} catch (IOException e){ } catch (RuntimeException e){
throw new RuntimeException(e); //Should never happen throw new RuntimeException(e); //Should never happen
} }
DeviceLocalNDArray dlp = new DeviceLocalNDArray(p); DeviceLocalNDArray dlp = new DeviceLocalNDArray(p);

View File

@ -97,13 +97,12 @@ public class SparkADSI extends AsyncDataSetIterator {
context = TaskContext.get(); context = TaskContext.get();
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null); this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null, Nd4j.getAffinityManager().getDeviceForCurrentThread());
/** /**
* We want to ensure, that background thread will have the same thread->device affinity, as master thread * We want to ensure, that background thread will have the same thread->device affinity, as master thread
*/ */
Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
thread.setDaemon(true); thread.setDaemon(true);
thread.start(); thread.start();
} }
@ -116,9 +115,8 @@ public class SparkADSI extends AsyncDataSetIterator {
public class SparkPrefetchThread extends AsyncPrefetchThread { public class SparkPrefetchThread extends AsyncPrefetchThread {
protected SparkPrefetchThread(BlockingQueue<DataSet> queue, DataSetIterator iterator, DataSet terminator, protected SparkPrefetchThread(BlockingQueue<DataSet> queue, DataSetIterator iterator, DataSet terminator, MemoryWorkspace workspace, int deviceId) {
MemoryWorkspace workspace) { super(queue, iterator, terminator, workspace, deviceId);
super(queue, iterator, terminator, workspace);
} }

View File

@ -97,15 +97,10 @@ public class SparkAMDSI extends AsyncMultiDataSetIterator {
if (iterator.resetSupported()) if (iterator.resetSupported())
this.backedIterator.reset(); this.backedIterator.reset();
this.thread = new SparkPrefetchThread(buffer, iterator, terminator); this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread());
context = TaskContext.get(); context = TaskContext.get();
/**
* We want to ensure, that background thread will have the same thread->device affinity, as master thread
*/
Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
thread.setDaemon(true); thread.setDaemon(true);
thread.start(); thread.start();
} }
@ -117,9 +112,8 @@ public class SparkAMDSI extends AsyncMultiDataSetIterator {
protected class SparkPrefetchThread extends AsyncPrefetchThread { protected class SparkPrefetchThread extends AsyncPrefetchThread {
protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue, protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue, @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator, int deviceId) {
@NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator) { super(queue, iterator, terminator, deviceId);
super(queue, iterator, terminator);
} }
} }
} }

View File

@ -34,6 +34,12 @@ public interface AffinityManager {
*/ */
Integer getDeviceForCurrentThread(); Integer getDeviceForCurrentThread();
/**
* This method returns deviceId for a given thread
* @return
*/
Integer getDeviceForThread(long threadId);
/** /**
* This method returns id of current device for a given INDArray * This method returns id of current device for a given INDArray

View File

@ -28,6 +28,11 @@ public abstract class BasicAffinityManager implements AffinityManager {
return 0; return 0;
} }
@Override
public Integer getDeviceForThread(long threadId) {
return 0;
}
@Override @Override
public Integer getDeviceForArray(INDArray array) { public Integer getDeviceForArray(INDArray array) {
return 0; return 0;

View File

@ -68,7 +68,24 @@ public class CudaAffinityManager extends BasicAffinityManager {
*/ */
@Override @Override
public Integer getDeviceForCurrentThread() { public Integer getDeviceForCurrentThread() {
return NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice(); val id = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
if (!affinityMap.containsKey(Thread.currentThread().getId()))
affinityMap.put(Thread.currentThread().getId(), id);
return id;
}
/**
* This method returns deviceId for a given thread
* @return
*/
@Override
public Integer getDeviceForThread(long threadId) {
val id = affinityMap.get(threadId);
if (id == null)
throw new RuntimeException("Affinity for thread [" + threadId + "] wasn't defined yet");
return id;
} }

View File

@ -11131,7 +11131,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// REGISTER_C(NAME) // REGISTER_C(NAME)
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { // nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
// auto shapeList = SHAPELIST(); // auto shapeList = SHAPELIST();
// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { // for (int e = 0; e < block.width(); e++) {
// auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); // auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e)));
// shapeList->push_back(newshape); // shapeList->push_back(newshape);
// } // }
@ -11168,7 +11168,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// REGISTER_C(NAME) // REGISTER_C(NAME)
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { // nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
// auto shapeList = SHAPELIST(); // auto shapeList = SHAPELIST();
// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { // for (int e = 0; e < block.width(); e++) {
// Nd4jLong* newshape; // Nd4jLong* newshape;
// COPY_SHAPE(inputShape->at(0), newshape); // COPY_SHAPE(inputShape->at(0), newshape);
// shapeList->push_back(CONSTANT(newshape)); // shapeList->push_back(CONSTANT(newshape));
@ -11191,7 +11191,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// REGISTER_C(NAME) // REGISTER_C(NAME)
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { // nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
// auto shapeList = SHAPELIST(); // auto shapeList = SHAPELIST();
// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { // for (int e = 0; e < block.width(); e++) {
// auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); // auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e)));
// shapeList->push_back(newshape); // shapeList->push_back(newshape);
// } // }
@ -16282,9 +16282,15 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
// #endif // #endif
/*
* Complete tensor with max indices merged from all input tensors list
*
* INPUT: tensors with the same shape
* OUTPUT: integer tensor with the same shape
* INT_ARG: result type (one of int), INT32 by default
*/
// #if NOT_EXCLUDED(OP_mergemaxindex) // #if NOT_EXCLUDED(OP_mergemaxindex)
@Namespace("nd4j::ops") public static class mergemaxindex extends DeclarableOp { @Namespace("nd4j::ops") public static class mergemaxindex extends DeclarableCustomOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public mergemaxindex(Pointer p) { super(p); } public mergemaxindex(Pointer p) { super(p); }
@ -21746,7 +21752,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/** /**
* This operation shift individual bits of each element in array * This operation shift individual bits of each element in array to the left: <<
* *
* PLEASE NOTE: This operation is applicable only to integer data types * PLEASE NOTE: This operation is applicable only to integer data types
* *
@ -21771,7 +21777,32 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #endif // #endif
/** /**
* This operation shift individual bits of each element in array * This operation shift individual bits of each element in array to the right: >>
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_rshift_bits)
@Namespace("nd4j::ops") public static class rshift_bits extends DeclarableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public rshift_bits(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public rshift_bits(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public rshift_bits position(long position) {
return (rshift_bits)super.position(position);
}
public rshift_bits() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
* This operation shift individual bits of each element in array, shifting to the left
* *
* PLEASE NOTE: This operation is applicable only to integer data types * PLEASE NOTE: This operation is applicable only to integer data types
* *
@ -21795,6 +21826,31 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/**
* This operation shift individual bits of each element in array, shifting to the right
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_cyclic_rshift_bits)
@Namespace("nd4j::ops") public static class cyclic_rshift_bits extends DeclarableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public cyclic_rshift_bits(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public cyclic_rshift_bits(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public cyclic_rshift_bits position(long position) {
return (cyclic_rshift_bits)super.position(position);
}
public cyclic_rshift_bits() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
// #endif // #endif
@ -22545,7 +22601,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases. * PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/ */
// #if NOT_EXCLUDED(OP_to_double) // #if NOT_EXCLUDED(OP_to_double)
@Namespace("nd4j::ops") public static class to_double extends DeclarableOp { @Namespace("nd4j::ops") public static class to_double extends DeclarableCustomOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_double(Pointer p) { super(p); } public to_double(Pointer p) { super(p); }
@ -22568,7 +22624,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases. * PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/ */
// #if NOT_EXCLUDED(OP_to_float16) // #if NOT_EXCLUDED(OP_to_float16)
@Namespace("nd4j::ops") public static class to_float16 extends DeclarableOp { @Namespace("nd4j::ops") public static class to_float16 extends DeclarableCustomOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_float16(Pointer p) { super(p); } public to_float16(Pointer p) { super(p); }
@ -22591,7 +22647,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases. * PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/ */
// #if NOT_EXCLUDED(OP_to_float32) // #if NOT_EXCLUDED(OP_to_float32)
@Namespace("nd4j::ops") public static class to_float32 extends DeclarableOp { @Namespace("nd4j::ops") public static class to_float32 extends DeclarableCustomOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_float32(Pointer p) { super(p); } public to_float32(Pointer p) { super(p); }
@ -22614,7 +22670,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases. * PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/ */
// #if NOT_EXCLUDED(OP_to_int32) // #if NOT_EXCLUDED(OP_to_int32)
@Namespace("nd4j::ops") public static class to_int32 extends DeclarableOp { @Namespace("nd4j::ops") public static class to_int32 extends DeclarableCustomOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_int32(Pointer p) { super(p); } public to_int32(Pointer p) { super(p); }
@ -22637,7 +22693,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases. * PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/ */
// #if NOT_EXCLUDED(OP_to_int64) // #if NOT_EXCLUDED(OP_to_int64)
@Namespace("nd4j::ops") public static class to_int64 extends DeclarableOp { @Namespace("nd4j::ops") public static class to_int64 extends DeclarableCustomOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_int64(Pointer p) { super(p); } public to_int64(Pointer p) { super(p); }
@ -22660,7 +22716,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases. * PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/ */
// #if NOT_EXCLUDED(OP_to_uint32) // #if NOT_EXCLUDED(OP_to_uint32)
@Namespace("nd4j::ops") public static class to_uint32 extends DeclarableOp { @Namespace("nd4j::ops") public static class to_uint32 extends DeclarableCustomOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_uint32(Pointer p) { super(p); } public to_uint32(Pointer p) { super(p); }
@ -22683,7 +22739,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases. * PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/ */
// #if NOT_EXCLUDED(OP_to_uint64) // #if NOT_EXCLUDED(OP_to_uint64)
@Namespace("nd4j::ops") public static class to_uint64 extends DeclarableOp { @Namespace("nd4j::ops") public static class to_uint64 extends DeclarableCustomOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_uint64(Pointer p) { super(p); } public to_uint64(Pointer p) { super(p); }