Added doc for fake_quant_with_min_max* op helpers cuda implementations.

master
shugeo 2019-10-10 18:35:28 +03:00
parent c890de5a7b
commit ace65355c5
1 changed files with 5 additions and 5 deletions

View File

@ -84,22 +84,22 @@ namespace helpers {
T* output, Nd4jLong* outputShape, Nd4jLong length) { T* output, Nd4jLong* outputShape, Nd4jLong length) {
__shared__ int block; __shared__ int block;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
block = length / channels; block = length / channels; // to loop with last dimension as block
} }
__syncthreads(); __syncthreads();
for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) { for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) {
T scale, nudgedMin, nudgedMax; T scale, nudgedMin, nudgedMax;
nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax);
// loop over blocks to quantization between nudged min and max
for (auto e = threadIdx.x; e < block; e += blockDim.x) { for (auto b = threadIdx.x; b < block; b += blockDim.x) {
T val = input[shape::getIndexOffset(e * channels + i, inputShape)]; T val = input[shape::getIndexOffset(b * channels + i, inputShape)];
if (val < nudgedMin) { if (val < nudgedMin) {
val = nudgedMin; val = nudgedMin;
} else if (val > nudgedMax) { } else if (val > nudgedMax) {
val = nudgedMax; val = nudgedMax;
} }
output[shape::getIndexOffset(e* channels + i, outputShape)] = output[shape::getIndexOffset(b * channels + i, outputShape)] =
(math::nd4j_floor<T, T>((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin); (math::nd4j_floor<T, T>((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin);
}; };
} }