[WIP] Thread safety (#229)

* sync after cublas*gemm

Signed-off-by: raver119 <raver119@gmail.com>

* mutex for CublasHelper

Signed-off-by: raver119 <raver119@gmail.com>

* don't store cublasHandle in LaunchContext, it's per-device anyway

Signed-off-by: raver119 <raver119@gmail.com>

* some printout

Signed-off-by: raver119 <raver119@gmail.com>

* check for field instead

Signed-off-by: raver119 <raver119@gmail.com>

* pew-pew

Signed-off-by: raver119 <raver119@gmail.com>

* don't release ContextBuffers until device changed

Signed-off-by: raver119 <raver119@gmail.com>

* small tweak

Signed-off-by: raver119 <raver119@gmail.com>

* some logging in sgemm

Signed-off-by: raver119 <raver119@gmail.com>

* stream sync

Signed-off-by: raver119 <raver119@gmail.com>

* some more logging

Signed-off-by: raver119 <raver119@gmail.com>

* some more error checks

Signed-off-by: raver119 <raver119@gmail.com>

* one fancy test

Signed-off-by: raver119 <raver119@gmail.com>

* one fancy test

Signed-off-by: raver119 <raver119@gmail.com>

* minor AffinityManager fix

Signed-off-by: raver119 <raver119@gmail.com>

* cudaEvent error logging improvement

Signed-off-by: raver119 <raver119@gmail.com>

* ConstantHelper thread safety

Signed-off-by: raver119 <raver119@gmail.com>

* - minor corrections in ConstantTadHelper

Signed-off-by: Yurii <yurii@skymind.io>

* ConstantShapeHelper thread safety

Signed-off-by: raver119 <raver119@gmail.com>

* ConstantTadHelper.cu updated

Signed-off-by: raver119 <raver119@gmail.com>

* logging off

Signed-off-by: raver119 <raver119@gmail.com>

* logging off

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-09-03 22:00:38 +03:00 committed by GitHub
parent 5be43e7253
commit dddc8a1143
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 227 additions and 78 deletions

View File

@ -0,0 +1,63 @@
package org.deeplearning4j;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.concurrent.CountDownLatch;
@Ignore
public class RandomTests {
@Test
public void testReproduce() throws Exception {
final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(10)
.activation(Activation.TANH).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10)
.activation(Activation.SOFTMAX).build())
.build();
for (int e = 0; e < 3; e++) {
int nThreads = 10;
final CountDownLatch l = new CountDownLatch(nThreads);
for (int i = 0; i < nThreads; i++) {
final int j = i;
Thread t = new Thread(new Runnable() {
@Override
public void run() {
try {
MultiLayerNetwork net = new MultiLayerNetwork(conf.clone());
net.init();
DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(10, false, 12345), 100);
net.fit(iter);
} catch (Throwable t) {
System.out.println("Thread failed: " + j);
t.printStackTrace();
} finally {
l.countDown();
}
}
});
t.start();
}
l.await();
System.out.println("DONE " + e + "\n");
}
}
}

View File

