diff --git a/libnd4j/include/templatemath.h b/libnd4j/include/templatemath.h index 2ce8cd338..f5b3ab4ed 100644 --- a/libnd4j/include/templatemath.h +++ b/libnd4j/include/templatemath.h @@ -731,28 +731,29 @@ namespace nd4j { template math_def inline void nd4j_swap(T &val1, T &val2) { T temp = val1; val1=val2; val2=temp; - }; + }; -// template -// math_def inline Z nd4j_gamma(X a) { -// -// } - template - math_def inline Z nd4j_igamma(X a, X x) { - Z aim = nd4j_pow(x, a) / (nd4j_exp(x) * std::tgamma(a)); + template + math_def inline Z nd4j_gamma(X a) { + return (Z)std::tgamma(a); + } + + template + math_def inline Z nd4j_igamma(X a, Y x) { + Z aim = nd4j_pow(x, a) / (nd4j_exp(x) * nd4j_gamma(a)); auto sum = Z(0.); auto denom = Z(1.); - for (auto i = 0; Z(1./denom) > Z(1.0e-6); i++) { + for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) { denom *= (a + i); - sum += nd4j_pow(x, i) / denom; + sum += nd4j_pow(x, i) / denom; } return aim * sum; - } + } - template - math_def inline Z nd4j_igammac(X a, X x) { - return Z(1.) - nd4j_igamma(a, x); - } + template + math_def inline Z nd4j_igammac(X a, Y x) { + return Z(1.) - nd4j_igamma(a, x); + } #ifdef __CUDACC__ namespace atomics { @@ -1494,4 +1495,4 @@ inline __device__ bfloat16 nd4j_atomicDiv(bfloat16* address, bfloat16 } -#endif /* TEMPLATEMATH_H_ */ \ No newline at end of file +#endif /* TEMPLATEMATH_H_ */