Implementation of cuda kernel for fake_quant_with_min_max_vars_per_channels op. Final revision.

master
shugeo 2019-10-10 16:51:29 +03:00
parent 02d8616692
commit d5b352273d
2 changed files with 11 additions and 6 deletions

View File

@ -81,19 +81,24 @@ namespace helpers {
static __global__ void fakeQuantWithMinMaxKernel(T* input, Nd4jLong* inputShape, T* min, T* max, static __global__ void fakeQuantWithMinMaxKernel(T* input, Nd4jLong* inputShape, T* min, T* max,
int lowIntBound, int upperIntBound, Nd4jLong channels, int lowIntBound, int upperIntBound, Nd4jLong channels,
T* output, Nd4jLong* outputShape, Nd4jLong length) { T* output, Nd4jLong* outputShape, Nd4jLong length) {
__shared__ int block;
if (threadIdx.x == 0) {
block = length / channels;
}
__syncthreads();
for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) { for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) {
T scale, nudged_min, nudged_max; T scale, nudged_min, nudged_max;
Nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); Nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max);
//auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudged_min, nudged_max, scale) { //auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudged_min, nudged_max, scale) {
for (auto e = threadIdx.x; e < (int)length; e += (int)channels) { for (auto e = threadIdx.x; e < block; e += blockDim.x) {
T val = input[shape::getIndexOffset(e + i, inputShape)]; T val = input[shape::getIndexOffset(e * channels + i, inputShape)];
if (val < nudged_min) { if (val < nudged_min) {
val = nudged_min; val = nudged_min;
} else if (val > nudged_max) { } else if (val > nudged_max) {
val = nudged_max; val = nudged_max;
} }
output[shape::getIndexOffset(e + i, outputShape)] = (math::nd4j_floor<T, T>((val - nudged_min) / scale + T(0.5)) * scale + nudged_min); output[shape::getIndexOffset(e* channels + i, outputShape)] = (math::nd4j_floor<T, T>((val - nudged_min) / scale + T(0.5)) * scale + nudged_min);
}; };
} }
@ -111,7 +116,7 @@ namespace helpers {
T* outputBuf = output->dataBuffer()->specialAsT<T>(); T* outputBuf = output->dataBuffer()->specialAsT<T>();
T* minBuf = min->dataBuffer()->specialAsT<T>(); T* minBuf = min->dataBuffer()->specialAsT<T>();
T* maxBuf = max->dataBuffer()->specialAsT<T>(); T* maxBuf = max->dataBuffer()->specialAsT<T>();
fakeQuantWithMinMaxKernel<<<1, 1, 256, *stream>>>(inputBuf, input->specialShapeInfo(), fakeQuantWithMinMaxKernel<<<128, 256, 256, *stream>>>(inputBuf, input->specialShapeInfo(),
minBuf, maxBuf, lowIntBound, upperIntBound, channels, outputBuf, output->specialShapeInfo(), length); minBuf, maxBuf, lowIntBound, upperIntBound, channels, outputBuf, output->specialShapeInfo(), length);
NDArray::registerSpecialUse({output}, {min, max, input}); NDArray::registerSpecialUse({output}, {min, max, input});

View File

@ -2127,8 +2127,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0); auto result = results->at(0);
result->printBuffer("Quantized"); // result->printBuffer("Quantized");
exp.printBuffer("Expected"); // exp.printBuffer("Expected");
ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result)); ASSERT_TRUE(exp.equalsTo(result));