Added doc for fake_quant_with_min_max* op helpers implementations.
This commit is contained in:
		
							parent
							
								
									c3f755d975
								
							
						
					
					
						commit
						c890de5a7b
					
				@ -25,43 +25,54 @@ namespace nd4j {
 | 
				
			|||||||
namespace ops {
 | 
					namespace ops {
 | 
				
			||||||
namespace helpers {
 | 
					namespace helpers {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // nudge - nudged min max over scale
 | 
				
			||||||
 | 
					    // scale = (Max - Min) / (quantMax - quantMin)
 | 
				
			||||||
 | 
					    // quantMin = 0 or 1, quantMax = 2^b - 1 == (1 << b) - 1
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
    template <typename T>
 | 
					    template <typename T>
 | 
				
			||||||
    static void nudge(T min, T max, int quant_min, int quant_max, T* scale, T* nudged_min, T* nudged_max) {
 | 
					    static void nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) {
 | 
				
			||||||
        T quant_max_float = static_cast<T>(quant_max);
 | 
					        // floating point instead integers
 | 
				
			||||||
        T quant_min_float = static_cast<T>(quant_min);
 | 
					        T quantMaxF = static_cast<T>(quantMax);
 | 
				
			||||||
        *scale = (max - min) / (quant_max_float - quant_min_float);
 | 
					        T quantMinF = static_cast<T>(quantMin);
 | 
				
			||||||
        auto zero_point_from_min = quant_min_float - min / *scale;
 | 
					        // compute scale
 | 
				
			||||||
        uint16_t const nudged_zero_point = [zero_point_from_min, quant_min, quant_max, quant_max_float, quant_min_float] {
 | 
					        *scale = (max - min) / (quantMaxF - quantMinF);
 | 
				
			||||||
                if (zero_point_from_min < quant_min_float) {
 | 
					        // compute left bound point
 | 
				
			||||||
                    return static_cast<uint16_t>(quant_min);
 | 
					        auto zeroPointFromMin = quantMinF - min / *scale;
 | 
				
			||||||
 | 
					        // bound zero point to conform with range [0 or 1, 2^b - 1]
 | 
				
			||||||
 | 
					        uint16_t const nudged_zero_point = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] {
 | 
				
			||||||
 | 
					                if (zeroPointFromMin < quantMinF) {
 | 
				
			||||||
 | 
					                    return static_cast<uint16_t>(quantMin);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                if (zero_point_from_min > quant_max_float) {
 | 
					                if (zeroPointFromMin > quantMaxF) {
 | 
				
			||||||
                    return static_cast<uint16_t>(quant_max);
 | 
					                    return static_cast<uint16_t>(quantMax);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                return nd4j::math::nd4j_round<T,uint16_t>(zero_point_from_min);
 | 
					                return nd4j::math::nd4j_round<T,uint16_t>(zeroPointFromMin);
 | 
				
			||||||
            }();
 | 
					        }();
 | 
				
			||||||
            *nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
 | 
					        // compute nudged min and max with computed nudged zero point
 | 
				
			||||||
            *nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
 | 
					        *nudgedMin = (quantMinF - nudged_zero_point) * (*scale);
 | 
				
			||||||
 | 
					        *nudgedMax = (quantMaxF - nudged_zero_point) * (*scale);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    template <typename T>
 | 
					    template <typename T>
 | 
				
			||||||
    void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
 | 
					    void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
 | 
				
			||||||
        int lowIntBound = narrowed ? 1 : 0;
 | 
					        int lowIntBound = narrowed ? 1 : 0; // 0 or 1
 | 
				
			||||||
        int upperIntBound = (1 << numBits) - 1;
 | 
					        int upperIntBound = (1 << numBits) - 1; // 2^b - 1
 | 
				
			||||||
        auto channels = input->sizeAt(-1);
 | 
					        auto channels = input->sizeAt(-1); // last dimension
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        PRAGMA_OMP_PARALLEL_FOR
 | 
					        PRAGMA_OMP_PARALLEL_FOR
 | 
				
			||||||
        for (auto i = 0; i < channels; i++) {
 | 
					        for (auto i = 0; i < channels; i++) {
 | 
				
			||||||
            T scale, nudged_min, nudged_max;
 | 
					            T scale, nudged_min, nudged_max;
 | 
				
			||||||
 | 
					            // nudge min and max first, with scale computing
 | 
				
			||||||
            nudge<T>(min->t<T>(i), max->t<T>(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max);
 | 
					            nudge<T>(min->t<T>(i), max->t<T>(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max);
 | 
				
			||||||
 | 
					            // slide using last dimension and process all for given channel
 | 
				
			||||||
            for (auto e = 0; e < input->lengthOf(); e += channels) {
 | 
					            for (auto e = 0; e < input->lengthOf(); e += channels) {
 | 
				
			||||||
                T val = input->t<T>(e + i);
 | 
					                T val = input->t<T>(e + i);
 | 
				
			||||||
                if ( val <= nudged_min)
 | 
					                if ( val <= nudged_min)
 | 
				
			||||||
                    val = nudged_min;
 | 
					                    val = nudged_min;
 | 
				
			||||||
                else if (val >= nudged_max)
 | 
					                else if (val >= nudged_max)
 | 
				
			||||||
                    val = nudged_max;
 | 
					                    val = nudged_max;
 | 
				
			||||||
 | 
					                // quantization itself
 | 
				
			||||||
                output->t<T>(e + i) = math::nd4j_floor<T,T>((val - nudged_min)/scale + T(0.5)) * scale + nudged_min;
 | 
					                output->t<T>(e + i) = math::nd4j_floor<T,T>((val - nudged_min)/scale + T(0.5)) * scale + nudged_min;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -73,16 +84,17 @@ namespace helpers {
 | 
				
			|||||||
        int upperIntBound = (1 << numBits) - 1;
 | 
					        int upperIntBound = (1 << numBits) - 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        T nudgedMin, nudgedMax, scale;
 | 
					        T nudgedMin, nudgedMax, scale;
 | 
				
			||||||
 | 
					        // nudge with given min and max and compute scale and nudged min and max
 | 
				
			||||||
        nudge<T>(min->t<T>(0), max->t<T>(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax);
 | 
					        nudge<T>(min->t<T>(0), max->t<T>(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax);
 | 
				
			||||||
 | 
					        // quantization as one
 | 
				
			||||||
        auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) {
 | 
					        auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) {
 | 
				
			||||||
            T val = x;
 | 
					            T val = x; // boundign value between nudged min and max
 | 
				
			||||||
            if (val < nudgedMin) {
 | 
					            if (val < nudgedMin) {
 | 
				
			||||||
                val = nudgedMin;
 | 
					                val = nudgedMin;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            else if (val > nudgedMax)
 | 
					            else if (val > nudgedMax)
 | 
				
			||||||
                val = nudgedMax;
 | 
					                val = nudgedMax;
 | 
				
			||||||
 | 
					            // converse value with scale and shifted with nudged min
 | 
				
			||||||
            return (nd4j::math::nd4j_floor<T,T>((val - nudgedMin)/scale + T(0.5)) * scale + nudgedMin);
 | 
					            return (nd4j::math::nd4j_floor<T,T>((val - nudgedMin)/scale + T(0.5)) * scale + nudgedMin);
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user