CUDA sync tweaks (#194)
* ThreadLocal cache for CudaContext Signed-off-by: raver119 <raver119@gmail.com> * temp commit Signed-off-by: raver119 <raver119@gmail.com> * remove unwanted synchronization Signed-off-by: raver119 <raver119@gmail.com>master
parent
7ef0ef907e
commit
9f719488b9
|
@ -353,18 +353,8 @@ public class AtomicAllocator implements Allocator {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void synchronizeHostData(DataBuffer buffer) {
|
public void synchronizeHostData(DataBuffer buffer) {
|
||||||
// we don't want non-committed ops left behind
|
|
||||||
Nd4j.getExecutioner().commit();
|
|
||||||
|
|
||||||
val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
|
|
||||||
|
|
||||||
// we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code
|
// we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
|
||||||
|
|
||||||
val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
|
|
||||||
|
|
||||||
//assert oPtr.address() == cPtr.address();
|
|
||||||
//assert buffer.address() == oPtr.address();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -102,6 +102,8 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
|
|
||||||
private final AffinityManager affinityManager = Nd4j.getAffinityManager();
|
private final AffinityManager affinityManager = Nd4j.getAffinityManager();
|
||||||
|
|
||||||
|
private final transient ThreadLocal<CudaContext> tlContext = new ThreadLocal<>();
|
||||||
|
|
||||||
/*
|
/*
|
||||||
table for Thread, Device, Object allocations of device memory. Objects should be used to grab Allocation point from allocationsMap
|
table for Thread, Device, Object allocations of device memory. Objects should be used to grab Allocation point from allocationsMap
|
||||||
*/
|
*/
|
||||||
|
@ -1018,18 +1020,25 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public CudaContext getCudaContext() {
|
public CudaContext getCudaContext() {
|
||||||
val lc = nativeOps.defaultLaunchContext();
|
var ctx = tlContext.get();
|
||||||
|
if (ctx == null) {
|
||||||
|
val lc = nativeOps.defaultLaunchContext();
|
||||||
|
|
||||||
return CudaContext.builder()
|
ctx = CudaContext.builder()
|
||||||
.bufferScalar(nativeOps.lcScalarPointer(lc))
|
.bufferScalar(nativeOps.lcScalarPointer(lc))
|
||||||
.bufferReduction(nativeOps.lcReductionPointer(lc))
|
.bufferReduction(nativeOps.lcReductionPointer(lc))
|
||||||
.bufferAllocation(nativeOps.lcAllocationPointer(lc))
|
.bufferAllocation(nativeOps.lcAllocationPointer(lc))
|
||||||
.bufferSpecial(nativeOps.lcScalarPointer(lc))
|
.bufferSpecial(nativeOps.lcScalarPointer(lc))
|
||||||
.oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc)))
|
.oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc)))
|
||||||
.specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc)))
|
.specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc)))
|
||||||
.cublasHandle(getCudaCublasHandle(lc))
|
.cublasHandle(getCudaCublasHandle(lc))
|
||||||
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
|
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
tlContext.set(ctx);
|
||||||
|
return ctx;
|
||||||
|
} else
|
||||||
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1610,8 +1610,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void commit() {
|
public void commit() {
|
||||||
AtomicAllocator.getInstance().getDeviceContext().syncOldStream();
|
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
||||||
AtomicAllocator.getInstance().getDeviceContext().syncSpecialStream();
|
ctx.syncOldStream();
|
||||||
|
ctx.syncSpecialStream();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -738,52 +738,52 @@ public class DataSetTest extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testShuffleNd() {
|
public void testShuffleNd() {
|
||||||
int numDims = 7;
|
int numDims = 7;
|
||||||
int nLabels = 3;
|
int nLabels = 3;
|
||||||
Random r = new Random();
|
Random r = new Random();
|
||||||
|
|
||||||
|
|
||||||
int[] shape = new int[numDims];
|
int[] shape = new int[numDims];
|
||||||
int entries = 1;
|
int entries = 1;
|
||||||
for (int i = 0; i < numDims; i++) {
|
for (int i = 0; i < numDims; i++) {
|
||||||
//randomly generating shapes bigger than 1
|
//randomly generating shapes bigger than 1
|
||||||
shape[i] = r.nextInt(4) + 2;
|
shape[i] = r.nextInt(4) + 2;
|
||||||
entries *= shape[i];
|
entries *= shape[i];
|
||||||
}
|
}
|
||||||
int labels = shape[0] * nLabels;
|
int labels = shape[0] * nLabels;
|
||||||
|
|
||||||
INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
|
INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
|
||||||
INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
|
INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
|
||||||
|
|
||||||
DataSet ds = new DataSet(ds_data, ds_labels);
|
DataSet ds = new DataSet(ds_data, ds_labels);
|
||||||
ds.shuffle();
|
ds.shuffle();
|
||||||
|
|
||||||
//Checking Nd dataset which is the data
|
//Checking Nd dataset which is the data
|
||||||
for (int dim = 1; dim < numDims; dim++) {
|
for (int dim = 1; dim < numDims; dim++) {
|
||||||
//get tensor along dimension - the order in every dimension but zero should be preserved
|
//get tensor along dimension - the order in every dimension but zero should be preserved
|
||||||
for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
|
for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
|
||||||
//the difference between consecutive elements should be equal to the stride
|
//the difference between consecutive elements should be equal to the stride
|
||||||
for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
|
for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
|
||||||
int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
|
int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
|
||||||
int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
|
int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
|
||||||
int f_element_diff = f_next_element - f_element;
|
int f_element_diff = f_next_element - f_element;
|
||||||
assertEquals(f_element_diff, ds_data.stride(dim));
|
assertEquals(f_element_diff, ds_data.stride(dim));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
//Checking 2d, features
|
//Checking 2d, features
|
||||||
int dim = 1;
|
int dim = 1;
|
||||||
//get tensor along dimension - the order in every dimension but zero should be preserved
|
//get tensor along dimension - the order in every dimension but zero should be preserved
|
||||||
for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
|
for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
|
||||||
//the difference between consecutive elements should be equal to the stride
|
//the difference between consecutive elements should be equal to the stride
|
||||||
for (int i = 0, j = 1; j < nLabels; i++, j++) {
|
for (int i = 0, j = 1; j < nLabels; i++, j++) {
|
||||||
int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
|
int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
|
||||||
int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
|
int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
|
||||||
int l_element_diff = l_next_element - l_element;
|
int l_element_diff = l_next_element - l_element;
|
||||||
assertEquals(l_element_diff, ds_labels.stride(dim));
|
assertEquals(l_element_diff, ds_labels.stride(dim));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
Loading…
Reference in New Issue