[WIP] build fix (#124)

* AffinityManager changes

Signed-off-by: raver119 <raver119@gmail.com>

* build fixes

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-16 08:14:18 +03:00 committed by GitHub
parent 65ff18383a
commit 2f3d7330ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 162 additions and 77 deletions

View File

@ -380,19 +380,23 @@ public class VPTree implements Serializable {
private Node buildFromPoints(INDArray items) {
if (executorService == null && items == this.items && workers > 1) {
final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread t = Executors.defaultThreadFactory().newThread(r);
public Thread newThread(final Runnable r) {
Thread t = new Thread(new Runnable() {
@Override
public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
r.run();
}
});
t.setDaemon(true);
t.setName("VPTree thread");
// we don't want threads to be working on different devices
Nd4j.getAffinityManager().attachThreadToDevice(t,
Nd4j.getAffinityManager().getDeviceForCurrentThread());
return t;
}
});

View File

@ -132,9 +132,8 @@ public class ParallelInference {
boolean cRoot = !assignedRoot.get() && cDevice == currentDevice;
assignedRoot.compareAndSet(false, cRoot);
zoo[i] = new InferenceWorker(i, model, observables, cRoot);
zoo[i] = new InferenceWorker(i, model, observables, cRoot, cDevice);
Nd4j.getAffinityManager().attachThreadToDevice(zoo[i], cDevice);
zoo[i].setDaemon(true);
zoo[i].start();
}
@ -425,13 +424,15 @@ public class ParallelInference {
private Model replicatedModel;
private AtomicLong counter = new AtomicLong(0);
private boolean rootDevice;
private int deviceId;
private ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock();
private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice) {
private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice, int deviceId) {
this.inputQueue = inputQueue;
this.protoModel = model;
this.rootDevice = rootDevice;
this.deviceId = deviceId;
this.setDaemon(true);
this.setName("InferenceThread-" + id);
@ -491,6 +492,7 @@ public class ParallelInference {
@Override
public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
try {
// model should be replicated & initialized here
initializeReplicaModel();

View File

@ -151,18 +151,21 @@ public class ParallelWrapper implements AutoCloseable {
workerCounter.set(0);
this.executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(workers, new ThreadFactory() {
@Override
public Thread newThread(@NonNull Runnable r) {
Thread t = Executors.defaultThreadFactory().newThread(r);
public Thread newThread(@NonNull final Runnable r) {
final int cThread = workerCounter.getAndIncrement();
int cThread = workerCounter.getAndIncrement();
Thread t = new Thread(new Runnable() {
@Override
public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(cThread % Nd4j.getAffinityManager().getNumberOfDevices());
r.run();
}
});
t.setName("ParallelWrapper training thread " + cThread);
t.setDaemon(true);
t.setUncaughtExceptionHandler(handler);
Nd4j.getAffinityManager().attachThreadToDevice(t,
cThread % Nd4j.getAffinityManager().getNumberOfDevices());
return t;
}
});

View File

@ -108,7 +108,7 @@ public class EvaluationRunner {
INDArray p;
try{
p = Nd4j.read(new ByteArrayInputStream(pBytes));
} catch (IOException e){
} catch (RuntimeException e){
throw new RuntimeException(e); //Should never happen
}
DeviceLocalNDArray dlp = new DeviceLocalNDArray(p);

View File

@ -97,13 +97,12 @@ public class SparkADSI extends AsyncDataSetIterator {
context = TaskContext.get();
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null);
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null, Nd4j.getAffinityManager().getDeviceForCurrentThread());
/**
* We want to ensure, that background thread will have the same thread->device affinity, as master thread
*/
Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
thread.setDaemon(true);
thread.start();
}
@ -116,9 +115,8 @@ public class SparkADSI extends AsyncDataSetIterator {
public class SparkPrefetchThread extends AsyncPrefetchThread {
protected SparkPrefetchThread(BlockingQueue<DataSet> queue, DataSetIterator iterator, DataSet terminator,
MemoryWorkspace workspace) {
super(queue, iterator, terminator, workspace);
protected SparkPrefetchThread(BlockingQueue<DataSet> queue, DataSetIterator iterator, DataSet terminator, MemoryWorkspace workspace, int deviceId) {
super(queue, iterator, terminator, workspace, deviceId);
}

View File

@ -97,15 +97,10 @@ public class SparkAMDSI extends AsyncMultiDataSetIterator {
if (iterator.resetSupported())
this.backedIterator.reset();
this.thread = new SparkPrefetchThread(buffer, iterator, terminator);
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread());
context = TaskContext.get();
/**
* We want to ensure, that background thread will have the same thread->device affinity, as master thread
*/
Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
thread.setDaemon(true);
thread.start();
}
@ -117,9 +112,8 @@ public class SparkAMDSI extends AsyncMultiDataSetIterator {
protected class SparkPrefetchThread extends AsyncPrefetchThread {
protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue,
@NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator) {
super(queue, iterator, terminator);
protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue, @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator, int deviceId) {
super(queue, iterator, terminator, deviceId);
}
}
}

View File

@ -34,6 +34,12 @@ public interface AffinityManager {
*/
Integer getDeviceForCurrentThread();
/**
* This method returns deviceId for a given thread
* @return
*/
Integer getDeviceForThread(long threadId);
/**
* This method returns id of current device for a given INDArray

View File

@ -28,6 +28,11 @@ public abstract class BasicAffinityManager implements AffinityManager {
return 0;
}
@Override
public Integer getDeviceForThread(long threadId) {
return 0;
}
@Override
public Integer getDeviceForArray(INDArray array) {
return 0;

View File

@ -68,7 +68,24 @@ public class CudaAffinityManager extends BasicAffinityManager {
*/
@Override
public Integer getDeviceForCurrentThread() {
return NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
val id = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
if (!affinityMap.containsKey(Thread.currentThread().getId()))
affinityMap.put(Thread.currentThread().getId(), id);
return id;
}
/**
* This method returns deviceId for a given thread
* @return
*/
@Override
public Integer getDeviceForThread(long threadId) {
val id = affinityMap.get(threadId);
if (id == null)
throw new RuntimeException("Affinity for thread [" + threadId + "] wasn't defined yet");
return id;
}

View File

@ -11131,7 +11131,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// REGISTER_C(NAME)
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
// auto shapeList = SHAPELIST();
// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) {
// for (int e = 0; e < block.width(); e++) {
// auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e)));
// shapeList->push_back(newshape);
// }
@ -11168,7 +11168,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// REGISTER_C(NAME)
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
// auto shapeList = SHAPELIST();
// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) {
// for (int e = 0; e < block.width(); e++) {
// Nd4jLong* newshape;
// COPY_SHAPE(inputShape->at(0), newshape);
// shapeList->push_back(CONSTANT(newshape));
@ -11191,7 +11191,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// REGISTER_C(NAME)
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
// auto shapeList = SHAPELIST();
// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) {
// for (int e = 0; e < block.width(); e++) {
// auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e)));
// shapeList->push_back(newshape);
// }
@ -16282,9 +16282,15 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/*
* Complete tensor with max indices merged from all input tensors list
*
* INPUT: tensors with the same shape
* OUTPUT: integer tensor with the same shape
* INT_ARG: result type (one of int), INT32 by default
*/
// #if NOT_EXCLUDED(OP_mergemaxindex)
@Namespace("nd4j::ops") public static class mergemaxindex extends DeclarableOp {
@Namespace("nd4j::ops") public static class mergemaxindex extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public mergemaxindex(Pointer p) { super(p); }
@ -16295,10 +16301,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
return (mergemaxindex)super.position(position);
}
public mergemaxindex() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
public mergemaxindex() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
// #if NOT_EXCLUDED(OP_mergeadd)
@ -21746,7 +21752,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/**
* This operation shift individual bits of each element in array
* This operation shift individual bits of each element in array to the left: <<
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
@ -21771,7 +21777,32 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #endif
/**
* This operation shift individual bits of each element in array
* This operation shift individual bits of each element in array to the right: >>
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_rshift_bits)
@Namespace("nd4j::ops") public static class rshift_bits extends DeclarableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public rshift_bits(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public rshift_bits(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public rshift_bits position(long position) {
return (rshift_bits)super.position(position);
}
public rshift_bits() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
* This operation shift individual bits of each element in array, shifting to the left
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
@ -21794,6 +21825,31 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
* This operation shift individual bits of each element in array, shifting to the right
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_cyclic_rshift_bits)
@Namespace("nd4j::ops") public static class cyclic_rshift_bits extends DeclarableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public cyclic_rshift_bits(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public cyclic_rshift_bits(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public cyclic_rshift_bits position(long position) {
return (cyclic_rshift_bits)super.position(position);
}
public cyclic_rshift_bits() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
@ -22545,7 +22601,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
// #if NOT_EXCLUDED(OP_to_double)
@Namespace("nd4j::ops") public static class to_double extends DeclarableOp {
@Namespace("nd4j::ops") public static class to_double extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_double(Pointer p) { super(p); }
@ -22556,10 +22612,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
return (to_double)super.position(position);
}
public to_double() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
public to_double() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
@ -22568,7 +22624,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
// #if NOT_EXCLUDED(OP_to_float16)
@Namespace("nd4j::ops") public static class to_float16 extends DeclarableOp {
@Namespace("nd4j::ops") public static class to_float16 extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_float16(Pointer p) { super(p); }
@ -22579,10 +22635,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
return (to_float16)super.position(position);
}
public to_float16() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
public to_float16() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
@ -22591,7 +22647,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
// #if NOT_EXCLUDED(OP_to_float32)
@Namespace("nd4j::ops") public static class to_float32 extends DeclarableOp {
@Namespace("nd4j::ops") public static class to_float32 extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_float32(Pointer p) { super(p); }
@ -22602,10 +22658,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
return (to_float32)super.position(position);
}
public to_float32() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
public to_float32() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
@ -22614,7 +22670,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
// #if NOT_EXCLUDED(OP_to_int32)
@Namespace("nd4j::ops") public static class to_int32 extends DeclarableOp {
@Namespace("nd4j::ops") public static class to_int32 extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_int32(Pointer p) { super(p); }
@ -22625,10 +22681,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
return (to_int32)super.position(position);
}
public to_int32() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
public to_int32() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
@ -22637,7 +22693,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
// #if NOT_EXCLUDED(OP_to_int64)
@Namespace("nd4j::ops") public static class to_int64 extends DeclarableOp {
@Namespace("nd4j::ops") public static class to_int64 extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_int64(Pointer p) { super(p); }
@ -22648,10 +22704,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
return (to_int64)super.position(position);
}
public to_int64() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
public to_int64() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
@ -22660,7 +22716,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
// #if NOT_EXCLUDED(OP_to_uint32)
@Namespace("nd4j::ops") public static class to_uint32 extends DeclarableOp {
@Namespace("nd4j::ops") public static class to_uint32 extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_uint32(Pointer p) { super(p); }
@ -22671,10 +22727,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
return (to_uint32)super.position(position);
}
public to_uint32() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
public to_uint32() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
@ -22683,7 +22739,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
// #if NOT_EXCLUDED(OP_to_uint64)
@Namespace("nd4j::ops") public static class to_uint64 extends DeclarableOp {
@Namespace("nd4j::ops") public static class to_uint64 extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public to_uint64(Pointer p) { super(p); }
@ -22694,10 +22750,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
return (to_uint64)super.position(position);
}
public to_uint64() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
public to_uint64() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**