cavis/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp

245 lines
8.8 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by Yurii Shyrma on 11.12.2017
//
#include<cmath>
#include <DataTypeUtils.h>
#include<ops/declarable/helpers/betaInc.h>
#include <NDArrayFactory.h>
namespace nd4j {
namespace ops {
namespace helpers {
const int maxIter = 10000; // max number of loop iterations in function for continued fractions
const int maxValue = 3000; // if a and b are both > maxValue, then apply Gauss-Legendre quadrature.
// 18 values of abscissas and weights for 36-point Gauss-Legendre integration,
// take a note - weights and abscissas are symmetric around the midpoint of the range of integration: 36/2 = 18
const double abscissas[18] = {0.0021695375159141994,
0.011413521097787704,0.027972308950302116,0.051727015600492421,
0.082502225484340941, 0.12007019910960293,0.16415283300752470,
0.21442376986779355, 0.27051082840644336, 0.33199876341447887,
0.39843234186401943, 0.46931971407375483, 0.54413605556657973,
0.62232745288031077, 0.70331500465597174, 0.78649910768313447,
0.87126389619061517, 0.95698180152629142};
const double weights[18] = {0.0055657196642445571,
0.012915947284065419,0.020181515297735382,0.027298621498568734,
0.034213810770299537,0.040875750923643261,0.047235083490265582,
0.053244713977759692,0.058860144245324798,0.064039797355015485,
0.068745323835736408,0.072941885005653087,0.076598410645870640,
0.079687828912071670,0.082187266704339706,0.084078218979661945,
0.085346685739338721,0.085983275670394821};
///////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////
// modified Lentzs algorithm for continued fractions,
// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions,”
template <typename T>
static T continFract(const T a, const T b, const T x) {
const T min = DataTypeUtils::min<T>() / DataTypeUtils::eps<T>();
const T amu = a - (T)1.;
const T apu = a + (T)1.;
const T apb = a + b;
// first iteration
T coeff1 = (T)1.;
T coeff2 = (T)1. - apb * x / apu;
if(math::nd4j_abs<T>(coeff2) < min)
coeff2 = min;
coeff2 = (T)1./coeff2;
T result = coeff2;
T val, delta;
int i2;
// rest iterations
for(int i=1; i <= maxIter; i+=2) {
i2 = 2*i;
// even step
val = i * (b - (T)i) * x / ((amu + (T)i2) * (a + (T)i2));
coeff2 = (T)(1.) + val * coeff2;
if(math::nd4j_abs<T>(coeff2) < min)
coeff2 = min;
coeff2 = (T)1. / coeff2;
coeff1 = (T)(1.) + val / coeff1;
if(math::nd4j_abs<T>(coeff1) < min)
coeff1 = min;
result *= coeff1 * coeff2;
//***********************************************//
// odd step
val = -(a + (T)i) * (apb + (T)i) * x / ((a + (T)i2) * (apu + (T)i2));
coeff2 = (T)(1.) + val * coeff2;
if(math::nd4j_abs<T>(coeff2) < min)
coeff2 = min;
coeff2 = (T)1. / coeff2;
coeff1 = (T)(1.) + val / coeff1;
if(math::nd4j_abs<T>(coeff1) < min)
coeff1 = min;
delta = coeff1 * coeff2;
result *= delta;
// condition to stop loop
if(math::nd4j_abs<T>(delta - (T)1.) <= DataTypeUtils::eps<T>())
break;
}
return result;
}
///////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////
// evaluates incomplete beta integral using Gauss-Legendre quadrature method
template <typename T>
static T gausLegQuad(const T a, const T b, const T x) {
T upLim, t, result;
T sum = (T)0.;
T amu = a - (T)1.;
T bmu = b - (T)1.;
T rat = a / (a + b);
T lnrat = math::nd4j_log<T,T>(rat);
T lnratm = math::nd4j_log<T,T>((T)1. - rat);
t = math::nd4j_sqrt<T,T>(a * b /((a + b) * (a + b) * (a + b + (T)1.)));
if (x > rat) {
if (x >= (T)1.)
return (T)1.0;
upLim = math::nd4j_min<T>((T)1., math::nd4j_max<T>(rat + (T)1.*t, x + (T)5.*t));
}
else {
if (x <= (T)0.)
return (T)0.;
upLim = math::nd4j_max<T>(0., math::nd4j_min<T>(rat - (T)10.*t, x - (T)5.*t));
}
// Gauss-Legendre
PRAGMA_OMP_SIMD_SUM(sum)
for (int i = 0; i < 18; ++i) {
auto t = x + (upLim - x) * (T)abscissas[i];
sum += (T)weights[i] * math::nd4j_exp<T,T>(amu * (math::nd4j_log<T,T>(t) - lnrat) + bmu * (math::nd4j_log<T,T>((T)1. - t) - lnratm));
}
if (std::is_same<T, double>::value) {
result = sum * (upLim - x) * math::nd4j_exp<T,T>(amu * lnrat - lgamma(static_cast<double>(a)) + bmu * lnratm - lgamma(static_cast<double>(b)) + lgamma(static_cast<double>(a + b)));
} else {
result = sum * (upLim - x) * math::nd4j_exp<T,T>(amu * lnrat - lgamma((float) a) + bmu * lnratm - lgamma((float) b) + lgamma(static_cast<float>(a + b)));
}
if(result > (T)0.)
return (T)1. - result;
return -result;
}
///////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////
// evaluates incomplete beta function for positive a and b, and x between 0 and 1.
template <typename T>
static T betaIncTA(T a, T b, T x) {
// if (a <= (T)0. || b <= (T)0.)
// throw("betaInc function: a and b must be > 0 !");
// if (x < (T)0. || x > (T)1.)
// throw("betaInc function: x must be within (0, 1) interval !");
// t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5
if(a == b && x == (T)0.5)
return (T)0.5;
if (x == (T)0. || x == (T)1.)
return x;
if (a > (T)maxValue && b > (T)maxValue)
return gausLegQuad<T>(a, b, x);
T front = math::nd4j_exp<T,T>( lgamma(static_cast<double>(a + b)) - lgamma(static_cast<double>(a)) - lgamma(static_cast<double>(b)) + a * math::nd4j_log<T, T>(x) + b * math::nd4j_log<T, T>((T)1. - x));
// continued fractions
if (x < (a + (T)1.) / (a + b + (T)2.))
return front * continFract(a, b, x) / a;
// symmetry relation
else
return (T)1. - front * continFract(b, a, (T)1. - x) / b;
}
template<typename T>
NDArray betaIncT(nd4j::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x) {
auto result = NDArray(&x, false, x.getContext());
int xLen = x.lengthOf();
PRAGMA_OMP_PARALLEL_FOR_IF(xLen > Environment::getInstance()->elementwiseThreshold())
for(int i = 0; i < xLen; ++i) {
result.p(i, betaIncTA<T>(a.e<T>(i), b.e<T>(i), x.e<T>(i)));
}
return result;
}
///////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////
// overload betaInc for arrays, shapes of a, b and x must be the same !!!
NDArray betaInc(nd4j::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x) {
auto xType = a.dataType();
BUILD_SINGLE_SELECTOR(xType, return betaIncT, (context,a, b, x), FLOAT_TYPES);
return a;
}
template float continFract<float> (const float a, const float b, const float x);
template float16 continFract<float16>(const float16 a, const float16 b, const float16 x);
template bfloat16 continFract<bfloat16>(const bfloat16 a, const bfloat16 b, const bfloat16 x);
template double continFract<double> (const double a, const double b, const double x);
template float gausLegQuad<float> (const float a, const float b, const float x);
template float16 gausLegQuad<float16>(const float16 a, const float16 b, const float16 x);
template bfloat16 gausLegQuad<bfloat16>(const bfloat16 a, const bfloat16 b, const bfloat16 x);
template double gausLegQuad<double> (const double a, const double b, const double x);
template float betaIncTA<float> (const float a, const float b, const float x);
template float16 betaIncTA<float16>(const float16 a, const float16 b, const float16 x);
template bfloat16 betaIncTA<bfloat16>(const bfloat16 a, const bfloat16 b, const bfloat16 x);
template double betaIncTA<double> (const double a, const double b, const double x);
template NDArray betaIncT<float> (nd4j::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x);
template NDArray betaIncT<float16>(nd4j::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x);
template NDArray betaIncT<bfloat16>(nd4j::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x);
template NDArray betaIncT<double> (nd4j::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x);
}
}
}