separate omp impl for softmax (#289)
Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									3bb22a6ff8
								
							
						
					
					
						commit
						784a2d13f8
					
				@ -85,11 +85,11 @@ namespace sd {
 | 
			
		||||
 | 
			
		||||
            template <typename T>
 | 
			
		||||
            void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen);
 | 
			
		||||
 | 
			
		||||
#ifdef _OPENMP
 | 
			
		||||
            template <>
 | 
			
		||||
            FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
 | 
			
		||||
                auto func = PRAGMA_THREADS_FOR {
 | 
			
		||||
                    for (auto i = start; i < stop; i++) {
 | 
			
		||||
#pragma omp parallel for
 | 
			
		||||
                    for (Nd4jLong i = 0; i < numOfSubArrs; i++) {
 | 
			
		||||
                        auto inBuff = input + offsets[i];
 | 
			
		||||
                        auto outBuff = output + offsets[i];
 | 
			
		||||
 | 
			
		||||
@ -107,6 +107,30 @@ namespace sd {
 | 
			
		||||
                            sum += temp;
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        for (uint j = 0; j < tadLen; ++j)
 | 
			
		||||
                            outBuff[j] /= sum;
 | 
			
		||||
                    }
 | 
			
		||||
            }
 | 
			
		||||
#else
 | 
			
		||||
            template <>
 | 
			
		||||
            FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
 | 
			
		||||
                auto func = PRAGMA_THREADS_FOR {
 | 
			
		||||
                    for (auto i = start; i < stop; i++) {
 | 
			
		||||
                        auto inBuff = input + offsets[i];
 | 
			
		||||
                        auto outBuff = output + offsets[i];
 | 
			
		||||
 | 
			
		||||
                        float max = -DataTypeUtils::max<float>();
 | 
			
		||||
                        float sum = 0.f;
 | 
			
		||||
 | 
			
		||||
                        for (uint j = 0; j < tadLen; ++j)
 | 
			
		||||
                            max = sd::math::nd4j_max<float>(max, inBuff[j]);
 | 
			
		||||
 | 
			
		||||
                        for (uint j = 0; j < tadLen; ++j) {
 | 
			
		||||
                            float temp = sd::math::nd4j_exp<float, float>(inBuff[j] - max);
 | 
			
		||||
                            outBuff[j] = temp;
 | 
			
		||||
                            sum += temp;
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        for (uint j = 0; j < tadLen; ++j)
 | 
			
		||||
                            outBuff[j] /= sum;
 | 
			
		||||
                    }
 | 
			
		||||
@ -115,6 +139,8 @@ namespace sd {
 | 
			
		||||
                samediff::Threads::parallel_tad(func,0, numOfSubArrs);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            template <typename T>
 | 
			
		||||
            FORCEINLINE void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user