Events removed from Java (#219)

* replace mutex with lock_guards

Signed-off-by: raver119 <raver119@gmail.com>

* Events ditched from Java CUDA logic

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-07 12:34:55 +03:00 committed by GitHub
parent 937a27ae27
commit a0da5a9e47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 49 additions and 50 deletions

View File

@ -92,7 +92,7 @@ namespace nd4j {
} }
void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) { void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) {
_mutex.lock(); std::lock_guard<std::mutex> lock(_mutex);
auto deviceId = getCurrentDevice(); auto deviceId = getCurrentDevice();
Nd4jPointer constantPtr = nullptr; Nd4jPointer constantPtr = nullptr;
@ -116,7 +116,6 @@ namespace nd4j {
if (res != 0) if (res != 0)
throw cuda_exception::build("cudaMemcpy failed", res); throw cuda_exception::build("cudaMemcpy failed", res);
_mutex.unlock();
return ptr; return ptr;
} else { } else {
auto originalBytes = numBytes; auto originalBytes = numBytes;
@ -130,7 +129,6 @@ namespace nd4j {
if (res != 0) if (res != 0)
throw cuda_exception::build("cudaMemcpyToSymbol failed", res); throw cuda_exception::build("cudaMemcpyToSymbol failed", res);
_mutex.unlock();
return reinterpret_cast<int8_t *>(constantPtr) + constantOffset; return reinterpret_cast<int8_t *>(constantPtr) + constantOffset;
} }
} }
@ -152,7 +150,7 @@ namespace nd4j {
ConstantDataBuffer* result; ConstantDataBuffer* result;
// access to this holder instance is synchronous // access to this holder instance is synchronous
holder->mutex()->lock(); std::lock_guard<std::mutex> lock(*holder->mutex());
if (holder->hasBuffer(dataType)) { if (holder->hasBuffer(dataType)) {
result = holder->getConstantDataBuffer(dataType); result = holder->getConstantDataBuffer(dataType);
@ -175,8 +173,6 @@ namespace nd4j {
holder->addBuffer(dataBuffer, dataType); holder->addBuffer(dataBuffer, dataType);
result = holder->getConstantDataBuffer(dataType); result = holder->getConstantDataBuffer(dataType);
} }
// release holder lock
holder->mutex()->unlock();
return result; return result;
} }

View File

