cavis/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu

136 lines
5.5 KiB
Plaintext

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
#include <ops/declarable/helpers/compare_elem.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static _CUDA_G void comparator(void *vx, const Nd4jLong *xShapeInfo, Nd4jLong length, const bool isStrict, void *reductionBuffer, bool *z) {
auto x = reinterpret_cast<T*>(vx);
auto reduction = reinterpret_cast<uint32_t*>(reductionBuffer);
extern __shared__ uint32_t shared[];
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
shared[threadIdx.x] = 0;
// each thread will compare 2 elements: E and E+1
for (int e = tid; e < length - 1; e += blockDim.x * gridDim.x) {
auto val0 = x[shape::getIndexOffset(e, xShapeInfo)];
auto val1 = x[shape::getIndexOffset(e+1, xShapeInfo)];
bool v = false;
if (isStrict)
v = val1 > val0;
else
v = val1 >= val0;
// store comparison result in shared memory
shared[threadIdx.x] += v ? 0 : 1;
}
__syncthreads();
// aggregate sums in shared memory
for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
if (threadIdx.x < activeThreads)
shared[threadIdx.x] += shared[threadIdx.x + activeThreads];
__syncthreads();
}
// store over the grid if we have more than 1 block
if (gridDim.x > 1) {
auto tc = reinterpret_cast<unsigned int *>(reductionBuffer);
__shared__ bool amLast;
tid = threadIdx.x;
if (threadIdx.x == 0)
reduction[blockIdx.x] = shared[0];
__threadfence();
__syncthreads();
if (threadIdx.x == 0) {
unsigned int ticket = atomicInc(&tc[16384], gridDim.x);
amLast = (ticket == gridDim.x - 1);
}
__syncthreads();
if (amLast) {
tc[16384] = 0;
shared[threadIdx.x] = 0;
for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x)
shared[threadIdx.x] += reduction[i];
__syncthreads();
for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
if (threadIdx.x < activeThreads)
shared[threadIdx.x] += shared[threadIdx.x + activeThreads];
__syncthreads();
}
__syncthreads();
if (threadIdx.x == 0) {
z[0] = shared[0] == 0;
}
}
}
else {
// if we have only 1 block, we just store results right away
if (threadIdx.x == 0) {
auto tc = reinterpret_cast<unsigned int*>(reductionBuffer);
tc[16384] = 0;
z[0] = shared[0] == 0;
}
}
}
template<typename T>
static void _compare_elem(nd4j::LaunchContext * context, NDArray *input, bool isStrictlyIncreasing, bool& output) {
auto z = NDArrayFactory::create<bool>(false, context);
const int numThreads = 256;
const int numBlocks = nd4j::math::nd4j_min<int>(128, nd4j::math::nd4j_max<int>(1, input->lengthOf() / numThreads));
comparator<T><<<numBlocks, numThreads, numThreads * 4 + 1024, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), input->lengthOf(), isStrictlyIncreasing, context->getReductionPointer(), reinterpret_cast<bool *>(z.specialBuffer()));
z.tickWriteDevice();
nd4j::DebugHelper::checkErrorCode(context->getCudaStream(), "is_strictly_increasing");
output = z.e<bool>(0);
}
void compare_elem(nd4j::LaunchContext * context, NDArray *input, bool isStrictlyIncreasing, bool& output) {
auto xType = input->dataType();
input->syncToDevice();
BUILD_SINGLE_SELECTOR(xType, _compare_elem, (context, input, isStrictlyIncreasing, output), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void _compare_elem, (nd4j::LaunchContext * context, NDArray *A, bool isStrictlyIncreasing, bool& output);, LIBND4J_TYPES);
}
}
}