From e42c34ca5582995ace8756d75dd7ce84a619f9d6 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 2 Sep 2019 11:25:48 +0300 Subject: [PATCH] [WIP] minor (#218) * - initial docs commit - merge* cuda fix Signed-off-by: raver119 * one more fix Signed-off-by: raver119 * one more fix Signed-off-by: raver119 --- libnd4j/blas/NDArray.h | 2 ++ .../loops/cuda/transform/transform_bool.cu | 4 +-- .../loops/cuda/transform/transform_same.cu | 4 +-- .../loops/cuda/transform/transform_strict.cu | 4 +-- .../declarable/helpers/cuda/compare_elem.cu | 9 +++--- .../declarable/helpers/cuda/convolutions.cu | 6 ++-- .../ops/declarable/helpers/cuda/diag.cu | 4 +-- .../ops/declarable/helpers/cuda/dropout.cu | 19 ++++++++++-- .../ops/declarable/helpers/cuda/dynamic.cu | 8 ++++- .../ops/declarable/helpers/cuda/flatten.cu | 4 +++ .../ops/declarable/helpers/cuda/gradient.cu | 1 + .../ops/declarable/helpers/cuda/histogram.cu | 7 ++--- .../ops/declarable/helpers/cuda/ismax.cu | 1 + .../ops/declarable/helpers/cuda/lrn.cu | 6 ++-- .../ops/declarable/helpers/cuda/merge.cu | 31 ++++++++++++++++--- .../declarable/helpers/cuda/nth_element.cu | 10 +++--- .../ops/declarable/helpers/cuda/polyGamma.cu | 7 ++--- .../ops/declarable/helpers/cuda/range.cu | 5 +-- .../declarable/helpers/cuda/toggle_bits.cu | 2 -- 19 files changed, 89 insertions(+), 45 deletions(-) diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index 3ef3716b3..3a57fc92b 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -273,9 +273,11 @@ namespace nd4j { * @param writeList * @param readList */ + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list static void registerSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList); static void prepareSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables = false); + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list static void registerPrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList); static void preparePrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables = false); diff --git a/libnd4j/include/loops/cuda/transform/transform_bool.cu b/libnd4j/include/loops/cuda/transform/transform_bool.cu index a01221cfa..52e6b4a10 100644 --- a/libnd4j/include/loops/cuda/transform/transform_bool.cu +++ b/libnd4j/include/loops/cuda/transform/transform_bool.cu @@ -96,13 +96,13 @@ namespace functions { } else { if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); z[xOffset] = OpType::op(x[xOffset], params); } } else { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); z[zOffset] = OpType::op(x[xOffset], params); diff --git a/libnd4j/include/loops/cuda/transform/transform_same.cu b/libnd4j/include/loops/cuda/transform/transform_same.cu index a0d137d64..6c533ac3a 100644 --- a/libnd4j/include/loops/cuda/transform/transform_same.cu +++ b/libnd4j/include/loops/cuda/transform/transform_same.cu @@ -94,13 +94,13 @@ namespace functions { } else { if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); z[xOffset] = OpType::op(x[xOffset], params); } } else { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); z[zOffset] = OpType::op(x[xOffset], params); diff --git a/libnd4j/include/loops/cuda/transform/transform_strict.cu b/libnd4j/include/loops/cuda/transform/transform_strict.cu index 10385812d..a0989b0e6 100644 --- a/libnd4j/include/loops/cuda/transform/transform_strict.cu +++ b/libnd4j/include/loops/cuda/transform/transform_strict.cu @@ -96,13 +96,13 @@ namespace functions { } else { if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); z[xOffset] = OpType::op(x[xOffset], params); } } else { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); z[zOffset] = OpType::op(x[xOffset], params); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu b/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu index 54f518ad9..545d7c668 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu @@ -30,7 +30,7 @@ namespace nd4j { shared[threadIdx.x] = 0; - + // each thread will compare 2 elements: E and E+1 for (int e = tid; e < length - 1; e += blockDim.x * gridDim.x) { auto val0 = x[shape::getIndexOffset(e, xShapeInfo, length)]; auto val1 = x[shape::getIndexOffset(e+1, xShapeInfo, length)]; @@ -41,11 +41,12 @@ namespace nd4j { else v = val1 >= val0; + // store comparison result in shared memory shared[threadIdx.x] += v ? 0 : 1; } __syncthreads(); - // aggregate sum + // aggregate sums in shared memory for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { if (threadIdx.x < activeThreads) shared[threadIdx.x] += shared[threadIdx.x + activeThreads]; @@ -53,7 +54,7 @@ namespace nd4j { } - // store over the grid + // store over the grid if we have more than 1 block if (gridDim.x > 1) { auto tc = reinterpret_cast(reductionBuffer); @@ -96,7 +97,7 @@ namespace nd4j { } } else { - + // if we have only 1 block, we just store results right away if (threadIdx.x == 0) { auto tc = reinterpret_cast(reductionBuffer); tc[16384] = 0; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index 8b58ac38e..87e7c4f08 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -424,7 +424,7 @@ static __global__ void avgPooling2dCuda(const void *vx, const Nd4jLong *xShapeIn } __syncthreads(); - int tid = blockIdx.x * gridDim.x + threadIdx.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; for (int index = tid; index < length; index += blockDim.x * gridDim.x) { @@ -519,7 +519,7 @@ static __global__ void pnormPooling2dCuda(const void *vx, const Nd4jLong *xShape } __syncthreads(); - int tid = blockIdx.x * gridDim.x + threadIdx.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; for (int index = tid; index < length; index += blockDim.x * gridDim.x) { @@ -610,7 +610,7 @@ static __global__ void maxPooling2dCuda(const void *vx, const Nd4jLong *xShapeIn } __syncthreads(); - int tid = blockIdx.x * gridDim.x + threadIdx.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; for (int index = tid; index < length; index += blockDim.x * gridDim.x) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu index f4dff2279..0e861b866 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu @@ -39,7 +39,7 @@ static __global__ void diagFunctorKernel(void* outputBuffer, Nd4jLong* outputSha } __syncthreads(); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (int t = tid; t < inputLength; t += step) { z[shape::getIndexOffset(t * (inputLength + 1), outputShape, outputLength)] = x[shape::getIndexOffset(t, inputShape, inputLength)]; //tX]; @@ -59,7 +59,7 @@ static __global__ void diagFunctorKernel(void* outputBuffer, Nd4jLong* outputSha } __syncthreads(); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; Nd4jLong i = threadIdx.x * (outputLength + 1); for (int t = tid; t < outputLength && i < inputLength; t += step) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu index 40251fb82..5b4c27bd0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu @@ -35,9 +35,11 @@ namespace helpers { T const* input = reinterpret_cast(inputBuf); T* output = reinterpret_cast(outputBuf); + // trivial idea: loop through all elements, get independent probability for each element to be nullified for (Nd4jLong e = 0; e < inLen; ++e) { T val = nodeRng->relativeT(e, T(0.f), T(1.f)); + // if probability is ok - we're saving scaled value if (double(val) < probVal) output[shape::getIndexOffset(e, outputShape, inLen)] = T(input[shape::getIndexOffset(e, inputShape, inLen)] / probVal); } @@ -80,7 +82,7 @@ namespace helpers { std::vector dims(reduceShape->lengthOf()); reduceShape->syncToHost(); // to ensure that follows are actual bool fit = true; -// PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit)) + for( int i = 0; i < dims.size(); i++ ) { if (fit) { dims[i] = reduceShape->e(i); @@ -96,8 +98,7 @@ namespace helpers { REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); std::unique_ptr chunk(new NDArray('c', dims, output->dataType(), context.launchContext())); chunk->assign(1.f); - //chunk->applyRandom>(rng, nullptr, chunk.get(), &probValue); - //NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob); + dropoutSimple(context.launchContext(), chunk.get(), chunk.get(), probValue, seed); // broadcast chunk to full matrix std::unique_ptr dropOutMultiplier(new NDArray(*input)); @@ -105,6 +106,7 @@ namespace helpers { *dropOutMultiplier += *chunk; + // FIXME: we could do this in one step, aren't we? output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); } @@ -113,8 +115,11 @@ namespace helpers { int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { auto xType = input->dataType(); + NDArray::prepareSpecialUse({output}, {input}); BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES); + + NDArray::registerSpecialUse({output}, {input}); } /////////////////////////////////// backrpopagations /////////////////////////////////////////////// @@ -136,6 +141,8 @@ namespace helpers { for (int e = tid; e < len; e += step) { const auto zOffset = shape::getIndexOffset(e, outputShape, len); + + // if probability was non-zero on FF step, we'll scale grads back if (output[zOffset] != T(0.)) output[zOffset] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue); @@ -143,12 +150,17 @@ namespace helpers { } template static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { + // we're making additional FF run to see how probabilities played out with given seeds int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue); auto stream = context.launchContext()->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, gradOut}); + if (ND4J_STATUS_OK == res) dropoutBPKernel<<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue); + NDArray::registerSpecialUse({output}, {input, gradOut}); + return res; } @@ -239,6 +251,7 @@ namespace helpers { int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta); if (res == ND4J_STATUS_OK) { + // FIXME: can we make it single-loop? (*output) *= alpha; (*output) *= (*gradOut); //->applyPairwiseTransform(gradOut, output, nullptr); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu index 92e5b38b4..7d520478e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu @@ -43,7 +43,7 @@ namespace nd4j { } __syncthreads(); - + // we run things in blocks, 1 partition per block of threads for (Nd4jLong o = blockIdx.x; o < numOutputs; o += gridDim.x) { auto z = reinterpret_cast(vz[o]); @@ -89,9 +89,11 @@ namespace nd4j { auto x = reinterpret_cast(vx); auto indices = reinterpret_cast(vindices); + // we run things in blocks, 1 partition per block of threads for (int i = blockIdx.x; i < numOutputs; i += gridDim.x) { auto z = reinterpret_cast(vz[i]); + // each thread has own counter for partitions int outCnt = 0; for (Nd4jLong e = 0; e < iLength; e++) { @@ -145,6 +147,7 @@ namespace nd4j { tadOffsets[i] = packZ.platformOffsets(); } + // we copy pointers to device auto dOutBuffers = reinterpret_cast(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); auto dOutTadShapes = reinterpret_cast(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); @@ -248,6 +251,7 @@ namespace nd4j { indicesShapes[e] = indices.at(e)->getSpecialShapeInfo(); } + // copying pointers to buffers to device auto dInputBuffers = reinterpret_cast(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); auto dIndicesBuffers = reinterpret_cast(pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *))); auto dInputShapes = reinterpret_cast(pm.replicatePointer(inputShapes.data(), inputSize * sizeof(Nd4jLong *))); @@ -283,6 +287,7 @@ namespace nd4j { inputTadOffsets[e] = packX.platformOffsets(); } + // copying pointers to buffers to device auto dInputBuffers = reinterpret_cast(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); auto dInputTadShapes = reinterpret_cast(pm.replicatePointer(inputTadShapes.data(), inputSize * sizeof(Nd4jLong *))); auto dInputTadOffsets = reinterpret_cast(pm.replicatePointer(inputTadOffsets.data(), inputSize * sizeof(Nd4jLong *))); @@ -313,6 +318,7 @@ namespace nd4j { NDArray::registerSpecialUse({}, {indices, input}); + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list for (auto v:outputList) { v->tickWriteDevice(); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu b/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu index 9baab5b36..6a818a2cd 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu @@ -29,6 +29,7 @@ namespace nd4j { Nd4jLong xCoord[MAX_RANK]; + // each block of threads works on 1 input array for (Nd4jLong e = blockIdx.x; e < numInputs; e += gridDim.x) { auto z = reinterpret_cast(zBuffer) + offsets[e]; @@ -39,6 +40,7 @@ namespace nd4j { auto xRank = shape::rank(xShapeInfo); auto xLength = shape::length(xShapeInfo); + // each element of this input array has own place within common output array for (uint i = threadIdx.x; i < xLength; i += blockDim.x) { shape::index2coords(xRank, xShape, i, xLength, xCoord, order); auto xOffset = shape::getOffset(0, xShape, xStride, xCoord, xRank); @@ -65,6 +67,7 @@ namespace nd4j { hdShapes[e] = inputs[e]->specialShapeInfo(); } + // copying pointers to device auto dBuffers = (void **) pm.replicatePointer(hdBuffers.data(), inputs.size() * sizeof(void*)); auto dShapes = (Nd4jLong **)pm.replicatePointer(hdShapes.data(), inputs.size() * sizeof(Nd4jLong*)); auto dOffsets = (Nd4jLong *) pm.replicatePointer(hOffsets.data(), inputs.size() * sizeof(Nd4jLong)); @@ -76,6 +79,7 @@ namespace nd4j { } void flatten(nd4j::LaunchContext *context, std::vector &inputs, NDArray *output, char order) { + // FIXME: we want NDArrayFactory::prepareSpecialUse here eventually for (auto v:inputs) v->syncToDevice(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu index 86f5b4d5a..9d0e5e55b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu @@ -26,6 +26,7 @@ namespace ops { namespace helpers { template void applyGradientDescent_(LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { + // classic one auto lambda = LAMBDA_TT(_x, _y, weight) { return _x - (_y * weight); }; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu index 52b059dad..a4bcbb311 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu @@ -44,6 +44,7 @@ namespace nd4j { X binSize = X((*max_val - *min_val) / numBins); + // nullify bins for (int e = threadIdx.x; e < numBins; e += blockDim.x) { bins[e] = (Z) 0; } @@ -53,14 +54,12 @@ namespace nd4j { int idx = int((dx[e] - *min_val) / binSize); idx = math::nd4j_max(idx, 0); //atomicMax(&idx, 0);//atomicMax(&idx, 0); idx = math::nd4j_min(idx, int(numBins - 1)); //atomicMin(&idx, int(numBins - 1)); - nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1); -// bins[idx]++; + nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1); } __syncthreads(); + // at this point all bins in shared memory are calculated, so we aggregate them now via threadfence trick // transfer shared memory to reduction memory - - if (gridDim.x > 1) { unsigned int *tc = (unsigned int *)reductionPointer; __shared__ bool amLast; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu index bc7ea1caa..a5d686dc2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu @@ -64,6 +64,7 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), copy.data(), copy.size()); + // we launch legacy IndexMax op, to get indices of max values along dimension auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions); dim3 launchDims(256, 256, 16384); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu index 239d280dc..ca2bda4a4 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu @@ -41,12 +41,12 @@ namespace helpers { const T tbeta = static_cast(beta); const T talpha = static_cast(alpha); - + // one block of threads processes 1 example within batch for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { auto x = reinterpret_cast(vx) + xTadOffsets[i]; auto z = reinterpret_cast(vz) + zTadOffsets[i]; - // load everything into shared memory + // load everything into shared memory, so we'll operate on shared memory from now on shared[threadIdx.x] = x[threadIdx.x * xEws]; __syncthreads(); @@ -94,7 +94,7 @@ namespace helpers { sharedY[threadIdx.x] = 0.f; __syncthreads(); - + // we're operating in shared memory for (int s = begin; s < end; s++) sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s]; __syncthreads(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu index 082472fce..27c8fc630 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu @@ -37,7 +37,7 @@ namespace nd4j { static __global__ void global_mergeMaxIndex_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { auto output = reinterpret_cast(voutput); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (Nd4jLong e = tid; e < length; e += step) { @@ -81,7 +81,13 @@ namespace nd4j { } void mergeMaxIndex(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, {}); + for (auto v:inArrs) + v->syncToDevice(); + BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES); + + NDArray::registerSpecialUse({&output}, {}); } @@ -90,7 +96,7 @@ namespace nd4j { static __global__ void global_mergeMax_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { auto output = reinterpret_cast(voutput); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (Nd4jLong e = tid; e < length; e += step) { @@ -131,7 +137,12 @@ namespace nd4j { } void mergeMax(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, {}); + for (auto v:inArrs) + v->syncToDevice(); + BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {}); } ////////////////////////////////////////////////////////////////////////// @@ -139,7 +150,7 @@ namespace nd4j { static __global__ void global_mergeAvg_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { auto output = reinterpret_cast(voutput); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (Nd4jLong e = tid; e < length; e += step) { @@ -178,7 +189,13 @@ namespace nd4j { } void mergeAvg(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, {}); + for (auto v:inArrs) + v->syncToDevice(); + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES); + + NDArray::registerSpecialUse({&output}, {}); } ////////////////////////////////////////////////////////////////////////// @@ -186,7 +203,7 @@ namespace nd4j { static __global__ void global_mergeAdd_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { auto output = reinterpret_cast(voutput); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (Nd4jLong e = tid; e < length; e += step) { @@ -226,7 +243,13 @@ namespace nd4j { BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), NUMERIC_TYPES); void mergeAdd(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, {}); + for (auto v:inArrs) + v->syncToDevice(); + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES); + + NDArray::registerSpecialUse({&output}, {}); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu index aeddd3b97..3b80f3df9 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu @@ -31,18 +31,18 @@ namespace helpers { template static __global__ void fillUpElementKernel(void* outputBuffer, Nd4jLong* outputShapeInfo, void* inputBuffer, Nd4jLong* inputShapeInfo, Nd4jLong* pTadShape, Nd4jLong* pTadOffsets, Nd4jLong n) { - __shared__ T *z, *x; __shared__ Nd4jLong bufferLength, arrLen; + auto z = reinterpret_cast(outputBuffer); + auto x = reinterpret_cast(inputBuffer); + if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - x = reinterpret_cast(inputBuffer); arrLen = shape::length(pTadShape); bufferLength = shape::length(outputShapeInfo); } __syncthreads(); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (int t = tid; t < bufferLength; t += step) { auto tX = x + pTadOffsets[t]; @@ -77,8 +77,6 @@ namespace helpers { // manager.synchronize(); sortedVals.tickWriteDevice(); sortedVals.syncToHost(); - sortedVals.printIndexedBuffer("Hello"); - sortedVals.printBuffer("Hello line"); auto stream = context->getCudaStream(); fillUpElementKernel<<<32, 64, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), pTadShape, pTadOffsets, n); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu index 94d3c02ea..bddaf65e3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu @@ -74,17 +74,14 @@ static void polyGammaCudaLauncher(const int blocksPerGrid, const int threadsPerB /////////////////////////////////////////////////////////////////// void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& z) { - if(!n.isActualOnDeviceSide()) n.syncToDevice(); - if(!x.isActualOnDeviceSide()) x.syncToDevice(); + NDArray::prepareSpecialUse({&z}, {&n, &x}); int threadsPerBlock = MAX_NUM_THREADS; int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); - n.tickReadHost(); - x.tickReadHost(); - z.tickWriteDevice(); + NDArray::registerSpecialUse({&z}, {&n, &x}); } BUILD_SINGLE_TEMPLATE(template void polyGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vn, const Nd4jLong *nShapeInfo, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/range.cu b/libnd4j/include/ops/declarable/helpers/cuda/range.cu index 7e8ddb2a7..3a5504905 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/range.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/range.cu @@ -28,7 +28,7 @@ namespace helpers { template static __global__ void global_range(void *output, Nd4jLong length, T start, T delta) { auto buff = reinterpret_cast(output); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for(Nd4jLong i = tid; i < length; i += step) @@ -43,10 +43,11 @@ namespace helpers { } void range(nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) { + NDArray::prepareSpecialUse({&outVector}, {&start, &delta}); BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (context, start, delta, outVector), LIBND4J_TYPES); + NDArray::registerSpecialUse({&outVector}, {&start, &delta}); } - BUILD_SINGLE_TEMPLATE(template void _range, (nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector), NUMERIC_TYPES); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu index f90c9f77f..8c67cbf1b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu @@ -26,13 +26,11 @@ namespace nd4j { namespace helpers { template void toggle_bits__(NDArray &in, NDArray &out) { - NDArray::prepareSpecialUse({&out}, {&in}); auto lambda = LAMBDA_T(_x) { return ~_x;//eUtils::flip_bits(_x); }; in.applyLambda(lambda, &out); - NDArray::registerSpecialUse({&out}, {&in}); } BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES);