/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by Yurii Shyrma on 11.12.2017 // #include #include #include #include #include namespace sd { namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// // modified Lentz’s algorithm for continued fractions, // reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions” template static T continuedFraction(const T a, const T b, const T x) { const T min = DataTypeUtils::min() / DataTypeUtils::eps(); const T aPlusb = a + b; T val, aPlus2i; T t2 = 1; T t1 = static_cast(1) - aPlusb * x / (a + static_cast(1)); if(math::nd4j_abs(t1) < min) t1 = min; t1 = static_cast(1) / t1; T result = t1; for(uint i = 1; i <= maxIter; ++i) { aPlus2i = a + static_cast(2*i); val = i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); // t1 t1 = static_cast(1) + val * t1; if(math::nd4j_abs(t1) < min) t1 = min; t1 = static_cast(1) / t1; // t2 t2 = static_cast(1) + val / t2; if(math::nd4j_abs(t2) < min) t2 = min; // result result *= t2 * t1; val = -(a + i) * (aPlusb + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); // t1 t1 = static_cast(1) + val * t1; if(math::nd4j_abs(t1) < min) t1 = min; t1 = static_cast(1) / t1; // t2 t2 = static_cast(1) + val / t2; if(math::nd4j_abs(t2) < min) t2 = min; // result val = t2 * t1; result *= val; // condition to stop loop if(math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) return result; } return DataTypeUtils::infOrMax(); // no convergence, more iterations is required, return infinity } /////////////////////////////////////////////////////////////////// // evaluates incomplete beta function for positive a and b, and x between 0 and 1. template static T betaIncCore(T a, T b, T x) { // if (a <= (T)0. || b <= (T)0.) // throw("betaInc function: a and b must be > 0 !"); // if (x < (T)0. || x > (T)1.) // throw("betaInc function: x must be within (0, 1) interval !"); // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 if(a == b && x == static_cast(0.5)) return static_cast(0.5); if (x == static_cast(0) || x == static_cast(1)) return x; const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); if (x <= (a + static_cast(1)) / (a + b + static_cast(2))) return front * continuedFraction(a, b, x) / a; else // symmetry relation return static_cast(1) - front * continuedFraction(b, a, static_cast(1) - x) / b; } /////////////////////////////////////////////////////////////////// template static void betaIncForArray(sd::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output) { int xLen = x.lengthOf(); auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) output.r(i) = betaIncCore(a.t(i), b.t(i), x.t(i)); }; samediff::Threads::parallel_for(func, 0, xLen); } /////////////////////////////////////////////////////////////////// // overload betaInc for arrays, shapes of a, b and x must be the same !!! void betaInc(sd::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output) { auto xType = a.dataType(); BUILD_SINGLE_SELECTOR(xType, betaIncForArray, (context, a, b, x, output), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template void betaIncForArray, (sd::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output), FLOAT_TYPES); } } }