From 75af3926715a957551351a29cf933becf24f3de4 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 14 Apr 2020 18:47:46 +0300 Subject: [PATCH] - memcpy fix + validation for CUDA: skip memcpu if length < 1 (#375) - Reset cached context after device affinity change Signed-off-by: raver119 --- .../nd4j/jita/concurrency/CudaAffinityManager.java | 4 ++++ .../java/org/nd4j/jita/handler/MemoryHandler.java | 2 ++ .../org/nd4j/jita/handler/impl/CudaZeroHandler.java | 12 ++++++++++++ 3 files changed, 18 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java index 356df88e0..415fa487f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java @@ -294,7 +294,11 @@ public class CudaAffinityManager extends BasicAffinityManager { @Override public void unsafeSetDevice(Integer deviceId) { + // actually set device NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId); + + // reset saved context, so it will be recreated on first call + AtomicAllocator.getInstance().getMemoryHandler().resetCachedContext(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java index abb919e8c..44d8e2042 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java @@ -304,4 +304,6 @@ public interface MemoryHandler { boolean promoteObject(DataBuffer buffer); void relocateObject(DataBuffer buffer); + + void resetCachedContext(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index a8f3a0a3b..4c6e56bc9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -17,6 +17,8 @@ package org.nd4j.jita.handler.impl; import lombok.var; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.nativeblas.OpaqueLaunchContext; import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.Table; @@ -325,6 +327,11 @@ public class CudaZeroHandler implements MemoryHandler { */ @Override public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) { + if (length < 1) + return; + + Preconditions.checkArgument(length <= (dstBuffer.length() * Nd4j.sizeOfDataType(dstBuffer.dataType())), "Length requested is bigger than target DataBuffer length"); + val point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); CudaContext tContext = null; @@ -1041,6 +1048,11 @@ public class CudaZeroHandler implements MemoryHandler { return ctx; } + @Override + public void resetCachedContext() { + tlContext.remove(); + } + /** * This method returns if this MemoryHandler instance is device-dependant (i.e. CUDA) *