[WIP] build fix (#124)
* AffinityManager changes Signed-off-by: raver119 <raver119@gmail.com> * build fixes Signed-off-by: raver119 <raver119@gmail.com>master
parent
65ff18383a
commit
2f3d7330ce
|
@ -380,19 +380,23 @@ public class VPTree implements Serializable {
|
|||
|
||||
private Node buildFromPoints(INDArray items) {
|
||||
if (executorService == null && items == this.items && workers > 1) {
|
||||
final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||
|
||||
executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() {
|
||||
@Override
|
||||
public Thread newThread(Runnable r) {
|
||||
Thread t = Executors.defaultThreadFactory().newThread(r);
|
||||
public Thread newThread(final Runnable r) {
|
||||
Thread t = new Thread(new Runnable() {
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
|
||||
r.run();
|
||||
}
|
||||
});
|
||||
|
||||
t.setDaemon(true);
|
||||
t.setName("VPTree thread");
|
||||
|
||||
// we don't want threads to be working on different devices
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(t,
|
||||
Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
||||
|
||||
return t;
|
||||
}
|
||||
});
|
||||
|
|
|
@ -132,9 +132,8 @@ public class ParallelInference {
|
|||
boolean cRoot = !assignedRoot.get() && cDevice == currentDevice;
|
||||
assignedRoot.compareAndSet(false, cRoot);
|
||||
|
||||
zoo[i] = new InferenceWorker(i, model, observables, cRoot);
|
||||
zoo[i] = new InferenceWorker(i, model, observables, cRoot, cDevice);
|
||||
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(zoo[i], cDevice);
|
||||
zoo[i].setDaemon(true);
|
||||
zoo[i].start();
|
||||
}
|
||||
|
@ -425,13 +424,15 @@ public class ParallelInference {
|
|||
private Model replicatedModel;
|
||||
private AtomicLong counter = new AtomicLong(0);
|
||||
private boolean rootDevice;
|
||||
private int deviceId;
|
||||
|
||||
private ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock();
|
||||
|
||||
private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice) {
|
||||
private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice, int deviceId) {
|
||||
this.inputQueue = inputQueue;
|
||||
this.protoModel = model;
|
||||
this.rootDevice = rootDevice;
|
||||
this.deviceId = deviceId;
|
||||
|
||||
this.setDaemon(true);
|
||||
this.setName("InferenceThread-" + id);
|
||||
|
@ -491,6 +492,7 @@ public class ParallelInference {
|
|||
|
||||
@Override
|
||||
public void run() {
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
|
||||
try {
|
||||
// model should be replicated & initialized here
|
||||
initializeReplicaModel();
|
||||
|
|
|
@ -151,18 +151,21 @@ public class ParallelWrapper implements AutoCloseable {
|
|||
workerCounter.set(0);
|
||||
this.executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(workers, new ThreadFactory() {
|
||||
@Override
|
||||
public Thread newThread(@NonNull Runnable r) {
|
||||
Thread t = Executors.defaultThreadFactory().newThread(r);
|
||||
public Thread newThread(@NonNull final Runnable r) {
|
||||
final int cThread = workerCounter.getAndIncrement();
|
||||
|
||||
int cThread = workerCounter.getAndIncrement();
|
||||
Thread t = new Thread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(cThread % Nd4j.getAffinityManager().getNumberOfDevices());
|
||||
r.run();
|
||||
}
|
||||
});
|
||||
|
||||
t.setName("ParallelWrapper training thread " + cThread);
|
||||
t.setDaemon(true);
|
||||
t.setUncaughtExceptionHandler(handler);
|
||||
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(t,
|
||||
cThread % Nd4j.getAffinityManager().getNumberOfDevices());
|
||||
|
||||
return t;
|
||||
}
|
||||
});
|
||||
|
|
|
@ -108,7 +108,7 @@ public class EvaluationRunner {
|
|||
INDArray p;
|
||||
try{
|
||||
p = Nd4j.read(new ByteArrayInputStream(pBytes));
|
||||
} catch (IOException e){
|
||||
} catch (RuntimeException e){
|
||||
throw new RuntimeException(e); //Should never happen
|
||||
}
|
||||
DeviceLocalNDArray dlp = new DeviceLocalNDArray(p);
|
||||
|
|
|
@ -97,13 +97,12 @@ public class SparkADSI extends AsyncDataSetIterator {
|
|||
|
||||
context = TaskContext.get();
|
||||
|
||||
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null);
|
||||
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null, Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
||||
|
||||
/**
|
||||
* We want to ensure, that background thread will have the same thread->device affinity, as master thread
|
||||
*/
|
||||
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
|
||||
thread.setDaemon(true);
|
||||
thread.start();
|
||||
}
|
||||
|
@ -116,9 +115,8 @@ public class SparkADSI extends AsyncDataSetIterator {
|
|||
|
||||
public class SparkPrefetchThread extends AsyncPrefetchThread {
|
||||
|
||||
protected SparkPrefetchThread(BlockingQueue<DataSet> queue, DataSetIterator iterator, DataSet terminator,
|
||||
MemoryWorkspace workspace) {
|
||||
super(queue, iterator, terminator, workspace);
|
||||
protected SparkPrefetchThread(BlockingQueue<DataSet> queue, DataSetIterator iterator, DataSet terminator, MemoryWorkspace workspace, int deviceId) {
|
||||
super(queue, iterator, terminator, workspace, deviceId);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -97,15 +97,10 @@ public class SparkAMDSI extends AsyncMultiDataSetIterator {
|
|||
if (iterator.resetSupported())
|
||||
this.backedIterator.reset();
|
||||
|
||||
this.thread = new SparkPrefetchThread(buffer, iterator, terminator);
|
||||
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
||||
|
||||
context = TaskContext.get();
|
||||
|
||||
/**
|
||||
* We want to ensure, that background thread will have the same thread->device affinity, as master thread
|
||||
*/
|
||||
Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
|
||||
|
||||
thread.setDaemon(true);
|
||||
thread.start();
|
||||
}
|
||||
|
@ -117,9 +112,8 @@ public class SparkAMDSI extends AsyncMultiDataSetIterator {
|
|||
|
||||
protected class SparkPrefetchThread extends AsyncPrefetchThread {
|
||||
|
||||
protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue,
|
||||
@NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator) {
|
||||
super(queue, iterator, terminator);
|
||||
protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue, @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator, int deviceId) {
|
||||
super(queue, iterator, terminator, deviceId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,12 @@ public interface AffinityManager {
|
|||
*/
|
||||
Integer getDeviceForCurrentThread();
|
||||
|
||||
/**
|
||||
* This method returns deviceId for a given thread
|
||||
* @return
|
||||
*/
|
||||
Integer getDeviceForThread(long threadId);
|
||||
|
||||
|
||||
/**
|
||||
* This method returns id of current device for a given INDArray
|
||||
|
|
|
@ -28,6 +28,11 @@ public abstract class BasicAffinityManager implements AffinityManager {
|
|||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getDeviceForThread(long threadId) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getDeviceForArray(INDArray array) {
|
||||
return 0;
|
||||
|
|
|
@ -68,7 +68,24 @@ public class CudaAffinityManager extends BasicAffinityManager {
|
|||
*/
|
||||
@Override
|
||||
public Integer getDeviceForCurrentThread() {
|
||||
return NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
|
||||
val id = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
|
||||
if (!affinityMap.containsKey(Thread.currentThread().getId()))
|
||||
affinityMap.put(Thread.currentThread().getId(), id);
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns deviceId for a given thread
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Integer getDeviceForThread(long threadId) {
|
||||
val id = affinityMap.get(threadId);
|
||||
if (id == null)
|
||||
throw new RuntimeException("Affinity for thread [" + threadId + "] wasn't defined yet");
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -11131,7 +11131,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// REGISTER_C(NAME)
|
||||
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
|
||||
// auto shapeList = SHAPELIST();
|
||||
// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) {
|
||||
// for (int e = 0; e < block.width(); e++) {
|
||||
// auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e)));
|
||||
// shapeList->push_back(newshape);
|
||||
// }
|
||||
|
@ -11168,7 +11168,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// REGISTER_C(NAME)
|
||||
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
|
||||
// auto shapeList = SHAPELIST();
|
||||
// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) {
|
||||
// for (int e = 0; e < block.width(); e++) {
|
||||
// Nd4jLong* newshape;
|
||||
// COPY_SHAPE(inputShape->at(0), newshape);
|
||||
// shapeList->push_back(CONSTANT(newshape));
|
||||
|
@ -11191,7 +11191,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// REGISTER_C(NAME)
|
||||
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
|
||||
// auto shapeList = SHAPELIST();
|
||||
// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) {
|
||||
// for (int e = 0; e < block.width(); e++) {
|
||||
// auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e)));
|
||||
// shapeList->push_back(newshape);
|
||||
// }
|
||||
|
@ -16282,9 +16282,15 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
/*
|
||||
* Complete tensor with max indices merged from all input tensors list
|
||||
*
|
||||
* INPUT: tensors with the same shape
|
||||
* OUTPUT: integer tensor with the same shape
|
||||
* INT_ARG: result type (one of int), INT32 by default
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_mergemaxindex)
|
||||
@Namespace("nd4j::ops") public static class mergemaxindex extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class mergemaxindex extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public mergemaxindex(Pointer p) { super(p); }
|
||||
|
@ -21746,7 +21752,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
|
||||
/**
|
||||
* This operation shift individual bits of each element in array
|
||||
* This operation shift individual bits of each element in array to the left: <<
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
|
@ -21771,7 +21777,32 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// #endif
|
||||
|
||||
/**
|
||||
* This operation shift individual bits of each element in array
|
||||
* This operation shift individual bits of each element in array to the right: >>
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_rshift_bits)
|
||||
@Namespace("nd4j::ops") public static class rshift_bits extends DeclarableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public rshift_bits(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public rshift_bits(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public rshift_bits position(long position) {
|
||||
return (rshift_bits)super.position(position);
|
||||
}
|
||||
|
||||
public rshift_bits() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation shift individual bits of each element in array, shifting to the left
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
|
@ -21795,6 +21826,31 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation shift individual bits of each element in array, shifting to the right
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_cyclic_rshift_bits)
|
||||
@Namespace("nd4j::ops") public static class cyclic_rshift_bits extends DeclarableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public cyclic_rshift_bits(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public cyclic_rshift_bits(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public cyclic_rshift_bits position(long position) {
|
||||
return (cyclic_rshift_bits)super.position(position);
|
||||
}
|
||||
|
||||
public cyclic_rshift_bits() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
|
||||
|
||||
// #endif
|
||||
|
@ -22545,7 +22601,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_to_double)
|
||||
@Namespace("nd4j::ops") public static class to_double extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class to_double extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public to_double(Pointer p) { super(p); }
|
||||
|
@ -22568,7 +22624,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_to_float16)
|
||||
@Namespace("nd4j::ops") public static class to_float16 extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class to_float16 extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public to_float16(Pointer p) { super(p); }
|
||||
|
@ -22591,7 +22647,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_to_float32)
|
||||
@Namespace("nd4j::ops") public static class to_float32 extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class to_float32 extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public to_float32(Pointer p) { super(p); }
|
||||
|
@ -22614,7 +22670,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_to_int32)
|
||||
@Namespace("nd4j::ops") public static class to_int32 extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class to_int32 extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public to_int32(Pointer p) { super(p); }
|
||||
|
@ -22637,7 +22693,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_to_int64)
|
||||
@Namespace("nd4j::ops") public static class to_int64 extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class to_int64 extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public to_int64(Pointer p) { super(p); }
|
||||
|
@ -22660,7 +22716,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_to_uint32)
|
||||
@Namespace("nd4j::ops") public static class to_uint32 extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class to_uint32 extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public to_uint32(Pointer p) { super(p); }
|
||||
|
@ -22683,7 +22739,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_to_uint64)
|
||||
@Namespace("nd4j::ops") public static class to_uint64 extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class to_uint64 extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public to_uint64(Pointer p) { super(p); }
|
||||
|
|
Loading…
Reference in New Issue