- memcpy fix + validation for CUDA: skip memcpu if length < 1 (#375)
- Reset cached context after device affinity change Signed-off-by: raver119 <raver119@gmail.com>master
parent
133223e865
commit
75af392671
|
@ -294,7 +294,11 @@ public class CudaAffinityManager extends BasicAffinityManager {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void unsafeSetDevice(Integer deviceId) {
|
public void unsafeSetDevice(Integer deviceId) {
|
||||||
|
// actually set device
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId);
|
||||||
|
|
||||||
|
// reset saved context, so it will be recreated on first call
|
||||||
|
AtomicAllocator.getInstance().getMemoryHandler().resetCachedContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -304,4 +304,6 @@ public interface MemoryHandler {
|
||||||
boolean promoteObject(DataBuffer buffer);
|
boolean promoteObject(DataBuffer buffer);
|
||||||
|
|
||||||
void relocateObject(DataBuffer buffer);
|
void relocateObject(DataBuffer buffer);
|
||||||
|
|
||||||
|
void resetCachedContext();
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
package org.nd4j.jita.handler.impl;
|
package org.nd4j.jita.handler.impl;
|
||||||
|
|
||||||
import lombok.var;
|
import lombok.var;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.nativeblas.OpaqueLaunchContext;
|
import org.nd4j.nativeblas.OpaqueLaunchContext;
|
||||||
import org.nd4j.shade.guava.collect.HashBasedTable;
|
import org.nd4j.shade.guava.collect.HashBasedTable;
|
||||||
import org.nd4j.shade.guava.collect.Table;
|
import org.nd4j.shade.guava.collect.Table;
|
||||||
|
@ -325,6 +327,11 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
|
public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
|
||||||
|
if (length < 1)
|
||||||
|
return;
|
||||||
|
|
||||||
|
Preconditions.checkArgument(length <= (dstBuffer.length() * Nd4j.sizeOfDataType(dstBuffer.dataType())), "Length requested is bigger than target DataBuffer length");
|
||||||
|
|
||||||
val point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
|
val point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
|
||||||
CudaContext tContext = null;
|
CudaContext tContext = null;
|
||||||
|
|
||||||
|
@ -1041,6 +1048,11 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void resetCachedContext() {
|
||||||
|
tlContext.remove();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns if this MemoryHandler instance is device-dependant (i.e. CUDA)
|
* This method returns if this MemoryHandler instance is device-dependant (i.e. CUDA)
|
||||||
*
|
*
|
||||||
|
|
Loading…
Reference in New Issue