change pointer reference for cudnn (#220)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-09-02 12:40:32 +03:00 committed by GitHub
parent 7ded4416cb
commit 0e05cba2f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 12 additions and 12 deletions

View File

@ -271,7 +271,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
Pointer deltaData = allocator.getPointer(delta, context); Pointer deltaData = allocator.getPointer(delta, context);
Pointer dstData = allocator.getPointer(epsNext, context); Pointer dstData = allocator.getPointer(epsNext, context);
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())); code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
checkCudnn(false, "cudnnSetStream", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnSetStream", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW, code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
@ -456,7 +456,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
Pointer biasData = allocator.getPointer(bias, context); Pointer biasData = allocator.getPointer(bias, context);
Pointer dstData = allocator.getPointer(z, context); Pointer dstData = allocator.getPointer(z, context);
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())); code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc, code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
@ -546,7 +546,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
CudaContext context = allocator.getFlowController().prepareAction(z); CudaContext context = allocator.getFlowController().prepareAction(z);
Pointer dstData = allocator.getPointer(z, context); Pointer dstData = allocator.getPointer(z, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
switch (afn.toString()) { switch (afn.toString()) {
case "identity": case "identity":
break; break;

View File

@ -185,7 +185,7 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
Pointer zData = allocator.getPointer(reduced, context); Pointer zData = allocator.getPointer(reduced, context);
Pointer dstData = allocator.getPointer(outEpsilon, context); Pointer dstData = allocator.getPointer(outEpsilon, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnPoolingBackward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.deltaTensorDesc, checkCudnn(cudnnPoolingBackward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.deltaTensorDesc,
zData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.srcTensorDesc, srcData, beta, zData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.srcTensorDesc, srcData, beta,
cudnnContext.dstTensorDesc, dstData)); cudnnContext.dstTensorDesc, dstData));
@ -259,7 +259,7 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
Pointer srcData = allocator.getPointer(input, context); Pointer srcData = allocator.getPointer(input, context);
Pointer dstData = allocator.getPointer(reduced, context); Pointer dstData = allocator.getPointer(reduced, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnPoolingForward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.srcTensorDesc, checkCudnn(cudnnPoolingForward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.srcTensorDesc,
srcData, beta, cudnnContext.dstTensorDesc, dstData)); srcData, beta, cudnnContext.dstTensorDesc, dstData));

View File

@ -194,7 +194,7 @@ public class CudnnDropoutHelper extends BaseCudnnHelper implements DropoutHelper
Pointer xPtr = allocator.getPointer(input, context); Pointer xPtr = allocator.getPointer(input, context);
Pointer yPtr = allocator.getPointer(resultArray, context); Pointer yPtr = allocator.getPointer(resultArray, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnDropoutForward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.xTensorDesc, xPtr, checkCudnn(cudnnDropoutForward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.xTensorDesc, xPtr,
cudnnContext.yTensorDesc, yPtr, mask, mask.capacity())); cudnnContext.yTensorDesc, yPtr, mask, mask.capacity()));

View File

@ -188,7 +188,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
Pointer meanCacheData = allocator.getPointer(meanCache, context); Pointer meanCacheData = allocator.getPointer(meanCache, context);
Pointer varCacheData = allocator.getPointer(varCache, context); Pointer varCacheData = allocator.getPointer(varCache, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha, checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha,
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData, cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
@ -268,7 +268,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
if (Nd4j.getExecutioner() instanceof GridExecutioner) if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
if (training) { if (training) {
if(meanCache == null || meanCache.length() < mean.length()){ if(meanCache == null || meanCache.length() < mean.length()){
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {

View File

@ -171,7 +171,7 @@ public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper imple
Pointer zData = allocator.getPointer(activations, context); Pointer zData = allocator.getPointer(activations, context);
Pointer dstData = allocator.getPointer(nextEpsilon, context); Pointer dstData = allocator.getPointer(nextEpsilon, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnLRNCrossChannelBackward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1, checkCudnn(cudnnLRNCrossChannelBackward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
this.alpha, cudnnContext.deltaTensorDesc, zData, cudnnContext.deltaTensorDesc, epsData, this.alpha, cudnnContext.deltaTensorDesc, zData, cudnnContext.deltaTensorDesc, epsData,
cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc, dstData)); cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc, dstData));
@ -215,7 +215,7 @@ public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper imple
if (Nd4j.getExecutioner() instanceof GridExecutioner) if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnLRNCrossChannelForward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1, checkCudnn(cudnnLRNCrossChannelForward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
this.alpha, cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc, this.alpha, cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc,
dstData)); dstData));

View File

@ -248,7 +248,7 @@ public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper {
Pointer rwGradientsOutData = allocator.getPointer(rwGradientsOut, context); Pointer rwGradientsOutData = allocator.getPointer(rwGradientsOut, context);
Pointer bGradientsOutData = allocator.getPointer(bGradientsOut, context); Pointer bGradientsOutData = allocator.getPointer(bGradientsOut, context);
CUstream_st stream = new CUstream_st(context.getOldStream()); CUstream_st stream = new CUstream_st(context.getCublasStream());
checkCudnn(cudnnSetStream(cudnnContext, stream)); checkCudnn(cudnnSetStream(cudnnContext, stream));
if (truncatedBPTT) { if (truncatedBPTT) {
@ -531,7 +531,7 @@ public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper {
Pointer finalMemCellStateData = allocator.getPointer(finalMemCellState, context); Pointer finalMemCellStateData = allocator.getPointer(finalMemCellState, context);
Pointer finalTimeStepActivationsData = allocator.getPointer(finalStepActivations, context); Pointer finalTimeStepActivationsData = allocator.getPointer(finalStepActivations, context);
CUstream_st stream = new CUstream_st(context.getOldStream()); CUstream_st stream = new CUstream_st(context.getCublasStream());
checkCudnn(cudnnSetStream(cudnnContext, stream)); checkCudnn(cudnnSetStream(cudnnContext, stream));
checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream)); checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream));