CUDA sync tweaks (#194)

* ThreadLocal cache for CudaContext

Signed-off-by: raver119 <raver119@gmail.com>

* temp commit

Signed-off-by: raver119 <raver119@gmail.com>

* remove unwanted synchronization

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-01-28 10:55:06 +03:00 committed by GitHub
parent 7ef0ef907e
commit 9f719488b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 60 deletions

View File

@ -353,18 +353,8 @@ public class AtomicAllocator implements Allocator {
*/ */
@Override @Override
public void synchronizeHostData(DataBuffer buffer) { public void synchronizeHostData(DataBuffer buffer) {
// we don't want non-committed ops left behind
Nd4j.getExecutioner().commit();
val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
// we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code // we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
//assert oPtr.address() == cPtr.address();
//assert buffer.address() == oPtr.address();
} }

View File

@ -102,6 +102,8 @@ public class CudaZeroHandler implements MemoryHandler {
private final AffinityManager affinityManager = Nd4j.getAffinityManager(); private final AffinityManager affinityManager = Nd4j.getAffinityManager();
private final transient ThreadLocal<CudaContext> tlContext = new ThreadLocal<>();
/* /*
table for Thread, Device, Object allocations of device memory. Objects should be used to grab Allocation point from allocationsMap table for Thread, Device, Object allocations of device memory. Objects should be used to grab Allocation point from allocationsMap
*/ */
@ -1018,18 +1020,25 @@ public class CudaZeroHandler implements MemoryHandler {
* @return * @return
*/ */
public CudaContext getCudaContext() { public CudaContext getCudaContext() {
val lc = nativeOps.defaultLaunchContext(); var ctx = tlContext.get();
if (ctx == null) {
val lc = nativeOps.defaultLaunchContext();
return CudaContext.builder() ctx = CudaContext.builder()
.bufferScalar(nativeOps.lcScalarPointer(lc)) .bufferScalar(nativeOps.lcScalarPointer(lc))
.bufferReduction(nativeOps.lcReductionPointer(lc)) .bufferReduction(nativeOps.lcReductionPointer(lc))
.bufferAllocation(nativeOps.lcAllocationPointer(lc)) .bufferAllocation(nativeOps.lcAllocationPointer(lc))
.bufferSpecial(nativeOps.lcScalarPointer(lc)) .bufferSpecial(nativeOps.lcScalarPointer(lc))
.oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc))) .oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc)))
.specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc))) .specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc)))
.cublasHandle(getCudaCublasHandle(lc)) .cublasHandle(getCudaCublasHandle(lc))
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
.build(); .build();
tlContext.set(ctx);
return ctx;
} else
return ctx;
} }
/** /**

View File

@ -1610,8 +1610,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public void commit() { public void commit() {
AtomicAllocator.getInstance().getDeviceContext().syncOldStream(); val ctx = AtomicAllocator.getInstance().getDeviceContext();
AtomicAllocator.getInstance().getDeviceContext().syncSpecialStream(); ctx.syncOldStream();
ctx.syncSpecialStream();
} }
@Override @Override

View File

@ -738,52 +738,52 @@ public class DataSetTest extends BaseNd4jTest {
@Test @Test
public void testShuffleNd() { public void testShuffleNd() {
int numDims = 7; int numDims = 7;
int nLabels = 3; int nLabels = 3;
Random r = new Random(); Random r = new Random();
int[] shape = new int[numDims]; int[] shape = new int[numDims];
int entries = 1; int entries = 1;
for (int i = 0; i < numDims; i++) { for (int i = 0; i < numDims; i++) {
//randomly generating shapes bigger than 1 //randomly generating shapes bigger than 1
shape[i] = r.nextInt(4) + 2; shape[i] = r.nextInt(4) + 2;
entries *= shape[i]; entries *= shape[i];
} }
int labels = shape[0] * nLabels; int labels = shape[0] * nLabels;
INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape); INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels); INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
DataSet ds = new DataSet(ds_data, ds_labels); DataSet ds = new DataSet(ds_data, ds_labels);
ds.shuffle(); ds.shuffle();
//Checking Nd dataset which is the data //Checking Nd dataset which is the data
for (int dim = 1; dim < numDims; dim++) { for (int dim = 1; dim < numDims; dim++) {
//get tensor along dimension - the order in every dimension but zero should be preserved //get tensor along dimension - the order in every dimension but zero should be preserved
for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) { for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
//the difference between consecutive elements should be equal to the stride //the difference between consecutive elements should be equal to the stride
for (int i = 0, j = 1; j < shape[dim]; i++, j++) { for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i); int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j); int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
int f_element_diff = f_next_element - f_element; int f_element_diff = f_next_element - f_element;
assertEquals(f_element_diff, ds_data.stride(dim)); assertEquals(f_element_diff, ds_data.stride(dim));
}
} }
} }
}
//Checking 2d, features //Checking 2d, features
int dim = 1; int dim = 1;
//get tensor along dimension - the order in every dimension but zero should be preserved //get tensor along dimension - the order in every dimension but zero should be preserved
for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) { for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
//the difference between consecutive elements should be equal to the stride //the difference between consecutive elements should be equal to the stride
for (int i = 0, j = 1; j < nLabels; i++, j++) { for (int i = 0, j = 1; j < nLabels; i++, j++) {
int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i); int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j); int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
int l_element_diff = l_next_element - l_element; int l_element_diff = l_next_element - l_element;
assertEquals(l_element_diff, ds_labels.stride(dim)); assertEquals(l_element_diff, ds_labels.stride(dim));
}
} }
}
} }
@Test @Test