[WIP] Thread safety (#229)

* sync after cublas*gemm

Signed-off-by: raver119 <raver119@gmail.com>

* mutex for CublasHelper

Signed-off-by: raver119 <raver119@gmail.com>

* don't store cublasHandle in LaunchContext, it's per-device anyway

Signed-off-by: raver119 <raver119@gmail.com>

* some printout

Signed-off-by: raver119 <raver119@gmail.com>

* check for field instead

Signed-off-by: raver119 <raver119@gmail.com>

* pew-pew

Signed-off-by: raver119 <raver119@gmail.com>

* don't release ContextBuffers until device changed

Signed-off-by: raver119 <raver119@gmail.com>

* small tweak

Signed-off-by: raver119 <raver119@gmail.com>

* some logging in sgemm

Signed-off-by: raver119 <raver119@gmail.com>

* stream sync

Signed-off-by: raver119 <raver119@gmail.com>

* some more logging

Signed-off-by: raver119 <raver119@gmail.com>

* some more error checks

Signed-off-by: raver119 <raver119@gmail.com>

* one fancy test

Signed-off-by: raver119 <raver119@gmail.com>

* one fancy test

Signed-off-by: raver119 <raver119@gmail.com>

* minor AffinityManager fix

Signed-off-by: raver119 <raver119@gmail.com>

* cudaEvent error logging improvement

Signed-off-by: raver119 <raver119@gmail.com>

* ConstantHelper thread safety

Signed-off-by: raver119 <raver119@gmail.com>

* - minor corrections in ConstantTadHelper

Signed-off-by: Yurii <yurii@skymind.io>

* ConstantShapeHelper thread safety

Signed-off-by: raver119 <raver119@gmail.com>

* ConstantTadHelper.cu updated

Signed-off-by: raver119 <raver119@gmail.com>

* logging off

Signed-off-by: raver119 <raver119@gmail.com>

* logging off

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-09-03 22:00:38 +03:00 committed by GitHub
parent 5be43e7253
commit dddc8a1143
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 227 additions and 78 deletions

View File

@ -0,0 +1,63 @@
package org.deeplearning4j;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.concurrent.CountDownLatch;
@Ignore
public class RandomTests {
@Test
public void testReproduce() throws Exception {
final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(10)
.activation(Activation.TANH).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10)
.activation(Activation.SOFTMAX).build())
.build();
for (int e = 0; e < 3; e++) {
int nThreads = 10;
final CountDownLatch l = new CountDownLatch(nThreads);
for (int i = 0; i < nThreads; i++) {
final int j = i;
Thread t = new Thread(new Runnable() {
@Override
public void run() {
try {
MultiLayerNetwork net = new MultiLayerNetwork(conf.clone());
net.init();
DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(10, false, 12345), 100);
net.fit(iter);
} catch (Throwable t) {
System.out.println("Thread failed: " + j);
t.printStackTrace();
} finally {
l.countDown();
}
}
});
t.start();
}
l.await();
System.out.println("DONE " + e + "\n");
}
}
}

View File

