Events removed from Java (#219)
* replace mutex with lock_guards Signed-off-by: raver119 <raver119@gmail.com> * Events ditched from Java CUDA logic Signed-off-by: raver119 <raver119@gmail.com>master
parent
937a27ae27
commit
a0da5a9e47
|
@ -92,7 +92,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) {
|
void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) {
|
||||||
_mutex.lock();
|
std::lock_guard<std::mutex> lock(_mutex);
|
||||||
|
|
||||||
auto deviceId = getCurrentDevice();
|
auto deviceId = getCurrentDevice();
|
||||||
Nd4jPointer constantPtr = nullptr;
|
Nd4jPointer constantPtr = nullptr;
|
||||||
|
@ -116,7 +116,6 @@ namespace nd4j {
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw cuda_exception::build("cudaMemcpy failed", res);
|
throw cuda_exception::build("cudaMemcpy failed", res);
|
||||||
|
|
||||||
_mutex.unlock();
|
|
||||||
return ptr;
|
return ptr;
|
||||||
} else {
|
} else {
|
||||||
auto originalBytes = numBytes;
|
auto originalBytes = numBytes;
|
||||||
|
@ -130,7 +129,6 @@ namespace nd4j {
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw cuda_exception::build("cudaMemcpyToSymbol failed", res);
|
throw cuda_exception::build("cudaMemcpyToSymbol failed", res);
|
||||||
|
|
||||||
_mutex.unlock();
|
|
||||||
return reinterpret_cast<int8_t *>(constantPtr) + constantOffset;
|
return reinterpret_cast<int8_t *>(constantPtr) + constantOffset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -152,7 +150,7 @@ namespace nd4j {
|
||||||
ConstantDataBuffer* result;
|
ConstantDataBuffer* result;
|
||||||
|
|
||||||
// access to this holder instance is synchronous
|
// access to this holder instance is synchronous
|
||||||
holder->mutex()->lock();
|
std::lock_guard<std::mutex> lock(*holder->mutex());
|
||||||
|
|
||||||
if (holder->hasBuffer(dataType)) {
|
if (holder->hasBuffer(dataType)) {
|
||||||
result = holder->getConstantDataBuffer(dataType);
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
|
@ -175,8 +173,6 @@ namespace nd4j {
|
||||||
holder->addBuffer(dataBuffer, dataType);
|
holder->addBuffer(dataBuffer, dataType);
|
||||||
result = holder->getConstantDataBuffer(dataType);
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
}
|
}
|
||||||
// release holder lock
|
|
||||||
holder->mutex()->unlock();
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,7 @@ namespace nd4j {
|
||||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||||
int deviceId = AffinityManager::currentDeviceId();
|
int deviceId = AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
_mutex.lock();
|
std::lock_guard<std::mutex> lock(_mutex);
|
||||||
|
|
||||||
if (_cache[deviceId].count(descriptor) == 0) {
|
if (_cache[deviceId].count(descriptor) == 0) {
|
||||||
auto hPtr = descriptor.toShapeInfo();
|
auto hPtr = descriptor.toShapeInfo();
|
||||||
|
@ -65,15 +65,9 @@ namespace nd4j {
|
||||||
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
|
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
|
||||||
ShapeDescriptor descriptor1(descriptor);
|
ShapeDescriptor descriptor1(descriptor);
|
||||||
_cache[deviceId][descriptor1] = buffer;
|
_cache[deviceId][descriptor1] = buffer;
|
||||||
auto r = _cache[deviceId][descriptor1];
|
return _cache[deviceId][descriptor1];
|
||||||
_mutex.unlock();
|
|
||||||
|
|
||||||
return r;
|
|
||||||
} else {
|
} else {
|
||||||
ConstantDataBuffer r = _cache[deviceId].at(descriptor);
|
return _cache[deviceId].at(descriptor);
|
||||||
_mutex.unlock();
|
|
||||||
|
|
||||||
return r;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,18 +77,10 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) {
|
bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) {
|
||||||
bool result;
|
|
||||||
auto deviceId = AffinityManager::currentDeviceId();
|
auto deviceId = AffinityManager::currentDeviceId();
|
||||||
_mutex.lock();
|
std::lock_guard<std::mutex> lock(_mutex);
|
||||||
|
|
||||||
if (_cache[deviceId].count(descriptor) == 0)
|
return _cache[deviceId].count(descriptor) != 0;
|
||||||
result = false;
|
|
||||||
else
|
|
||||||
result = true;
|
|
||||||
|
|
||||||
_mutex.unlock();
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||||
|
|
|
@ -64,7 +64,7 @@ namespace nd4j {
|
||||||
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||||
const int deviceId = AffinityManager::currentDeviceId();
|
const int deviceId = AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
_mutex.lock();
|
std::lock_guard<std::mutex> lock(_mutex);
|
||||||
|
|
||||||
if (_cache[deviceId].count(descriptor) == 0) {
|
if (_cache[deviceId].count(descriptor) == 0) {
|
||||||
const auto shapeInfo = descriptor.originalShape().toShapeInfo();
|
const auto shapeInfo = descriptor.originalShape().toShapeInfo();
|
||||||
|
@ -97,14 +97,12 @@ namespace nd4j {
|
||||||
_cache[deviceId][descriptor] = t;
|
_cache[deviceId][descriptor] = t;
|
||||||
|
|
||||||
TadPack r = _cache[deviceId][descriptor];
|
TadPack r = _cache[deviceId][descriptor];
|
||||||
_mutex.unlock();
|
|
||||||
|
|
||||||
delete[] shapeInfo;
|
delete[] shapeInfo;
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
} else {
|
} else {
|
||||||
TadPack r = _cache[deviceId][descriptor];
|
TadPack r = _cache[deviceId][descriptor];
|
||||||
_mutex.unlock();
|
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
|
@ -469,8 +469,8 @@ public class AtomicAllocator implements Allocator {
|
||||||
|
|
||||||
memoryHandler.purgeZeroObject(bucketId, objectId, point, copyback);
|
memoryHandler.purgeZeroObject(bucketId, objectId, point, copyback);
|
||||||
|
|
||||||
getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent());
|
//getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent());
|
||||||
getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent());
|
//getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -26,11 +26,11 @@ import java.util.concurrent.ConcurrentLinkedQueue;
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public class EventsProvider {
|
public class EventsProvider {
|
||||||
//private static final EventsProvider INSTANCE = new EventsProvider();
|
|
||||||
|
|
||||||
private List<ConcurrentLinkedQueue<cudaEvent_t>> queue = new ArrayList<>();
|
private List<ConcurrentLinkedQueue<cudaEvent_t>> queue = new ArrayList<>();
|
||||||
private AtomicLong newCounter = new AtomicLong(0);
|
private AtomicLong newCounter = new AtomicLong(0);
|
||||||
private AtomicLong cacheCounter = new AtomicLong(0);
|
private AtomicLong cacheCounter = new AtomicLong(0);
|
||||||
|
|
|
@ -72,12 +72,7 @@ public class SynchronousFlowController implements FlowController {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void waitTillFinished(AllocationPoint point) {
|
public void waitTillFinished(AllocationPoint point) {
|
||||||
/*CudaContext context = point.getCurrentContext(); //(CudaContext) allocator.getDeviceContext().getContext();
|
// this should be always null, since synchronization happens in C++ now
|
||||||
if (context == null)
|
|
||||||
context = (CudaContext) allocator.getDeviceContext().getContext();
|
|
||||||
context.syncOldStream();
|
|
||||||
*/
|
|
||||||
|
|
||||||
if (point.getLastWriteEvent() != null) {
|
if (point.getLastWriteEvent() != null) {
|
||||||
point.getLastWriteEvent().synchronize();
|
point.getLastWriteEvent().synchronize();
|
||||||
}
|
}
|
||||||
|
@ -181,8 +176,8 @@ public class SynchronousFlowController implements FlowController {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint... operands) {
|
public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint... operands) {
|
||||||
|
// this method is irrelevant now, everything happens in C++ now
|
||||||
|
/*
|
||||||
eventsProvider.storeEvent(result.getLastWriteEvent());
|
eventsProvider.storeEvent(result.getLastWriteEvent());
|
||||||
result.setLastWriteEvent(eventsProvider.getEvent());
|
result.setLastWriteEvent(eventsProvider.getEvent());
|
||||||
result.getLastWriteEvent().register(context.getOldStream());
|
result.getLastWriteEvent().register(context.getOldStream());
|
||||||
|
@ -194,6 +189,7 @@ public class SynchronousFlowController implements FlowController {
|
||||||
operand.getLastReadEvent().register(context.getOldStream());
|
operand.getLastReadEvent().register(context.getOldStream());
|
||||||
}
|
}
|
||||||
// context.syncOldStream();
|
// context.syncOldStream();
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -204,9 +200,6 @@ public class SynchronousFlowController implements FlowController {
|
||||||
|
|
||||||
val pointOperand = allocator.getAllocationPoint(operand);
|
val pointOperand = allocator.getAllocationPoint(operand);
|
||||||
pointOperand.tickDeviceWrite();
|
pointOperand.tickDeviceWrite();
|
||||||
eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
|
|
||||||
pointOperand.setLastWriteEvent(eventsProvider.getEvent());
|
|
||||||
pointOperand.getLastWriteEvent().register(context.getOldStream());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -216,18 +209,13 @@ public class SynchronousFlowController implements FlowController {
|
||||||
|
|
||||||
val point = allocator.getAllocationPoint(result);
|
val point = allocator.getAllocationPoint(result);
|
||||||
point.tickDeviceWrite();
|
point.tickDeviceWrite();
|
||||||
eventsProvider.storeEvent(point.getLastWriteEvent());
|
|
||||||
point.setLastWriteEvent(eventsProvider.getEvent());
|
|
||||||
point.getLastWriteEvent().register(context.getOldStream());
|
|
||||||
|
|
||||||
for (INDArray operand : operands) {
|
for (INDArray operand : operands) {
|
||||||
if (operand == null || operand.isEmpty())
|
if (operand == null || operand.isEmpty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
val pointOperand = allocator.getAllocationPoint(operand);
|
val pointOperand = allocator.getAllocationPoint(operand);
|
||||||
eventsProvider.storeEvent(pointOperand.getLastReadEvent());
|
pointOperand.tickDeviceRead();
|
||||||
pointOperand.setLastReadEvent(eventsProvider.getEvent());
|
|
||||||
pointOperand.getLastReadEvent().register(context.getOldStream());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -307,7 +307,6 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
if (allocationPoint.getHostPointer() == null) {
|
if (allocationPoint.getHostPointer() == null) {
|
||||||
val location = allocationPoint.getAllocationStatus();
|
val location = allocationPoint.getAllocationStatus();
|
||||||
if (parentWorkspace == null) {
|
if (parentWorkspace == null) {
|
||||||
//log.info("dbAllocate step");
|
|
||||||
// let cpp allocate primary buffer
|
// let cpp allocate primary buffer
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().dbAllocatePrimaryBuffer(ptrDataBuffer);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbAllocatePrimaryBuffer(ptrDataBuffer);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -19177,6 +19177,38 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* solve op. - solve systems of linear equations - general method.
|
||||||
|
*
|
||||||
|
* input params:
|
||||||
|
* 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations
|
||||||
|
* 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
|
||||||
|
*
|
||||||
|
* boolean args:
|
||||||
|
* 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used
|
||||||
|
*
|
||||||
|
* return value:
|
||||||
|
* tensor with dimension (x * y * z * ::: * M * K) with solutions
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
// #if NOT_EXCLUDED(OP_solve)
|
||||||
|
@Namespace("nd4j::ops") public static class solve extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public solve(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public solve(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public solve position(long position) {
|
||||||
|
return (solve)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public solve() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* lu op. - make LUP decomposition of given batch of 2D square matricies
|
* lu op. - make LUP decomposition of given batch of 2D square matricies
|
||||||
*
|
*
|
||||||
|
|
Loading…
Reference in New Issue