@ -24,11 +24,13 @@
#include <map> #include <map>
#include <array/ConstantDescriptor.h> #include <array/ConstantDescriptor.h>
#include <array/ConstantDataBuffer.h> #include <array/ConstantDataBuffer.h>
#include <mutex>
namespace nd4j { namespace nd4j {
class ConstantHolder { class ConstantHolder {
private: private:
int _deviceId = 0; int _deviceId = 0;
std::mutex _mutex;
std::map<nd4j::DataType, ConstantDataBuffer> _buffers; std::map<nd4j::DataType, ConstantDataBuffer> _buffers;
public: public:
@ -53,6 +55,8 @@ namespace nd4j {
template <typename T> template <typename T>
ConstantDataBuffer* getConstantDataBuffer(); ConstantDataBuffer* getConstantDataBuffer();
std::mutex* mutex();
}; };
} }

View File

@ -16,6 +16,10 @@ namespace nd4j {
return _buffers.count(dataType) > 0; return _buffers.count(dataType) > 0;
} }
std::mutex* ConstantHolder::mutex() {
return &_mutex;
}
template <typename T> template <typename T>
bool ConstantHolder::hasBuffer() { bool ConstantHolder::hasBuffer() {
return hasBuffer(DataTypeUtils::fromT<T>()); return hasBuffer(DataTypeUtils::fromT<T>());

View File

@ -47,7 +47,7 @@ namespace nd4j {
_currentMutex.unlock(); _currentMutex.unlock();
setCurrentDevice(globalThreadToDevice); setCurrentNativeDevice(globalThreadToDevice);
} }
// if we already know affinity - just return it // if we already know affinity - just return it
@ -92,6 +92,8 @@ namespace nd4j {
void AffinityManager::setCurrentNativeDevice(int deviceId) { void AffinityManager::setCurrentNativeDevice(int deviceId) {
auto res = cudaSetDevice(deviceId); auto res = cudaSetDevice(deviceId);
if (res != 0)
throw cuda_exception::build("setCurrentDevice failed", res);
} }
void AffinityManager::setCurrentDevice(int deviceId) { void AffinityManager::setCurrentDevice(int deviceId) {
@ -104,17 +106,22 @@ namespace nd4j {
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream()); res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream());
if (res != 0) if (res != 0)
throw cuda_exception::build("setCurrentDevice -> specialSync failed", res); throw cuda_exception::build("setCurrentDevice -> specialSync failed", res);
if (deviceId != previousDeviceId) {
// discard existing stuff
nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", "");
LaunchContext::releaseBuffers();
}
} }
auto res = cudaSetDevice(deviceId); if (deviceId != previousDeviceId) {
if (res != 0) auto res = cudaSetDevice(deviceId);
throw cuda_exception::build("cudaSetDevice failed", res); if (res != 0)
throw cuda_exception::build("cudaSetDevice failed", res);
}
// update thread-device affinity // update thread-device affinity
globalThreadToDevice = deviceId; globalThreadToDevice = deviceId;
// discard existing stuff
LaunchContext::releaseBuffers();
} }
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV); std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);

View File

@ -107,7 +107,6 @@ namespace nd4j {
////// //////
_allocated = false; _allocated = false;
_initialized = false;
_deviceId = -1; _deviceId = -1;
this->_specialStream = nullptr; this->_specialStream = nullptr;
@ -116,6 +115,8 @@ namespace nd4j {
this->_reductionPointer = nullptr; this->_reductionPointer = nullptr;
this->_scalarPointer = nullptr; this->_scalarPointer = nullptr;
} }
_initialized = false;
} }
ContextBuffers::~ContextBuffers() { ContextBuffers::~ContextBuffers() {
@ -163,21 +164,21 @@ namespace nd4j {
} }
void* ContextBuffers::reductionBuffer() { void* ContextBuffers::reductionBuffer() {
if (_reductionPointer == nullptr) if (!_initialized)
initialize(); initialize();
return _reductionPointer; return _reductionPointer;
} }
void* ContextBuffers::scalarBuffer() { void* ContextBuffers::scalarBuffer() {
if (_scalarPointer == nullptr) if (!_initialized)
initialize(); initialize();
return _scalarPointer; return _scalarPointer;
} }
void* ContextBuffers::allocationBuffer() { void* ContextBuffers::allocationBuffer() {
if (_allocationPointer == nullptr) if (!_initialized)
initialize(); initialize();
return _allocationPointer; return _allocationPointer;
@ -204,15 +205,23 @@ namespace nd4j {
} }
void* ContextBuffers::execStream() { void* ContextBuffers::execStream() {
if (_execStream == nullptr) if (!_initialized) {
//nd4j_printf("execStream not initialized\n", "");
initialize(); initialize();
} else {
//nd4j_printf("execStream is initialized\n", "");
}
return _execStream; return _execStream;
} }
void* ContextBuffers::specialStream() { void* ContextBuffers::specialStream() {
if (_specialStream == nullptr) if (!_initialized) {
//nd4j_printf("specialStream not initialized\n", "");
initialize(); initialize();
} else {
//nd4j_printf("specialStream is initialized\n", "");
}
return _specialStream; return _specialStream;
} }

View File

@ -57,10 +57,6 @@ LaunchContext::LaunchContext() {
_deviceID = 0; _deviceID = 0;
_isAllocated = true; _isAllocated = true;
_cublasHandle = CublasHelper::getInstance()->handle();
_cusolverHandle = CublasHelper::getInstance()->solver();
} }
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
@ -89,13 +85,13 @@ LaunchContext::LaunchContext() {
_contexts.resize(numDevices); _contexts.resize(numDevices);
for (int e = 0; e < numDevices; e++) { for (int e = 0; e < numDevices; e++) {
AffinityManager::setCurrentDevice(e); AffinityManager::setCurrentNativeDevice(e);
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>(); LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
} }
// don't forget to restore device back again // don't forget to restore device back again
AffinityManager::setCurrentDevice(deviceId); AffinityManager::setCurrentNativeDevice(deviceId);
} }
_mutex.unlock(); _mutex.unlock();
@ -117,11 +113,11 @@ LaunchContext::LaunchContext() {
}; };
void* LaunchContext::getCublasHandle() const { void* LaunchContext::getCublasHandle() const {
return _cublasHandle; return CublasHelper::getInstance()->handle();
}; };
void* LaunchContext::getCusolverHandle() const { void* LaunchContext::getCusolverHandle() const {
return _cusolverHandle; return CublasHelper::getInstance()->solver();
}; };
cudaStream_t* LaunchContext::getCudaStream() const { cudaStream_t* LaunchContext::getCudaStream() const {
@ -162,6 +158,7 @@ LaunchContext::LaunchContext() {
}; };
void LaunchContext::releaseBuffers() { void LaunchContext::releaseBuffers() {
nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", "");
contextBuffers.release(); contextBuffers.release();
} }

View File

@ -38,12 +38,13 @@ namespace nd4j {
static ConstantHelper* _INSTANCE; static ConstantHelper* _INSTANCE;
ConstantHelper(); ConstantHelper();
std::vector<std::map<ConstantDescriptor, ConstantHolder>> _cache; std::vector<std::map<ConstantDescriptor, ConstantHolder*>> _cache;
// tracking of per-device constant memory buffers (CUDA only atm) // tracking of per-device constant memory buffers (CUDA only atm)
std::vector<Nd4jPointer> _devicePointers; std::vector<Nd4jPointer> _devicePointers;
std::vector<Nd4jLong> _deviceOffsets; std::vector<Nd4jLong> _deviceOffsets;
std::mutex _mutex; std::mutex _mutex;
std::mutex _mutexHolder;
std::vector<Nd4jLong> _counters; std::vector<Nd4jLong> _counters;
public: public:

View File

@ -48,10 +48,10 @@ namespace nd4j {
static ConstantShapeHelper* getInstance(); static ConstantShapeHelper* getInstance();
ConstantDataBuffer& bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape); ConstantDataBuffer bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
ConstantDataBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor); ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor);
ConstantDataBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo); ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo);
ConstantDataBuffer& bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape); ConstantDataBuffer bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType); Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType);

