separate omp impl for softmax (#289)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-03-05 11:14:22 +03:00 committed by GitHub
parent 3bb22a6ff8
commit 784a2d13f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 29 additions and 3 deletions

View File

@ -85,11 +85,11 @@ namespace sd {
template <typename T> template <typename T>
void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen); void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen);
#ifdef _OPENMP
template <> template <>
FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) { FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
auto func = PRAGMA_THREADS_FOR { #pragma omp parallel for
for (auto i = start; i < stop; i++) { for (Nd4jLong i = 0; i < numOfSubArrs; i++) {
auto inBuff = input + offsets[i]; auto inBuff = input + offsets[i];
auto outBuff = output + offsets[i]; auto outBuff = output + offsets[i];
@ -107,6 +107,30 @@ namespace sd {
sum += temp; sum += temp;
} }
for (uint j = 0; j < tadLen; ++j)
outBuff[j] /= sum;
}
}
#else
template <>
FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) {
auto inBuff = input + offsets[i];
auto outBuff = output + offsets[i];
float max = -DataTypeUtils::max<float>();
float sum = 0.f;
for (uint j = 0; j < tadLen; ++j)
max = sd::math::nd4j_max<float>(max, inBuff[j]);
for (uint j = 0; j < tadLen; ++j) {
float temp = sd::math::nd4j_exp<float, float>(inBuff[j] - max);
outBuff[j] = temp;
sum += temp;
}
for (uint j = 0; j < tadLen; ++j) for (uint j = 0; j < tadLen; ++j)
outBuff[j] /= sum; outBuff[j] /= sum;
} }
@ -115,6 +139,8 @@ namespace sd {
samediff::Threads::parallel_tad(func,0, numOfSubArrs); samediff::Threads::parallel_tad(func,0, numOfSubArrs);
} }
#endif
template <typename T> template <typename T>
FORCEINLINE void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) { FORCEINLINE void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {