diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java index 356df88e0..415fa487f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java @@ -294,7 +294,11 @@ public class CudaAffinityManager extends BasicAffinityManager { @Override public void unsafeSetDevice(Integer deviceId) { + // actually set device NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId); + + // reset saved context, so it will be recreated on first call + AtomicAllocator.getInstance().getMemoryHandler().resetCachedContext(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java index abb919e8c..44d8e2042 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java @@ -304,4 +304,6 @@ public interface MemoryHandler { boolean promoteObject(DataBuffer buffer); void relocateObject(DataBuffer buffer); + + void resetCachedContext(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index a8f3a0a3b..4c6e56bc9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -17,6 +17,8 @@ package org.nd4j.jita.handler.impl; import lombok.var; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.nativeblas.OpaqueLaunchContext; import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.Table; @@ -325,6 +327,11 @@ public class CudaZeroHandler implements MemoryHandler { */ @Override public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) { + if (length < 1) + return; + + Preconditions.checkArgument(length <= (dstBuffer.length() * Nd4j.sizeOfDataType(dstBuffer.dataType())), "Length requested is bigger than target DataBuffer length"); + val point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); CudaContext tContext = null; @@ -1041,6 +1048,11 @@ public class CudaZeroHandler implements MemoryHandler { return ctx; } + @Override + public void resetCachedContext() { + tlContext.remove(); + } + /** * This method returns if this MemoryHandler instance is device-dependant (i.e. CUDA) *