View File

@ -54,11 +54,11 @@ namespace nd4j {
* @param keepUnitiesInShape * @param keepUnitiesInShape
* @return * @return
*/ */
TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false); TadPack tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false); TadPack tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false); TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false); TadPack tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(TadDescriptor &descriptor); TadPack tadForDimensions(TadDescriptor &descriptor);
/** /**
* This method returns number of cached TAD shapes/offsets on specific device * This method returns number of cached TAD shapes/offsets on specific device

View File

@ -33,7 +33,8 @@ namespace nd4j {
_cache.resize(numDevices); _cache.resize(numDevices);
_counters.resize(numDevices); _counters.resize(numDevices);
for (int e = 0; e < numDevices; e++) { for (int e = 0; e < numDevices; e++) {
std::map<ConstantDescriptor, ConstantHolder> map; std::map<ConstantDescriptor, ConstantHolder*> map;
_cache[e] = map; _cache[e] = map;
_counters[e] = 0L; _counters[e] = 0L;
} }
@ -70,15 +71,26 @@ namespace nd4j {
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) { ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
const auto deviceId = getCurrentDevice(); const auto deviceId = getCurrentDevice();
// we're locking away cache modification
_mutexHolder.lock();
if (_cache[deviceId].count(descriptor) == 0) { if (_cache[deviceId].count(descriptor) == 0) {
ConstantHolder holder; _cache[deviceId][descriptor] = new ConstantHolder();
_cache[deviceId][descriptor] = holder;
} }
ConstantHolder* holder = &_cache[deviceId][descriptor]; auto holder = _cache[deviceId][descriptor];
// releasing cache lock
_mutexHolder.unlock();
ConstantDataBuffer* result;
// access to this holder instance is synchronous
holder->mutex()->lock();
if (holder->hasBuffer(dataType)) if (holder->hasBuffer(dataType))
return holder->getConstantDataBuffer(dataType); result = holder->getConstantDataBuffer(dataType);
else { else {
auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType); auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType);
auto cbuff = new int8_t[size]; auto cbuff = new int8_t[size];
@ -94,8 +106,11 @@ namespace nd4j {
ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType)); ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType));
holder->addBuffer(dataBuffer, dataType); holder->addBuffer(dataBuffer, dataType);
return holder->getConstantDataBuffer(dataType); result = holder->getConstantDataBuffer(dataType);
} }
holder->mutex()->unlock();
return result;
} }
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {

View File

@ -41,18 +41,18 @@ namespace nd4j {
return _INSTANCE; return _INSTANCE;
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
ShapeDescriptor descriptor(dataType, order, shape); ShapeDescriptor descriptor(dataType, order, shape);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ShapeDescriptor descriptor(dataType, order, shape, rank); ShapeDescriptor descriptor(dataType, order, shape, rank);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
int deviceId = 0; int deviceId = 0;
_mutex.lock(); _mutex.lock();
@ -62,19 +62,19 @@ namespace nd4j {
ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64); ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64);
ShapeDescriptor descriptor1(descriptor); ShapeDescriptor descriptor1(descriptor);
_cache[deviceId][descriptor1] = buffer; _cache[deviceId][descriptor1] = buffer;
ConstantDataBuffer &r = _cache[deviceId][descriptor1]; auto r = _cache[deviceId][descriptor1];
_mutex.unlock(); _mutex.unlock();
return r; return r;
} else { } else {
ConstantDataBuffer &r = _cache[deviceId].at(descriptor); auto r = _cache[deviceId].at(descriptor);
_mutex.unlock(); _mutex.unlock();
return r; return r;
} }
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
ShapeDescriptor descriptor(shapeInfo); ShapeDescriptor descriptor(shapeInfo);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }

View File

@ -38,25 +38,25 @@ namespace nd4j {
return _INSTANCE; return _INSTANCE;
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape); return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
return tadForDimensions(tadDescriptor); return tadForDimensions(tadDescriptor);
} }
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
return tadForDimensions(tadDescriptor); return tadForDimensions(tadDescriptor);
} }
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
const int deviceId = 0; const int deviceId = 0;
_mutex.lock(); _mutex.lock();
@ -105,7 +105,7 @@ namespace nd4j {
return r; return r;
} else { } else {
TadPack &r = _cache[deviceId][descriptor]; TadPack r = _cache[deviceId][descriptor];
_mutex.unlock(); _mutex.unlock();
return r; return r;

View File

@ -24,11 +24,13 @@
#include <dll.h> #include <dll.h>
#include <pointercast.h> #include <pointercast.h>
#include <vector> #include <vector>
#include <mutex>
namespace nd4j { namespace nd4j {
class CublasHelper { class CublasHelper {
private: private:
static CublasHelper *_INSTANCE; static CublasHelper *_INSTANCE;
static std::mutex _mutex;
std::vector<void*> _cache; std::vector<void*> _cache;
std::vector<void*> _solvers; std::vector<void*> _solvers;

View File

@ -68,7 +68,7 @@ namespace nd4j {
throw cuda_exception::build("cudaSetDevice failed", res); throw cuda_exception::build("cudaSetDevice failed", res);
auto constant = getConstantSpace(); auto constant = getConstantSpace();
std::map<ConstantDescriptor, ConstantHolder> devCache; std::map<ConstantDescriptor, ConstantHolder*> devCache;
_devicePointers[e] = constant; _devicePointers[e] = constant;
_deviceOffsets[e] = 0; _deviceOffsets[e] = 0;
@ -136,15 +136,24 @@ namespace nd4j {
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) { ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
const auto deviceId = getCurrentDevice(); const auto deviceId = getCurrentDevice();
if (_cache[deviceId].count(descriptor) == 0) { // all cache modifications are synchronous
ConstantHolder holder; _mutexHolder.lock();
_cache[deviceId][descriptor] = holder;
}
ConstantHolder* holder = &_cache[deviceId][descriptor]; if (_cache[deviceId].count(descriptor) == 0) {
_cache[deviceId][descriptor] = new ConstantHolder();
}
auto holder = _cache[deviceId][descriptor];
// release cache lock
_mutexHolder.unlock();
ConstantDataBuffer* result;
// access to this holder instance is synchronous
holder->mutex()->lock();
if (holder->hasBuffer(dataType)) { if (holder->hasBuffer(dataType)) {
return holder->getConstantDataBuffer(dataType); result = holder->getConstantDataBuffer(dataType);
} else { } else {
auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType); auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType);
auto cbuff = new int8_t[numBytes]; auto cbuff = new int8_t[numBytes];
@ -160,10 +169,14 @@ namespace nd4j {
auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType)); auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType));
ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType)); ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType));
holder->addBuffer(dataBuffer, dataType);
return holder->getConstantDataBuffer(dataType); holder->addBuffer(dataBuffer, dataType);
result = holder->getConstantDataBuffer(dataType);
} }
// release holder lock
holder->mutex()->unlock();
return result;
} }
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {

View File

@ -44,17 +44,17 @@ namespace nd4j {
return _INSTANCE; return _INSTANCE;
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
ShapeDescriptor descriptor(dataType, order, shape); ShapeDescriptor descriptor(dataType, order, shape);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ShapeDescriptor descriptor(dataType, order, shape, rank); ShapeDescriptor descriptor(dataType, order, shape, rank);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
int deviceId = AffinityManager::currentDeviceId(); int deviceId = AffinityManager::currentDeviceId();
_mutex.lock(); _mutex.lock();
@ -65,19 +65,19 @@ namespace nd4j {
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64); ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
ShapeDescriptor descriptor1(descriptor); ShapeDescriptor descriptor1(descriptor);
_cache[deviceId][descriptor1] = buffer; _cache[deviceId][descriptor1] = buffer;
ConstantDataBuffer &r = _cache[deviceId][descriptor1]; auto r = _cache[deviceId][descriptor1];
_mutex.unlock(); _mutex.unlock();
return r; return r;
} else { } else {
ConstantDataBuffer &r = _cache[deviceId].at(descriptor); ConstantDataBuffer r = _cache[deviceId].at(descriptor);
_mutex.unlock(); _mutex.unlock();
return r; return r;
} }
} }
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
ShapeDescriptor descriptor(shapeInfo); ShapeDescriptor descriptor(shapeInfo);
return bufferForShapeInfo(descriptor); return bufferForShapeInfo(descriptor);
} }

View File

@ -43,25 +43,25 @@ namespace nd4j {
return _INSTANCE; return _INSTANCE;
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape); return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
return tadForDimensions(tadDescriptor); return tadForDimensions(tadDescriptor);
} }
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) { TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
return tadForDimensions(tadDescriptor); return tadForDimensions(tadDescriptor);
} }
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
const int deviceId = AffinityManager::currentDeviceId(); const int deviceId = AffinityManager::currentDeviceId();
_mutex.lock(); _mutex.lock();
@ -96,14 +96,14 @@ namespace nd4j {
TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs); TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs);
_cache[deviceId][descriptor] = t; _cache[deviceId][descriptor] = t;
TadPack &r = _cache[deviceId][descriptor]; TadPack r = _cache[deviceId][descriptor];
_mutex.unlock(); _mutex.unlock();
delete[] shapeInfo; delete[] shapeInfo;
return r; return r;
} else { } else {
TadPack &r = _cache[deviceId][descriptor]; TadPack r = _cache[deviceId][descriptor];
_mutex.unlock(); _mutex.unlock();
return r; return r;

View File

@ -27,6 +27,7 @@
#include <execution/AffinityManager.h> #include <execution/AffinityManager.h>
namespace nd4j { namespace nd4j {
std::mutex CublasHelper::_mutex;
static void* handle_() { static void* handle_() {
auto _handle = new cublasHandle_t(); auto _handle = new cublasHandle_t();
@ -56,22 +57,24 @@ namespace nd4j {
} }
CublasHelper::CublasHelper() { CublasHelper::CublasHelper() {
//nd4j_printf("Initializing cuBLAS\n","");
auto numDevices = AffinityManager::numberOfDevices(); auto numDevices = AffinityManager::numberOfDevices();
auto currentDevice = AffinityManager::currentDeviceId(); auto currentDevice = AffinityManager::currentDeviceId();
_cache.resize(numDevices); _cache.resize(numDevices);
_solvers.resize(numDevices); _solvers.resize(numDevices);
for (int e = 0; e < numDevices; e++) { for (int e = 0; e < numDevices; e++) {
AffinityManager::setCurrentDevice(e); AffinityManager::setCurrentNativeDevice(e);
_cache[e] = handle_(); _cache[e] = handle_();
_solvers[e] = solver_(); _solvers[e] = solver_();
} }
// don't forget to restore back original device // don't forget to restore back original device
AffinityManager::setCurrentDevice(currentDevice); AffinityManager::setCurrentNativeDevice(currentDevice);
} }
CublasHelper::~CublasHelper() { CublasHelper::~CublasHelper() {
nd4j_printf("Releasing cuBLAS\n","");
auto numDevices = AffinityManager::numberOfDevices(); auto numDevices = AffinityManager::numberOfDevices();
for (int e = 0; e < numDevices; e++) for (int e = 0; e < numDevices; e++)
@ -79,8 +82,10 @@ namespace nd4j {
} }
CublasHelper* CublasHelper::getInstance() { CublasHelper* CublasHelper::getInstance() {
_mutex.lock();
if (!_INSTANCE) if (!_INSTANCE)
_INSTANCE = new nd4j::CublasHelper(); _INSTANCE = new nd4j::CublasHelper();
_mutex.unlock();
return _INSTANCE; return _INSTANCE;
} }

View File

@ -18,6 +18,7 @@ package org.nd4j.jita.allocator.pointers.cuda;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.val;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.linalg.exception.ND4JException; import org.nd4j.linalg.exception.ND4JException;
@ -69,8 +70,9 @@ public class cudaEvent_t extends CudaPointer {
if (res == 0) if (res == 0)
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]"); throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0) val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage()); if (code != 0)
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
} }
} }
@ -78,8 +80,9 @@ public class cudaEvent_t extends CudaPointer {
if (!isDestroyed()) { if (!isDestroyed()) {
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream); int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream);
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0) val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage()); if (code != 0)
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
} }
} }
} }

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.jcublas.blas; package org.nd4j.linalg.jcublas.blas;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.bytedeco.javacpp.DoublePointer; import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer; import org.bytedeco.javacpp.FloatPointer;
@ -52,6 +53,7 @@ import static org.nd4j.linalg.jcublas.blas.CudaBlas.*;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j
public class JcublasLevel3 extends BaseLevel3 { public class JcublasLevel3 extends BaseLevel3 {
private Allocator allocator = AtomicAllocator.getInstance(); private Allocator allocator = AtomicAllocator.getInstance();
private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas(); private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas();
@ -78,7 +80,7 @@ public class JcublasLevel3 extends BaseLevel3 {
int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture(); int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch == 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) { if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch >= 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) {
// on these selected archs we run with cublasHgemm // on these selected archs we run with cublasHgemm
__half alphaHalf = new __half(); __half alphaHalf = new __half();
__half betaHalf = new __half(); __half betaHalf = new __half();
@ -96,7 +98,11 @@ public class JcublasLevel3 extends BaseLevel3 {
new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda, new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda,
(ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta), (ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta),
(ShortPointer) cCPointer.getDevicePointer(), 2, ldc); (ShortPointer) cCPointer.getDevicePointer(), 2, ldc);
} }
ctx.getOldStream().synchronize();
} }
allocator.registerAction(ctx, C, A, B); allocator.registerAction(ctx, C, A, B);
@ -114,18 +120,24 @@ public class JcublasLevel3 extends BaseLevel3 {
val ctx = allocator.getFlowController().prepareAction(C, A, B); val ctx = allocator.getFlowController().prepareAction(C, A, B);
//log.info("Synchronizing CUDA stream");
ctx.getOldStream().synchronize();
val cAPointer = new CublasPointer(A, ctx); val cAPointer = new CublasPointer(A, ctx);
val cBPointer = new CublasPointer(B, ctx); val cBPointer = new CublasPointer(B, ctx);
val cCPointer = new CublasPointer(C, ctx); val cCPointer = new CublasPointer(C, ctx);
val handle = ctx.getCublasHandle(); val handle = ctx.getCublasHandle();
synchronized (handle) { synchronized (handle) {
//log.info("Handle: {}; Stream: {}", handle.address(), ctx.getCublasStream().address());
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K, cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda, new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda,
(FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta), (FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta),
(FloatPointer) cCPointer.getDevicePointer(), ldc); (FloatPointer) cCPointer.getDevicePointer(), ldc);
ctx.getOldStream().synchronize();
} }
allocator.registerAction(ctx, C, A, B); allocator.registerAction(ctx, C, A, B);
@ -244,6 +256,8 @@ public class JcublasLevel3 extends BaseLevel3 {
new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda, new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda,
(DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta), (DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta),
(DoublePointer) cCPointer.getDevicePointer(), ldc); (DoublePointer) cCPointer.getDevicePointer(), ldc);
ctx.getOldStream().synchronize();
} }
allocator.registerAction(ctx, C, A, B); allocator.registerAction(ctx, C, A, B);

View File

@ -2548,6 +2548,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
@ -2562,6 +2565,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
@ -2577,6 +2583,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) { public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length); OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
@ -2590,6 +2599,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) { public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length); OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)