[WIP] DeviceLocalNDArray updates (#149)

* ContextBuffers are released upon device change

Signed-off-by: raver119 <raver119@gmail.com>

* DeviceLocalNDArray updates + tests

Signed-off-by: raver119 <raver119@gmail.com>

* special array for delayed mode

Signed-off-by: raver119 <raver119@gmail.com>

* additional detach()

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-22 20:01:29 +03:00 committed by GitHub
parent c523aa792f
commit 930b49e87f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 445 additions and 38 deletions

View File

@ -33,15 +33,22 @@ namespace nd4j {
void* _execStream = nullptr;
void* _specialStream = nullptr;
bool _allocated = false;
bool _initialized = false;
int _deviceId = -1;
void initialize();
public:
ContextBuffers();
ContextBuffers(const ContextBuffers &other);
ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner = false);
~ContextBuffers();
ContextBuffers& operator=(const ContextBuffers& other);
ContextBuffers& operator=(ContextBuffers&& other);
void release();
void* reductionBuffer();
void* scalarBuffer();
void* allocationBuffer();
@ -56,6 +63,8 @@ namespace nd4j {
void triggerOwnership(bool isOwner);
int deviceId();
bool isInitialized();
};
}

View File

@ -98,6 +98,8 @@ class ND4J_EXPORT LaunchContext {
int getDeviceID() const {return _deviceID;}
void setDeviceID(int deviceID) { _deviceID = deviceID; }
static bool isInitialized();
static void releaseBuffers();
static LaunchContext* defaultContext();

View File

@ -36,6 +36,10 @@ namespace nd4j {
_allocated = isOwner;
}
ContextBuffers::ContextBuffers(const ContextBuffers &other) {
//
}
void ContextBuffers::initialize() {
// no-op
}
@ -79,4 +83,20 @@ namespace nd4j {
void* ContextBuffers::specialStream() {
return _specialStream;
}
bool ContextBuffers::isInitialized() {
return true;
}
void ContextBuffers::release() {
//
}
ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) {
return *this;
}
ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) {
return *this;
}
}

View File

@ -57,4 +57,12 @@ namespace nd4j {
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
//
}
bool LaunchContext::isInitialized() {
return true;
}
void LaunchContext::releaseBuffers() {
//
}
}

View File

@ -95,17 +95,26 @@ namespace nd4j {
}
void AffinityManager::setCurrentDevice(int deviceId) {
auto previousDeviceId = globalThreadToDevice;
if (previousDeviceId >= 0 && LaunchContext::isInitialized()) {
auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream());
if (res != 0)
throw cuda_exception::build("setCurrentDevice -> sync failed", res);
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream());
if (res != 0)
throw cuda_exception::build("setCurrentDevice -> specialSync failed", res);
}
auto res = cudaSetDevice(deviceId);
if (res != 0)
throw cuda_exception::build("cudaSetDevice failed", res);
auto previousDeviceId = globalThreadToDevice;
// update thread-device affinity
globalThreadToDevice = deviceId;
ContextBuffers newBuffers;
LaunchContext::swapContextBuffers(newBuffers);
// discard existing stuff
LaunchContext::releaseBuffers();
}
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);

View File

