* initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored buffer() and shapeInfo() methods usage with NDArray class. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt Graph class methods to use const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt choose op to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt where op shape method to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt lstsq op to use constant empty shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt matrix_diag_part op shape routine to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt determinant ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt mean_pairwssqerr_loss ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for loss ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt log_loss op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt dilation2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted deconv2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted dynamicRNN op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for lstm layer ops. Signed-off-by: shugeo <sgazeos@gmail.com> * few updates Signed-off-by: raver119@gmail.com <raver119@gmail.com> * first cuda tweak Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Adopt constant shapes for sconv2d ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes for gru ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes with shape methods for segment ops and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with unsorted_segment_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with gamma op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods of reduce_stddev ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for reduce_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape method for squeeze op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt strided_slice shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored concat op shape method to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape method for mirror_pad op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted split op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted tile ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Added const cast for mkldnn routines handles. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored logSoftMaxForVector_ routine to conform with proper data and shape pointer casts. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetic changes to proper usage of constant pointers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple shape comparators for strides and addBias helpers to proper use data pointers with inplace option. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored depthToSpace helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored histogram helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored im2col helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored gather and gatherND helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage on percentile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed gather shape with helpers and range buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with space to depth helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage and constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with LUP decomposition> Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored onehot_ helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pad and prefix to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactoed softmax helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed space to batch helpers to use buffers properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed stack and split helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with sparse to dense helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with mindistance_ helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with tile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with legacy pairwise bool ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple of methods to adopt constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed broadcasting with constant shape." Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const usage with inplace reverse and constant shapes with legacy reduction. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored sort to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected sort for constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with special methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored Context to conform with constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * CUDA broadcasting headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * pairwise/indexreduce/random headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored native ops to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * legacy reduce3/scalar headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected pullRow signature and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected routines to proper use of constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with NDArray tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed native ops tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed special concat routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with test. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with a test. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored TAD.h and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored calcStrides* routines to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed miscelaneous errors with constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected definitions for declared functions. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed const shapes with shape routines. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed shape method for broadcastable case. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * xw_plus_b BP shape fn restored Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed signatures with broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Repaired backprops shape methods for a set of operations. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored broadcast bool for cuda. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods for 3 args with const qualifier. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed a couple of kernel signatures for broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels signatures for const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise methods to persistent buffers and shapes usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with scalar kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored indexreduce kernels signatures to use const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise bool kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored random special ops to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored native ops to conform with const shapes and buffers under cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetical changes only. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes and buffers error. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected start pos routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored helpers to use proper methods instead. Signed-off-by: shugeo <sgazeos@gmail.com> * bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected const shape cases with sort and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes for sort. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored kernel declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernel declarations to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed segment helpers kernels declarations and so on to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with segment and solve helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernel declaration with adjustWeight helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed cuda implementations for constant shape helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted const shape usage with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted top_k kernels to use const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernels declarations to adopt const shapes with helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored NDArray definitions to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes with image suppression helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Slight improvement with buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with definitions. Signed-off-by: shugeo <sgazeos@gmail.com> * minor updates on cpu side Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored const shape usage with ConstantDescritor and native ops with cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tear and tile kernels to adopt with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * softmax_loop fix Signed-off-by: raver119 <raver119@gmail.com> * update missing signature Signed-off-by: raver119@gmail.com <raver119@gmail.com> * softmax again Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more missing consts Signed-off-by: raver119 <raver119@gmail.com> * new methods updated Signed-off-by: raver119@gmail.com <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com>
232 lines
11 KiB
Plaintext
232 lines
11 KiB
Plaintext
/*******************************************************************************
|
|
* Copyright (c) 2020 Konduit K.K.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// @author raver119@gmail.com
|
|
//
|
|
|
|
#include <ops/declarable/helpers/threshold.h>
|
|
#include <loops/type_conversions.h>
|
|
#include <helpers/PointersManager.h>
|
|
#include <vector>
|
|
|
|
namespace sd {
|
|
namespace ops {
|
|
namespace helpers {
|
|
void prescanArrayRecursive(int** g_scanBlockSums, int *dZ, int *dX, int numElements, int level) {
|
|
auto stream = LaunchContext::defaultContext()->getCudaStream();
|
|
|
|
|
|
int blockSize = 512; // max size of the thread blocks
|
|
int numBlocks = sd::math::nd4j_max<int>(1, static_cast<int>(ceil(static_cast<float>(numElements) / (2.f * blockSize))));
|
|
int numThreads;
|
|
|
|
if (numBlocks > 1)
|
|
numThreads = blockSize;
|
|
else if (sd::isPowerOfTwo(numElements))
|
|
numThreads = numElements / 2;
|
|
else
|
|
numThreads = sd::floorPow2(numElements);
|
|
|
|
int numEltsPerBlock = numThreads * 2;
|
|
|
|
// if this is a non-power-of-2 array, the last block will be non-full
|
|
// compute the smallest power of 2 able to compute its scan.
|
|
int numEltsLastBlock =
|
|
numElements - (numBlocks-1) * numEltsPerBlock;
|
|
int numThreadsLastBlock = sd::math::nd4j_max<int>(1, numEltsLastBlock / 2);
|
|
int np2LastBlock = 0;
|
|
int sharedMemLastBlock = 0;
|
|
|
|
if (numEltsLastBlock != numEltsPerBlock) {
|
|
np2LastBlock = 1;
|
|
|
|
if(!isPowerOfTwo(numEltsLastBlock))
|
|
numThreadsLastBlock = floorPow2(numEltsLastBlock);
|
|
|
|
unsigned int extraSpace = (2 * numThreadsLastBlock) / NUM_BANKS;
|
|
sharedMemLastBlock = sizeof(int) * (2 * numThreadsLastBlock + extraSpace);
|
|
}
|
|
|
|
// padding space is used to avoid shared memory bank conflicts
|
|
int extraSpace = numEltsPerBlock / NUM_BANKS;
|
|
int sharedMemSize = sizeof(int) * (numEltsPerBlock + extraSpace);
|
|
|
|
// setup execution parameters
|
|
// if NP2, we process the last block separately
|
|
dim3 grid(sd::math::nd4j_max<int>(1, numBlocks - np2LastBlock), 1, 1);
|
|
dim3 threads(numThreads, 1, 1);
|
|
dim3 gridOnes(1, 1, 1);
|
|
dim3 threadsOnes(numThreadsLastBlock, 1, 1);
|
|
|
|
if (sharedMemSize < 2048)
|
|
sharedMemSize = 2048;
|
|
|
|
if (sharedMemLastBlock < 2048)
|
|
sharedMemLastBlock = 2048;
|
|
|
|
// execute the scan
|
|
if (numBlocks > 1) {
|
|
sd::prescanLauncher<true, false>(grid, threads, sharedMemSize, stream, dZ, dX, g_scanBlockSums[level], numThreads * 2, 0, 0);
|
|
if (np2LastBlock) {
|
|
sd::prescanLauncher<true, true>(gridOnes, threadsOnes, sharedMemLastBlock, stream, dZ, dX, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock);
|
|
}
|
|
|
|
// After scanning all the sub-blocks, we are mostly done. But now we
|
|
// need to take all of the last values of the sub-blocks and scan those.
|
|
// This will give us a new value that must be sdded to each block to
|
|
// get the final results.
|
|
// recursive (CPU) call
|
|
prescanArrayRecursive(g_scanBlockSums, g_scanBlockSums[level], g_scanBlockSums[level], numBlocks, level+1);
|
|
|
|
sd::uniformAdd<<<grid, threads, 1024, *stream>>>(dZ, g_scanBlockSums[level], numElements - numEltsLastBlock, 0, 0);
|
|
|
|
if (np2LastBlock) {
|
|
sd::uniformAdd<<<1, numThreadsLastBlock, 1024, *stream>>>(dZ, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock);
|
|
}
|
|
} else if (isPowerOfTwo(numElements)) {
|
|
sd::prescanLauncher<false, false>(grid, threads, sharedMemSize, stream, dZ, dX, 0, numThreads * 2, 0, 0);
|
|
} else {
|
|
sd::prescanLauncher<false, true>(grid, threads, sharedMemSize, stream, dZ, dX, 0, numElements, 0, 0);
|
|
}
|
|
|
|
sd::DebugHelper::checkErrorCode(stream, "prescanArray(...) failed");
|
|
}
|
|
|
|
static void encodeThresholdP2Int_(void **prs, int *dx, Nd4jLong N, int *dz) {
|
|
auto stream = LaunchContext::defaultContext()->getCudaStream();
|
|
|
|
prescanArrayRecursive(reinterpret_cast<int**>(prs), dz, dx + 1, (int) N, 0);
|
|
sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP2Int(...) failed");
|
|
}
|
|
|
|
static void encodeThresholdP3_(void *dx, const Nd4jLong *hXShapeInfo, int *offsets, Nd4jLong N, int *dz){
|
|
auto stream = LaunchContext::defaultContext()->getCudaStream();
|
|
|
|
int blockSize = 512;
|
|
int numBlocks = N / blockSize + (N % blockSize ? 1 : 0);
|
|
|
|
dim3 launchDims(numBlocks, blockSize, 8192);
|
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
|
BUILD_SINGLE_SELECTOR(xType, encoderKernelP3Generic, (launchDims, stream, dx, offsets, N, dz), FLOAT_TYPES);
|
|
|
|
sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP3Float(...) failed");
|
|
}
|
|
|
|
|
|
static NDArray thresholdEstimate_(const NDArray &updates, const float threshold) {
|
|
const int numThreads = 512;
|
|
const int numBlocks = updates.lengthOf() / numThreads + (updates.lengthOf() % numThreads ? 1 : 0);
|
|
|
|
auto tmp = NDArrayFactory::create<int>('c', {numBlocks + 1});
|
|
|
|
dim3 launchDims(numBlocks, numThreads, 1024);
|
|
auto xType = updates.dataType();
|
|
|
|
NDArray::prepareSpecialUse({&tmp}, {&updates});
|
|
BUILD_SINGLE_SELECTOR(xType, encoderKernelP1Generic, (launchDims, LaunchContext::defaultContext()->getCudaStream(), updates.specialBuffer(), updates.lengthOf(), tmp.specialBuffer(), threshold), FLOAT_TYPES);
|
|
NDArray::registerSpecialUse({&tmp}, {&updates});
|
|
|
|
return std::move(tmp);
|
|
}
|
|
|
|
int32_t thresholdEstimate(const NDArray &updates, const float threshold) {
|
|
return thresholdEstimate_(updates, threshold).e<int>(0);
|
|
}
|
|
|
|
void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold) {
|
|
// we need these blocks in order to know, how many "updates" will be processed by each GPU block
|
|
auto blocks = thresholdEstimate_(updates, threshold);
|
|
|
|
const int numThreads = 512;
|
|
const int numBlocks = updates.lengthOf() / numThreads + (updates.lengthOf() % numThreads ? 1 : 0);
|
|
|
|
const int prefixThreads = 512;
|
|
int numElts = numBlocks;
|
|
int level = 0;
|
|
|
|
// here we just calculate number of sumBlock arrays
|
|
do {
|
|
int numPrefixBlocks = sd::math::nd4j_max<int>(1, sd::math::nd4j_ceil<float, int>((float) numElts / (2.0f * prefixThreads)));
|
|
if (numBlocks > 1) {
|
|
level++;
|
|
}
|
|
numElts = numPrefixBlocks;
|
|
} while (numElts > 1);
|
|
|
|
|
|
|
|
std::vector<NDArray> tempArrays(level);
|
|
std::vector<Nd4jPointer> pointers(level);
|
|
|
|
level = 0;
|
|
numElts = numBlocks;
|
|
|
|
do {
|
|
int numPrefixBlocks = sd::math::nd4j_max<int>(1, sd::math::nd4j_ceil<float, int>((float) numElts / (2.0f * prefixThreads)));
|
|
if (numPrefixBlocks > 1) {
|
|
tempArrays[level] = std::move(NDArrayFactory::create<int>('c', {numPrefixBlocks}));
|
|
pointers[level] = tempArrays[level++].specialBuffer();
|
|
}
|
|
numElts = numPrefixBlocks;
|
|
} while (numElts > 1);
|
|
|
|
PointersManager pm(LaunchContext::defaultContext(), "thresholdEncode");
|
|
auto dptr = pm.replicatePointer(pointers.data(), pointers.size() * 8);
|
|
auto offsets = NDArrayFactory::create<int>('c', {numBlocks});
|
|
|
|
// we want to check, if we're hiting external limit on number of encoded elements
|
|
auto numMatches = blocks.e<int>(0);
|
|
if (numMatches > encoded.lengthOf() - 4) {
|
|
blocks.p(0, encoded.lengthOf() - 4);
|
|
blocks.syncToDevice();
|
|
}
|
|
|
|
NDArray::prepareSpecialUse({}, {&encoded, &updates});
|
|
|
|
// filling offsets
|
|
encodeThresholdP2Int_(reinterpret_cast<void **>(dptr),
|
|
reinterpret_cast<int*>(blocks.specialBuffer()),
|
|
numBlocks,
|
|
reinterpret_cast<int*>(offsets.specialBuffer()));
|
|
|
|
NDArray::registerSpecialUse({&blocks, &offsets}, {});
|
|
pm.synchronize();
|
|
|
|
|
|
encodeThresholdP3_(updates.specialBuffer(),
|
|
updates.shapeInfo(),
|
|
reinterpret_cast<int*>(offsets.specialBuffer()),
|
|
updates.lengthOf(),
|
|
reinterpret_cast<int*>(encoded.specialBuffer()));
|
|
|
|
pm.synchronize();
|
|
|
|
NDArray::registerSpecialUse({&encoded, &updates}, {});
|
|
}
|
|
|
|
void thresholdDecode(const NDArray &encoded, NDArray &updates) {
|
|
dim3 launchDims(128, 512, 512);
|
|
auto xType = updates.dataType();
|
|
|
|
NDArray::prepareSpecialUse({&updates}, {&encoded});
|
|
BUILD_SINGLE_SELECTOR(xType, decoderKernelGeneric, (launchDims, LaunchContext::defaultContext()->getCudaStream(), encoded.specialBuffer(), updates.lengthOf(), updates.specialBuffer()), FLOAT_TYPES);
|
|
NDArray::registerSpecialUse({&updates}, {&encoded});
|
|
}
|
|
}
|
|
}
|
|
}
|