@ -24,11 +24,13 @@
#include <map>
#include <array/ConstantDescriptor.h>
#include <array/ConstantDataBuffer.h>
#include <mutex>
namespace nd4j {
class ConstantHolder {
private:
int _deviceId = 0;
std::mutex _mutex;
std::map<nd4j::DataType, ConstantDataBuffer> _buffers;
public:
@ -53,6 +55,8 @@ namespace nd4j {
template <typename T>
ConstantDataBuffer* getConstantDataBuffer();
std::mutex* mutex();
};
}

View File

@ -16,6 +16,10 @@ namespace nd4j {
return _buffers.count(dataType) > 0;
}
std::mutex* ConstantHolder::mutex() {
return &_mutex;
}
template <typename T>
bool ConstantHolder::hasBuffer() {
return hasBuffer(DataTypeUtils::fromT<T>());

View File

@ -47,7 +47,7 @@ namespace nd4j {
_currentMutex.unlock();
setCurrentDevice(globalThreadToDevice);
setCurrentNativeDevice(globalThreadToDevice);
}
// if we already know affinity - just return it
@ -92,6 +92,8 @@ namespace nd4j {
void AffinityManager::setCurrentNativeDevice(int deviceId) {
auto res = cudaSetDevice(deviceId);
if (res != 0)
throw cuda_exception::build("setCurrentDevice failed", res);
}
void AffinityManager::setCurrentDevice(int deviceId) {
@ -104,17 +106,22 @@ namespace nd4j {
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream());
if (res != 0)
throw cuda_exception::build("setCurrentDevice -> specialSync failed", res);
if (deviceId != previousDeviceId) {
// discard existing stuff
nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", "");
LaunchContext::releaseBuffers();
}
}
auto res = cudaSetDevice(deviceId);
if (res != 0)
throw cuda_exception::build("cudaSetDevice failed", res);
if (deviceId != previousDeviceId) {
auto res = cudaSetDevice(deviceId);
if (res != 0)
throw cuda_exception::build("cudaSetDevice failed", res);
}
// update thread-device affinity
globalThreadToDevice = deviceId;
// discard existing stuff
LaunchContext::releaseBuffers();
}
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);

View File

@ -107,7 +107,6 @@ namespace nd4j {
//////
_allocated = false;
_initialized = false;
_deviceId = -1;
this->_specialStream = nullptr;
@ -116,6 +115,8 @@ namespace nd4j {
this->_reductionPointer = nullptr;
this->_scalarPointer = nullptr;
}
_initialized = false;
}
ContextBuffers::~ContextBuffers() {
@ -163,21 +164,21 @@ namespace nd4j {
}
void* ContextBuffers::reductionBuffer() {
if (_reductionPointer == nullptr)
if (!_initialized)
initialize();
return _reductionPointer;
}
void* ContextBuffers::scalarBuffer() {
if (_scalarPointer == nullptr)
if (!_initialized)
initialize();
return _scalarPointer;
}
void* ContextBuffers::allocationBuffer() {
if (_allocationPointer == nullptr)
if (!_initialized)
initialize();
return _allocationPointer;
@ -204,15 +205,23 @@ namespace nd4j {
}
void* ContextBuffers::execStream() {
if (_execStream == nullptr)
if (!_initialized) {
//nd4j_printf("execStream not initialized\n", "");
initialize();
} else {
//nd4j_printf("execStream is initialized\n", "");
}
return _execStream;
}
void* ContextBuffers::specialStream() {
if (_specialStream == nullptr)
if (!_initialized) {
//nd4j_printf("specialStream not initialized\n", "");
initialize();
} else {
//nd4j_printf("specialStream is initialized\n", "");
}
return _specialStream;
}

View File

@ -57,10 +57,6 @@ LaunchContext::LaunchContext() {
_deviceID = 0;
_isAllocated = true;
_cublasHandle = CublasHelper::getInstance()->handle();
_cusolverHandle = CublasHelper::getInstance()->solver();
}
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
@ -89,13 +85,13 @@ LaunchContext::LaunchContext() {
_contexts.resize(numDevices);
for (int e = 0; e < numDevices; e++) {
AffinityManager::setCurrentDevice(e);
AffinityManager::setCurrentNativeDevice(e);
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
}
// don't forget to restore device back again
AffinityManager::setCurrentDevice(deviceId);
AffinityManager::setCurrentNativeDevice(deviceId);
}
_mutex.unlock();
@ -117,11 +113,11 @@ LaunchContext::LaunchContext() {
};
void* LaunchContext::getCublasHandle() const {
return _cublasHandle;
return CublasHelper::getInstance()->handle();
};
void* LaunchContext::getCusolverHandle() const {
return _cusolverHandle;
return CublasHelper::getInstance()->solver();
};
cudaStream_t* LaunchContext::getCudaStream() const {
@ -162,6 +158,7 @@ LaunchContext::LaunchContext() {
};
void LaunchContext::releaseBuffers() {
nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", "");
contextBuffers.release();
}

View File

@ -38,12 +38,13 @@ namespace nd4j {
static ConstantHelper* _INSTANCE;
ConstantHelper();
std::vector<std::map<ConstantDescriptor, ConstantHolder>> _cache;
std::vector<std::map<ConstantDescriptor, ConstantHolder*>> _cache;
// tracking of per-device constant memory buffers (CUDA only atm)
std::vector<Nd4jPointer> _devicePointers;
std::vector<Nd4jLong> _deviceOffsets;
std::mutex _mutex;
std::mutex _mutexHolder;
std::vector<Nd4jLong> _counters;
public:

View File

@ -48,10 +48,10 @@ namespace nd4j {
static ConstantShapeHelper* getInstance();
ConstantDataBuffer& bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
ConstantDataBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor);
ConstantDataBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo);
ConstantDataBuffer& bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
ConstantDataBuffer bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor);
ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo);
ConstantDataBuffer bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType);

View File

@ -54,11 +54,11 @@ namespace nd4j {
* @param keepUnitiesInShape
* @return
*/
TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(TadDescriptor &descriptor);
TadPack tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
TadPack tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
TadPack tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
TadPack tadForDimensions(TadDescriptor &descriptor);
/**
* This method returns number of cached TAD shapes/offsets on specific device

View File

@ -33,7 +33,8 @@ namespace nd4j {
_cache.resize(numDevices);
_counters.resize(numDevices);
for (int e = 0; e < numDevices; e++) {
std::map<ConstantDescriptor, ConstantHolder> map;
std::map<ConstantDescriptor, ConstantHolder*> map;
_cache[e] = map;
_counters[e] = 0L;
}
@ -70,15 +71,26 @@ namespace nd4j {
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
const auto deviceId = getCurrentDevice();
// we're locking away cache modification
_mutexHolder.lock();
if (_cache[deviceId].count(descriptor) == 0) {
ConstantHolder holder;
_cache[deviceId][descriptor] = holder;
_cache[deviceId][descriptor] = new ConstantHolder();
}
ConstantHolder* holder = &_cache[deviceId][descriptor];
auto holder = _cache[deviceId][descriptor];
// releasing cache lock
_mutexHolder.unlock();
ConstantDataBuffer* result;
// access to this holder instance is synchronous
holder->mutex()->lock();
if (holder->hasBuffer(dataType))
return holder->getConstantDataBuffer(dataType);
result = holder->getConstantDataBuffer(dataType);
else {
auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType);
auto cbuff = new int8_t[size];
@ -94,8 +106,11 @@ namespace nd4j {
ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType));
holder->addBuffer(dataBuffer, dataType);
return holder->getConstantDataBuffer(dataType);
result = holder->getConstantDataBuffer(dataType);
}
holder->mutex()->unlock();
return result;
}
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {

View File

@ -41,18 +41,18 @@ namespace nd4j {
return _INSTANCE;
}
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
ShapeDescriptor descriptor(dataType, order, shape);
return bufferForShapeInfo(descriptor);
}
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ShapeDescriptor descriptor(dataType, order, shape, rank);
return bufferForShapeInfo(descriptor);
}
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
int deviceId = 0;
_mutex.lock();
@ -62,19 +62,19 @@ namespace nd4j {
ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64);
ShapeDescriptor descriptor1(descriptor);
_cache[deviceId][descriptor1] = buffer;
ConstantDataBuffer &r = _cache[deviceId][descriptor1];
auto r = _cache[deviceId][descriptor1];
_mutex.unlock();
return r;
} else {
ConstantDataBuffer &r = _cache[deviceId].at(descriptor);
auto r = _cache[deviceId].at(descriptor);
_mutex.unlock();
return r;
}
}
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
ShapeDescriptor descriptor(shapeInfo);
return bufferForShapeInfo(descriptor);
}

View File

@ -38,25 +38,25 @@ namespace nd4j {
return _INSTANCE;
}
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
}
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
}
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
return tadForDimensions(tadDescriptor);
}
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
return tadForDimensions(tadDescriptor);
}
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
const int deviceId = 0;
_mutex.lock();
@ -105,7 +105,7 @@ namespace nd4j {
return r;
} else {
TadPack &r = _cache[deviceId][descriptor];
TadPack r = _cache[deviceId][descriptor];
_mutex.unlock();
return r;

View File

@ -24,11 +24,13 @@
#include <dll.h>
#include <pointercast.h>
#include <vector>
#include <mutex>
namespace nd4j {
class CublasHelper {
private:
static CublasHelper *_INSTANCE;
static std::mutex _mutex;
std::vector<void*> _cache;
std::vector<void*> _solvers;

View File

@ -68,7 +68,7 @@ namespace nd4j {
throw cuda_exception::build("cudaSetDevice failed", res);
auto constant = getConstantSpace();
std::map<ConstantDescriptor, ConstantHolder> devCache;
std::map<ConstantDescriptor, ConstantHolder*> devCache;
_devicePointers[e] = constant;
_deviceOffsets[e] = 0;
@ -136,15 +136,24 @@ namespace nd4j {
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
const auto deviceId = getCurrentDevice();
if (_cache[deviceId].count(descriptor) == 0) {
ConstantHolder holder;
_cache[deviceId][descriptor] = holder;
}
// all cache modifications are synchronous
_mutexHolder.lock();
ConstantHolder* holder = &_cache[deviceId][descriptor];
if (_cache[deviceId].count(descriptor) == 0) {
_cache[deviceId][descriptor] = new ConstantHolder();
}
auto holder = _cache[deviceId][descriptor];
// release cache lock
_mutexHolder.unlock();
ConstantDataBuffer* result;
// access to this holder instance is synchronous
holder->mutex()->lock();
if (holder->hasBuffer(dataType)) {
return holder->getConstantDataBuffer(dataType);
result = holder->getConstantDataBuffer(dataType);
} else {
auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType);
auto cbuff = new int8_t[numBytes];
@ -160,10 +169,14 @@ namespace nd4j {
auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType));
ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType));
holder->addBuffer(dataBuffer, dataType);
return holder->getConstantDataBuffer(dataType);
holder->addBuffer(dataBuffer, dataType);
result = holder->getConstantDataBuffer(dataType);
}
// release holder lock
holder->mutex()->unlock();
return result;
}
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {

View File

@ -44,17 +44,17 @@ namespace nd4j {
return _INSTANCE;
}
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
ShapeDescriptor descriptor(dataType, order, shape);
return bufferForShapeInfo(descriptor);
}
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
ShapeDescriptor descriptor(dataType, order, shape, rank);
return bufferForShapeInfo(descriptor);
}
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
int deviceId = AffinityManager::currentDeviceId();
_mutex.lock();
@ -65,19 +65,19 @@ namespace nd4j {
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
ShapeDescriptor descriptor1(descriptor);
_cache[deviceId][descriptor1] = buffer;
ConstantDataBuffer &r = _cache[deviceId][descriptor1];
auto r = _cache[deviceId][descriptor1];
_mutex.unlock();
return r;
} else {
ConstantDataBuffer &r = _cache[deviceId].at(descriptor);
ConstantDataBuffer r = _cache[deviceId].at(descriptor);
_mutex.unlock();
return r;
}
}
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
ShapeDescriptor descriptor(shapeInfo);
return bufferForShapeInfo(descriptor);
}

View File

@ -43,25 +43,25 @@ namespace nd4j {
return _INSTANCE;
}
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
}
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
}
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
return tadForDimensions(tadDescriptor);
}
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
return tadForDimensions(tadDescriptor);
}
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
const int deviceId = AffinityManager::currentDeviceId();
_mutex.lock();
@ -96,14 +96,14 @@ namespace nd4j {
TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs);
_cache[deviceId][descriptor] = t;
TadPack &r = _cache[deviceId][descriptor];
TadPack r = _cache[deviceId][descriptor];
_mutex.unlock();
delete[] shapeInfo;
return r;
} else {
TadPack &r = _cache[deviceId][descriptor];
TadPack r = _cache[deviceId][descriptor];
_mutex.unlock();
return r;

View File

@ -27,6 +27,7 @@
#include <execution/AffinityManager.h>
namespace nd4j {
std::mutex CublasHelper::_mutex;
static void* handle_() {
auto _handle = new cublasHandle_t();
@ -56,22 +57,24 @@ namespace nd4j {
}
CublasHelper::CublasHelper() {
//nd4j_printf("Initializing cuBLAS\n","");
auto numDevices = AffinityManager::numberOfDevices();
auto currentDevice = AffinityManager::currentDeviceId();
_cache.resize(numDevices);
_solvers.resize(numDevices);
for (int e = 0; e < numDevices; e++) {
AffinityManager::setCurrentDevice(e);
AffinityManager::setCurrentNativeDevice(e);
_cache[e] = handle_();
_solvers[e] = solver_();
}
// don't forget to restore back original device
AffinityManager::setCurrentDevice(currentDevice);
AffinityManager::setCurrentNativeDevice(currentDevice);
}
CublasHelper::~CublasHelper() {
nd4j_printf("Releasing cuBLAS\n","");
auto numDevices = AffinityManager::numberOfDevices();
for (int e = 0; e < numDevices; e++)
@ -79,8 +82,10 @@ namespace nd4j {
}
CublasHelper* CublasHelper::getInstance() {
_mutex.lock();
if (!_INSTANCE)
_INSTANCE = new nd4j::CublasHelper();
_mutex.unlock();
return _INSTANCE;
}

View File

@ -18,6 +18,7 @@ package org.nd4j.jita.allocator.pointers.cuda;
import lombok.Getter;
import lombok.Setter;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.linalg.exception.ND4JException;
@ -69,8 +70,9 @@ public class cudaEvent_t extends CudaPointer {
if (res == 0)
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (code != 0)
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
}
}
@ -78,8 +80,9 @@ public class cudaEvent_t extends CudaPointer {
if (!isDestroyed()) {
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream);
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (code != 0)
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
}
}
}

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.jcublas.blas;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
@ -52,6 +53,7 @@ import static org.nd4j.linalg.jcublas.blas.CudaBlas.*;
*
* @author Adam Gibson
*/
@Slf4j
public class JcublasLevel3 extends BaseLevel3 {
private Allocator allocator = AtomicAllocator.getInstance();
private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas();
@ -78,7 +80,7 @@ public class JcublasLevel3 extends BaseLevel3 {
int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch == 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) {
if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch >= 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) {
// on these selected archs we run with cublasHgemm
__half alphaHalf = new __half();
__half betaHalf = new __half();
@ -96,7 +98,11 @@ public class JcublasLevel3 extends BaseLevel3 {
new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda,
(ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta),
(ShortPointer) cCPointer.getDevicePointer(), 2, ldc);
}
ctx.getOldStream().synchronize();
}
allocator.registerAction(ctx, C, A, B);
@ -114,18 +120,24 @@ public class JcublasLevel3 extends BaseLevel3 {
val ctx = allocator.getFlowController().prepareAction(C, A, B);
//log.info("Synchronizing CUDA stream");
ctx.getOldStream().synchronize();
val cAPointer = new CublasPointer(A, ctx);
val cBPointer = new CublasPointer(B, ctx);
val cCPointer = new CublasPointer(C, ctx);
val handle = ctx.getCublasHandle();
synchronized (handle) {
//log.info("Handle: {}; Stream: {}", handle.address(), ctx.getCublasStream().address());
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda,
(FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta),
(FloatPointer) cCPointer.getDevicePointer(), ldc);
ctx.getOldStream().synchronize();
}
allocator.registerAction(ctx, C, A, B);
@ -244,6 +256,8 @@ public class JcublasLevel3 extends BaseLevel3 {
new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda,
(DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta),
(DoublePointer) cCPointer.getDevicePointer(), ldc);
ctx.getOldStream().synchronize();
}
allocator.registerAction(ctx, C, A, B);

View File

@ -2548,6 +2548,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
if (nativeOps.lastErrorCode() != 0)
@ -2562,6 +2565,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
if (nativeOps.lastErrorCode() != 0)
@ -2577,6 +2583,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
if (nativeOps.lastErrorCode() != 0)
@ -2590,6 +2599,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
if (nativeOps.lastErrorCode() != 0)