cavis/libnd4j/include/ops/declarable/helpers/cuda/compression/threshold.cu

232 lines
11 KiB
Plaintext

/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/threshold.h>
#include <loops/type_conversions.h>
#include <helpers/PointersManager.h>
#include <vector>
namespace sd {
namespace ops {
namespace helpers {
void prescanArrayRecursive(int** g_scanBlockSums, int *dZ, int *dX, int numElements, int level) {
auto stream = LaunchContext::defaultContext()->getCudaStream();
int blockSize = 512; // max size of the thread blocks
int numBlocks = sd::math::nd4j_max<int>(1, static_cast<int>(ceil(static_cast<float>(numElements) / (2.f * blockSize))));
int numThreads;
if (numBlocks > 1)
numThreads = blockSize;
else if (sd::isPowerOfTwo(numElements))
numThreads = numElements / 2;
else
numThreads = sd::floorPow2(numElements);
int numEltsPerBlock = numThreads * 2;
// if this is a non-power-of-2 array, the last block will be non-full
// compute the smallest power of 2 able to compute its scan.
int numEltsLastBlock =
numElements - (numBlocks-1) * numEltsPerBlock;
int numThreadsLastBlock = sd::math::nd4j_max<int>(1, numEltsLastBlock / 2);
int np2LastBlock = 0;
int sharedMemLastBlock = 0;
if (numEltsLastBlock != numEltsPerBlock) {
np2LastBlock = 1;
if(!isPowerOfTwo(numEltsLastBlock))
numThreadsLastBlock = floorPow2(numEltsLastBlock);
unsigned int extraSpace = (2 * numThreadsLastBlock) / NUM_BANKS;
sharedMemLastBlock = sizeof(int) * (2 * numThreadsLastBlock + extraSpace);
}
// padding space is used to avoid shared memory bank conflicts
int extraSpace = numEltsPerBlock / NUM_BANKS;
int sharedMemSize = sizeof(int) * (numEltsPerBlock + extraSpace);
// setup execution parameters
// if NP2, we process the last block separately
dim3 grid(sd::math::nd4j_max<int>(1, numBlocks - np2LastBlock), 1, 1);
dim3 threads(numThreads, 1, 1);
dim3 gridOnes(1, 1, 1);
dim3 threadsOnes(numThreadsLastBlock, 1, 1);
if (sharedMemSize < 2048)
sharedMemSize = 2048;
if (sharedMemLastBlock < 2048)
sharedMemLastBlock = 2048;
// execute the scan
if (numBlocks > 1) {
sd::prescanLauncher<true, false>(grid, threads, sharedMemSize, stream, dZ, dX, g_scanBlockSums[level], numThreads * 2, 0, 0);
if (np2LastBlock) {
sd::prescanLauncher<true, true>(gridOnes, threadsOnes, sharedMemLastBlock, stream, dZ, dX, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock);
}
// After scanning all the sub-blocks, we are mostly done. But now we
// need to take all of the last values of the sub-blocks and scan those.
// This will give us a new value that must be sdded to each block to
// get the final results.
// recursive (CPU) call
prescanArrayRecursive(g_scanBlockSums, g_scanBlockSums[level], g_scanBlockSums[level], numBlocks, level+1);
sd::uniformAdd<<<grid, threads, 1024, *stream>>>(dZ, g_scanBlockSums[level], numElements - numEltsLastBlock, 0, 0);
if (np2LastBlock) {
sd::uniformAdd<<<1, numThreadsLastBlock, 1024, *stream>>>(dZ, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock);
}
} else if (isPowerOfTwo(numElements)) {
sd::prescanLauncher<false, false>(grid, threads, sharedMemSize, stream, dZ, dX, 0, numThreads * 2, 0, 0);
} else {
sd::prescanLauncher<false, true>(grid, threads, sharedMemSize, stream, dZ, dX, 0, numElements, 0, 0);
}
sd::DebugHelper::checkErrorCode(stream, "prescanArray(...) failed");
}
static void encodeThresholdP2Int_(void **prs, int *dx, Nd4jLong N, int *dz) {
auto stream = LaunchContext::defaultContext()->getCudaStream();
prescanArrayRecursive(reinterpret_cast<int**>(prs), dz, dx + 1, (int) N, 0);
sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP2Int(...) failed");
}
static void encodeThresholdP3_(void *dx, Nd4jLong *hXShapeInfo, int *offsets, Nd4jLong N, int *dz){
auto stream = LaunchContext::defaultContext()->getCudaStream();
int blockSize = 512;
int numBlocks = N / blockSize + (N % blockSize ? 1 : 0);
dim3 launchDims(numBlocks, blockSize, 8192);
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
BUILD_SINGLE_SELECTOR(xType, encoderKernelP3Generic, (launchDims, stream, dx, offsets, N, dz), FLOAT_TYPES);
sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP3Float(...) failed");
}
static NDArray thresholdEstimate_(const NDArray &updates, const float threshold) {
const int numThreads = 512;
const int numBlocks = updates.lengthOf() / numThreads + (updates.lengthOf() % numThreads ? 1 : 0);
auto tmp = NDArrayFactory::create<int>('c', {numBlocks + 1});
dim3 launchDims(numBlocks, numThreads, 1024);
auto xType = updates.dataType();
NDArray::prepareSpecialUse({&tmp}, {&updates});
BUILD_SINGLE_SELECTOR(xType, encoderKernelP1Generic, (launchDims, LaunchContext::defaultContext()->getCudaStream(), updates.getSpecialBuffer(), updates.lengthOf(), tmp.specialBuffer(), threshold), FLOAT_TYPES);
NDArray::registerSpecialUse({&tmp}, {&updates});
return std::move(tmp);
}
int32_t thresholdEstimate(const NDArray &updates, const float threshold) {
return thresholdEstimate_(updates, threshold).e<int>(0);
}
void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold) {
// we need these blocks in order to know, how many "updates" will be processed by each GPU block
auto blocks = thresholdEstimate_(updates, threshold);
const int numThreads = 512;
const int numBlocks = updates.lengthOf() / numThreads + (updates.lengthOf() % numThreads ? 1 : 0);
const int prefixThreads = 512;
int numElts = numBlocks;
int level = 0;
// here we just calculate number of sumBlock arrays
do {
int numPrefixBlocks = sd::math::nd4j_max<int>(1, sd::math::nd4j_ceil<float, int>((float) numElts / (2.0f * prefixThreads)));
if (numBlocks > 1) {
level++;
}
numElts = numPrefixBlocks;
} while (numElts > 1);
std::vector<NDArray> tempArrays(level);
std::vector<Nd4jPointer> pointers(level);
level = 0;
numElts = numBlocks;
do {
int numPrefixBlocks = sd::math::nd4j_max<int>(1, sd::math::nd4j_ceil<float, int>((float) numElts / (2.0f * prefixThreads)));
if (numPrefixBlocks > 1) {
tempArrays[level] = std::move(NDArrayFactory::create<int>('c', {numPrefixBlocks}));
pointers[level] = tempArrays[level++].specialBuffer();
}
numElts = numPrefixBlocks;
} while (numElts > 1);
PointersManager pm(LaunchContext::defaultContext(), "thresholdEncode");
auto dptr = pm.replicatePointer(pointers.data(), pointers.size() * 8);
auto offsets = NDArrayFactory::create<int>('c', {numBlocks});
// we want to check, if we're hiting external limit on number of encoded elements
auto numMatches = blocks.e<int>(0);
if (numMatches > encoded.lengthOf() - 4) {
blocks.p(0, encoded.lengthOf() - 4);
blocks.syncToDevice();
}
NDArray::prepareSpecialUse({}, {&encoded, &updates});
// filling offsets
encodeThresholdP2Int_(reinterpret_cast<void **>(dptr),
reinterpret_cast<int*>(blocks.getSpecialBuffer()),
numBlocks,
reinterpret_cast<int*>(offsets.getSpecialBuffer()));
NDArray::registerSpecialUse({&blocks, &offsets}, {});
pm.synchronize();
encodeThresholdP3_(updates.specialBuffer(),
updates.shapeInfo(),
reinterpret_cast<int*>(offsets.getSpecialBuffer()),
updates.lengthOf(),
reinterpret_cast<int*>(encoded.specialBuffer()));
pm.synchronize();
NDArray::registerSpecialUse({&encoded, &updates}, {});
}
void thresholdDecode(const NDArray &encoded, NDArray &updates) {
dim3 launchDims(128, 512, 512);
auto xType = updates.dataType();
NDArray::prepareSpecialUse({&updates}, {&encoded});
BUILD_SINGLE_SELECTOR(xType, decoderKernelGeneric, (launchDims, LaunchContext::defaultContext()->getCudaStream(), encoded.getSpecialBuffer(), updates.lengthOf(), updates.specialBuffer()), FLOAT_TYPES);
NDArray::registerSpecialUse({&updates}, {&encoded});
}
}
}
}