parent
3bb22a6ff8
commit
784a2d13f8
|
@ -85,11 +85,11 @@ namespace sd {
|
|||
|
||||
template <typename T>
|
||||
void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen);
|
||||
|
||||
#ifdef _OPENMP
|
||||
template <>
|
||||
FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
#pragma omp parallel for
|
||||
for (Nd4jLong i = 0; i < numOfSubArrs; i++) {
|
||||
auto inBuff = input + offsets[i];
|
||||
auto outBuff = output + offsets[i];
|
||||
|
||||
|
@ -107,6 +107,30 @@ namespace sd {
|
|||
sum += temp;
|
||||
}
|
||||
|
||||
for (uint j = 0; j < tadLen; ++j)
|
||||
outBuff[j] /= sum;
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <>
|
||||
FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto inBuff = input + offsets[i];
|
||||
auto outBuff = output + offsets[i];
|
||||
|
||||
float max = -DataTypeUtils::max<float>();
|
||||
float sum = 0.f;
|
||||
|
||||
for (uint j = 0; j < tadLen; ++j)
|
||||
max = sd::math::nd4j_max<float>(max, inBuff[j]);
|
||||
|
||||
for (uint j = 0; j < tadLen; ++j) {
|
||||
float temp = sd::math::nd4j_exp<float, float>(inBuff[j] - max);
|
||||
outBuff[j] = temp;
|
||||
sum += temp;
|
||||
}
|
||||
|
||||
for (uint j = 0; j < tadLen; ++j)
|
||||
outBuff[j] /= sum;
|
||||
}
|
||||
|
@ -115,6 +139,8 @@ namespace sd {
|
|||
samediff::Threads::parallel_tad(func,0, numOfSubArrs);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
|
||||
|
|
Loading…
Reference in New Issue