@ -34,9 +34,55 @@ namespace nd4j {
_deviceId = AffinityManager::currentDeviceId();
}
ContextBuffers::~ContextBuffers() {
ContextBuffers::ContextBuffers(const ContextBuffers &other) {
release();
this->_initialized = other._initialized;
this->_allocated = other._allocated;
this->_deviceId = other._deviceId;
this->_specialStream = other._specialStream;
this->_execStream = other._execStream;
this->_allocationPointer = other._allocationPointer;
this->_reductionPointer = other._reductionPointer;
this->_scalarPointer = other._scalarPointer;
}
ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) {
release();
this->_initialized = other._initialized;
this->_allocated = other._allocated;
this->_deviceId = other._deviceId;
this->_specialStream = other._specialStream;
this->_execStream = other._execStream;
this->_allocationPointer = other._allocationPointer;
this->_reductionPointer = other._reductionPointer;
this->_scalarPointer = other._scalarPointer;
return *this;
}
ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) {
release();
this->_initialized = other._initialized;
this->_allocated = other._allocated;
this->_deviceId = other._deviceId;
this->_specialStream = other._specialStream;
this->_execStream = other._execStream;
this->_allocationPointer = other._allocationPointer;
this->_reductionPointer = other._reductionPointer;
this->_scalarPointer = other._scalarPointer;
return *this;
}
void ContextBuffers::release() {
if (_allocated) {
//nd4j_printf("Releasing ContextBuffers\n","");
//nd4j_printf("Releasing ContextBuffers on device [%i]\n", _deviceId);
if (_allocationPointer != nullptr)
cudaFree(_allocationPointer);
@ -58,9 +104,24 @@ namespace nd4j {
delete _cudaStream;
delete _cudaSpecialStream;
//////
_allocated = false;
_initialized = false;
_deviceId = -1;
this->_specialStream = nullptr;
this->_execStream = nullptr;
this->_allocationPointer = nullptr;
this->_reductionPointer = nullptr;
this->_scalarPointer = nullptr;
}
}
ContextBuffers::~ContextBuffers() {
release();
}
ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) {
_reductionPointer = rPointer;
_scalarPointer = sPointer;
@ -69,19 +130,20 @@ namespace nd4j {
}
void ContextBuffers::initialize() {
//nd4j_printf("Initializing buffers on deviceId [%i]\n", AffinityManager::currentNativeDeviceId());
_deviceId = AffinityManager::currentNativeDeviceId();
//nd4j_printf("Initializing buffers on deviceId [%i]\n", _deviceId);
auto res = cudaMalloc(reinterpret_cast<void**>(&_reductionPointer), 1024 * 1024 * 8);
if (res != 0)
throw std::runtime_error("_reductionPointer allocation failed");
throw cuda_exception::build("_reductionPointer allocation failed", res);
res = cudaMalloc(reinterpret_cast<void**>(&_scalarPointer), 16);
if (res != 0)
throw std::runtime_error("_scalarPointer allocation failed");
throw cuda_exception::build("_scalarPointer allocation failed", res);
res = cudaMalloc(reinterpret_cast<void**>(&_allocationPointer), 1024 * 1024 * 8);
if (res != 0)
throw std::runtime_error("_allocationPointer allocation failed");
throw cuda_exception::build("_allocationPointer allocation failed", res);
_execStream = new cudaStream_t();
_specialStream = new cudaStream_t();
@ -97,6 +159,7 @@ namespace nd4j {
throw cuda_exception::build("Failed to create special CUDA stream with launch context", res);
_allocated = true;
_initialized = true;
}
void* ContextBuffers::reductionBuffer() {
@ -153,4 +216,9 @@ namespace nd4j {
return _specialStream;
}
bool ContextBuffers::isInitialized() {
return _initialized;
}
}

View File

@ -160,4 +160,12 @@ LaunchContext::LaunchContext() {
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
contextBuffers = buffers;
};
void LaunchContext::releaseBuffers() {
contextBuffers.release();
}
bool LaunchContext::isInitialized() {
return contextBuffers.isInitialized();
}
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.util;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import edu.umd.cs.findbugs.annotations.Nullable;
@ -24,6 +25,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
@ -31,14 +33,23 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
*
* @author raver119@gmail.com
*/
public class DeviceLocal<T extends Object> {
private Map<Integer, T> backingMap = new ConcurrentHashMap<>();
private List<ReentrantReadWriteLock> locksMap = new ArrayList<>();
public abstract class DeviceLocal<T extends Object> {
protected Map<Integer, T> backingMap = new ConcurrentHashMap<>();
protected List<ReentrantReadWriteLock> locksMap = new ArrayList<>();
protected List<AtomicInteger> updatesMap = new ArrayList<>();
protected final boolean delayedMode;
protected volatile INDArray delayedArray;
protected int lastSettledDevice = -1;
public DeviceLocal(boolean delayedMode) {
this.delayedMode = delayedMode;
public DeviceLocal() {
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
for (int i = 0; i < numDevices; i++) {
locksMap.add(new ReentrantReadWriteLock());
updatesMap.add(new AtomicInteger(-1));
}
}

View File

@ -16,14 +16,20 @@
package org.nd4j.linalg.util;
import edu.umd.cs.findbugs.annotations.Nullable;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.ProfilerConfig;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
/**
* DeviceLocal implementation for INDArray, with special broadcast method
* @author raver119@gmail.com
@ -32,24 +38,71 @@ import org.nd4j.linalg.profiler.ProfilerConfig;
public class DeviceLocalNDArray extends DeviceLocal<INDArray> {
public DeviceLocalNDArray() {
super();
this(false);
}
public DeviceLocalNDArray(boolean delayedMode) {
super(delayedMode);
}
public DeviceLocalNDArray(INDArray array) {
super();
this(array, false);
}
public DeviceLocalNDArray(INDArray array, boolean delayedMode) {
super(delayedMode);
broadcast(array);
}
/**
* This method returns object local to current deviceId
*
* @return
*/
@Nullable
@Override
public synchronized INDArray get() {
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
val sourceId = updatesMap.get(deviceId).get();
if (sourceId >= 0 && sourceId != deviceId) {
// if updates map contains some deviceId - we should take updated array from there
val newArray = Nd4j.create(delayedArray.dataType(), delayedArray.shape(), delayedArray.stride(), delayedArray.ordering());
Nd4j.getMemoryManager().memcpy(newArray.data(), delayedArray.data());
backingMap.put(deviceId, newArray);
// reset updates flag
updatesMap.get(deviceId).set(deviceId);
// also check if all updates were consumed
boolean allUpdated = true;
for (int e = 0; e < numDevices; e++) {
if (updatesMap.get(e).get() != e) {
allUpdated = false;
break;
}
}
if (allUpdated)
delayedArray = null;
}
return get(deviceId);
}
/**
* This method duplicates array, and stores it to all devices
*
* PLEASE NOTE: this method is NOT atomic, so you must be sure no other threads are using this instance during the update
* @param array
*/
public void broadcast(INDArray array) {
public synchronized void broadcast(INDArray array) {
if (array == null)
return;
Preconditions.checkArgument(!array.isView() || array.elementWiseStride() != 1, "View can't be used in DeviceLocalNDArray");
Nd4j.getExecutioner().commit();
val config = OpProfiler.getInstance().getConfig();
@ -57,18 +110,76 @@ public class DeviceLocalNDArray extends DeviceLocal<INDArray> {
if (locality)
config.setCheckLocality(false);
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
if (!delayedMode) {
// in immediate mode we put data in
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
for (int i = 0; i < numDevices; i++) {
// if current thread equal to this device - we just save it, without duplication
if (Nd4j.getAffinityManager().getDeviceForCurrentThread() == i) {
set(i, array);
if (deviceId == i) {
set(i, array.detach());
} else {
set(i, Nd4j.getAffinityManager().replicateToDevice(i, array));
}
}
} else {
// we're only updating this device
set(Nd4j.getAffinityManager().getDeviceForCurrentThread(), array);
delayedArray = array.dup(array.ordering()).detach();
// and marking all other devices as stale, and provide id of device with the most recent array
for (int i = 0; i < numDevices; i++) {
if (i != deviceId) {
updatesMap.get(i).set(deviceId);
}
}
}
config.setCheckLocality(locality);
}
/**
* This method updates
*
* PLEASE NOTE: this method is NOT atomic, so you must be sure no other threads are using this instance during the update
* @param array
*/
public synchronized void update(@NonNull INDArray array) {
Preconditions.checkArgument(!array.isView() || array.elementWiseStride() != 1, "View can't be used in DeviceLocalNDArray");
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
val device = Nd4j.getAffinityManager().getDeviceForCurrentThread();
val currentArray = backingMap.get(device);
boolean wasDelayed = false;
if (Arrays.equals(currentArray.shapeInfoJava(), array.shapeInfoJava())) {
// if arrays are the same - we'll just issue memcpy
for (int k = 0; k < numDevices; k++) {
val lock = locksMap.get(k);
try {
lock.writeLock().lock();
val v = backingMap.get(k);
if (v == null) {
if (!wasDelayed) {
delayedArray = array.dup(array.ordering()).detach();
wasDelayed = true;
}
updatesMap.get(k).set(device);
continue;
}
Nd4j.getMemoryManager().memcpy(v.data(), array.data());
Nd4j.getExecutioner().commit();
} finally {
lock.writeLock().unlock();
}
}
} else {
// if arrays are not the same - we'll issue broadcast call
broadcast(array);
}
}
}

View File

@ -659,8 +659,8 @@ public class CudaZeroHandler implements MemoryHandler {
//log.info("Buffer MemCpy called");
//log.info("Memcpy buffer: {} bytes ", dstBuffer.length() * dstBuffer.getElementSize());
CudaContext context = getCudaContext();
AllocationPoint dstPoint = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
AllocationPoint srcPoint = ((BaseCudaDataBuffer) srcBuffer).getAllocationPoint();
val dstPoint = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
val srcPoint = ((BaseCudaDataBuffer) srcBuffer).getAllocationPoint();
Pointer dP = null; //new CudaPointer(dstPoint.getPointers().getHostPointer().address());
Pointer sP = null;

View File

@ -3865,12 +3865,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
/**
* create a new array by replicating current array by repeats times along given dimension
* dimension - dimension along which to repeat elements
* axis - axis along which to repeat elements
* repeats - number of repetitions
*/
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongPointer repeats);
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongBuffer repeats);
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector long[] repeats);
public native NDArray repeat(int axis, @StdVector IntPointer repeats);
public native NDArray repeat(int axis, @StdVector IntBuffer repeats);
public native NDArray repeat(int axis, @StdVector int[] repeats);
/**
* This method fills this array with zeros
@ -3894,9 +3894,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
/**
* fill target array by repeating current array
* dimension - dimension along which to repeat elements
* axis - axis along which to repeat elements
* repeats - vector containing numbers of repetition for elements at given axis
*/
public native void repeat(int dimension, @ByRef NDArray target);
public native void repeat(int axis, @StdVector IntPointer repeats, @ByRef NDArray target);
public native void repeat(int axis, @StdVector IntBuffer repeats, @ByRef NDArray target);
public native void repeat(int axis, @StdVector int[] repeats, @ByRef NDArray target);
/**
* creates array which points on certain sub-range of this array, sub-range is defined by given indices
@ -9939,11 +9942,17 @@ public static final int PREALLOC_SIZE = 33554432;
public ContextBuffers() { super((Pointer)null); allocate(); }
private native void allocate();
public ContextBuffers(@Const @ByRef ContextBuffers other) { super((Pointer)null); allocate(other); }
private native void allocate(@Const @ByRef ContextBuffers other);
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); }
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/);
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); }
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer);
public native @ByRef @Name("operator =") ContextBuffers put(@Const @ByRef ContextBuffers other);
public native void release();
public native Pointer reductionBuffer();
public native Pointer scalarBuffer();
public native Pointer allocationBuffer();
@ -9958,6 +9967,8 @@ public static final int PREALLOC_SIZE = 33554432;
public native void triggerOwnership(@Cast("bool") boolean isOwner);
public native int deviceId();
public native @Cast("bool") boolean isInitialized();
}
@ -10036,6 +10047,8 @@ public static final int PREALLOC_SIZE = 33554432;
public native int getDeviceID();
public native void setDeviceID(int deviceID);
public static native @Cast("bool") boolean isInitialized();
public static native void releaseBuffers();
public static native LaunchContext defaultContext();

View File

@ -3865,12 +3865,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
/**
* create a new array by replicating current array by repeats times along given dimension
* dimension - dimension along which to repeat elements
* axis - axis along which to repeat elements
* repeats - number of repetitions
*/
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongPointer repeats);
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongBuffer repeats);
public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector long[] repeats);
public native NDArray repeat(int axis, @StdVector IntPointer repeats);
public native NDArray repeat(int axis, @StdVector IntBuffer repeats);
public native NDArray repeat(int axis, @StdVector int[] repeats);
/**
* This method fills this array with zeros
@ -3894,9 +3894,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
/**
* fill target array by repeating current array
* dimension - dimension along which to repeat elements
* axis - axis along which to repeat elements
* repeats - vector containing numbers of repetition for elements at given axis
*/
public native void repeat(int dimension, @ByRef NDArray target);
public native void repeat(int axis, @StdVector IntPointer repeats, @ByRef NDArray target);
public native void repeat(int axis, @StdVector IntBuffer repeats, @ByRef NDArray target);
public native void repeat(int axis, @StdVector int[] repeats, @ByRef NDArray target);
/**
* creates array which points on certain sub-range of this array, sub-range is defined by given indices
@ -18209,6 +18212,24 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
}
// #endif
// #if NOT_EXCLUDED(OP_space_to_batch_nd)
@Namespace("nd4j::ops") public static class space_to_batch_nd extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public space_to_batch_nd(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public space_to_batch_nd(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public space_to_batch_nd position(long position) {
return (space_to_batch_nd)super.position(position);
}
public space_to_batch_nd() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
*
*
@ -18230,6 +18251,23 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
// #if NOT_EXCLUDED(OP_batch_to_space_nd)
@Namespace("nd4j::ops") public static class batch_to_space_nd extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public batch_to_space_nd(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public batch_to_space_nd(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public batch_to_space_nd position(long position) {
return (batch_to_space_nd)super.position(position);
}
public batch_to_space_nd() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
* top_k operation returns a vector of k top values for
@ -22831,11 +22869,17 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public ContextBuffers() { super((Pointer)null); allocate(); }
private native void allocate();
public ContextBuffers(@Const @ByRef ContextBuffers other) { super((Pointer)null); allocate(other); }
private native void allocate(@Const @ByRef ContextBuffers other);
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); }
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/);
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); }
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer);
public native @ByRef @Name("operator =") ContextBuffers put(@Const @ByRef ContextBuffers other);
public native void release();
public native Pointer reductionBuffer();
public native Pointer scalarBuffer();
public native Pointer allocationBuffer();
@ -22850,6 +22894,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native void triggerOwnership(@Cast("bool") boolean isOwner);
public native int deviceId();
public native @Cast("bool") boolean isInitialized();
}
@ -22919,6 +22965,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native int getDeviceID();
public native void setDeviceID(int deviceID);
public static native @Cast("bool") boolean isInitialized();
public static native void releaseBuffers();
public static native LaunchContext defaultContext();

View File

@ -29,6 +29,7 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
@ -67,6 +68,105 @@ public class DeviceLocalNDArrayTests extends BaseNd4jTest {
}
}
@Test
public void testDeviceLocalUpdate_1() throws Exception {
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
if (numDevices < 2)
return;
val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f);
val deviceLocal = new DeviceLocalNDArray(array);
for (int e = 0; e < numDevices; e++) {
val t = new Thread(new Runnable() {
@Override
public void run() {
deviceLocal.get().add(1.f);
Nd4j.getExecutioner().commit();;
}
});
t.start();
t.join();
}
val counter = new AtomicInteger(0);
val update = Nd4j.createFromArray(5.f, 5.f, 5.f, 5.f);
deviceLocal.update(update);
for (int e = 0; e < numDevices; e++) {
val t = new Thread(new Runnable() {
@Override
public void run() {
assertEquals(5.f, deviceLocal.get().meanNumber().floatValue(), 1e-5f);
counter.incrementAndGet();
}
});
t.start();
t.join();
}
assertEquals(numDevices, counter.get());
}
@Test
public void testDelayedDeviceLocalUpdate_1() throws Exception {
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
if (numDevices < 2)
return;
val array = Nd4j.createFromArray(5.f, 5.f, 5.f, 5.f);
val deviceLocal = new DeviceLocalNDArray(array, true);
val counter = new AtomicInteger(0);
for (int e = 0; e < numDevices; e++) {
val t = new Thread(new Runnable() {
@Override
public void run() {
assertEquals(5.f, deviceLocal.get().meanNumber().floatValue(), 1e-5f);
counter.incrementAndGet();
}
});
t.start();
t.join();
}
assertEquals(numDevices, counter.get());
}
@Test
public void testDelayedDeviceLocalUpdate_2() throws Exception {
val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
if (numDevices < 2)
return;
val array = Nd4j.createFromArray(5.f, 5.f, 5.f, 5.f);
val deviceLocal = new DeviceLocalNDArray(array, true);
val counter = new AtomicInteger(0);
deviceLocal.update(Nd4j.createFromArray(4.f, 4.f, 4.f, 4.f));
for (int e = 0; e < numDevices; e++) {
val t = new Thread(new Runnable() {
@Override
public void run() {
assertEquals(4.f, deviceLocal.get().meanNumber().floatValue(), 1e-5f);
counter.incrementAndGet();
}
});
t.start();
t.join();
}
assertEquals(numDevices, counter.get());
}
@Override
public char ordering() {