@ -57,7 +57,7 @@ namespace nd4j {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
int deviceId = AffinityManager::currentDeviceId(); int deviceId = AffinityManager::currentDeviceId();
_mutex.lock(); std::lock_guard<std::mutex> lock(_mutex);
if (_cache[deviceId].count(descriptor) == 0) { if (_cache[deviceId].count(descriptor) == 0) {
auto hPtr = descriptor.toShapeInfo(); auto hPtr = descriptor.toShapeInfo();
@ -65,15 +65,9 @@ namespace nd4j {
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64); ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
ShapeDescriptor descriptor1(descriptor); ShapeDescriptor descriptor1(descriptor);
_cache[deviceId][descriptor1] = buffer; _cache[deviceId][descriptor1] = buffer;
auto r = _cache[deviceId][descriptor1]; return _cache[deviceId][descriptor1];
_mutex.unlock();
return r;
} else { } else {
ConstantDataBuffer r = _cache[deviceId].at(descriptor); return _cache[deviceId].at(descriptor);
_mutex.unlock();
return r;
} }
} }
@ -83,18 +77,10 @@ namespace nd4j {
} }
bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) { bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) {
bool result;
auto deviceId = AffinityManager::currentDeviceId(); auto deviceId = AffinityManager::currentDeviceId();
_mutex.lock(); std::lock_guard<std::mutex> lock(_mutex);
if (_cache[deviceId].count(descriptor) == 0) return _cache[deviceId].count(descriptor) != 0;
result = false;
else
result = true;
_mutex.unlock();
return result;
} }
Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {

View File

@ -64,7 +64,7 @@ namespace nd4j {
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
const int deviceId = AffinityManager::currentDeviceId(); const int deviceId = AffinityManager::currentDeviceId();
_mutex.lock(); std::lock_guard<std::mutex> lock(_mutex);
if (_cache[deviceId].count(descriptor) == 0) { if (_cache[deviceId].count(descriptor) == 0) {
const auto shapeInfo = descriptor.originalShape().toShapeInfo(); const auto shapeInfo = descriptor.originalShape().toShapeInfo();
@ -97,14 +97,12 @@ namespace nd4j {
_cache[deviceId][descriptor] = t; _cache[deviceId][descriptor] = t;
TadPack r = _cache[deviceId][descriptor]; TadPack r = _cache[deviceId][descriptor];
_mutex.unlock();
delete[] shapeInfo; delete[] shapeInfo;
return r; return r;
} else { } else {
TadPack r = _cache[deviceId][descriptor]; TadPack r = _cache[deviceId][descriptor];
_mutex.unlock();
return r; return r;
} }

View File

@ -469,8 +469,8 @@ public class AtomicAllocator implements Allocator {
memoryHandler.purgeZeroObject(bucketId, objectId, point, copyback); memoryHandler.purgeZeroObject(bucketId, objectId, point, copyback);
getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent()); //getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent());
getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent()); //getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent());
} }
/** /**

View File

@ -26,11 +26,11 @@ import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
/** /**
*
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Deprecated
public class EventsProvider { public class EventsProvider {
//private static final EventsProvider INSTANCE = new EventsProvider();
private List<ConcurrentLinkedQueue<cudaEvent_t>> queue = new ArrayList<>(); private List<ConcurrentLinkedQueue<cudaEvent_t>> queue = new ArrayList<>();
private AtomicLong newCounter = new AtomicLong(0); private AtomicLong newCounter = new AtomicLong(0);
private AtomicLong cacheCounter = new AtomicLong(0); private AtomicLong cacheCounter = new AtomicLong(0);

View File

@ -72,12 +72,7 @@ public class SynchronousFlowController implements FlowController {
@Override @Override
public void waitTillFinished(AllocationPoint point) { public void waitTillFinished(AllocationPoint point) {
/*CudaContext context = point.getCurrentContext(); //(CudaContext) allocator.getDeviceContext().getContext(); // this should be always null, since synchronization happens in C++ now
if (context == null)
context = (CudaContext) allocator.getDeviceContext().getContext();
context.syncOldStream();
*/
if (point.getLastWriteEvent() != null) { if (point.getLastWriteEvent() != null) {
point.getLastWriteEvent().synchronize(); point.getLastWriteEvent().synchronize();
} }
@ -181,8 +176,8 @@ public class SynchronousFlowController implements FlowController {
@Override @Override
public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint... operands) { public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint... operands) {
// this method is irrelevant now, everything happens in C++ now
/*
eventsProvider.storeEvent(result.getLastWriteEvent()); eventsProvider.storeEvent(result.getLastWriteEvent());
result.setLastWriteEvent(eventsProvider.getEvent()); result.setLastWriteEvent(eventsProvider.getEvent());
result.getLastWriteEvent().register(context.getOldStream()); result.getLastWriteEvent().register(context.getOldStream());
@ -194,6 +189,7 @@ public class SynchronousFlowController implements FlowController {
operand.getLastReadEvent().register(context.getOldStream()); operand.getLastReadEvent().register(context.getOldStream());
} }
// context.syncOldStream(); // context.syncOldStream();
*/
} }
@Override @Override
@ -204,9 +200,6 @@ public class SynchronousFlowController implements FlowController {
val pointOperand = allocator.getAllocationPoint(operand); val pointOperand = allocator.getAllocationPoint(operand);
pointOperand.tickDeviceWrite(); pointOperand.tickDeviceWrite();
eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
pointOperand.setLastWriteEvent(eventsProvider.getEvent());
pointOperand.getLastWriteEvent().register(context.getOldStream());
} }
} }
@ -216,18 +209,13 @@ public class SynchronousFlowController implements FlowController {
val point = allocator.getAllocationPoint(result); val point = allocator.getAllocationPoint(result);
point.tickDeviceWrite(); point.tickDeviceWrite();
eventsProvider.storeEvent(point.getLastWriteEvent());
point.setLastWriteEvent(eventsProvider.getEvent());
point.getLastWriteEvent().register(context.getOldStream());
for (INDArray operand : operands) { for (INDArray operand : operands) {
if (operand == null || operand.isEmpty()) if (operand == null || operand.isEmpty())
continue; continue;
val pointOperand = allocator.getAllocationPoint(operand); val pointOperand = allocator.getAllocationPoint(operand);
eventsProvider.storeEvent(pointOperand.getLastReadEvent()); pointOperand.tickDeviceRead();
pointOperand.setLastReadEvent(eventsProvider.getEvent());
pointOperand.getLastReadEvent().register(context.getOldStream());
} }
} }

View File

@ -307,7 +307,6 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
if (allocationPoint.getHostPointer() == null) { if (allocationPoint.getHostPointer() == null) {
val location = allocationPoint.getAllocationStatus(); val location = allocationPoint.getAllocationStatus();
if (parentWorkspace == null) { if (parentWorkspace == null) {
//log.info("dbAllocate step");
// let cpp allocate primary buffer // let cpp allocate primary buffer
NativeOpsHolder.getInstance().getDeviceNativeOps().dbAllocatePrimaryBuffer(ptrDataBuffer); NativeOpsHolder.getInstance().getDeviceNativeOps().dbAllocatePrimaryBuffer(ptrDataBuffer);
} else { } else {

View File

@ -19177,6 +19177,38 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/**
* solve op. - solve systems of linear equations - general method.
*
* input params:
* 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations
* 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
*
* boolean args:
* 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used
*
* return value:
* tensor with dimension (x * y * z * ::: * M * K) with solutions
*
*/
// #if NOT_EXCLUDED(OP_solve)
@Namespace("nd4j::ops") public static class solve extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public solve(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public solve(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public solve position(long position) {
return (solve)super.position(position);
}
public solve() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/** /**
* lu op. - make LUP decomposition of given batch of 2D square matricies * lu op. - make LUP decomposition of given batch of 2D square matricies
* *