Added gamma and lgamma functions.
parent
7617682a46
commit
24a2b2933f
|
@ -223,6 +223,12 @@ namespace nd4j {
|
||||||
return nd4j_sgn<T, Z>(val);
|
return nd4j_sgn<T, Z>(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
math_def inline Z nd4j_gamma(X a);
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
math_def inline Z nd4j_lgamma(X x);
|
||||||
|
|
||||||
//#ifndef __CUDACC__
|
//#ifndef __CUDACC__
|
||||||
/*
|
/*
|
||||||
template<>
|
template<>
|
||||||
|
@ -656,6 +662,53 @@ namespace nd4j {
|
||||||
return p_pow<Z>(static_cast<Z>(val), static_cast<Z>(val2));
|
return p_pow<Z>(static_cast<Z>(val), static_cast<Z>(val2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LogGamma(a) - float point extension of ln(n!)
|
||||||
|
**/
|
||||||
|
template <typename X, typename Z>
|
||||||
|
math_def inline Z nd4j_lgamma(X x) {
|
||||||
|
// if (x <= X(0.0))
|
||||||
|
// {
|
||||||
|
// std::stringstream os;
|
||||||
|
// os << "Logarithm of Gamma has sence only for positive values, but " << x << " was given.";
|
||||||
|
// throw std::invalid_argument( os.str() );
|
||||||
|
// }
|
||||||
|
|
||||||
|
if (x < X(12.0)) {
|
||||||
|
return nd4j_log<Z,Z>(nd4j_gamma<X,Z>(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abramowitz and Stegun 6.1.41
|
||||||
|
// Asymptotic series should be good to at least 11 or 12 figures
|
||||||
|
// For error analysis, see Whittiker and Watson
|
||||||
|
// A Course in Modern Analysis (1927), page 252
|
||||||
|
|
||||||
|
static const double c[8] = {
|
||||||
|
1.0/12.0,
|
||||||
|
-1.0/360.0,
|
||||||
|
1.0/1260.0,
|
||||||
|
-1.0/1680.0,
|
||||||
|
1.0/1188.0,
|
||||||
|
-691.0/360360.0,
|
||||||
|
1.0/156.0,
|
||||||
|
-3617.0/122400.0
|
||||||
|
};
|
||||||
|
|
||||||
|
double z = Z(1.0 / Z(x * x));
|
||||||
|
double sum = c[7];
|
||||||
|
|
||||||
|
for (int i = 6; i >= 0; i--) {
|
||||||
|
sum *= z;
|
||||||
|
sum += c[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
double series = sum / Z(x);
|
||||||
|
|
||||||
|
static const double halfLogTwoPi = 0.91893853320467274178032973640562;
|
||||||
|
|
||||||
|
return Z((double(x) - 0.5) * nd4j_log<X,double>(x) - double(x) + halfLogTwoPi + series);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
@ -735,7 +788,105 @@ namespace nd4j {
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
math_def inline Z nd4j_gamma(X a) {
|
math_def inline Z nd4j_gamma(X a) {
|
||||||
return (Z)std::tgamma(a);
|
// nd4j_lgamma<X,Z>(a);
|
||||||
|
// return (Z)std::tgamma(a);
|
||||||
|
// Split the function domain into three intervals:
|
||||||
|
// (0, 0.001), [0.001, 12), and (12, infinity)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
// First interval: (0, 0.001)
|
||||||
|
//
|
||||||
|
// For small a, 1/Gamma(a) has power series a + gamma a^2 - ...
|
||||||
|
// So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of a^3.
|
||||||
|
// The relative error over this interval is less than 6e-7.
|
||||||
|
|
||||||
|
const double eulerGamma = 0.577215664901532860606512090; // Euler's gamma constant
|
||||||
|
|
||||||
|
if (a < X(0.001))
|
||||||
|
return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a)));
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
// Second interval: [0.001, 12)
|
||||||
|
|
||||||
|
if (a < X(12.0)) {
|
||||||
|
// The algorithm directly approximates gamma over (1,2) and uses
|
||||||
|
// reduction identities to reduce other arguments to this interval.
|
||||||
|
|
||||||
|
double y = (double)a;
|
||||||
|
int n = 0;
|
||||||
|
bool argWasLessThanOne = y < 1.0;
|
||||||
|
|
||||||
|
// Add or subtract integers as necessary to bring y into (1,2)
|
||||||
|
// Will correct for this below
|
||||||
|
if (argWasLessThanOne) {
|
||||||
|
y += 1.0;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
n = static_cast<int>(floor(y)) - 1; // will use n later
|
||||||
|
y -= n;
|
||||||
|
}
|
||||||
|
|
||||||
|
// numerator coefficients for approximation over the interval (1,2)
|
||||||
|
static const double p[] = {
|
||||||
|
-1.71618513886549492533811E+0,
|
||||||
|
2.47656508055759199108314E+1,
|
||||||
|
-3.79804256470945635097577E+2,
|
||||||
|
6.29331155312818442661052E+2,
|
||||||
|
8.66966202790413211295064E+2,
|
||||||
|
-3.14512729688483675254357E+4,
|
||||||
|
-3.61444134186911729807069E+4,
|
||||||
|
6.64561438202405440627855E+4
|
||||||
|
};
|
||||||
|
|
||||||
|
// denominator coefficients for approximation over the interval (1,2)
|
||||||
|
static const double q[] = {
|
||||||
|
-3.08402300119738975254353E+1,
|
||||||
|
3.15350626979604161529144E+2,
|
||||||
|
-1.01515636749021914166146E+3,
|
||||||
|
-3.10777167157231109440444E+3,
|
||||||
|
2.25381184209801510330112E+4,
|
||||||
|
4.75584627752788110767815E+3,
|
||||||
|
-1.34659959864969306392456E+5,
|
||||||
|
-1.15132259675553483497211E+5
|
||||||
|
};
|
||||||
|
|
||||||
|
double num = 0.0;
|
||||||
|
double den = 1.0;
|
||||||
|
|
||||||
|
|
||||||
|
double z = y - 1;
|
||||||
|
for (auto i = 0; i < 8; i++) {
|
||||||
|
num = (num + p[i]) * z;
|
||||||
|
den = den * z + q[i];
|
||||||
|
}
|
||||||
|
double result = num / den + 1.0;
|
||||||
|
|
||||||
|
// Apply correction if argument was not initially in (1,2)
|
||||||
|
if (argWasLessThanOne) {
|
||||||
|
// Use identity gamma(z) = gamma(z+1)/z
|
||||||
|
// The variable "result" now holds gamma of the original y + 1
|
||||||
|
// Thus we use y-1 to get back the orginal y.
|
||||||
|
result /= (y - 1.0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z)
|
||||||
|
for (auto i = 0; i < n; i++)
|
||||||
|
result *= y++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Z(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
// Third interval: [12, infinity)
|
||||||
|
|
||||||
|
if (a > 171.624) {
|
||||||
|
// Correct answer too large to display. Force +infinity.
|
||||||
|
return Z(DOUBLE_MAX_VALUE);
|
||||||
|
//DataTypeUtils::infOrMax<Z>();
|
||||||
|
}
|
||||||
|
|
||||||
|
return nd4j::math::nd4j_exp<Z,Z>(nd4j::math::nd4j_lgamma<X,Z>(a));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
template <typename X, typename Y, typename Z>
|
||||||
|
|
Loading…
Reference in New Issue