- memcpy fix + validation for CUDA: skip memcpu if length < 1 (#375)

- Reset cached context after device affinity change

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-04-14 18:47:46 +03:00 committed by GitHub
parent 133223e865
commit 75af392671
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 0 deletions

View File

@ -294,7 +294,11 @@ public class CudaAffinityManager extends BasicAffinityManager {
@Override @Override
public void unsafeSetDevice(Integer deviceId) { public void unsafeSetDevice(Integer deviceId) {
// actually set device
NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId); NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId);
// reset saved context, so it will be recreated on first call
AtomicAllocator.getInstance().getMemoryHandler().resetCachedContext();
} }
@Override @Override

View File

@ -304,4 +304,6 @@ public interface MemoryHandler {
boolean promoteObject(DataBuffer buffer); boolean promoteObject(DataBuffer buffer);
void relocateObject(DataBuffer buffer); void relocateObject(DataBuffer buffer);
void resetCachedContext();
} }

View File

@ -17,6 +17,8 @@
package org.nd4j.jita.handler.impl; package org.nd4j.jita.handler.impl;
import lombok.var; import lombok.var;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.nativeblas.OpaqueLaunchContext; import org.nd4j.nativeblas.OpaqueLaunchContext;
import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Table; import org.nd4j.shade.guava.collect.Table;
@ -325,6 +327,11 @@ public class CudaZeroHandler implements MemoryHandler {
*/ */
@Override @Override
public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) { public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
if (length < 1)
return;
Preconditions.checkArgument(length <= (dstBuffer.length() * Nd4j.sizeOfDataType(dstBuffer.dataType())), "Length requested is bigger than target DataBuffer length");
val point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); val point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
CudaContext tContext = null; CudaContext tContext = null;
@ -1041,6 +1048,11 @@ public class CudaZeroHandler implements MemoryHandler {
return ctx; return ctx;
} }
@Override
public void resetCachedContext() {
tlContext.remove();
}
/** /**
* This method returns if this MemoryHandler instance is device-dependant (i.e. CUDA) * This method returns if this MemoryHandler instance is device-dependant (i.e. CUDA)
* *