diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java index 03e5df160..aaccf9a34 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java @@ -353,18 +353,8 @@ public class AtomicAllocator implements Allocator { */ @Override public void synchronizeHostData(DataBuffer buffer) { - // we don't want non-committed ops left behind - Nd4j.getExecutioner().commit(); - - val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); - // we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); - - val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); - - //assert oPtr.address() == cPtr.address(); - //assert buffer.address() == oPtr.address(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index 6f0944c5a..a8f3a0a3b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -102,6 +102,8 @@ public class CudaZeroHandler implements MemoryHandler { private final AffinityManager affinityManager = Nd4j.getAffinityManager(); + private final transient ThreadLocal tlContext = new ThreadLocal<>(); + /* table for Thread, Device, Object allocations of device memory. Objects should be used to grab Allocation point from allocationsMap */ @@ -1018,18 +1020,25 @@ public class CudaZeroHandler implements MemoryHandler { * @return */ public CudaContext getCudaContext() { - val lc = nativeOps.defaultLaunchContext(); + var ctx = tlContext.get(); + if (ctx == null) { + val lc = nativeOps.defaultLaunchContext(); - return CudaContext.builder() - .bufferScalar(nativeOps.lcScalarPointer(lc)) - .bufferReduction(nativeOps.lcReductionPointer(lc)) - .bufferAllocation(nativeOps.lcAllocationPointer(lc)) - .bufferSpecial(nativeOps.lcScalarPointer(lc)) - .oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc))) - .specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc))) - .cublasHandle(getCudaCublasHandle(lc)) - .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) - .build(); + ctx = CudaContext.builder() + .bufferScalar(nativeOps.lcScalarPointer(lc)) + .bufferReduction(nativeOps.lcReductionPointer(lc)) + .bufferAllocation(nativeOps.lcAllocationPointer(lc)) + .bufferSpecial(nativeOps.lcScalarPointer(lc)) + .oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc))) + .specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc))) + .cublasHandle(getCudaCublasHandle(lc)) + .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) + .build(); + + tlContext.set(ctx); + return ctx; + } else + return ctx; } /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 1615e4843..04b86dc02 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1610,8 +1610,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void commit() { - AtomicAllocator.getInstance().getDeviceContext().syncOldStream(); - AtomicAllocator.getInstance().getDeviceContext().syncSpecialStream(); + val ctx = AtomicAllocator.getInstance().getDeviceContext(); + ctx.syncOldStream(); + ctx.syncSpecialStream(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java index a62dc631e..816003009 100755 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java @@ -738,52 +738,52 @@ public class DataSetTest extends BaseNd4jTest { @Test public void testShuffleNd() { - int numDims = 7; - int nLabels = 3; - Random r = new Random(); + int numDims = 7; + int nLabels = 3; + Random r = new Random(); - int[] shape = new int[numDims]; - int entries = 1; - for (int i = 0; i < numDims; i++) { - //randomly generating shapes bigger than 1 - shape[i] = r.nextInt(4) + 2; - entries *= shape[i]; - } - int labels = shape[0] * nLabels; + int[] shape = new int[numDims]; + int entries = 1; + for (int i = 0; i < numDims; i++) { + //randomly generating shapes bigger than 1 + shape[i] = r.nextInt(4) + 2; + entries *= shape[i]; + } + int labels = shape[0] * nLabels; - INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape); - INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels); + INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape); + INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels); - DataSet ds = new DataSet(ds_data, ds_labels); - ds.shuffle(); + DataSet ds = new DataSet(ds_data, ds_labels); + ds.shuffle(); - //Checking Nd dataset which is the data - for (int dim = 1; dim < numDims; dim++) { - //get tensor along dimension - the order in every dimension but zero should be preserved - for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) { - //the difference between consecutive elements should be equal to the stride - for (int i = 0, j = 1; j < shape[dim]; i++, j++) { - int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i); - int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j); - int f_element_diff = f_next_element - f_element; - assertEquals(f_element_diff, ds_data.stride(dim)); + //Checking Nd dataset which is the data + for (int dim = 1; dim < numDims; dim++) { + //get tensor along dimension - the order in every dimension but zero should be preserved + for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) { + //the difference between consecutive elements should be equal to the stride + for (int i = 0, j = 1; j < shape[dim]; i++, j++) { + int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i); + int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j); + int f_element_diff = f_next_element - f_element; + assertEquals(f_element_diff, ds_data.stride(dim)); + } } } - } - //Checking 2d, features - int dim = 1; - //get tensor along dimension - the order in every dimension but zero should be preserved - for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) { - //the difference between consecutive elements should be equal to the stride - for (int i = 0, j = 1; j < nLabels; i++, j++) { - int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i); - int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j); - int l_element_diff = l_next_element - l_element; - assertEquals(l_element_diff, ds_labels.stride(dim)); + //Checking 2d, features + int dim = 1; + //get tensor along dimension - the order in every dimension but zero should be preserved + for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) { + //the difference between consecutive elements should be equal to the stride + for (int i = 0, j = 1; j < nLabels; i++, j++) { + int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i); + int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j); + int l_element_diff = l_next_element - l_element; + assertEquals(l_element_diff, ds_labels.stride(dim)); + } } - } } @Test