change pointer reference for cudnn (#220)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-09-02 12:40:32 +03:00 committed by GitHub
parent 7ded4416cb
commit 0e05cba2f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 12 additions and 12 deletions

View File

@ -271,7 +271,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
Pointer deltaData = allocator.getPointer(delta, context);
Pointer dstData = allocator.getPointer(epsNext, context);
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()));
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
checkCudnn(false, "cudnnSetStream", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
@ -456,7 +456,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
Pointer biasData = allocator.getPointer(bias, context);
Pointer dstData = allocator.getPointer(z, context);
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()));
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
@ -546,7 +546,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
CudaContext context = allocator.getFlowController().prepareAction(z);
Pointer dstData = allocator.getPointer(z, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
switch (afn.toString()) {
case "identity":
break;

View File

@ -185,7 +185,7 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
Pointer zData = allocator.getPointer(reduced, context);
Pointer dstData = allocator.getPointer(outEpsilon, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnPoolingBackward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.deltaTensorDesc,
zData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.srcTensorDesc, srcData, beta,
cudnnContext.dstTensorDesc, dstData));
@ -259,7 +259,7 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
Pointer srcData = allocator.getPointer(input, context);
Pointer dstData = allocator.getPointer(reduced, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnPoolingForward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.srcTensorDesc,
srcData, beta, cudnnContext.dstTensorDesc, dstData));

View File

@ -194,7 +194,7 @@ public class CudnnDropoutHelper extends BaseCudnnHelper implements DropoutHelper
Pointer xPtr = allocator.getPointer(input, context);
Pointer yPtr = allocator.getPointer(resultArray, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnDropoutForward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.xTensorDesc, xPtr,
cudnnContext.yTensorDesc, yPtr, mask, mask.capacity()));

View File

@ -188,7 +188,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
Pointer meanCacheData = allocator.getPointer(meanCache, context);
Pointer varCacheData = allocator.getPointer(varCache, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha,
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
@ -268,7 +268,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
if (training) {
if(meanCache == null || meanCache.length() < mean.length()){
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {

View File

@ -171,7 +171,7 @@ public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper imple
Pointer zData = allocator.getPointer(activations, context);
Pointer dstData = allocator.getPointer(nextEpsilon, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnLRNCrossChannelBackward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
this.alpha, cudnnContext.deltaTensorDesc, zData, cudnnContext.deltaTensorDesc, epsData,
cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc, dstData));
@ -215,7 +215,7 @@ public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper imple
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnLRNCrossChannelForward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
this.alpha, cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc,
dstData));

View File

@ -248,7 +248,7 @@ public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper {
Pointer rwGradientsOutData = allocator.getPointer(rwGradientsOut, context);
Pointer bGradientsOutData = allocator.getPointer(bGradientsOut, context);
CUstream_st stream = new CUstream_st(context.getOldStream());
CUstream_st stream = new CUstream_st(context.getCublasStream());
checkCudnn(cudnnSetStream(cudnnContext, stream));
if (truncatedBPTT) {
@ -531,7 +531,7 @@ public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper {
Pointer finalMemCellStateData = allocator.getPointer(finalMemCellState, context);
Pointer finalTimeStepActivationsData = allocator.getPointer(finalStepActivations, context);
CUstream_st stream = new CUstream_st(context.getOldStream());
CUstream_st stream = new CUstream_st(context.getCublasStream());
checkCudnn(cudnnSetStream(cudnnContext, stream));
checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream));