- memcpy fix + validation for CUDA: skip memcpu if length < 1 (#375)
- Reset cached context after device affinity change Signed-off-by: raver119 <raver119@gmail.com>master
parent
133223e865
commit
75af392671
|
@ -294,7 +294,11 @@ public class CudaAffinityManager extends BasicAffinityManager {
|
|||
|
||||
@Override
|
||||
public void unsafeSetDevice(Integer deviceId) {
|
||||
// actually set device
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId);
|
||||
|
||||
// reset saved context, so it will be recreated on first call
|
||||
AtomicAllocator.getInstance().getMemoryHandler().resetCachedContext();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -304,4 +304,6 @@ public interface MemoryHandler {
|
|||
boolean promoteObject(DataBuffer buffer);
|
||||
|
||||
void relocateObject(DataBuffer buffer);
|
||||
|
||||
void resetCachedContext();
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
package org.nd4j.jita.handler.impl;
|
||||
|
||||
import lombok.var;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.nativeblas.OpaqueLaunchContext;
|
||||
import org.nd4j.shade.guava.collect.HashBasedTable;
|
||||
import org.nd4j.shade.guava.collect.Table;
|
||||
|
@ -325,6 +327,11 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
*/
|
||||
@Override
|
||||
public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
|
||||
if (length < 1)
|
||||
return;
|
||||
|
||||
Preconditions.checkArgument(length <= (dstBuffer.length() * Nd4j.sizeOfDataType(dstBuffer.dataType())), "Length requested is bigger than target DataBuffer length");
|
||||
|
||||
val point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
|
||||
CudaContext tContext = null;
|
||||
|
||||
|
@ -1041,6 +1048,11 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
return ctx;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void resetCachedContext() {
|
||||
tlContext.remove();
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns if this MemoryHandler instance is device-dependant (i.e. CUDA)
|
||||
*
|
||||
|
|
Loading…
Reference in New Issue