Implementation of cuda kernel for fake_quant_with_min_max_vars_per_channels op.

master
shugeo 2019-10-10 16:40:56 +03:00
parent 3504b0cda9
commit 02d8616692
4 changed files with 35 additions and 25 deletions

View File

@ -55,7 +55,7 @@ namespace nd4j {
" for quatization should be in between 2 and 16, but %i "
"was given.", numBits);
}
helpers::fakeQuantWithMinMaxVarsPerChannel(x, min, max, numBits, narrowed, output);
helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, numBits, narrowed, output);
return ND4J_STATUS_OK;
}

View File

@ -103,7 +103,7 @@ namespace helpers {
void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES);
}
void fakeQuantWithMinMaxVarsPerChannel(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES);
}

View File

@ -34,7 +34,7 @@ namespace helpers {
// output - output tensor
//
template <typename T>
static void Nudge(T min, T max, int quant_min, int quant_max, T* scale, T* nudged_min, T* nudged_max) {
static __host__ __device__ void Nudge(T min, T max, int quant_min, int quant_max, T* scale, T* nudged_min, T* nudged_max) {
T quant_max_float = static_cast<T>(quant_max);
T quant_min_float = static_cast<T>(quant_min);
*scale = (max - min) / (quant_max_float - quant_min_float);
@ -78,44 +78,54 @@ namespace helpers {
}
template <typename T>
void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
int lowIntBound = narrowed?1:0;
int upperIntBound = (1 << numBits) - 1;
min->syncToHost();
max->syncToHost();
T scale, nudged_min, nudged_max;
auto channels = min->lengthOf();
input->syncToHost();
input->syncToDevice();
output->syncToHost();
for (auto i = 0; i < channels; i++) {
Nudge(min->t<T>(i), max->t<T>(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max);
static __global__ void fakeQuantWithMinMaxKernel(T* input, Nd4jLong* inputShape, T* min, T* max,
int lowIntBound, int upperIntBound, Nd4jLong channels,
T* output, Nd4jLong* outputShape, Nd4jLong length) {
for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) {
T scale, nudged_min, nudged_max;
Nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max);
//auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudged_min, nudged_max, scale) {
for (auto e = 0; e < input->lengthOf(); e += channels) {
T val = input->t<T>(e + i);
for (auto e = threadIdx.x; e < (int)length; e += (int)channels) {
T val = input[shape::getIndexOffset(e + i, inputShape)];
if (val < nudged_min) {
val = nudged_min;
} else if (val > nudged_max) {
val = nudged_max;
}
output->t<T>(e + i) = (math::nd4j_floor<T, T>((val - nudged_min) / scale + T(0.5)) * scale + nudged_min);
output[shape::getIndexOffset(e + i, outputShape)] = (math::nd4j_floor<T, T>((val - nudged_min) / scale + T(0.5)) * scale + nudged_min);
};
}
output->syncToDevice();
output->tickWriteDevice();
}
template <typename T>
void fakeQuantWithMinMaxVarsPerChannel_(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
int lowIntBound = narrowed?1:0;
int upperIntBound = (1 << numBits) - 1;
auto channels = min->lengthOf();
auto length = input->lengthOf();
NDArray::prepareSpecialUse({output}, {min, max, input});
auto stream = context->getCudaStream();
T* inputBuf = input->dataBuffer()->specialAsT<T>();
T* outputBuf = output->dataBuffer()->specialAsT<T>();
T* minBuf = min->dataBuffer()->specialAsT<T>();
T* maxBuf = max->dataBuffer()->specialAsT<T>();
fakeQuantWithMinMaxKernel<<<1, 1, 256, *stream>>>(inputBuf, input->specialShapeInfo(),
minBuf, maxBuf, lowIntBound, upperIntBound, channels, outputBuf, output->specialShapeInfo(), length);
NDArray::registerSpecialUse({output}, {min, max, input});
}
void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES);
}
void fakeQuantWithMinMaxVarsPerChannel(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES);
void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (context, input, min, max, numBits, narrowed, output), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVarsPerChannel_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVarsPerChannel_, (LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES);
}
}

View File

@ -27,7 +27,7 @@ namespace ops {
namespace helpers {
void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output);
void fakeQuantWithMinMaxVarsPerChannel(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output);
void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output);
}
}
}