From 930b49e87fee82b6f0ef46f846c6e2536a62db4f Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 22 Aug 2019 20:01:29 +0300 Subject: [PATCH] [WIP] DeviceLocalNDArray updates (#149) * ContextBuffers are released upon device change Signed-off-by: raver119 * DeviceLocalNDArray updates + tests Signed-off-by: raver119 * special array for delayed mode Signed-off-by: raver119 * additional detach() Signed-off-by: raver119 --- libnd4j/include/execution/ContextBuffers.h | 9 ++ libnd4j/include/execution/LaunchContext.h | 2 + .../include/execution/cpu/ContextBuffers.cpp | 20 +++ .../include/execution/cpu/LaunchContext.cpp | 8 ++ .../include/execution/cuda/AffinityManager.cu | 17 ++- .../include/execution/cuda/ContextBuffers.cu | 80 ++++++++++- .../include/execution/cuda/LaunchContext.cu | 8 ++ .../org/nd4j/linalg/util/DeviceLocal.java | 19 ++- .../nd4j/linalg/util/DeviceLocalNDArray.java | 131 ++++++++++++++++-- .../jita/handler/impl/CudaZeroHandler.java | 4 +- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 25 +++- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 60 +++++++- .../memory/DeviceLocalNDArrayTests.java | 100 +++++++++++++ 13 files changed, 445 insertions(+), 38 deletions(-) diff --git a/libnd4j/include/execution/ContextBuffers.h b/libnd4j/include/execution/ContextBuffers.h index 39490d288..130354070 100644 --- a/libnd4j/include/execution/ContextBuffers.h +++ b/libnd4j/include/execution/ContextBuffers.h @@ -33,15 +33,22 @@ namespace nd4j { void* _execStream = nullptr; void* _specialStream = nullptr; bool _allocated = false; + bool _initialized = false; int _deviceId = -1; void initialize(); public: ContextBuffers(); + ContextBuffers(const ContextBuffers &other); ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner = false); ~ContextBuffers(); + ContextBuffers& operator=(const ContextBuffers& other); + ContextBuffers& operator=(ContextBuffers&& other); + + void release(); + void* reductionBuffer(); void* scalarBuffer(); void* allocationBuffer(); @@ -56,6 +63,8 @@ namespace nd4j { void triggerOwnership(bool isOwner); int deviceId(); + + bool isInitialized(); }; } diff --git a/libnd4j/include/execution/LaunchContext.h b/libnd4j/include/execution/LaunchContext.h index f165e1297..23165fa0e 100644 --- a/libnd4j/include/execution/LaunchContext.h +++ b/libnd4j/include/execution/LaunchContext.h @@ -98,6 +98,8 @@ class ND4J_EXPORT LaunchContext { int getDeviceID() const {return _deviceID;} void setDeviceID(int deviceID) { _deviceID = deviceID; } + static bool isInitialized(); + static void releaseBuffers(); static LaunchContext* defaultContext(); diff --git a/libnd4j/include/execution/cpu/ContextBuffers.cpp b/libnd4j/include/execution/cpu/ContextBuffers.cpp index 19ceb1b36..3bf0a01eb 100644 --- a/libnd4j/include/execution/cpu/ContextBuffers.cpp +++ b/libnd4j/include/execution/cpu/ContextBuffers.cpp @@ -36,6 +36,10 @@ namespace nd4j { _allocated = isOwner; } + ContextBuffers::ContextBuffers(const ContextBuffers &other) { + // + } + void ContextBuffers::initialize() { // no-op } @@ -79,4 +83,20 @@ namespace nd4j { void* ContextBuffers::specialStream() { return _specialStream; } + + bool ContextBuffers::isInitialized() { + return true; + } + + void ContextBuffers::release() { + // + } + + ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) { + return *this; + } + + ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { + return *this; + } } \ No newline at end of file diff --git a/libnd4j/include/execution/cpu/LaunchContext.cpp b/libnd4j/include/execution/cpu/LaunchContext.cpp index d16ea24d3..3ee460350 100644 --- a/libnd4j/include/execution/cpu/LaunchContext.cpp +++ b/libnd4j/include/execution/cpu/LaunchContext.cpp @@ -57,4 +57,12 @@ namespace nd4j { void LaunchContext::swapContextBuffers(ContextBuffers &buffers) { // } + + bool LaunchContext::isInitialized() { + return true; + } + + void LaunchContext::releaseBuffers() { + // + } } \ No newline at end of file diff --git a/libnd4j/include/execution/cuda/AffinityManager.cu b/libnd4j/include/execution/cuda/AffinityManager.cu index 522698c98..1f028b011 100644 --- a/libnd4j/include/execution/cuda/AffinityManager.cu +++ b/libnd4j/include/execution/cuda/AffinityManager.cu @@ -95,17 +95,26 @@ namespace nd4j { } void AffinityManager::setCurrentDevice(int deviceId) { + auto previousDeviceId = globalThreadToDevice; + if (previousDeviceId >= 0 && LaunchContext::isInitialized()) { + auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + if (res != 0) + throw cuda_exception::build("setCurrentDevice -> sync failed", res); + + res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream()); + if (res != 0) + throw cuda_exception::build("setCurrentDevice -> specialSync failed", res); + } + auto res = cudaSetDevice(deviceId); if (res != 0) throw cuda_exception::build("cudaSetDevice failed", res); - auto previousDeviceId = globalThreadToDevice; - // update thread-device affinity globalThreadToDevice = deviceId; - ContextBuffers newBuffers; - LaunchContext::swapContextBuffers(newBuffers); + // discard existing stuff + LaunchContext::releaseBuffers(); } std::atomic AffinityManager::_lastDevice;// = std::atomic(initialV); diff --git a/libnd4j/include/execution/cuda/ContextBuffers.cu b/libnd4j/include/execution/cuda/ContextBuffers.cu index ed84d511a..84db0c284 100644 --- a/libnd4j/include/execution/cuda/ContextBuffers.cu +++ b/libnd4j/include/execution/cuda/ContextBuffers.cu @@ -34,9 +34,55 @@ namespace nd4j { _deviceId = AffinityManager::currentDeviceId(); } - ContextBuffers::~ContextBuffers() { + ContextBuffers::ContextBuffers(const ContextBuffers &other) { + release(); + + this->_initialized = other._initialized; + this->_allocated = other._allocated; + this->_deviceId = other._deviceId; + + this->_specialStream = other._specialStream; + this->_execStream = other._execStream; + this->_allocationPointer = other._allocationPointer; + this->_reductionPointer = other._reductionPointer; + this->_scalarPointer = other._scalarPointer; + } + + ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) { + release(); + + this->_initialized = other._initialized; + this->_allocated = other._allocated; + this->_deviceId = other._deviceId; + + this->_specialStream = other._specialStream; + this->_execStream = other._execStream; + this->_allocationPointer = other._allocationPointer; + this->_reductionPointer = other._reductionPointer; + this->_scalarPointer = other._scalarPointer; + + return *this; + } + + ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { + release(); + + this->_initialized = other._initialized; + this->_allocated = other._allocated; + this->_deviceId = other._deviceId; + + this->_specialStream = other._specialStream; + this->_execStream = other._execStream; + this->_allocationPointer = other._allocationPointer; + this->_reductionPointer = other._reductionPointer; + this->_scalarPointer = other._scalarPointer; + + return *this; + } + + void ContextBuffers::release() { if (_allocated) { - //nd4j_printf("Releasing ContextBuffers\n",""); + //nd4j_printf("Releasing ContextBuffers on device [%i]\n", _deviceId); if (_allocationPointer != nullptr) cudaFree(_allocationPointer); @@ -58,9 +104,24 @@ namespace nd4j { delete _cudaStream; delete _cudaSpecialStream; + + ////// + _allocated = false; + _initialized = false; + _deviceId = -1; + + this->_specialStream = nullptr; + this->_execStream = nullptr; + this->_allocationPointer = nullptr; + this->_reductionPointer = nullptr; + this->_scalarPointer = nullptr; } } + ContextBuffers::~ContextBuffers() { + release(); + } + ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) { _reductionPointer = rPointer; _scalarPointer = sPointer; @@ -69,19 +130,20 @@ namespace nd4j { } void ContextBuffers::initialize() { - //nd4j_printf("Initializing buffers on deviceId [%i]\n", AffinityManager::currentNativeDeviceId()); + _deviceId = AffinityManager::currentNativeDeviceId(); + //nd4j_printf("Initializing buffers on deviceId [%i]\n", _deviceId); auto res = cudaMalloc(reinterpret_cast(&_reductionPointer), 1024 * 1024 * 8); if (res != 0) - throw std::runtime_error("_reductionPointer allocation failed"); + throw cuda_exception::build("_reductionPointer allocation failed", res); res = cudaMalloc(reinterpret_cast(&_scalarPointer), 16); if (res != 0) - throw std::runtime_error("_scalarPointer allocation failed"); + throw cuda_exception::build("_scalarPointer allocation failed", res); res = cudaMalloc(reinterpret_cast(&_allocationPointer), 1024 * 1024 * 8); if (res != 0) - throw std::runtime_error("_allocationPointer allocation failed"); + throw cuda_exception::build("_allocationPointer allocation failed", res); _execStream = new cudaStream_t(); _specialStream = new cudaStream_t(); @@ -97,6 +159,7 @@ namespace nd4j { throw cuda_exception::build("Failed to create special CUDA stream with launch context", res); _allocated = true; + _initialized = true; } void* ContextBuffers::reductionBuffer() { @@ -153,4 +216,9 @@ namespace nd4j { return _specialStream; } + + bool ContextBuffers::isInitialized() { + return _initialized; + } } + diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu index a881633a0..1292f756c 100644 --- a/libnd4j/include/execution/cuda/LaunchContext.cu +++ b/libnd4j/include/execution/cuda/LaunchContext.cu @@ -160,4 +160,12 @@ LaunchContext::LaunchContext() { void LaunchContext::swapContextBuffers(ContextBuffers &buffers) { contextBuffers = buffers; }; + + void LaunchContext::releaseBuffers() { + contextBuffers.release(); + } + + bool LaunchContext::isInitialized() { + return contextBuffers.isInitialized(); + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DeviceLocal.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DeviceLocal.java index 31154cd37..ceaa41e19 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DeviceLocal.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DeviceLocal.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.util; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import edu.umd.cs.findbugs.annotations.Nullable; @@ -24,6 +25,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantReadWriteLock; /** @@ -31,14 +33,23 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; * * @author raver119@gmail.com */ -public class DeviceLocal { - private Map backingMap = new ConcurrentHashMap<>(); - private List locksMap = new ArrayList<>(); +public abstract class DeviceLocal { + protected Map backingMap = new ConcurrentHashMap<>(); + protected List locksMap = new ArrayList<>(); + protected List updatesMap = new ArrayList<>(); + protected final boolean delayedMode; + + protected volatile INDArray delayedArray; + + protected int lastSettledDevice = -1; + + public DeviceLocal(boolean delayedMode) { + this.delayedMode = delayedMode; - public DeviceLocal() { int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); for (int i = 0; i < numDevices; i++) { locksMap.add(new ReentrantReadWriteLock()); + updatesMap.add(new AtomicInteger(-1)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DeviceLocalNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DeviceLocalNDArray.java index 2cf085383..eb9644a39 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DeviceLocalNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DeviceLocalNDArray.java @@ -16,14 +16,20 @@ package org.nd4j.linalg.util; +import edu.umd.cs.findbugs.annotations.Nullable; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.profiler.OpProfiler; import org.nd4j.linalg.profiler.ProfilerConfig; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; + /** * DeviceLocal implementation for INDArray, with special broadcast method * @author raver119@gmail.com @@ -32,24 +38,71 @@ import org.nd4j.linalg.profiler.ProfilerConfig; public class DeviceLocalNDArray extends DeviceLocal { public DeviceLocalNDArray() { - super(); + this(false); + } + + public DeviceLocalNDArray(boolean delayedMode) { + super(delayedMode); } public DeviceLocalNDArray(INDArray array) { - super(); + this(array, false); + } + + public DeviceLocalNDArray(INDArray array, boolean delayedMode) { + super(delayedMode); broadcast(array); } + /** + * This method returns object local to current deviceId + * + * @return + */ + @Nullable + @Override + public synchronized INDArray get() { + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + val numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); + val sourceId = updatesMap.get(deviceId).get(); + if (sourceId >= 0 && sourceId != deviceId) { + // if updates map contains some deviceId - we should take updated array from there + val newArray = Nd4j.create(delayedArray.dataType(), delayedArray.shape(), delayedArray.stride(), delayedArray.ordering()); + Nd4j.getMemoryManager().memcpy(newArray.data(), delayedArray.data()); + backingMap.put(deviceId, newArray); + + // reset updates flag + updatesMap.get(deviceId).set(deviceId); + + + // also check if all updates were consumed + boolean allUpdated = true; + for (int e = 0; e < numDevices; e++) { + if (updatesMap.get(e).get() != e) { + allUpdated = false; + break; + } + } + + if (allUpdated) + delayedArray = null; + } + return get(deviceId); + } + /** * This method duplicates array, and stores it to all devices * + * PLEASE NOTE: this method is NOT atomic, so you must be sure no other threads are using this instance during the update * @param array */ - public void broadcast(INDArray array) { + public synchronized void broadcast(INDArray array) { if (array == null) return; + Preconditions.checkArgument(!array.isView() || array.elementWiseStride() != 1, "View can't be used in DeviceLocalNDArray"); + Nd4j.getExecutioner().commit(); val config = OpProfiler.getInstance().getConfig(); @@ -57,18 +110,76 @@ public class DeviceLocalNDArray extends DeviceLocal { if (locality) config.setCheckLocality(false); + val numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + + if (!delayedMode) { + // in immediate mode we put data in + + for (int i = 0; i < numDevices; i++) { + // if current thread equal to this device - we just save it, without duplication + if (deviceId == i) { + set(i, array.detach()); + } else { + set(i, Nd4j.getAffinityManager().replicateToDevice(i, array)); + } - int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); - for (int i = 0; i < numDevices; i++) { - // if current thread equal to this device - we just save it, without duplication - if (Nd4j.getAffinityManager().getDeviceForCurrentThread() == i) { - set(i, array); - } else { - set(i, Nd4j.getAffinityManager().replicateToDevice(i, array)); } + } else { + // we're only updating this device + set(Nd4j.getAffinityManager().getDeviceForCurrentThread(), array); + delayedArray = array.dup(array.ordering()).detach(); + // and marking all other devices as stale, and provide id of device with the most recent array + for (int i = 0; i < numDevices; i++) { + if (i != deviceId) { + updatesMap.get(i).set(deviceId); + } + } } config.setCheckLocality(locality); } + + /** + * This method updates + * + * PLEASE NOTE: this method is NOT atomic, so you must be sure no other threads are using this instance during the update + * @param array + */ + public synchronized void update(@NonNull INDArray array) { + Preconditions.checkArgument(!array.isView() || array.elementWiseStride() != 1, "View can't be used in DeviceLocalNDArray"); + + val numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); + val device = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + val currentArray = backingMap.get(device); + boolean wasDelayed = false; + + if (Arrays.equals(currentArray.shapeInfoJava(), array.shapeInfoJava())) { + // if arrays are the same - we'll just issue memcpy + for (int k = 0; k < numDevices; k++) { + val lock = locksMap.get(k); + try { + lock.writeLock().lock(); + val v = backingMap.get(k); + if (v == null) { + if (!wasDelayed) { + delayedArray = array.dup(array.ordering()).detach(); + wasDelayed = true; + } + updatesMap.get(k).set(device); + continue; + } + + Nd4j.getMemoryManager().memcpy(v.data(), array.data()); + Nd4j.getExecutioner().commit(); + } finally { + lock.writeLock().unlock(); + } + } + } else { + // if arrays are not the same - we'll issue broadcast call + broadcast(array); + } + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index 67f58ba94..f6a0eafc0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -659,8 +659,8 @@ public class CudaZeroHandler implements MemoryHandler { //log.info("Buffer MemCpy called"); //log.info("Memcpy buffer: {} bytes ", dstBuffer.length() * dstBuffer.getElementSize()); CudaContext context = getCudaContext(); - AllocationPoint dstPoint = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); - AllocationPoint srcPoint = ((BaseCudaDataBuffer) srcBuffer).getAllocationPoint(); + val dstPoint = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); + val srcPoint = ((BaseCudaDataBuffer) srcBuffer).getAllocationPoint(); Pointer dP = null; //new CudaPointer(dstPoint.getPointers().getHostPointer().address()); Pointer sP = null; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 33ba27069..0541f914a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3865,12 +3865,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc /** * create a new array by replicating current array by repeats times along given dimension - * dimension - dimension along which to repeat elements + * axis - axis along which to repeat elements * repeats - number of repetitions */ - public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongPointer repeats); - public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongBuffer repeats); - public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector long[] repeats); + public native NDArray repeat(int axis, @StdVector IntPointer repeats); + public native NDArray repeat(int axis, @StdVector IntBuffer repeats); + public native NDArray repeat(int axis, @StdVector int[] repeats); /** * This method fills this array with zeros @@ -3894,9 +3894,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc /** * fill target array by repeating current array - * dimension - dimension along which to repeat elements + * axis - axis along which to repeat elements + * repeats - vector containing numbers of repetition for elements at given axis */ - public native void repeat(int dimension, @ByRef NDArray target); + public native void repeat(int axis, @StdVector IntPointer repeats, @ByRef NDArray target); + public native void repeat(int axis, @StdVector IntBuffer repeats, @ByRef NDArray target); + public native void repeat(int axis, @StdVector int[] repeats, @ByRef NDArray target); /** * creates array which points on certain sub-range of this array, sub-range is defined by given indices @@ -9939,11 +9942,17 @@ public static final int PREALLOC_SIZE = 33554432; public ContextBuffers() { super((Pointer)null); allocate(); } private native void allocate(); + public ContextBuffers(@Const @ByRef ContextBuffers other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef ContextBuffers other); public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); } private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/); public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); } private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer); + public native @ByRef @Name("operator =") ContextBuffers put(@Const @ByRef ContextBuffers other); + + public native void release(); + public native Pointer reductionBuffer(); public native Pointer scalarBuffer(); public native Pointer allocationBuffer(); @@ -9958,6 +9967,8 @@ public static final int PREALLOC_SIZE = 33554432; public native void triggerOwnership(@Cast("bool") boolean isOwner); public native int deviceId(); + + public native @Cast("bool") boolean isInitialized(); } @@ -10036,6 +10047,8 @@ public static final int PREALLOC_SIZE = 33554432; public native int getDeviceID(); public native void setDeviceID(int deviceID); + public static native @Cast("bool") boolean isInitialized(); + public static native void releaseBuffers(); public static native LaunchContext defaultContext(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 4fa484a23..4fee4b3b7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3865,12 +3865,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc /** * create a new array by replicating current array by repeats times along given dimension - * dimension - dimension along which to repeat elements + * axis - axis along which to repeat elements * repeats - number of repetitions */ - public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongPointer repeats); - public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector LongBuffer repeats); - public native NDArray repeat(int dimension, @Cast("Nd4jLong*") @StdVector long[] repeats); + public native NDArray repeat(int axis, @StdVector IntPointer repeats); + public native NDArray repeat(int axis, @StdVector IntBuffer repeats); + public native NDArray repeat(int axis, @StdVector int[] repeats); /** * This method fills this array with zeros @@ -3894,9 +3894,12 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc /** * fill target array by repeating current array - * dimension - dimension along which to repeat elements + * axis - axis along which to repeat elements + * repeats - vector containing numbers of repetition for elements at given axis */ - public native void repeat(int dimension, @ByRef NDArray target); + public native void repeat(int axis, @StdVector IntPointer repeats, @ByRef NDArray target); + public native void repeat(int axis, @StdVector IntBuffer repeats, @ByRef NDArray target); + public native void repeat(int axis, @StdVector int[] repeats, @ByRef NDArray target); /** * creates array which points on certain sub-range of this array, sub-range is defined by given indices @@ -18209,6 +18212,24 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif +// #if NOT_EXCLUDED(OP_space_to_batch_nd) + @Namespace("nd4j::ops") public static class space_to_batch_nd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public space_to_batch_nd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public space_to_batch_nd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public space_to_batch_nd position(long position) { + return (space_to_batch_nd)super.position(position); + } + + public space_to_batch_nd() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /** * * @@ -18230,6 +18251,23 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif +// #if NOT_EXCLUDED(OP_batch_to_space_nd) + @Namespace("nd4j::ops") public static class batch_to_space_nd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public batch_to_space_nd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public batch_to_space_nd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public batch_to_space_nd position(long position) { + return (batch_to_space_nd)super.position(position); + } + + public batch_to_space_nd() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif /** * top_k operation returns a vector of k top values for @@ -22831,11 +22869,17 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public ContextBuffers() { super((Pointer)null); allocate(); } private native void allocate(); + public ContextBuffers(@Const @ByRef ContextBuffers other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef ContextBuffers other); public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); } private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/); public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); } private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer); + public native @ByRef @Name("operator =") ContextBuffers put(@Const @ByRef ContextBuffers other); + + public native void release(); + public native Pointer reductionBuffer(); public native Pointer scalarBuffer(); public native Pointer allocationBuffer(); @@ -22850,6 +22894,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native void triggerOwnership(@Cast("bool") boolean isOwner); public native int deviceId(); + + public native @Cast("bool") boolean isInitialized(); } @@ -22919,6 +22965,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native int getDeviceID(); public native void setDeviceID(int deviceID); + public static native @Cast("bool") boolean isInitialized(); + public static native void releaseBuffers(); public static native LaunchContext defaultContext(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java index 2057bc98b..cf8154ab4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.util.DeviceLocalNDArray; import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; @@ -67,6 +68,105 @@ public class DeviceLocalNDArrayTests extends BaseNd4jTest { } } + @Test + public void testDeviceLocalUpdate_1() throws Exception { + val numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); + if (numDevices < 2) + return; + + val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f); + + val deviceLocal = new DeviceLocalNDArray(array); + for (int e = 0; e < numDevices; e++) { + val t = new Thread(new Runnable() { + @Override + public void run() { + deviceLocal.get().add(1.f); + Nd4j.getExecutioner().commit();; + } + }); + + t.start(); + t.join(); + } + + val counter = new AtomicInteger(0); + + val update = Nd4j.createFromArray(5.f, 5.f, 5.f, 5.f); + deviceLocal.update(update); + + for (int e = 0; e < numDevices; e++) { + val t = new Thread(new Runnable() { + @Override + public void run() { + assertEquals(5.f, deviceLocal.get().meanNumber().floatValue(), 1e-5f); + counter.incrementAndGet(); + } + }); + + t.start(); + t.join(); + } + + assertEquals(numDevices, counter.get()); + } + + + @Test + public void testDelayedDeviceLocalUpdate_1() throws Exception { + val numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); + if (numDevices < 2) + return; + + val array = Nd4j.createFromArray(5.f, 5.f, 5.f, 5.f); + + val deviceLocal = new DeviceLocalNDArray(array, true); + val counter = new AtomicInteger(0); + + for (int e = 0; e < numDevices; e++) { + val t = new Thread(new Runnable() { + @Override + public void run() { + assertEquals(5.f, deviceLocal.get().meanNumber().floatValue(), 1e-5f); + counter.incrementAndGet(); + } + }); + + t.start(); + t.join(); + } + + assertEquals(numDevices, counter.get()); + } + + @Test + public void testDelayedDeviceLocalUpdate_2() throws Exception { + val numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); + if (numDevices < 2) + return; + + val array = Nd4j.createFromArray(5.f, 5.f, 5.f, 5.f); + + val deviceLocal = new DeviceLocalNDArray(array, true); + val counter = new AtomicInteger(0); + + deviceLocal.update(Nd4j.createFromArray(4.f, 4.f, 4.f, 4.f)); + + for (int e = 0; e < numDevices; e++) { + val t = new Thread(new Runnable() { + @Override + public void run() { + assertEquals(4.f, deviceLocal.get().meanNumber().floatValue(), 1e-5f); + counter.incrementAndGet(); + } + }); + + t.start(); + t.join(); + } + + assertEquals(numDevices, counter.get()); + } @Override public char ordering() {