[WIP] Thread safety (#229)
* sync after cublas*gemm Signed-off-by: raver119 <raver119@gmail.com> * mutex for CublasHelper Signed-off-by: raver119 <raver119@gmail.com> * don't store cublasHandle in LaunchContext, it's per-device anyway Signed-off-by: raver119 <raver119@gmail.com> * some printout Signed-off-by: raver119 <raver119@gmail.com> * check for field instead Signed-off-by: raver119 <raver119@gmail.com> * pew-pew Signed-off-by: raver119 <raver119@gmail.com> * don't release ContextBuffers until device changed Signed-off-by: raver119 <raver119@gmail.com> * small tweak Signed-off-by: raver119 <raver119@gmail.com> * some logging in sgemm Signed-off-by: raver119 <raver119@gmail.com> * stream sync Signed-off-by: raver119 <raver119@gmail.com> * some more logging Signed-off-by: raver119 <raver119@gmail.com> * some more error checks Signed-off-by: raver119 <raver119@gmail.com> * one fancy test Signed-off-by: raver119 <raver119@gmail.com> * one fancy test Signed-off-by: raver119 <raver119@gmail.com> * minor AffinityManager fix Signed-off-by: raver119 <raver119@gmail.com> * cudaEvent error logging improvement Signed-off-by: raver119 <raver119@gmail.com> * ConstantHelper thread safety Signed-off-by: raver119 <raver119@gmail.com> * - minor corrections in ConstantTadHelper Signed-off-by: Yurii <yurii@skymind.io> * ConstantShapeHelper thread safety Signed-off-by: raver119 <raver119@gmail.com> * ConstantTadHelper.cu updated Signed-off-by: raver119 <raver119@gmail.com> * logging off Signed-off-by: raver119 <raver119@gmail.com> * logging off Signed-off-by: raver119 <raver119@gmail.com>master
parent
5be43e7253
commit
dddc8a1143
|
@ -0,0 +1,63 @@
|
|||
package org.deeplearning4j;
|
||||
|
||||
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.RmsProp;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
@Ignore
|
||||
public class RandomTests {
|
||||
|
||||
@Test
|
||||
public void testReproduce() throws Exception {
|
||||
|
||||
final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp())
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
|
||||
.layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(10)
|
||||
.activation(Activation.TANH).build())
|
||||
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
|
||||
LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10)
|
||||
.activation(Activation.SOFTMAX).build())
|
||||
.build();
|
||||
|
||||
for (int e = 0; e < 3; e++) {
|
||||
|
||||
int nThreads = 10;
|
||||
final CountDownLatch l = new CountDownLatch(nThreads);
|
||||
for (int i = 0; i < nThreads; i++) {
|
||||
final int j = i;
|
||||
Thread t = new Thread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf.clone());
|
||||
net.init();
|
||||
DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(10, false, 12345), 100);
|
||||
net.fit(iter);
|
||||
} catch (Throwable t) {
|
||||
System.out.println("Thread failed: " + j);
|
||||
t.printStackTrace();
|
||||
} finally {
|
||||
l.countDown();
|
||||
}
|
||||
}
|
||||
});
|
||||
t.start();
|
||||
}
|
||||
|
||||
l.await();
|
||||
System.out.println("DONE " + e + "\n");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -24,11 +24,13 @@
|
|||
#include <map>
|
||||
#include <array/ConstantDescriptor.h>
|
||||
#include <array/ConstantDataBuffer.h>
|
||||
#include <mutex>
|
||||
|
||||
namespace nd4j {
|
||||
class ConstantHolder {
|
||||
private:
|
||||
int _deviceId = 0;
|
||||
std::mutex _mutex;
|
||||
|
||||
std::map<nd4j::DataType, ConstantDataBuffer> _buffers;
|
||||
public:
|
||||
|
@ -53,6 +55,8 @@ namespace nd4j {
|
|||
|
||||
template <typename T>
|
||||
ConstantDataBuffer* getConstantDataBuffer();
|
||||
|
||||
std::mutex* mutex();
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,10 @@ namespace nd4j {
|
|||
return _buffers.count(dataType) > 0;
|
||||
}
|
||||
|
||||
std::mutex* ConstantHolder::mutex() {
|
||||
return &_mutex;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ConstantHolder::hasBuffer() {
|
||||
return hasBuffer(DataTypeUtils::fromT<T>());
|
||||
|
|
|
@ -47,7 +47,7 @@ namespace nd4j {
|
|||
|
||||
_currentMutex.unlock();
|
||||
|
||||
setCurrentDevice(globalThreadToDevice);
|
||||
setCurrentNativeDevice(globalThreadToDevice);
|
||||
}
|
||||
|
||||
// if we already know affinity - just return it
|
||||
|
@ -92,6 +92,8 @@ namespace nd4j {
|
|||
|
||||
void AffinityManager::setCurrentNativeDevice(int deviceId) {
|
||||
auto res = cudaSetDevice(deviceId);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("setCurrentDevice failed", res);
|
||||
}
|
||||
|
||||
void AffinityManager::setCurrentDevice(int deviceId) {
|
||||
|
@ -104,17 +106,22 @@ namespace nd4j {
|
|||
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream());
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("setCurrentDevice -> specialSync failed", res);
|
||||
|
||||
if (deviceId != previousDeviceId) {
|
||||
// discard existing stuff
|
||||
nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", "");
|
||||
LaunchContext::releaseBuffers();
|
||||
}
|
||||
}
|
||||
|
||||
auto res = cudaSetDevice(deviceId);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("cudaSetDevice failed", res);
|
||||
if (deviceId != previousDeviceId) {
|
||||
auto res = cudaSetDevice(deviceId);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("cudaSetDevice failed", res);
|
||||
}
|
||||
|
||||
// update thread-device affinity
|
||||
globalThreadToDevice = deviceId;
|
||||
|
||||
// discard existing stuff
|
||||
LaunchContext::releaseBuffers();
|
||||
}
|
||||
|
||||
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);
|
||||
|
|
|
@ -107,7 +107,6 @@ namespace nd4j {
|
|||
|
||||
//////
|
||||
_allocated = false;
|
||||
_initialized = false;
|
||||
_deviceId = -1;
|
||||
|
||||
this->_specialStream = nullptr;
|
||||
|
@ -116,6 +115,8 @@ namespace nd4j {
|
|||
this->_reductionPointer = nullptr;
|
||||
this->_scalarPointer = nullptr;
|
||||
}
|
||||
|
||||
_initialized = false;
|
||||
}
|
||||
|
||||
ContextBuffers::~ContextBuffers() {
|
||||
|
@ -163,21 +164,21 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
void* ContextBuffers::reductionBuffer() {
|
||||
if (_reductionPointer == nullptr)
|
||||
if (!_initialized)
|
||||
initialize();
|
||||
|
||||
return _reductionPointer;
|
||||
}
|
||||
|
||||
void* ContextBuffers::scalarBuffer() {
|
||||
if (_scalarPointer == nullptr)
|
||||
if (!_initialized)
|
||||
initialize();
|
||||
|
||||
return _scalarPointer;
|
||||
}
|
||||
|
||||
void* ContextBuffers::allocationBuffer() {
|
||||
if (_allocationPointer == nullptr)
|
||||
if (!_initialized)
|
||||
initialize();
|
||||
|
||||
return _allocationPointer;
|
||||
|
@ -204,15 +205,23 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
void* ContextBuffers::execStream() {
|
||||
if (_execStream == nullptr)
|
||||
if (!_initialized) {
|
||||
//nd4j_printf("execStream not initialized\n", "");
|
||||
initialize();
|
||||
} else {
|
||||
//nd4j_printf("execStream is initialized\n", "");
|
||||
}
|
||||
|
||||
return _execStream;
|
||||
}
|
||||
|
||||
void* ContextBuffers::specialStream() {
|
||||
if (_specialStream == nullptr)
|
||||
if (!_initialized) {
|
||||
//nd4j_printf("specialStream not initialized\n", "");
|
||||
initialize();
|
||||
} else {
|
||||
//nd4j_printf("specialStream is initialized\n", "");
|
||||
}
|
||||
|
||||
return _specialStream;
|
||||
}
|
||||
|
|
|
@ -57,10 +57,6 @@ LaunchContext::LaunchContext() {
|
|||
_deviceID = 0;
|
||||
|
||||
_isAllocated = true;
|
||||
|
||||
_cublasHandle = CublasHelper::getInstance()->handle();
|
||||
|
||||
_cusolverHandle = CublasHelper::getInstance()->solver();
|
||||
}
|
||||
|
||||
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
|
||||
|
@ -89,13 +85,13 @@ LaunchContext::LaunchContext() {
|
|||
|
||||
_contexts.resize(numDevices);
|
||||
for (int e = 0; e < numDevices; e++) {
|
||||
AffinityManager::setCurrentDevice(e);
|
||||
AffinityManager::setCurrentNativeDevice(e);
|
||||
|
||||
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
|
||||
}
|
||||
|
||||
// don't forget to restore device back again
|
||||
AffinityManager::setCurrentDevice(deviceId);
|
||||
AffinityManager::setCurrentNativeDevice(deviceId);
|
||||
}
|
||||
_mutex.unlock();
|
||||
|
||||
|
@ -117,11 +113,11 @@ LaunchContext::LaunchContext() {
|
|||
};
|
||||
|
||||
void* LaunchContext::getCublasHandle() const {
|
||||
return _cublasHandle;
|
||||
return CublasHelper::getInstance()->handle();
|
||||
};
|
||||
|
||||
void* LaunchContext::getCusolverHandle() const {
|
||||
return _cusolverHandle;
|
||||
return CublasHelper::getInstance()->solver();
|
||||
};
|
||||
|
||||
cudaStream_t* LaunchContext::getCudaStream() const {
|
||||
|
@ -162,6 +158,7 @@ LaunchContext::LaunchContext() {
|
|||
};
|
||||
|
||||
void LaunchContext::releaseBuffers() {
|
||||
nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", "");
|
||||
contextBuffers.release();
|
||||
}
|
||||
|
||||
|
|
|
@ -38,12 +38,13 @@ namespace nd4j {
|
|||
static ConstantHelper* _INSTANCE;
|
||||
ConstantHelper();
|
||||
|
||||
std::vector<std::map<ConstantDescriptor, ConstantHolder>> _cache;
|
||||
std::vector<std::map<ConstantDescriptor, ConstantHolder*>> _cache;
|
||||
|
||||
// tracking of per-device constant memory buffers (CUDA only atm)
|
||||
std::vector<Nd4jPointer> _devicePointers;
|
||||
std::vector<Nd4jLong> _deviceOffsets;
|
||||
std::mutex _mutex;
|
||||
std::mutex _mutexHolder;
|
||||
|
||||
std::vector<Nd4jLong> _counters;
|
||||
public:
|
||||
|
|
|
@ -48,10 +48,10 @@ namespace nd4j {
|
|||
static ConstantShapeHelper* getInstance();
|
||||
|
||||
|
||||
ConstantDataBuffer& bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
|
||||
ConstantDataBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor);
|
||||
ConstantDataBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo);
|
||||
ConstantDataBuffer& bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
|
||||
ConstantDataBuffer bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
|
||||
ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor);
|
||||
ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo);
|
||||
ConstantDataBuffer bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
|
||||
|
||||
|
||||
Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType);
|
||||
|
|
|
@ -54,11 +54,11 @@ namespace nd4j {
|
|||
* @param keepUnitiesInShape
|
||||
* @return
|
||||
*/
|
||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
||||
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||
TadPack& tadForDimensions(TadDescriptor &descriptor);
|
||||
TadPack tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||
TadPack tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
||||
TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
||||
TadPack tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||
TadPack tadForDimensions(TadDescriptor &descriptor);
|
||||
|
||||
/**
|
||||
* This method returns number of cached TAD shapes/offsets on specific device
|
||||
|
|
|
@ -33,7 +33,8 @@ namespace nd4j {
|
|||
_cache.resize(numDevices);
|
||||
_counters.resize(numDevices);
|
||||
for (int e = 0; e < numDevices; e++) {
|
||||
std::map<ConstantDescriptor, ConstantHolder> map;
|
||||
std::map<ConstantDescriptor, ConstantHolder*> map;
|
||||
|
||||
_cache[e] = map;
|
||||
_counters[e] = 0L;
|
||||
}
|
||||
|
@ -70,15 +71,26 @@ namespace nd4j {
|
|||
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
|
||||
const auto deviceId = getCurrentDevice();
|
||||
|
||||
// we're locking away cache modification
|
||||
_mutexHolder.lock();
|
||||
|
||||
if (_cache[deviceId].count(descriptor) == 0) {
|
||||
ConstantHolder holder;
|
||||
_cache[deviceId][descriptor] = holder;
|
||||
_cache[deviceId][descriptor] = new ConstantHolder();
|
||||
}
|
||||
|
||||
ConstantHolder* holder = &_cache[deviceId][descriptor];
|
||||
auto holder = _cache[deviceId][descriptor];
|
||||
|
||||
// releasing cache lock
|
||||
_mutexHolder.unlock();
|
||||
|
||||
|
||||
ConstantDataBuffer* result;
|
||||
|
||||
// access to this holder instance is synchronous
|
||||
holder->mutex()->lock();
|
||||
|
||||
if (holder->hasBuffer(dataType))
|
||||
return holder->getConstantDataBuffer(dataType);
|
||||
result = holder->getConstantDataBuffer(dataType);
|
||||
else {
|
||||
auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
||||
auto cbuff = new int8_t[size];
|
||||
|
@ -94,8 +106,11 @@ namespace nd4j {
|
|||
ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
||||
holder->addBuffer(dataBuffer, dataType);
|
||||
|
||||
return holder->getConstantDataBuffer(dataType);
|
||||
result = holder->getConstantDataBuffer(dataType);
|
||||
}
|
||||
holder->mutex()->unlock();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {
|
||||
|
|
|
@ -41,18 +41,18 @@ namespace nd4j {
|
|||
return _INSTANCE;
|
||||
}
|
||||
|
||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
||||
|
||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||
int deviceId = 0;
|
||||
|
||||
_mutex.lock();
|
||||
|
@ -62,19 +62,19 @@ namespace nd4j {
|
|||
ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64);
|
||||
ShapeDescriptor descriptor1(descriptor);
|
||||
_cache[deviceId][descriptor1] = buffer;
|
||||
ConstantDataBuffer &r = _cache[deviceId][descriptor1];
|
||||
auto r = _cache[deviceId][descriptor1];
|
||||
_mutex.unlock();
|
||||
|
||||
return r;
|
||||
} else {
|
||||
ConstantDataBuffer &r = _cache[deviceId].at(descriptor);
|
||||
auto r = _cache[deviceId].at(descriptor);
|
||||
_mutex.unlock();
|
||||
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||
ShapeDescriptor descriptor(shapeInfo);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
|
|
@ -38,25 +38,25 @@ namespace nd4j {
|
|||
return _INSTANCE;
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
||||
return tadForDimensions(tadDescriptor);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
|
||||
return tadForDimensions(tadDescriptor);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||
const int deviceId = 0;
|
||||
|
||||
_mutex.lock();
|
||||
|
@ -105,7 +105,7 @@ namespace nd4j {
|
|||
|
||||
return r;
|
||||
} else {
|
||||
TadPack &r = _cache[deviceId][descriptor];
|
||||
TadPack r = _cache[deviceId][descriptor];
|
||||
_mutex.unlock();
|
||||
|
||||
return r;
|
||||
|
|
|
@ -24,11 +24,13 @@
|
|||
#include <dll.h>
|
||||
#include <pointercast.h>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
|
||||
namespace nd4j {
|
||||
class CublasHelper {
|
||||
private:
|
||||
static CublasHelper *_INSTANCE;
|
||||
static std::mutex _mutex;
|
||||
|
||||
std::vector<void*> _cache;
|
||||
std::vector<void*> _solvers;
|
||||
|
|
|
@ -68,7 +68,7 @@ namespace nd4j {
|
|||
throw cuda_exception::build("cudaSetDevice failed", res);
|
||||
auto constant = getConstantSpace();
|
||||
|
||||
std::map<ConstantDescriptor, ConstantHolder> devCache;
|
||||
std::map<ConstantDescriptor, ConstantHolder*> devCache;
|
||||
|
||||
_devicePointers[e] = constant;
|
||||
_deviceOffsets[e] = 0;
|
||||
|
@ -136,15 +136,24 @@ namespace nd4j {
|
|||
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
|
||||
const auto deviceId = getCurrentDevice();
|
||||
|
||||
if (_cache[deviceId].count(descriptor) == 0) {
|
||||
ConstantHolder holder;
|
||||
_cache[deviceId][descriptor] = holder;
|
||||
}
|
||||
// all cache modifications are synchronous
|
||||
_mutexHolder.lock();
|
||||
|
||||
ConstantHolder* holder = &_cache[deviceId][descriptor];
|
||||
if (_cache[deviceId].count(descriptor) == 0) {
|
||||
_cache[deviceId][descriptor] = new ConstantHolder();
|
||||
}
|
||||
auto holder = _cache[deviceId][descriptor];
|
||||
|
||||
// release cache lock
|
||||
_mutexHolder.unlock();
|
||||
|
||||
ConstantDataBuffer* result;
|
||||
|
||||
// access to this holder instance is synchronous
|
||||
holder->mutex()->lock();
|
||||
|
||||
if (holder->hasBuffer(dataType)) {
|
||||
return holder->getConstantDataBuffer(dataType);
|
||||
result = holder->getConstantDataBuffer(dataType);
|
||||
} else {
|
||||
auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
||||
auto cbuff = new int8_t[numBytes];
|
||||
|
@ -160,10 +169,14 @@ namespace nd4j {
|
|||
auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType));
|
||||
|
||||
ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
||||
holder->addBuffer(dataBuffer, dataType);
|
||||
|
||||
return holder->getConstantDataBuffer(dataType);
|
||||
holder->addBuffer(dataBuffer, dataType);
|
||||
result = holder->getConstantDataBuffer(dataType);
|
||||
}
|
||||
// release holder lock
|
||||
holder->mutex()->unlock();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {
|
||||
|
|
|
@ -44,17 +44,17 @@ namespace nd4j {
|
|||
return _INSTANCE;
|
||||
}
|
||||
|
||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||
int deviceId = AffinityManager::currentDeviceId();
|
||||
|
||||
_mutex.lock();
|
||||
|
@ -65,19 +65,19 @@ namespace nd4j {
|
|||
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
|
||||
ShapeDescriptor descriptor1(descriptor);
|
||||
_cache[deviceId][descriptor1] = buffer;
|
||||
ConstantDataBuffer &r = _cache[deviceId][descriptor1];
|
||||
auto r = _cache[deviceId][descriptor1];
|
||||
_mutex.unlock();
|
||||
|
||||
return r;
|
||||
} else {
|
||||
ConstantDataBuffer &r = _cache[deviceId].at(descriptor);
|
||||
ConstantDataBuffer r = _cache[deviceId].at(descriptor);
|
||||
_mutex.unlock();
|
||||
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||
ShapeDescriptor descriptor(shapeInfo);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
|
|
@ -43,25 +43,25 @@ namespace nd4j {
|
|||
return _INSTANCE;
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
||||
return tadForDimensions(tadDescriptor);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
|
||||
return tadForDimensions(tadDescriptor);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||
const int deviceId = AffinityManager::currentDeviceId();
|
||||
|
||||
_mutex.lock();
|
||||
|
@ -96,14 +96,14 @@ namespace nd4j {
|
|||
TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs);
|
||||
_cache[deviceId][descriptor] = t;
|
||||
|
||||
TadPack &r = _cache[deviceId][descriptor];
|
||||
TadPack r = _cache[deviceId][descriptor];
|
||||
_mutex.unlock();
|
||||
|
||||
delete[] shapeInfo;
|
||||
|
||||
return r;
|
||||
} else {
|
||||
TadPack &r = _cache[deviceId][descriptor];
|
||||
TadPack r = _cache[deviceId][descriptor];
|
||||
_mutex.unlock();
|
||||
|
||||
return r;
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <execution/AffinityManager.h>
|
||||
|
||||
namespace nd4j {
|
||||
std::mutex CublasHelper::_mutex;
|
||||
|
||||
static void* handle_() {
|
||||
auto _handle = new cublasHandle_t();
|
||||
|
@ -56,22 +57,24 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
CublasHelper::CublasHelper() {
|
||||
//nd4j_printf("Initializing cuBLAS\n","");
|
||||
auto numDevices = AffinityManager::numberOfDevices();
|
||||
auto currentDevice = AffinityManager::currentDeviceId();
|
||||
_cache.resize(numDevices);
|
||||
_solvers.resize(numDevices);
|
||||
for (int e = 0; e < numDevices; e++) {
|
||||
AffinityManager::setCurrentDevice(e);
|
||||
AffinityManager::setCurrentNativeDevice(e);
|
||||
|
||||
_cache[e] = handle_();
|
||||
_solvers[e] = solver_();
|
||||
}
|
||||
|
||||
// don't forget to restore back original device
|
||||
AffinityManager::setCurrentDevice(currentDevice);
|
||||
AffinityManager::setCurrentNativeDevice(currentDevice);
|
||||
}
|
||||
|
||||
CublasHelper::~CublasHelper() {
|
||||
nd4j_printf("Releasing cuBLAS\n","");
|
||||
auto numDevices = AffinityManager::numberOfDevices();
|
||||
|
||||
for (int e = 0; e < numDevices; e++)
|
||||
|
@ -79,8 +82,10 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
CublasHelper* CublasHelper::getInstance() {
|
||||
_mutex.lock();
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new nd4j::CublasHelper();
|
||||
_mutex.unlock();
|
||||
|
||||
return _INSTANCE;
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.jita.allocator.pointers.cuda;
|
|||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.val;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
||||
import org.nd4j.linalg.exception.ND4JException;
|
||||
|
@ -69,8 +70,9 @@ public class cudaEvent_t extends CudaPointer {
|
|||
if (res == 0)
|
||||
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
|
||||
|
||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
|
||||
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
|
||||
val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||
if (code != 0)
|
||||
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,8 +80,9 @@ public class cudaEvent_t extends CudaPointer {
|
|||
if (!isDestroyed()) {
|
||||
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream);
|
||||
|
||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
|
||||
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
|
||||
val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||
if (code != 0)
|
||||
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.linalg.jcublas.blas;
|
||||
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.bytedeco.javacpp.DoublePointer;
|
||||
import org.bytedeco.javacpp.FloatPointer;
|
||||
|
@ -52,6 +53,7 @@ import static org.nd4j.linalg.jcublas.blas.CudaBlas.*;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@Slf4j
|
||||
public class JcublasLevel3 extends BaseLevel3 {
|
||||
private Allocator allocator = AtomicAllocator.getInstance();
|
||||
private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas();
|
||||
|
@ -78,7 +80,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
|||
|
||||
int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
|
||||
|
||||
if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch == 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) {
|
||||
if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch >= 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) {
|
||||
// on these selected archs we run with cublasHgemm
|
||||
__half alphaHalf = new __half();
|
||||
__half betaHalf = new __half();
|
||||
|
@ -96,7 +98,11 @@ public class JcublasLevel3 extends BaseLevel3 {
|
|||
new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda,
|
||||
(ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta),
|
||||
(ShortPointer) cCPointer.getDevicePointer(), 2, ldc);
|
||||
|
||||
|
||||
}
|
||||
|
||||
ctx.getOldStream().synchronize();
|
||||
}
|
||||
|
||||
allocator.registerAction(ctx, C, A, B);
|
||||
|
@ -114,18 +120,24 @@ public class JcublasLevel3 extends BaseLevel3 {
|
|||
|
||||
val ctx = allocator.getFlowController().prepareAction(C, A, B);
|
||||
|
||||
//log.info("Synchronizing CUDA stream");
|
||||
ctx.getOldStream().synchronize();
|
||||
|
||||
val cAPointer = new CublasPointer(A, ctx);
|
||||
val cBPointer = new CublasPointer(B, ctx);
|
||||
val cCPointer = new CublasPointer(C, ctx);
|
||||
|
||||
val handle = ctx.getCublasHandle();
|
||||
synchronized (handle) {
|
||||
//log.info("Handle: {}; Stream: {}", handle.address(), ctx.getCublasStream().address());
|
||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||
|
||||
cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
|
||||
new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda,
|
||||
(FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta),
|
||||
(FloatPointer) cCPointer.getDevicePointer(), ldc);
|
||||
|
||||
ctx.getOldStream().synchronize();
|
||||
}
|
||||
|
||||
allocator.registerAction(ctx, C, A, B);
|
||||
|
@ -244,6 +256,8 @@ public class JcublasLevel3 extends BaseLevel3 {
|
|||
new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda,
|
||||
(DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta),
|
||||
(DoublePointer) cCPointer.getDevicePointer(), ldc);
|
||||
|
||||
ctx.getOldStream().synchronize();
|
||||
}
|
||||
|
||||
allocator.registerAction(ctx, C, A, B);
|
||||
|
|
|
@ -2548,6 +2548,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
|
@ -2562,6 +2565,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
|
@ -2577,6 +2583,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
|
@ -2590,6 +2599,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
|
|
Loading…
Reference in New Issue