parent
3bb22a6ff8
commit
784a2d13f8
|
@ -85,11 +85,11 @@ namespace sd {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen);
|
void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen);
|
||||||
|
#ifdef _OPENMP
|
||||||
template <>
|
template <>
|
||||||
FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
|
FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
#pragma omp parallel for
|
||||||
for (auto i = start; i < stop; i++) {
|
for (Nd4jLong i = 0; i < numOfSubArrs; i++) {
|
||||||
auto inBuff = input + offsets[i];
|
auto inBuff = input + offsets[i];
|
||||||
auto outBuff = output + offsets[i];
|
auto outBuff = output + offsets[i];
|
||||||
|
|
||||||
|
@ -107,6 +107,30 @@ namespace sd {
|
||||||
sum += temp;
|
sum += temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (uint j = 0; j < tadLen; ++j)
|
||||||
|
outBuff[j] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
template <>
|
||||||
|
FORCEINLINE void softmax_loop(float *input, float *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
auto inBuff = input + offsets[i];
|
||||||
|
auto outBuff = output + offsets[i];
|
||||||
|
|
||||||
|
float max = -DataTypeUtils::max<float>();
|
||||||
|
float sum = 0.f;
|
||||||
|
|
||||||
|
for (uint j = 0; j < tadLen; ++j)
|
||||||
|
max = sd::math::nd4j_max<float>(max, inBuff[j]);
|
||||||
|
|
||||||
|
for (uint j = 0; j < tadLen; ++j) {
|
||||||
|
float temp = sd::math::nd4j_exp<float, float>(inBuff[j] - max);
|
||||||
|
outBuff[j] = temp;
|
||||||
|
sum += temp;
|
||||||
|
}
|
||||||
|
|
||||||
for (uint j = 0; j < tadLen; ++j)
|
for (uint j = 0; j < tadLen; ++j)
|
||||||
outBuff[j] /= sum;
|
outBuff[j] /= sum;
|
||||||
}
|
}
|
||||||
|
@ -115,6 +139,8 @@ namespace sd {
|
||||||
samediff::Threads::parallel_tad(func,0, numOfSubArrs);
|
samediff::Threads::parallel_tad(func,0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCEINLINE void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
|
FORCEINLINE void softmax_loop(T *input, T *output, Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) {
|
||||||
|
|
Loading…
Reference in New Issue