[WIP] Thread safety (#229)
* sync after cublas*gemm Signed-off-by: raver119 <raver119@gmail.com> * mutex for CublasHelper Signed-off-by: raver119 <raver119@gmail.com> * don't store cublasHandle in LaunchContext, it's per-device anyway Signed-off-by: raver119 <raver119@gmail.com> * some printout Signed-off-by: raver119 <raver119@gmail.com> * check for field instead Signed-off-by: raver119 <raver119@gmail.com> * pew-pew Signed-off-by: raver119 <raver119@gmail.com> * don't release ContextBuffers until device changed Signed-off-by: raver119 <raver119@gmail.com> * small tweak Signed-off-by: raver119 <raver119@gmail.com> * some logging in sgemm Signed-off-by: raver119 <raver119@gmail.com> * stream sync Signed-off-by: raver119 <raver119@gmail.com> * some more logging Signed-off-by: raver119 <raver119@gmail.com> * some more error checks Signed-off-by: raver119 <raver119@gmail.com> * one fancy test Signed-off-by: raver119 <raver119@gmail.com> * one fancy test Signed-off-by: raver119 <raver119@gmail.com> * minor AffinityManager fix Signed-off-by: raver119 <raver119@gmail.com> * cudaEvent error logging improvement Signed-off-by: raver119 <raver119@gmail.com> * ConstantHelper thread safety Signed-off-by: raver119 <raver119@gmail.com> * - minor corrections in ConstantTadHelper Signed-off-by: Yurii <yurii@skymind.io> * ConstantShapeHelper thread safety Signed-off-by: raver119 <raver119@gmail.com> * ConstantTadHelper.cu updated Signed-off-by: raver119 <raver119@gmail.com> * logging off Signed-off-by: raver119 <raver119@gmail.com> * logging off Signed-off-by: raver119 <raver119@gmail.com>master
parent
5be43e7253
commit
dddc8a1143
|
@ -0,0 +1,63 @@
|
||||||
|
package org.deeplearning4j;
|
||||||
|
|
||||||
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.junit.Ignore;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
|
@Ignore
|
||||||
|
public class RandomTests {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testReproduce() throws Exception {
|
||||||
|
|
||||||
|
final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp())
|
||||||
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
|
||||||
|
.layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(10)
|
||||||
|
.activation(Activation.TANH).build())
|
||||||
|
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
|
||||||
|
LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10)
|
||||||
|
.activation(Activation.SOFTMAX).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
for (int e = 0; e < 3; e++) {
|
||||||
|
|
||||||
|
int nThreads = 10;
|
||||||
|
final CountDownLatch l = new CountDownLatch(nThreads);
|
||||||
|
for (int i = 0; i < nThreads; i++) {
|
||||||
|
final int j = i;
|
||||||
|
Thread t = new Thread(new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
try {
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf.clone());
|
||||||
|
net.init();
|
||||||
|
DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(10, false, 12345), 100);
|
||||||
|
net.fit(iter);
|
||||||
|
} catch (Throwable t) {
|
||||||
|
System.out.println("Thread failed: " + j);
|
||||||
|
t.printStackTrace();
|
||||||
|
} finally {
|
||||||
|
l.countDown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
t.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
l.await();
|
||||||
|
System.out.println("DONE " + e + "\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -24,11 +24,13 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <array/ConstantDescriptor.h>
|
#include <array/ConstantDescriptor.h>
|
||||||
#include <array/ConstantDataBuffer.h>
|
#include <array/ConstantDataBuffer.h>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
class ConstantHolder {
|
class ConstantHolder {
|
||||||
private:
|
private:
|
||||||
int _deviceId = 0;
|
int _deviceId = 0;
|
||||||
|
std::mutex _mutex;
|
||||||
|
|
||||||
std::map<nd4j::DataType, ConstantDataBuffer> _buffers;
|
std::map<nd4j::DataType, ConstantDataBuffer> _buffers;
|
||||||
public:
|
public:
|
||||||
|
@ -53,6 +55,8 @@ namespace nd4j {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ConstantDataBuffer* getConstantDataBuffer();
|
ConstantDataBuffer* getConstantDataBuffer();
|
||||||
|
|
||||||
|
std::mutex* mutex();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,10 @@ namespace nd4j {
|
||||||
return _buffers.count(dataType) > 0;
|
return _buffers.count(dataType) > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::mutex* ConstantHolder::mutex() {
|
||||||
|
return &_mutex;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool ConstantHolder::hasBuffer() {
|
bool ConstantHolder::hasBuffer() {
|
||||||
return hasBuffer(DataTypeUtils::fromT<T>());
|
return hasBuffer(DataTypeUtils::fromT<T>());
|
||||||
|
|
|
@ -47,7 +47,7 @@ namespace nd4j {
|
||||||
|
|
||||||
_currentMutex.unlock();
|
_currentMutex.unlock();
|
||||||
|
|
||||||
setCurrentDevice(globalThreadToDevice);
|
setCurrentNativeDevice(globalThreadToDevice);
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we already know affinity - just return it
|
// if we already know affinity - just return it
|
||||||
|
@ -92,6 +92,8 @@ namespace nd4j {
|
||||||
|
|
||||||
void AffinityManager::setCurrentNativeDevice(int deviceId) {
|
void AffinityManager::setCurrentNativeDevice(int deviceId) {
|
||||||
auto res = cudaSetDevice(deviceId);
|
auto res = cudaSetDevice(deviceId);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("setCurrentDevice failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AffinityManager::setCurrentDevice(int deviceId) {
|
void AffinityManager::setCurrentDevice(int deviceId) {
|
||||||
|
@ -104,17 +106,22 @@ namespace nd4j {
|
||||||
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream());
|
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream());
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw cuda_exception::build("setCurrentDevice -> specialSync failed", res);
|
throw cuda_exception::build("setCurrentDevice -> specialSync failed", res);
|
||||||
|
|
||||||
|
if (deviceId != previousDeviceId) {
|
||||||
|
// discard existing stuff
|
||||||
|
nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", "");
|
||||||
|
LaunchContext::releaseBuffers();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto res = cudaSetDevice(deviceId);
|
if (deviceId != previousDeviceId) {
|
||||||
if (res != 0)
|
auto res = cudaSetDevice(deviceId);
|
||||||
throw cuda_exception::build("cudaSetDevice failed", res);
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("cudaSetDevice failed", res);
|
||||||
|
}
|
||||||
|
|
||||||
// update thread-device affinity
|
// update thread-device affinity
|
||||||
globalThreadToDevice = deviceId;
|
globalThreadToDevice = deviceId;
|
||||||
|
|
||||||
// discard existing stuff
|
|
||||||
LaunchContext::releaseBuffers();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);
|
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);
|
||||||
|
|
|
@ -107,7 +107,6 @@ namespace nd4j {
|
||||||
|
|
||||||
//////
|
//////
|
||||||
_allocated = false;
|
_allocated = false;
|
||||||
_initialized = false;
|
|
||||||
_deviceId = -1;
|
_deviceId = -1;
|
||||||
|
|
||||||
this->_specialStream = nullptr;
|
this->_specialStream = nullptr;
|
||||||
|
@ -116,6 +115,8 @@ namespace nd4j {
|
||||||
this->_reductionPointer = nullptr;
|
this->_reductionPointer = nullptr;
|
||||||
this->_scalarPointer = nullptr;
|
this->_scalarPointer = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_initialized = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ContextBuffers::~ContextBuffers() {
|
ContextBuffers::~ContextBuffers() {
|
||||||
|
@ -163,21 +164,21 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ContextBuffers::reductionBuffer() {
|
void* ContextBuffers::reductionBuffer() {
|
||||||
if (_reductionPointer == nullptr)
|
if (!_initialized)
|
||||||
initialize();
|
initialize();
|
||||||
|
|
||||||
return _reductionPointer;
|
return _reductionPointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ContextBuffers::scalarBuffer() {
|
void* ContextBuffers::scalarBuffer() {
|
||||||
if (_scalarPointer == nullptr)
|
if (!_initialized)
|
||||||
initialize();
|
initialize();
|
||||||
|
|
||||||
return _scalarPointer;
|
return _scalarPointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ContextBuffers::allocationBuffer() {
|
void* ContextBuffers::allocationBuffer() {
|
||||||
if (_allocationPointer == nullptr)
|
if (!_initialized)
|
||||||
initialize();
|
initialize();
|
||||||
|
|
||||||
return _allocationPointer;
|
return _allocationPointer;
|
||||||
|
@ -204,15 +205,23 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ContextBuffers::execStream() {
|
void* ContextBuffers::execStream() {
|
||||||
if (_execStream == nullptr)
|
if (!_initialized) {
|
||||||
|
//nd4j_printf("execStream not initialized\n", "");
|
||||||
initialize();
|
initialize();
|
||||||
|
} else {
|
||||||
|
//nd4j_printf("execStream is initialized\n", "");
|
||||||
|
}
|
||||||
|
|
||||||
return _execStream;
|
return _execStream;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ContextBuffers::specialStream() {
|
void* ContextBuffers::specialStream() {
|
||||||
if (_specialStream == nullptr)
|
if (!_initialized) {
|
||||||
|
//nd4j_printf("specialStream not initialized\n", "");
|
||||||
initialize();
|
initialize();
|
||||||
|
} else {
|
||||||
|
//nd4j_printf("specialStream is initialized\n", "");
|
||||||
|
}
|
||||||
|
|
||||||
return _specialStream;
|
return _specialStream;
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,10 +57,6 @@ LaunchContext::LaunchContext() {
|
||||||
_deviceID = 0;
|
_deviceID = 0;
|
||||||
|
|
||||||
_isAllocated = true;
|
_isAllocated = true;
|
||||||
|
|
||||||
_cublasHandle = CublasHelper::getInstance()->handle();
|
|
||||||
|
|
||||||
_cusolverHandle = CublasHelper::getInstance()->solver();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
|
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
|
||||||
|
@ -89,13 +85,13 @@ LaunchContext::LaunchContext() {
|
||||||
|
|
||||||
_contexts.resize(numDevices);
|
_contexts.resize(numDevices);
|
||||||
for (int e = 0; e < numDevices; e++) {
|
for (int e = 0; e < numDevices; e++) {
|
||||||
AffinityManager::setCurrentDevice(e);
|
AffinityManager::setCurrentNativeDevice(e);
|
||||||
|
|
||||||
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
|
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't forget to restore device back again
|
// don't forget to restore device back again
|
||||||
AffinityManager::setCurrentDevice(deviceId);
|
AffinityManager::setCurrentNativeDevice(deviceId);
|
||||||
}
|
}
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
|
@ -117,11 +113,11 @@ LaunchContext::LaunchContext() {
|
||||||
};
|
};
|
||||||
|
|
||||||
void* LaunchContext::getCublasHandle() const {
|
void* LaunchContext::getCublasHandle() const {
|
||||||
return _cublasHandle;
|
return CublasHelper::getInstance()->handle();
|
||||||
};
|
};
|
||||||
|
|
||||||
void* LaunchContext::getCusolverHandle() const {
|
void* LaunchContext::getCusolverHandle() const {
|
||||||
return _cusolverHandle;
|
return CublasHelper::getInstance()->solver();
|
||||||
};
|
};
|
||||||
|
|
||||||
cudaStream_t* LaunchContext::getCudaStream() const {
|
cudaStream_t* LaunchContext::getCudaStream() const {
|
||||||
|
@ -162,6 +158,7 @@ LaunchContext::LaunchContext() {
|
||||||
};
|
};
|
||||||
|
|
||||||
void LaunchContext::releaseBuffers() {
|
void LaunchContext::releaseBuffers() {
|
||||||
|
nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", "");
|
||||||
contextBuffers.release();
|
contextBuffers.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,12 +38,13 @@ namespace nd4j {
|
||||||
static ConstantHelper* _INSTANCE;
|
static ConstantHelper* _INSTANCE;
|
||||||
ConstantHelper();
|
ConstantHelper();
|
||||||
|
|
||||||
std::vector<std::map<ConstantDescriptor, ConstantHolder>> _cache;
|
std::vector<std::map<ConstantDescriptor, ConstantHolder*>> _cache;
|
||||||
|
|
||||||
// tracking of per-device constant memory buffers (CUDA only atm)
|
// tracking of per-device constant memory buffers (CUDA only atm)
|
||||||
std::vector<Nd4jPointer> _devicePointers;
|
std::vector<Nd4jPointer> _devicePointers;
|
||||||
std::vector<Nd4jLong> _deviceOffsets;
|
std::vector<Nd4jLong> _deviceOffsets;
|
||||||
std::mutex _mutex;
|
std::mutex _mutex;
|
||||||
|
std::mutex _mutexHolder;
|
||||||
|
|
||||||
std::vector<Nd4jLong> _counters;
|
std::vector<Nd4jLong> _counters;
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -48,10 +48,10 @@ namespace nd4j {
|
||||||
static ConstantShapeHelper* getInstance();
|
static ConstantShapeHelper* getInstance();
|
||||||
|
|
||||||
|
|
||||||
ConstantDataBuffer& bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
|
ConstantDataBuffer bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
|
||||||
ConstantDataBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor);
|
ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor);
|
||||||
ConstantDataBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo);
|
ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo);
|
||||||
ConstantDataBuffer& bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
|
ConstantDataBuffer bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
|
||||||
|
|
||||||
|
|
||||||
Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType);
|
Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType);
|
||||||
|
|
|
@ -54,11 +54,11 @@ namespace nd4j {
|
||||||
* @param keepUnitiesInShape
|
* @param keepUnitiesInShape
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
TadPack tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
TadPack tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
TadPack tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(TadDescriptor &descriptor);
|
TadPack tadForDimensions(TadDescriptor &descriptor);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns number of cached TAD shapes/offsets on specific device
|
* This method returns number of cached TAD shapes/offsets on specific device
|
||||||
|
|
|
@ -33,7 +33,8 @@ namespace nd4j {
|
||||||
_cache.resize(numDevices);
|
_cache.resize(numDevices);
|
||||||
_counters.resize(numDevices);
|
_counters.resize(numDevices);
|
||||||
for (int e = 0; e < numDevices; e++) {
|
for (int e = 0; e < numDevices; e++) {
|
||||||
std::map<ConstantDescriptor, ConstantHolder> map;
|
std::map<ConstantDescriptor, ConstantHolder*> map;
|
||||||
|
|
||||||
_cache[e] = map;
|
_cache[e] = map;
|
||||||
_counters[e] = 0L;
|
_counters[e] = 0L;
|
||||||
}
|
}
|
||||||
|
@ -70,15 +71,26 @@ namespace nd4j {
|
||||||
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
|
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
|
||||||
const auto deviceId = getCurrentDevice();
|
const auto deviceId = getCurrentDevice();
|
||||||
|
|
||||||
|
// we're locking away cache modification
|
||||||
|
_mutexHolder.lock();
|
||||||
|
|
||||||
if (_cache[deviceId].count(descriptor) == 0) {
|
if (_cache[deviceId].count(descriptor) == 0) {
|
||||||
ConstantHolder holder;
|
_cache[deviceId][descriptor] = new ConstantHolder();
|
||||||
_cache[deviceId][descriptor] = holder;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantHolder* holder = &_cache[deviceId][descriptor];
|
auto holder = _cache[deviceId][descriptor];
|
||||||
|
|
||||||
|
// releasing cache lock
|
||||||
|
_mutexHolder.unlock();
|
||||||
|
|
||||||
|
|
||||||
|
ConstantDataBuffer* result;
|
||||||
|
|
||||||
|
// access to this holder instance is synchronous
|
||||||
|
holder->mutex()->lock();
|
||||||
|
|
||||||
if (holder->hasBuffer(dataType))
|
if (holder->hasBuffer(dataType))
|
||||||
return holder->getConstantDataBuffer(dataType);
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
else {
|
else {
|
||||||
auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
||||||
auto cbuff = new int8_t[size];
|
auto cbuff = new int8_t[size];
|
||||||
|
@ -94,8 +106,11 @@ namespace nd4j {
|
||||||
ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
||||||
holder->addBuffer(dataBuffer, dataType);
|
holder->addBuffer(dataBuffer, dataType);
|
||||||
|
|
||||||
return holder->getConstantDataBuffer(dataType);
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
}
|
}
|
||||||
|
holder->mutex()->unlock();
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {
|
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {
|
||||||
|
|
|
@ -41,18 +41,18 @@ namespace nd4j {
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||||
ShapeDescriptor descriptor(dataType, order, shape);
|
ShapeDescriptor descriptor(dataType, order, shape);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||||
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||||
int deviceId = 0;
|
int deviceId = 0;
|
||||||
|
|
||||||
_mutex.lock();
|
_mutex.lock();
|
||||||
|
@ -62,19 +62,19 @@ namespace nd4j {
|
||||||
ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64);
|
ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64);
|
||||||
ShapeDescriptor descriptor1(descriptor);
|
ShapeDescriptor descriptor1(descriptor);
|
||||||
_cache[deviceId][descriptor1] = buffer;
|
_cache[deviceId][descriptor1] = buffer;
|
||||||
ConstantDataBuffer &r = _cache[deviceId][descriptor1];
|
auto r = _cache[deviceId][descriptor1];
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
} else {
|
} else {
|
||||||
ConstantDataBuffer &r = _cache[deviceId].at(descriptor);
|
auto r = _cache[deviceId].at(descriptor);
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||||
ShapeDescriptor descriptor(shapeInfo);
|
ShapeDescriptor descriptor(shapeInfo);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,25 +38,25 @@ namespace nd4j {
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||||
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
||||||
return tadForDimensions(tadDescriptor);
|
return tadForDimensions(tadDescriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||||
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
|
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
|
||||||
return tadForDimensions(tadDescriptor);
|
return tadForDimensions(tadDescriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||||
const int deviceId = 0;
|
const int deviceId = 0;
|
||||||
|
|
||||||
_mutex.lock();
|
_mutex.lock();
|
||||||
|
@ -105,7 +105,7 @@ namespace nd4j {
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
} else {
|
} else {
|
||||||
TadPack &r = _cache[deviceId][descriptor];
|
TadPack r = _cache[deviceId][descriptor];
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
|
|
|
@ -24,11 +24,13 @@
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
#include <pointercast.h>
|
#include <pointercast.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
class CublasHelper {
|
class CublasHelper {
|
||||||
private:
|
private:
|
||||||
static CublasHelper *_INSTANCE;
|
static CublasHelper *_INSTANCE;
|
||||||
|
static std::mutex _mutex;
|
||||||
|
|
||||||
std::vector<void*> _cache;
|
std::vector<void*> _cache;
|
||||||
std::vector<void*> _solvers;
|
std::vector<void*> _solvers;
|
||||||
|
|
|
@ -68,7 +68,7 @@ namespace nd4j {
|
||||||
throw cuda_exception::build("cudaSetDevice failed", res);
|
throw cuda_exception::build("cudaSetDevice failed", res);
|
||||||
auto constant = getConstantSpace();
|
auto constant = getConstantSpace();
|
||||||
|
|
||||||
std::map<ConstantDescriptor, ConstantHolder> devCache;
|
std::map<ConstantDescriptor, ConstantHolder*> devCache;
|
||||||
|
|
||||||
_devicePointers[e] = constant;
|
_devicePointers[e] = constant;
|
||||||
_deviceOffsets[e] = 0;
|
_deviceOffsets[e] = 0;
|
||||||
|
@ -136,15 +136,24 @@ namespace nd4j {
|
||||||
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
|
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
|
||||||
const auto deviceId = getCurrentDevice();
|
const auto deviceId = getCurrentDevice();
|
||||||
|
|
||||||
if (_cache[deviceId].count(descriptor) == 0) {
|
// all cache modifications are synchronous
|
||||||
ConstantHolder holder;
|
_mutexHolder.lock();
|
||||||
_cache[deviceId][descriptor] = holder;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConstantHolder* holder = &_cache[deviceId][descriptor];
|
if (_cache[deviceId].count(descriptor) == 0) {
|
||||||
|
_cache[deviceId][descriptor] = new ConstantHolder();
|
||||||
|
}
|
||||||
|
auto holder = _cache[deviceId][descriptor];
|
||||||
|
|
||||||
|
// release cache lock
|
||||||
|
_mutexHolder.unlock();
|
||||||
|
|
||||||
|
ConstantDataBuffer* result;
|
||||||
|
|
||||||
|
// access to this holder instance is synchronous
|
||||||
|
holder->mutex()->lock();
|
||||||
|
|
||||||
if (holder->hasBuffer(dataType)) {
|
if (holder->hasBuffer(dataType)) {
|
||||||
return holder->getConstantDataBuffer(dataType);
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
} else {
|
} else {
|
||||||
auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
||||||
auto cbuff = new int8_t[numBytes];
|
auto cbuff = new int8_t[numBytes];
|
||||||
|
@ -160,10 +169,14 @@ namespace nd4j {
|
||||||
auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType));
|
auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType));
|
||||||
|
|
||||||
ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
||||||
holder->addBuffer(dataBuffer, dataType);
|
|
||||||
|
|
||||||
return holder->getConstantDataBuffer(dataType);
|
holder->addBuffer(dataBuffer, dataType);
|
||||||
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
}
|
}
|
||||||
|
// release holder lock
|
||||||
|
holder->mutex()->unlock();
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {
|
Nd4jLong ConstantHelper::getCachedAmount(int deviceId) {
|
||||||
|
|
|
@ -44,17 +44,17 @@ namespace nd4j {
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||||
ShapeDescriptor descriptor(dataType, order, shape);
|
ShapeDescriptor descriptor(dataType, order, shape);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||||
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||||
int deviceId = AffinityManager::currentDeviceId();
|
int deviceId = AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
_mutex.lock();
|
_mutex.lock();
|
||||||
|
@ -65,19 +65,19 @@ namespace nd4j {
|
||||||
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
|
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
|
||||||
ShapeDescriptor descriptor1(descriptor);
|
ShapeDescriptor descriptor1(descriptor);
|
||||||
_cache[deviceId][descriptor1] = buffer;
|
_cache[deviceId][descriptor1] = buffer;
|
||||||
ConstantDataBuffer &r = _cache[deviceId][descriptor1];
|
auto r = _cache[deviceId][descriptor1];
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
} else {
|
} else {
|
||||||
ConstantDataBuffer &r = _cache[deviceId].at(descriptor);
|
ConstantDataBuffer r = _cache[deviceId].at(descriptor);
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||||
ShapeDescriptor descriptor(shapeInfo);
|
ShapeDescriptor descriptor(shapeInfo);
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,25 +43,25 @@ namespace nd4j {
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||||
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
||||||
return tadForDimensions(tadDescriptor);
|
return tadForDimensions(tadDescriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||||
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
|
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
|
||||||
return tadForDimensions(tadDescriptor);
|
return tadForDimensions(tadDescriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||||
const int deviceId = AffinityManager::currentDeviceId();
|
const int deviceId = AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
_mutex.lock();
|
_mutex.lock();
|
||||||
|
@ -96,14 +96,14 @@ namespace nd4j {
|
||||||
TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs);
|
TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs);
|
||||||
_cache[deviceId][descriptor] = t;
|
_cache[deviceId][descriptor] = t;
|
||||||
|
|
||||||
TadPack &r = _cache[deviceId][descriptor];
|
TadPack r = _cache[deviceId][descriptor];
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
delete[] shapeInfo;
|
delete[] shapeInfo;
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
} else {
|
} else {
|
||||||
TadPack &r = _cache[deviceId][descriptor];
|
TadPack r = _cache[deviceId][descriptor];
|
||||||
_mutex.unlock();
|
_mutex.unlock();
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include <execution/AffinityManager.h>
|
#include <execution/AffinityManager.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
std::mutex CublasHelper::_mutex;
|
||||||
|
|
||||||
static void* handle_() {
|
static void* handle_() {
|
||||||
auto _handle = new cublasHandle_t();
|
auto _handle = new cublasHandle_t();
|
||||||
|
@ -56,22 +57,24 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
CublasHelper::CublasHelper() {
|
CublasHelper::CublasHelper() {
|
||||||
|
//nd4j_printf("Initializing cuBLAS\n","");
|
||||||
auto numDevices = AffinityManager::numberOfDevices();
|
auto numDevices = AffinityManager::numberOfDevices();
|
||||||
auto currentDevice = AffinityManager::currentDeviceId();
|
auto currentDevice = AffinityManager::currentDeviceId();
|
||||||
_cache.resize(numDevices);
|
_cache.resize(numDevices);
|
||||||
_solvers.resize(numDevices);
|
_solvers.resize(numDevices);
|
||||||
for (int e = 0; e < numDevices; e++) {
|
for (int e = 0; e < numDevices; e++) {
|
||||||
AffinityManager::setCurrentDevice(e);
|
AffinityManager::setCurrentNativeDevice(e);
|
||||||
|
|
||||||
_cache[e] = handle_();
|
_cache[e] = handle_();
|
||||||
_solvers[e] = solver_();
|
_solvers[e] = solver_();
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't forget to restore back original device
|
// don't forget to restore back original device
|
||||||
AffinityManager::setCurrentDevice(currentDevice);
|
AffinityManager::setCurrentNativeDevice(currentDevice);
|
||||||
}
|
}
|
||||||
|
|
||||||
CublasHelper::~CublasHelper() {
|
CublasHelper::~CublasHelper() {
|
||||||
|
nd4j_printf("Releasing cuBLAS\n","");
|
||||||
auto numDevices = AffinityManager::numberOfDevices();
|
auto numDevices = AffinityManager::numberOfDevices();
|
||||||
|
|
||||||
for (int e = 0; e < numDevices; e++)
|
for (int e = 0; e < numDevices; e++)
|
||||||
|
@ -79,8 +82,10 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
CublasHelper* CublasHelper::getInstance() {
|
CublasHelper* CublasHelper::getInstance() {
|
||||||
|
_mutex.lock();
|
||||||
if (!_INSTANCE)
|
if (!_INSTANCE)
|
||||||
_INSTANCE = new nd4j::CublasHelper();
|
_INSTANCE = new nd4j::CublasHelper();
|
||||||
|
_mutex.unlock();
|
||||||
|
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.jita.allocator.pointers.cuda;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
||||||
import org.nd4j.linalg.exception.ND4JException;
|
import org.nd4j.linalg.exception.ND4JException;
|
||||||
|
@ -69,8 +70,9 @@ public class cudaEvent_t extends CudaPointer {
|
||||||
if (res == 0)
|
if (res == 0)
|
||||||
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
|
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
|
||||||
|
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
|
val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||||
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
|
if (code != 0)
|
||||||
|
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,8 +80,9 @@ public class cudaEvent_t extends CudaPointer {
|
||||||
if (!isDestroyed()) {
|
if (!isDestroyed()) {
|
||||||
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream);
|
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream);
|
||||||
|
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
|
val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||||
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
|
if (code != 0)
|
||||||
|
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.jcublas.blas;
|
package org.nd4j.linalg.jcublas.blas;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.DoublePointer;
|
import org.bytedeco.javacpp.DoublePointer;
|
||||||
import org.bytedeco.javacpp.FloatPointer;
|
import org.bytedeco.javacpp.FloatPointer;
|
||||||
|
@ -52,6 +53,7 @@ import static org.nd4j.linalg.jcublas.blas.CudaBlas.*;
|
||||||
*
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class JcublasLevel3 extends BaseLevel3 {
|
public class JcublasLevel3 extends BaseLevel3 {
|
||||||
private Allocator allocator = AtomicAllocator.getInstance();
|
private Allocator allocator = AtomicAllocator.getInstance();
|
||||||
private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas();
|
private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas();
|
||||||
|
@ -78,7 +80,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
|
int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
|
||||||
|
|
||||||
if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch == 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) {
|
if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch >= 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) {
|
||||||
// on these selected archs we run with cublasHgemm
|
// on these selected archs we run with cublasHgemm
|
||||||
__half alphaHalf = new __half();
|
__half alphaHalf = new __half();
|
||||||
__half betaHalf = new __half();
|
__half betaHalf = new __half();
|
||||||
|
@ -96,7 +98,11 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda,
|
new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda,
|
||||||
(ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta),
|
(ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta),
|
||||||
(ShortPointer) cCPointer.getDevicePointer(), 2, ldc);
|
(ShortPointer) cCPointer.getDevicePointer(), 2, ldc);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx.getOldStream().synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
allocator.registerAction(ctx, C, A, B);
|
allocator.registerAction(ctx, C, A, B);
|
||||||
|
@ -114,18 +120,24 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
val ctx = allocator.getFlowController().prepareAction(C, A, B);
|
val ctx = allocator.getFlowController().prepareAction(C, A, B);
|
||||||
|
|
||||||
|
//log.info("Synchronizing CUDA stream");
|
||||||
|
ctx.getOldStream().synchronize();
|
||||||
|
|
||||||
val cAPointer = new CublasPointer(A, ctx);
|
val cAPointer = new CublasPointer(A, ctx);
|
||||||
val cBPointer = new CublasPointer(B, ctx);
|
val cBPointer = new CublasPointer(B, ctx);
|
||||||
val cCPointer = new CublasPointer(C, ctx);
|
val cCPointer = new CublasPointer(C, ctx);
|
||||||
|
|
||||||
val handle = ctx.getCublasHandle();
|
val handle = ctx.getCublasHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
|
//log.info("Handle: {}; Stream: {}", handle.address(), ctx.getCublasStream().address());
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
|
cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
|
||||||
new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda,
|
new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda,
|
||||||
(FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta),
|
(FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta),
|
||||||
(FloatPointer) cCPointer.getDevicePointer(), ldc);
|
(FloatPointer) cCPointer.getDevicePointer(), ldc);
|
||||||
|
|
||||||
|
ctx.getOldStream().synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
allocator.registerAction(ctx, C, A, B);
|
allocator.registerAction(ctx, C, A, B);
|
||||||
|
@ -244,6 +256,8 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda,
|
new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda,
|
||||||
(DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta),
|
(DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta),
|
||||||
(DoublePointer) cCPointer.getDevicePointer(), ldc);
|
(DoublePointer) cCPointer.getDevicePointer(), ldc);
|
||||||
|
|
||||||
|
ctx.getOldStream().synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
allocator.registerAction(ctx, C, A, B);
|
allocator.registerAction(ctx, C, A, B);
|
||||||
|
|
|
@ -2548,6 +2548,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
||||||
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
@ -2562,6 +2565,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
||||||
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
@ -2577,6 +2583,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
|
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
|
||||||
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
|
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
@ -2590,6 +2599,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
|
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
|
||||||
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
|
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
|
Loading…
Reference in New Issue