diff --git a/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp index bcb4f6cfb..e2c0f5183 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp @@ -85,11 +85,11 @@ namespace sd { template void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen); - +#ifdef _OPENMP template <> FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { +#pragma omp parallel for + for (Nd4jLong i = 0; i < numOfSubArrs; i++) { auto inBuff = input + offsets[i]; auto outBuff = output + offsets[i]; @@ -107,6 +107,30 @@ namespace sd { sum += temp; } + for (uint j = 0; j < tadLen; ++j) + outBuff[j] /= sum; + } + } +#else + template <> + FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input + offsets[i]; + auto outBuff = output + offsets[i]; + + float max = -DataTypeUtils::max(); + float sum = 0.f; + + for (uint j = 0; j < tadLen; ++j) + max = sd::math::nd4j_max(max, inBuff[j]); + + for (uint j = 0; j < tadLen; ++j) { + float temp = sd::math::nd4j_exp(inBuff[j] - max); + outBuff[j] = temp; + sum += temp; + } + for (uint j = 0; j < tadLen; ++j) outBuff[j] /= sum; } @@ -115,6 +139,8 @@ namespace sd { samediff::Threads::parallel_tad(func,0, numOfSubArrs); } +#endif + template FORCEINLINE void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {