[WIP] DeviceLocalNDArray updates (#149)
* ContextBuffers are released upon device change Signed-off-by: raver119 <raver119@gmail.com> * DeviceLocalNDArray updates + tests Signed-off-by: raver119 <raver119@gmail.com> * special array for delayed mode Signed-off-by: raver119 <raver119@gmail.com> * additional detach() Signed-off-by: raver119 <raver119@gmail.com>master
parent
c523aa792f
commit
930b49e87f
|
@ -33,15 +33,22 @@ namespace nd4j {
|
||||||
void* _execStream = nullptr;
|
void* _execStream = nullptr;
|
||||||
void* _specialStream = nullptr;
|
void* _specialStream = nullptr;
|
||||||
bool _allocated = false;
|
bool _allocated = false;
|
||||||
|
bool _initialized = false;
|
||||||
|
|
||||||
int _deviceId = -1;
|
int _deviceId = -1;
|
||||||
|
|
||||||
void initialize();
|
void initialize();
|
||||||
public:
|
public:
|
||||||
ContextBuffers();
|
ContextBuffers();
|
||||||
|
ContextBuffers(const ContextBuffers &other);
|
||||||
ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner = false);
|
ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner = false);
|
||||||
~ContextBuffers();
|
~ContextBuffers();
|
||||||
|
|
||||||
|
ContextBuffers& operator=(const ContextBuffers& other);
|
||||||
|
ContextBuffers& operator=(ContextBuffers&& other);
|
||||||
|
|
||||||
|
void release();
|
||||||
|
|
||||||
void* reductionBuffer();
|
void* reductionBuffer();
|
||||||
void* scalarBuffer();
|
void* scalarBuffer();
|
||||||
void* allocationBuffer();
|
void* allocationBuffer();
|
||||||
|
@ -56,6 +63,8 @@ namespace nd4j {
|
||||||
void triggerOwnership(bool isOwner);
|
void triggerOwnership(bool isOwner);
|
||||||
|
|
||||||
int deviceId();
|
int deviceId();
|
||||||
|
|
||||||
|
bool isInitialized();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -98,6 +98,8 @@ class ND4J_EXPORT LaunchContext {
|
||||||
int getDeviceID() const {return _deviceID;}
|
int getDeviceID() const {return _deviceID;}
|
||||||
void setDeviceID(int deviceID) { _deviceID = deviceID; }
|
void setDeviceID(int deviceID) { _deviceID = deviceID; }
|
||||||
|
|
||||||
|
static bool isInitialized();
|
||||||
|
static void releaseBuffers();
|
||||||
static LaunchContext* defaultContext();
|
static LaunchContext* defaultContext();
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,10 @@ namespace nd4j {
|
||||||
_allocated = isOwner;
|
_allocated = isOwner;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ContextBuffers::ContextBuffers(const ContextBuffers &other) {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
void ContextBuffers::initialize() {
|
void ContextBuffers::initialize() {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
|
@ -79,4 +83,20 @@ namespace nd4j {
|
||||||
void* ContextBuffers::specialStream() {
|
void* ContextBuffers::specialStream() {
|
||||||
return _specialStream;
|
return _specialStream;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ContextBuffers::isInitialized() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ContextBuffers::release() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -57,4 +57,12 @@ namespace nd4j {
|
||||||
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
|
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool LaunchContext::isInitialized() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LaunchContext::releaseBuffers() {
|
||||||
|
//
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -95,17 +95,26 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void AffinityManager::setCurrentDevice(int deviceId) {
|
void AffinityManager::setCurrentDevice(int deviceId) {
|
||||||
|
auto previousDeviceId = globalThreadToDevice;
|
||||||
|
if (previousDeviceId >= 0 && LaunchContext::isInitialized()) {
|
||||||
|
auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream());
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("setCurrentDevice -> sync failed", res);
|
||||||
|
|
||||||
|
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream());
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("setCurrentDevice -> specialSync failed", res);
|
||||||
|
}
|
||||||
|
|
||||||
auto res = cudaSetDevice(deviceId);
|
auto res = cudaSetDevice(deviceId);
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw cuda_exception::build("cudaSetDevice failed", res);
|
throw cuda_exception::build("cudaSetDevice failed", res);
|
||||||
|
|
||||||
auto previousDeviceId = globalThreadToDevice;
|
|
||||||
|
|
||||||
// update thread-device affinity
|
// update thread-device affinity
|
||||||
globalThreadToDevice = deviceId;
|
globalThreadToDevice = deviceId;
|
||||||
|
|
||||||
ContextBuffers newBuffers;
|
// discard existing stuff
|
||||||
LaunchContext::swapContextBuffers(newBuffers);
|
LaunchContext::releaseBuffers();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);
|
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);
|
||||||
|
|
|
@ -34,9 +34,55 @@ namespace nd4j {
|
||||||
_deviceId = AffinityManager::currentDeviceId();
|
_deviceId = AffinityManager::currentDeviceId();
|
||||||
}
|
}
|
||||||
|
|
||||||
ContextBuffers::~ContextBuffers() {
|
ContextBuffers::ContextBuffers(const ContextBuffers &other) {
|
||||||
|
release();
|
||||||
|
|
||||||
|
this->_initialized = other._initialized;
|
||||||
|
this->_allocated = other._allocated;
|
||||||
|
this->_deviceId = other._deviceId;
|
||||||
|
|
||||||
|
this->_specialStream = other._specialStream;
|
||||||
|
this->_execStream = other._execStream;
|
||||||
|
this->_allocationPointer = other._allocationPointer;
|
||||||
|
this->_reductionPointer = other._reductionPointer;
|
||||||
|
this->_scalarPointer = other._scalarPointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) {
|
||||||
|
release();
|
||||||
|
|
||||||
|
this->_initialized = other._initialized;
|
||||||
|
this->_allocated = other._allocated;
|
||||||
|
this->_deviceId = other._deviceId;
|
||||||
|
|
||||||
|
this->_specialStream = other._specialStream;
|
||||||
|
this->_execStream = other._execStream;
|
||||||
|
this->_allocationPointer = other._allocationPointer;
|
||||||
|
this->_reductionPointer = other._reductionPointer;
|
||||||
|
this->_scalarPointer = other._scalarPointer;
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) {
|
||||||
|
release();
|
||||||
|
|
||||||
|
this->_initialized = other._initialized;
|
||||||
|
this->_allocated = other._allocated;
|
||||||
|
this->_deviceId = other._deviceId;
|
||||||
|
|
||||||
|
this->_specialStream = other._specialStream;
|
||||||
|
this->_execStream = other._execStream;
|
||||||
|
this->_allocationPointer = other._allocationPointer;
|
||||||
|
this->_reductionPointer = other._reductionPointer;
|
||||||
|
this->_scalarPointer = other._scalarPointer;
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ContextBuffers::release() {
|
||||||
if (_allocated) {
|
if (_allocated) {
|
||||||
//nd4j_printf("Releasing ContextBuffers\n","");
|
//nd4j_printf("Releasing ContextBuffers on device [%i]\n", _deviceId);
|
||||||
|
|
||||||
if (_allocationPointer != nullptr)
|
if (_allocationPointer != nullptr)
|
||||||
cudaFree(_allocationPointer);
|
cudaFree(_allocationPointer);
|
||||||
|
@ -58,9 +104,24 @@ namespace nd4j {
|
||||||
|
|
||||||
delete _cudaStream;
|
delete _cudaStream;
|
||||||
delete _cudaSpecialStream;
|
delete _cudaSpecialStream;
|
||||||
|
|
||||||
|
//////
|
||||||
|
_allocated = false;
|
||||||
|
_initialized = false;
|
||||||
|
_deviceId = -1;
|
||||||
|
|
||||||
|
this->_specialStream = nullptr;
|
||||||
|
this->_execStream = nullptr;
|
||||||
|
this->_allocationPointer = nullptr;
|
||||||
|
this->_reductionPointer = nullptr;
|
||||||
|
this->_scalarPointer = nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ContextBuffers::~ContextBuffers() {
|
||||||
|
release();
|
||||||
|
}
|
||||||
|
|
||||||
ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) {
|
ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) {
|
||||||
_reductionPointer = rPointer;
|
_reductionPointer = rPointer;
|
||||||
_scalarPointer = sPointer;
|
_scalarPointer = sPointer;
|
||||||
|
@ -69,19 +130,20 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ContextBuffers::initialize() {
|
void ContextBuffers::initialize() {
|
||||||
//nd4j_printf("Initializing buffers on deviceId [%i]\n", AffinityManager::currentNativeDeviceId());
|
_deviceId = AffinityManager::currentNativeDeviceId();
|
||||||
|
//nd4j_printf("Initializing buffers on deviceId [%i]\n", _deviceId);
|
||||||
|
|
||||||
auto res = cudaMalloc(reinterpret_cast<void**>(&_reductionPointer), 1024 * 1024 * 8);
|
auto res = cudaMalloc(reinterpret_cast<void**>(&_reductionPointer), 1024 * 1024 * 8);
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw std::runtime_error("_reductionPointer allocation failed");
|
throw cuda_exception::build("_reductionPointer allocation failed", res);
|
||||||
|
|
||||||
res = cudaMalloc(reinterpret_cast<void**>(&_scalarPointer), 16);
|
res = cudaMalloc(reinterpret_cast<void**>(&_scalarPointer), 16);
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw std::runtime_error("_scalarPointer allocation failed");
|
throw cuda_exception::build("_scalarPointer allocation failed", res);
|
||||||
|
|
||||||
res = cudaMalloc(reinterpret_cast<void**>(&_allocationPointer), 1024 * 1024 * 8);
|
res = cudaMalloc(reinterpret_cast<void**>(&_allocationPointer), 1024 * 1024 * 8);
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw std::runtime_error("_allocationPointer allocation failed");
|
throw cuda_exception::build("_allocationPointer allocation failed", res);
|
||||||
|
|
||||||
_execStream = new cudaStream_t();
|
_execStream = new cudaStream_t();
|
||||||
_specialStream = new cudaStream_t();
|
_specialStream = new cudaStream_t();
|
||||||
|
@ -97,6 +159,7 @@ namespace nd4j {
|
||||||
throw cuda_exception::build("Failed to create special CUDA stream with launch context", res);
|
throw cuda_exception::build("Failed to create special CUDA stream with launch context", res);
|
||||||
|
|
||||||
_allocated = true;
|
_allocated = true;
|
||||||
|
_initialized = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ContextBuffers::reductionBuffer() {
|
void* ContextBuffers::reductionBuffer() {
|
||||||
|
@ -153,4 +216,9 @@ namespace nd4j {
|
||||||
|
|
||||||
return _specialStream;
|
return _specialStream;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ContextBuffers::isInitialized() {
|
||||||
|
return _initialized;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -160,4 +160,12 @@ LaunchContext::LaunchContext() {
|
||||||
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
|
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
|
||||||
contextBuffers = buffers;
|
contextBuffers = buffers;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void LaunchContext::releaseBuffers() {
|
||||||
|
contextBuffers.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool LaunchContext::isInitialized() {
|
||||||
|
return contextBuffers.isInitialized();
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.util;
|
package org.nd4j.linalg.util;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import edu.umd.cs.findbugs.annotations.Nullable;
|
import edu.umd.cs.findbugs.annotations.Nullable;
|
||||||
|
@ -24,6 +25,7 @@ import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -31,14 +33,23 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||||
*
|
*
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
*/
|
*/
|
||||||
public class DeviceLocal<T extends Object> {
|
public abstract class DeviceLocal<T extends Object> {
|
||||||
private Map<Integer, T> backingMap = new ConcurrentHashMap<>();
|
protected Map<Integer, T> backingMap = new ConcurrentHashMap<>();
|
||||||
private List<ReentrantReadWriteLock> locksMap = new ArrayList<>();
|
protected List<ReentrantReadWriteLock> locksMap = new ArrayList<>();
|
||||||
|
protected List<AtomicInteger> updatesMap = new ArrayList<>();
|
||||||
|
protected final boolean delayedMode;
|
||||||
|
|
||||||
|
protected volatile INDArray delayedArray;
|
||||||
|
|
||||||
|
protected int lastSettledDevice = -1;
|
||||||
|
|
||||||
|
public DeviceLocal(boolean delayedMode) {
|
||||||
|
this.delayedMode = delayedMode;
|
||||||
|
|
||||||
public DeviceLocal() {
|
|
||||||
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
||||||
for (int i = 0; i < numDevices; i++) {
|
for (int i = 0; i < numDevices; i++) {
|
||||||
locksMap.add(new ReentrantReadWriteLock());
|
locksMap.add(new ReentrantReadWriteLock());
|
||||||
|
updatesMap.add(new AtomicInteger(-1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,14 +16,20 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.util;
|
package org.nd4j.linalg.util;
|
||||||
|
|
||||||
|
import edu.umd.cs.findbugs.annotations.Nullable;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.profiler.OpProfiler;
|
import org.nd4j.linalg.profiler.OpProfiler;
|
||||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DeviceLocal implementation for INDArray, with special broadcast method
|
* DeviceLocal implementation for INDArray, with special broadcast method
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
|
@ -32,24 +38,71 @@ import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||||
public class DeviceLocalNDArray extends DeviceLocal<INDArray> {
|
public class DeviceLocalNDArray extends DeviceLocal<INDArray> {
|
||||||
|
|
||||||
public DeviceLocalNDArray() {
|
public DeviceLocalNDArray() {
|
||||||
super();
|
this(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DeviceLocalNDArray(boolean delayedMode) {
|
||||||
|
super(delayedMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
public DeviceLocalNDArray(INDArray array) {
|
public DeviceLocalNDArray(INDArray array) {
|
||||||
super();
|
this(array, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DeviceLocalNDArray(INDArray array, boolean delayedMode) {
|
||||||
|
super(delayedMode);
|
||||||
|
|
||||||
broadcast(array);
|
broadcast(array);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns object local to current deviceId
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
@Override
|
||||||
|
public synchronized INDArray get() {
|
||||||
|
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||||
|
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
||||||
|
val sourceId = updatesMap.get(deviceId).get();
|
||||||
|
if (sourceId >= 0 && sourceId != deviceId) {
|
||||||
|
// if updates map contains some deviceId - we should take updated array from there
|
||||||
|
val newArray = Nd4j.create(delayedArray.dataType(), delayedArray.shape(), delayedArray.stride(), delayedArray.ordering());
|
||||||
|
Nd4j.getMemoryManager().memcpy(newArray.data(), delayedArray.data());
|
||||||
|
backingMap.put(deviceId, newArray);
|
||||||
|
|
||||||
|
// reset updates flag
|
||||||
|
updatesMap.get(deviceId).set(deviceId);
|
||||||
|
|
||||||
|
|
||||||
|
// also check if all updates were consumed
|
||||||
|
boolean allUpdated = true;
|
||||||
|
for (int e = 0; e < numDevices; e++) {
|
||||||
|
if (updatesMap.get(e).get() != e) {
|
||||||
|
allUpdated = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allUpdated)
|
||||||
|
delayedArray = null;
|
||||||
|
}
|
||||||
|
return get(deviceId);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method duplicates array, and stores it to all devices
|
* This method duplicates array, and stores it to all devices
|
||||||
*
|
*
|
||||||
|
* PLEASE NOTE: this method is NOT atomic, so you must be sure no other threads are using this instance during the update
|
||||||
* @param array
|
* @param array
|
||||||
*/
|
*/
|
||||||
public void broadcast(INDArray array) {
|
public synchronized void broadcast(INDArray array) {
|
||||||
if (array == null)
|
if (array == null)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
Preconditions.checkArgument(!array.isView() || array.elementWiseStride() != 1, "View can't be used in DeviceLocalNDArray");
|
||||||
|
|
||||||
Nd4j.getExecutioner().commit();
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
val config = OpProfiler.getInstance().getConfig();
|
val config = OpProfiler.getInstance().getConfig();
|
||||||
|
@ -57,18 +110,76 @@ public class DeviceLocalNDArray extends DeviceLocal<INDArray> {
|
||||||
|
|
||||||
if (locality)
|
if (locality)
|
||||||
config.setCheckLocality(false);
|
config.setCheckLocality(false);
|
||||||
|
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
||||||
|
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||||
|
|
||||||
|
if (!delayedMode) {
|
||||||
|
// in immediate mode we put data in
|
||||||
|
|
||||||
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
|
||||||
for (int i = 0; i < numDevices; i++) {
|
for (int i = 0; i < numDevices; i++) {
|
||||||
// if current thread equal to this device - we just save it, without duplication
|
// if current thread equal to this device - we just save it, without duplication
|
||||||
if (Nd4j.getAffinityManager().getDeviceForCurrentThread() == i) {
|
if (deviceId == i) {
|
||||||
set(i, array);
|
set(i, array.detach());
|
||||||
} else {
|
} else {
|
||||||
set(i, Nd4j.getAffinityManager().replicateToDevice(i, array));
|
set(i, Nd4j.getAffinityManager().replicateToDevice(i, array));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// we're only updating this device
|
||||||
|
set(Nd4j.getAffinityManager().getDeviceForCurrentThread(), array);
|
||||||
|
delayedArray = array.dup(array.ordering()).detach();
|
||||||
|
|
||||||
|
// and marking all other devices as stale, and provide id of device with the most recent array
|
||||||
|
for (int i = 0; i < numDevices; i++) {
|
||||||
|
if (i != deviceId) {
|
||||||
|
updatesMap.get(i).set(deviceId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
config.setCheckLocality(locality);
|
config.setCheckLocality(locality);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method updates
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: this method is NOT atomic, so you must be sure no other threads are using this instance during the update
|
||||||
|
* @param array
|
||||||
|
*/
|
||||||
|
public synchronized void update(@NonNull INDArray array) {
|
||||||
|
Preconditions.checkArgument(!array.isView() || array.elementWiseStride() != 1, "View can't be used in DeviceLocalNDArray");
|
||||||
|
|
||||||
|
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
||||||
|
val device = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||||
|
val currentArray = backingMap.get(device);
|
||||||
|
boolean wasDelayed = false;
|
||||||
|
|
||||||
|
if (Arrays.equals(currentArray.shapeInfoJava(), array.shapeInfoJava())) {
|
||||||
|
// if arrays are the same - we'll just issue memcpy
|
||||||
|
for (int k = 0; k < numDevices; k++) {
|
||||||
|
val lock = locksMap.get(k);
|
||||||
|
try {
|
||||||
|
lock.writeLock().lock();
|
||||||
|
val v = backingMap.get(k);
|
||||||
|
if (v == null) {
|
||||||
|
if (!wasDelayed) {
|
||||||
|
delayedArray = array.dup(array.ordering()).detach();
|
||||||
|
wasDelayed = true;
|
||||||
|
}
|
||||||
|
updatesMap.get(k).set(device);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4j.getMemoryManager().memcpy(v.data(), array.data());
|
||||||
|
Nd4j.getExecutioner().commit();
|
||||||
|
} finally {
|
||||||
|
lock.writeLock().unlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// if arrays are not the same - we'll issue broadcast call
|
||||||
|
broadcast(array);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -659,8 +659,8 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
//log.info("Buffer MemCpy called");
|
//log.info("Buffer MemCpy called");
|
||||||
//log.info("Memcpy buffer: {} bytes ", dstBuffer.length() * dstBuffer.getElementSize());
|
//log.info("Memcpy buffer: {} bytes ", dstBuffer.length() * dstBuffer.getElementSize());
|
||||||
CudaContext context = getCudaContext();
|
CudaContext context = getCudaContext();
|
||||||
AllocationPoint dstPoint = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
|
val dstPoint = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
|
||||||
AllocationPoint srcPoint = ((BaseCudaDataBuffer) srcBuffer).getAllocationPoint();
|
val srcPoint = ((BaseCudaDataBuffer) srcBuffer).getAllocationPoint();
|
||||||
|
|
||||||
Pointer dP = null; //new CudaPointer(dstPoint.getPointers().getHostPointer().address());
|
Pointer dP = null; //new CudaPointer(dstPoint.getPointers().getHostPointer().address());
|
||||||
Pointer sP = null;
|
Pointer sP = null;
|
||||||
|
|
|
@ -3865,12 +3865,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create a new array by replicating current array by repeats times along given dimension
|
* create a new array by replicating current array by repeats times along given dimension
|
||||||
* dimension - dimension along which to repeat elements
|
* axis - axis along which to repeat elements
|
||||||
* repeats - number of repetitions
|
* repeats - number of repetitions
|
||||||
*/
|
*/
|
||||||
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongPointer repeats);
|
public native NDArray repeat(int axis, @StdVector IntPointer repeats);
|
||||||
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongBuffer repeats);
|
public native NDArray repeat(int axis, @StdVector IntBuffer repeats);
|
||||||
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector long[] repeats);
|
public native NDArray repeat(int axis, @StdVector int[] repeats);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method fills this array with zeros
|
* This method fills this array with zeros
|
||||||
|
@ -3894,9 +3894,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* fill target array by repeating current array
|
* fill target array by repeating current array
|
||||||
* dimension - dimension along which to repeat elements
|
* axis - axis along which to repeat elements
|
||||||
|
* repeats - vector containing numbers of repetition for elements at given axis
|
||||||
*/
|
*/
|
||||||
public native void repeat(int dimension, @ByRef NDArray target);
|
public native void repeat(int axis, @StdVector IntPointer repeats, @ByRef NDArray target);
|
||||||
|
public native void repeat(int axis, @StdVector IntBuffer repeats, @ByRef NDArray target);
|
||||||
|
public native void repeat(int axis, @StdVector int[] repeats, @ByRef NDArray target);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* creates array which points on certain sub-range of this array, sub-range is defined by given indices
|
* creates array which points on certain sub-range of this array, sub-range is defined by given indices
|
||||||
|
@ -9939,11 +9942,17 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
public ContextBuffers() { super((Pointer)null); allocate(); }
|
public ContextBuffers() { super((Pointer)null); allocate(); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
|
public ContextBuffers(@Const @ByRef ContextBuffers other) { super((Pointer)null); allocate(other); }
|
||||||
|
private native void allocate(@Const @ByRef ContextBuffers other);
|
||||||
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); }
|
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); }
|
||||||
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/);
|
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/);
|
||||||
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); }
|
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); }
|
||||||
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer);
|
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer);
|
||||||
|
|
||||||
|
public native @ByRef @Name("operator =") ContextBuffers put(@Const @ByRef ContextBuffers other);
|
||||||
|
|
||||||
|
public native void release();
|
||||||
|
|
||||||
public native Pointer reductionBuffer();
|
public native Pointer reductionBuffer();
|
||||||
public native Pointer scalarBuffer();
|
public native Pointer scalarBuffer();
|
||||||
public native Pointer allocationBuffer();
|
public native Pointer allocationBuffer();
|
||||||
|
@ -9958,6 +9967,8 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
public native void triggerOwnership(@Cast("bool") boolean isOwner);
|
public native void triggerOwnership(@Cast("bool") boolean isOwner);
|
||||||
|
|
||||||
public native int deviceId();
|
public native int deviceId();
|
||||||
|
|
||||||
|
public native @Cast("bool") boolean isInitialized();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -10036,6 +10047,8 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
public native int getDeviceID();
|
public native int getDeviceID();
|
||||||
public native void setDeviceID(int deviceID);
|
public native void setDeviceID(int deviceID);
|
||||||
|
|
||||||
|
public static native @Cast("bool") boolean isInitialized();
|
||||||
|
public static native void releaseBuffers();
|
||||||
public static native LaunchContext defaultContext();
|
public static native LaunchContext defaultContext();
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3865,12 +3865,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create a new array by replicating current array by repeats times along given dimension
|
* create a new array by replicating current array by repeats times along given dimension
|
||||||
* dimension - dimension along which to repeat elements
|
* axis - axis along which to repeat elements
|
||||||
* repeats - number of repetitions
|
* repeats - number of repetitions
|
||||||
*/
|
*/
|
||||||
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongPointer repeats);
|
public native NDArray repeat(int axis, @StdVector IntPointer repeats);
|
||||||
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongBuffer repeats);
|
public native NDArray repeat(int axis, @StdVector IntBuffer repeats);
|
||||||
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector long[] repeats);
|
public native NDArray repeat(int axis, @StdVector int[] repeats);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method fills this array with zeros
|
* This method fills this array with zeros
|
||||||
|
@ -3894,9 +3894,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* fill target array by repeating current array
|
* fill target array by repeating current array
|
||||||
* dimension - dimension along which to repeat elements
|
* axis - axis along which to repeat elements
|
||||||
|
* repeats - vector containing numbers of repetition for elements at given axis
|
||||||
*/
|
*/
|
||||||
public native void repeat(int dimension, @ByRef NDArray target);
|
public native void repeat(int axis, @StdVector IntPointer repeats, @ByRef NDArray target);
|
||||||
|
public native void repeat(int axis, @StdVector IntBuffer repeats, @ByRef NDArray target);
|
||||||
|
public native void repeat(int axis, @StdVector int[] repeats, @ByRef NDArray target);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* creates array which points on certain sub-range of this array, sub-range is defined by given indices
|
* creates array which points on certain sub-range of this array, sub-range is defined by given indices
|
||||||
|
@ -18209,6 +18212,24 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
// #if NOT_EXCLUDED(OP_space_to_batch_nd)
|
||||||
|
@Namespace("nd4j::ops") public static class space_to_batch_nd extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public space_to_batch_nd(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public space_to_batch_nd(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public space_to_batch_nd position(long position) {
|
||||||
|
return (space_to_batch_nd)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public space_to_batch_nd() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
|
@ -18230,6 +18251,23 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
// #if NOT_EXCLUDED(OP_batch_to_space_nd)
|
||||||
|
@Namespace("nd4j::ops") public static class batch_to_space_nd extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public batch_to_space_nd(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public batch_to_space_nd(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public batch_to_space_nd position(long position) {
|
||||||
|
return (batch_to_space_nd)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public batch_to_space_nd() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* top_k operation returns a vector of k top values for
|
* top_k operation returns a vector of k top values for
|
||||||
|
@ -22831,11 +22869,17 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
|
|
||||||
public ContextBuffers() { super((Pointer)null); allocate(); }
|
public ContextBuffers() { super((Pointer)null); allocate(); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
|
public ContextBuffers(@Const @ByRef ContextBuffers other) { super((Pointer)null); allocate(other); }
|
||||||
|
private native void allocate(@Const @ByRef ContextBuffers other);
|
||||||
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); }
|
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); }
|
||||||
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/);
|
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/);
|
||||||
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); }
|
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); }
|
||||||
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer);
|
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer);
|
||||||
|
|
||||||
|
public native @ByRef @Name("operator =") ContextBuffers put(@Const @ByRef ContextBuffers other);
|
||||||
|
|
||||||
|
public native void release();
|
||||||
|
|
||||||
public native Pointer reductionBuffer();
|
public native Pointer reductionBuffer();
|
||||||
public native Pointer scalarBuffer();
|
public native Pointer scalarBuffer();
|
||||||
public native Pointer allocationBuffer();
|
public native Pointer allocationBuffer();
|
||||||
|
@ -22850,6 +22894,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
public native void triggerOwnership(@Cast("bool") boolean isOwner);
|
public native void triggerOwnership(@Cast("bool") boolean isOwner);
|
||||||
|
|
||||||
public native int deviceId();
|
public native int deviceId();
|
||||||
|
|
||||||
|
public native @Cast("bool") boolean isInitialized();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -22919,6 +22965,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
public native int getDeviceID();
|
public native int getDeviceID();
|
||||||
public native void setDeviceID(int deviceID);
|
public native void setDeviceID(int deviceID);
|
||||||
|
|
||||||
|
public static native @Cast("bool") boolean isInitialized();
|
||||||
|
public static native void releaseBuffers();
|
||||||
public static native LaunchContext defaultContext();
|
public static native LaunchContext defaultContext();
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@ -67,6 +68,105 @@ public class DeviceLocalNDArrayTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDeviceLocalUpdate_1() throws Exception {
|
||||||
|
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
||||||
|
if (numDevices < 2)
|
||||||
|
return;
|
||||||
|
|
||||||
|
val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f);
|
||||||
|
|
||||||
|
val deviceLocal = new DeviceLocalNDArray(array);
|
||||||
|
for (int e = 0; e < numDevices; e++) {
|
||||||
|
val t = new Thread(new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
deviceLocal.get().add(1.f);
|
||||||
|
Nd4j.getExecutioner().commit();;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
t.start();
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
val counter = new AtomicInteger(0);
|
||||||
|
|
||||||
|
val update = Nd4j.createFromArray(5.f, 5.f, 5.f, 5.f);
|
||||||
|
deviceLocal.update(update);
|
||||||
|
|
||||||
|
for (int e = 0; e < numDevices; e++) {
|
||||||
|
val t = new Thread(new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
assertEquals(5.f, deviceLocal.get().meanNumber().floatValue(), 1e-5f);
|
||||||
|
counter.incrementAndGet();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
t.start();
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(numDevices, counter.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDelayedDeviceLocalUpdate_1() throws Exception {
|
||||||
|
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
||||||
|
if (numDevices < 2)
|
||||||
|
return;
|
||||||
|
|
||||||
|
val array = Nd4j.createFromArray(5.f, 5.f, 5.f, 5.f);
|
||||||
|
|
||||||
|
val deviceLocal = new DeviceLocalNDArray(array, true);
|
||||||
|
val counter = new AtomicInteger(0);
|
||||||
|
|
||||||
|
for (int e = 0; e < numDevices; e++) {
|
||||||
|
val t = new Thread(new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
assertEquals(5.f, deviceLocal.get().meanNumber().floatValue(), 1e-5f);
|
||||||
|
counter.incrementAndGet();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
t.start();
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(numDevices, counter.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDelayedDeviceLocalUpdate_2() throws Exception {
|
||||||
|
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
||||||
|
if (numDevices < 2)
|
||||||
|
return;
|
||||||
|
|
||||||
|
val array = Nd4j.createFromArray(5.f, 5.f, 5.f, 5.f);
|
||||||
|
|
||||||
|
val deviceLocal = new DeviceLocalNDArray(array, true);
|
||||||
|
val counter = new AtomicInteger(0);
|
||||||
|
|
||||||
|
deviceLocal.update(Nd4j.createFromArray(4.f, 4.f, 4.f, 4.f));
|
||||||
|
|
||||||
|
for (int e = 0; e < numDevices; e++) {
|
||||||
|
val t = new Thread(new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
assertEquals(4.f, deviceLocal.get().meanNumber().floatValue(), 1e-5f);
|
||||||
|
counter.incrementAndGet();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
t.start();
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(numDevices, counter.get());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
|
|
Loading…
Reference in New Issue