parent
7ded4416cb
commit
0e05cba2f9
|
@ -271,7 +271,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
|||
Pointer deltaData = allocator.getPointer(delta, context);
|
||||
Pointer dstData = allocator.getPointer(epsNext, context);
|
||||
|
||||
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()));
|
||||
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
|
||||
checkCudnn(false, "cudnnSetStream", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||
|
||||
code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
|
||||
|
@ -456,7 +456,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
|||
Pointer biasData = allocator.getPointer(bias, context);
|
||||
Pointer dstData = allocator.getPointer(z, context);
|
||||
|
||||
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()));
|
||||
code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()));
|
||||
checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||
|
||||
code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
|
||||
|
@ -546,7 +546,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
|||
CudaContext context = allocator.getFlowController().prepareAction(z);
|
||||
Pointer dstData = allocator.getPointer(z, context);
|
||||
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||
switch (afn.toString()) {
|
||||
case "identity":
|
||||
break;
|
||||
|
|
|
@ -185,7 +185,7 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
|
|||
Pointer zData = allocator.getPointer(reduced, context);
|
||||
Pointer dstData = allocator.getPointer(outEpsilon, context);
|
||||
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||
checkCudnn(cudnnPoolingBackward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.deltaTensorDesc,
|
||||
zData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.srcTensorDesc, srcData, beta,
|
||||
cudnnContext.dstTensorDesc, dstData));
|
||||
|
@ -259,7 +259,7 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
|
|||
Pointer srcData = allocator.getPointer(input, context);
|
||||
Pointer dstData = allocator.getPointer(reduced, context);
|
||||
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||
checkCudnn(cudnnPoolingForward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.srcTensorDesc,
|
||||
srcData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||
|
||||
|
|
|
@ -194,7 +194,7 @@ public class CudnnDropoutHelper extends BaseCudnnHelper implements DropoutHelper
|
|||
Pointer xPtr = allocator.getPointer(input, context);
|
||||
Pointer yPtr = allocator.getPointer(resultArray, context);
|
||||
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||
checkCudnn(cudnnDropoutForward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.xTensorDesc, xPtr,
|
||||
cudnnContext.yTensorDesc, yPtr, mask, mask.capacity()));
|
||||
|
||||
|
|
|
@ -188,7 +188,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
|||
Pointer meanCacheData = allocator.getPointer(meanCache, context);
|
||||
Pointer varCacheData = allocator.getPointer(varCache, context);
|
||||
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha,
|
||||
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
|
||||
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
|
||||
|
@ -268,7 +268,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
|||
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||
if (training) {
|
||||
if(meanCache == null || meanCache.length() < mean.length()){
|
||||
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
|
|
|
@ -171,7 +171,7 @@ public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper imple
|
|||
Pointer zData = allocator.getPointer(activations, context);
|
||||
Pointer dstData = allocator.getPointer(nextEpsilon, context);
|
||||
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||
checkCudnn(cudnnLRNCrossChannelBackward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
|
||||
this.alpha, cudnnContext.deltaTensorDesc, zData, cudnnContext.deltaTensorDesc, epsData,
|
||||
cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc, dstData));
|
||||
|
@ -215,7 +215,7 @@ public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper imple
|
|||
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||
checkCudnn(cudnnLRNCrossChannelForward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
|
||||
this.alpha, cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc,
|
||||
dstData));
|
||||
|
|
|
@ -248,7 +248,7 @@ public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper {
|
|||
Pointer rwGradientsOutData = allocator.getPointer(rwGradientsOut, context);
|
||||
Pointer bGradientsOutData = allocator.getPointer(bGradientsOut, context);
|
||||
|
||||
CUstream_st stream = new CUstream_st(context.getOldStream());
|
||||
CUstream_st stream = new CUstream_st(context.getCublasStream());
|
||||
checkCudnn(cudnnSetStream(cudnnContext, stream));
|
||||
|
||||
if (truncatedBPTT) {
|
||||
|
@ -531,7 +531,7 @@ public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper {
|
|||
Pointer finalMemCellStateData = allocator.getPointer(finalMemCellState, context);
|
||||
Pointer finalTimeStepActivationsData = allocator.getPointer(finalStepActivations, context);
|
||||
|
||||
CUstream_st stream = new CUstream_st(context.getOldStream());
|
||||
CUstream_st stream = new CUstream_st(context.getCublasStream());
|
||||
checkCudnn(cudnnSetStream(cudnnContext, stream));
|
||||
|
||||
checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream));
|
||||
|
|
Loading…
Reference in New Issue