Added doc for fake_quant_with_min_max* op helpers implementations.
This commit is contained in:
		
							parent
							
								
									c3f755d975
								
							
						
					
					
						commit
						c890de5a7b
					
				| @ -25,43 +25,54 @@ namespace nd4j { | ||||
| namespace ops { | ||||
| namespace helpers { | ||||
| 
 | ||||
|     //
 | ||||
|     // nudge - nudged min max over scale
 | ||||
|     // scale = (Max - Min) / (quantMax - quantMin)
 | ||||
|     // quantMin = 0 or 1, quantMax = 2^b - 1 == (1 << b) - 1
 | ||||
|     //
 | ||||
|     template <typename T> | ||||
|     static void nudge(T min, T max, int quant_min, int quant_max, T* scale, T* nudged_min, T* nudged_max) { | ||||
|         T quant_max_float = static_cast<T>(quant_max); | ||||
|         T quant_min_float = static_cast<T>(quant_min); | ||||
|         *scale = (max - min) / (quant_max_float - quant_min_float); | ||||
|         auto zero_point_from_min = quant_min_float - min / *scale; | ||||
|         uint16_t const nudged_zero_point = [zero_point_from_min, quant_min, quant_max, quant_max_float, quant_min_float] { | ||||
|                 if (zero_point_from_min < quant_min_float) { | ||||
|                     return static_cast<uint16_t>(quant_min); | ||||
|     static void nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) { | ||||
|         // floating point instead integers
 | ||||
|         T quantMaxF = static_cast<T>(quantMax); | ||||
|         T quantMinF = static_cast<T>(quantMin); | ||||
|         // compute scale
 | ||||
|         *scale = (max - min) / (quantMaxF - quantMinF); | ||||
|         // compute left bound point
 | ||||
|         auto zeroPointFromMin = quantMinF - min / *scale; | ||||
|         // bound zero point to conform with range [0 or 1, 2^b - 1]
 | ||||
|         uint16_t const nudged_zero_point = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] { | ||||
|                 if (zeroPointFromMin < quantMinF) { | ||||
|                     return static_cast<uint16_t>(quantMin); | ||||
|                 } | ||||
|                 if (zero_point_from_min > quant_max_float) { | ||||
|                     return static_cast<uint16_t>(quant_max); | ||||
|                 if (zeroPointFromMin > quantMaxF) { | ||||
|                     return static_cast<uint16_t>(quantMax); | ||||
|                 } | ||||
|                 return nd4j::math::nd4j_round<T,uint16_t>(zero_point_from_min); | ||||
|             }(); | ||||
|             *nudged_min = (quant_min_float - nudged_zero_point) * (*scale); | ||||
|             *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); | ||||
|                 return nd4j::math::nd4j_round<T,uint16_t>(zeroPointFromMin); | ||||
|         }(); | ||||
|         // compute nudged min and max with computed nudged zero point
 | ||||
|         *nudgedMin = (quantMinF - nudged_zero_point) * (*scale); | ||||
|         *nudgedMax = (quantMaxF - nudged_zero_point) * (*scale); | ||||
|     } | ||||
| 
 | ||||
|     template <typename T> | ||||
|     void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { | ||||
|         int lowIntBound = narrowed ? 1 : 0; | ||||
|         int upperIntBound = (1 << numBits) - 1; | ||||
|         auto channels = input->sizeAt(-1); | ||||
|         int lowIntBound = narrowed ? 1 : 0; // 0 or 1
 | ||||
|         int upperIntBound = (1 << numBits) - 1; // 2^b - 1
 | ||||
|         auto channels = input->sizeAt(-1); // last dimension
 | ||||
| 
 | ||||
|         PRAGMA_OMP_PARALLEL_FOR | ||||
|         for (auto i = 0; i < channels; i++) { | ||||
|             T scale, nudged_min, nudged_max; | ||||
|             // nudge min and max first, with scale computing
 | ||||
|             nudge<T>(min->t<T>(i), max->t<T>(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); | ||||
| 
 | ||||
|             // slide using last dimension and process all for given channel
 | ||||
|             for (auto e = 0; e < input->lengthOf(); e += channels) { | ||||
|                 T val = input->t<T>(e + i); | ||||
|                 if ( val <= nudged_min) | ||||
|                     val = nudged_min; | ||||
|                 else if (val >= nudged_max) | ||||
|                     val = nudged_max; | ||||
| 
 | ||||
|                 // quantization itself
 | ||||
|                 output->t<T>(e + i) = math::nd4j_floor<T,T>((val - nudged_min)/scale + T(0.5)) * scale + nudged_min; | ||||
|             } | ||||
|         } | ||||
| @ -73,16 +84,17 @@ namespace helpers { | ||||
|         int upperIntBound = (1 << numBits) - 1; | ||||
| 
 | ||||
|         T nudgedMin, nudgedMax, scale; | ||||
| 
 | ||||
|         // nudge with given min and max and compute scale and nudged min and max
 | ||||
|         nudge<T>(min->t<T>(0), max->t<T>(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); | ||||
| 
 | ||||
|         // quantization as one
 | ||||
|         auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { | ||||
|             T val = x; | ||||
|             T val = x; // boundign value between nudged min and max
 | ||||
|             if (val < nudgedMin) { | ||||
|                 val = nudgedMin; | ||||
|             } | ||||
|             else if (val > nudgedMax) | ||||
|                 val = nudgedMax; | ||||
|             // converse value with scale and shifted with nudged min
 | ||||
|             return (nd4j::math::nd4j_floor<T,T>((val - nudgedMin)/scale + T(0.5)) * scale + nudgedMin); | ||||
|         }; | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user