commit
3662657d5c
|
@ -78,7 +78,9 @@
|
|||
(28, LogicalXor) ,\
|
||||
(29, LogicalNot) ,\
|
||||
(30, LogicalAnd), \
|
||||
(31, DivideNoNan)
|
||||
(31, DivideNoNan), \
|
||||
(32, IGamma), \
|
||||
(33, IGammac)
|
||||
|
||||
// these ops return same data type as input
|
||||
#define TRANSFORM_SAME_OPS \
|
||||
|
@ -245,7 +247,9 @@
|
|||
(43, TruncateMod) ,\
|
||||
(44, SquaredReverseSubtract) ,\
|
||||
(45, ReversePow), \
|
||||
(46, DivideNoNan)
|
||||
(46, DivideNoNan), \
|
||||
(47, IGamma), \
|
||||
(48, IGammac)
|
||||
|
||||
|
||||
|
||||
|
@ -380,7 +384,9 @@
|
|||
(35, AMinPairwise) ,\
|
||||
(36, TruncateMod), \
|
||||
(37, ReplaceNans), \
|
||||
(38, DivideNoNan)
|
||||
(38, DivideNoNan), \
|
||||
(39, IGamma), \
|
||||
(40, IGammac)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -49,6 +49,8 @@ namespace nd4j {
|
|||
static BroadcastOpsTuple DivideNoNan();
|
||||
static BroadcastOpsTuple Multiply();
|
||||
static BroadcastOpsTuple Subtract();
|
||||
static BroadcastOpsTuple IGamma();
|
||||
static BroadcastOpsTuple IGammac();
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author sgazeos@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_igamma)
|
||||
|
||||
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(igamma, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
//REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!");
|
||||
|
||||
// auto tZ = BroadcastHelper::broadcastApply({scalar::IGamma, pairwise::IGamma, broadcast::IGamma}, x, y, z);
|
||||
auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGamma(), x, y, z);
|
||||
|
||||
if (tZ == nullptr)
|
||||
return ND4J_STATUS_KERNEL_FAILURE;
|
||||
else if (tZ != z) {
|
||||
OVERWRITE_RESULT(tZ);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(igamma) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,58 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author sgazeos@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_igammac)
|
||||
|
||||
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(igammac, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
//REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!");
|
||||
|
||||
// auto tZ = BroadcastHelper::broadcastApply({scalar::IGammac, pairwise::IGammac, broadcast::IGammac}, x, y, z);
|
||||
auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGammac(), x, y, z);
|
||||
if (tZ == nullptr)
|
||||
return ND4J_STATUS_KERNEL_FAILURE;
|
||||
else if (tZ != z) {
|
||||
OVERWRITE_RESULT(tZ);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(igammac) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -357,6 +357,28 @@ namespace nd4j {
|
|||
#if NOT_EXCLUDED(OP_Pow)
|
||||
DECLARE_BROADCASTABLE_OP(Pow, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Broadcastable igamma implementation
|
||||
*
|
||||
* igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x)
|
||||
* Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt }
|
||||
* gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt }
|
||||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_igamma)
|
||||
DECLARE_BROADCASTABLE_OP(igamma, 0, 0);
|
||||
#endif
|
||||
/**
|
||||
* Broadcastable igammac implementation
|
||||
* igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x)
|
||||
* Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt }
|
||||
* Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt }
|
||||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_igammac)
|
||||
DECLARE_BROADCASTABLE_OP(igammac, 0, 0);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -48,4 +48,11 @@ namespace nd4j {
|
|||
BroadcastOpsTuple BroadcastOpsTuple::Subtract() {
|
||||
return custom(nd4j::scalar::Subtract, nd4j::pairwise::Subtract, nd4j::broadcast::Subtract);
|
||||
}
|
||||
BroadcastOpsTuple BroadcastOpsTuple::IGamma() {
|
||||
return custom(nd4j::scalar::IGamma, nd4j::pairwise::IGamma, nd4j::broadcast::IGamma);
|
||||
}
|
||||
BroadcastOpsTuple BroadcastOpsTuple::IGammac() {
|
||||
return custom(nd4j::scalar::IGammac, nd4j::pairwise::IGammac, nd4j::broadcast::IGammac);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1482,6 +1482,52 @@ namespace simdOps {
|
|||
};
|
||||
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
class IGamma {
|
||||
public:
|
||||
no_op_exec_special
|
||||
no_op_exec_special_cuda
|
||||
|
||||
op_def static Z op(X d1, Z *params) {
|
||||
return nd4j::math::nd4j_igamma<X, X, Z>(d1, params[0]);
|
||||
}
|
||||
|
||||
op_def static Z op(X d1, Y d2) {
|
||||
return nd4j::math::nd4j_igamma<X, Y, Z>(d1, d2);
|
||||
}
|
||||
|
||||
op_def static Z op(X d1, Y d2, Z *params) {
|
||||
return nd4j::math::nd4j_igamma<X, Y, Z>(d1, d2);
|
||||
}
|
||||
|
||||
op_def static Z op(X d1) {
|
||||
return d1;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
class IGammac {
|
||||
public:
|
||||
no_op_exec_special
|
||||
no_op_exec_special_cuda
|
||||
|
||||
op_def static Z op(X d1, Z *params) {
|
||||
return nd4j::math::nd4j_igammac<X, X, Z>(d1, params[0]);
|
||||
}
|
||||
|
||||
op_def static Z op(X d1, Y d2) {
|
||||
return nd4j::math::nd4j_igammac<X, Y, Z>(d1, d2);
|
||||
}
|
||||
|
||||
op_def static Z op(X d1, Y d2, Z *params) {
|
||||
return nd4j::math::nd4j_igammac<X, Y, Z>(d1, d2);
|
||||
}
|
||||
|
||||
op_def static Z op(X d1) {
|
||||
return d1;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class Round {
|
||||
public:
|
||||
|
|
|
@ -223,6 +223,12 @@ namespace nd4j {
|
|||
return nd4j_sgn<T, Z>(val);
|
||||
}
|
||||
|
||||
template<typename X, typename Z>
|
||||
math_def inline Z nd4j_gamma(X a);
|
||||
|
||||
template<typename X, typename Z>
|
||||
math_def inline Z nd4j_lgamma(X x);
|
||||
|
||||
//#ifndef __CUDACC__
|
||||
/*
|
||||
template<>
|
||||
|
@ -656,9 +662,56 @@ namespace nd4j {
|
|||
return p_pow<Z>(static_cast<Z>(val), static_cast<Z>(val2));
|
||||
}
|
||||
|
||||
/**
|
||||
* LogGamma(a) - float point extension of ln(n!)
|
||||
**/
|
||||
template <typename X, typename Z>
|
||||
math_def inline Z nd4j_lgamma(X x) {
|
||||
// if (x <= X(0.0))
|
||||
// {
|
||||
// std::stringstream os;
|
||||
// os << "Logarithm of Gamma has sence only for positive values, but " << x << " was given.";
|
||||
// throw std::invalid_argument( os.str() );
|
||||
// }
|
||||
|
||||
if (x < X(12.0)) {
|
||||
return nd4j_log<Z,Z>(nd4j_gamma<X,Z>(x));
|
||||
}
|
||||
|
||||
// Abramowitz and Stegun 6.1.41
|
||||
// Asymptotic series should be good to at least 11 or 12 figures
|
||||
// For error analysis, see Whittiker and Watson
|
||||
// A Course in Modern Analysis (1927), page 252
|
||||
|
||||
static const double c[8] = {
|
||||
1.0/12.0,
|
||||
-1.0/360.0,
|
||||
1.0/1260.0,
|
||||
-1.0/1680.0,
|
||||
1.0/1188.0,
|
||||
-691.0/360360.0,
|
||||
1.0/156.0,
|
||||
-3617.0/122400.0
|
||||
};
|
||||
|
||||
double z = Z(1.0 / Z(x * x));
|
||||
double sum = c[7];
|
||||
|
||||
for (int i = 6; i >= 0; i--) {
|
||||
sum *= z;
|
||||
sum += c[i];
|
||||
}
|
||||
|
||||
double series = sum / Z(x);
|
||||
|
||||
static const double halfLogTwoPi = 0.91893853320467274178032973640562;
|
||||
|
||||
return Z((double(x) - 0.5) * nd4j_log<X,double>(x) - double(x) + halfLogTwoPi + series);
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
|
||||
template<typename T>
|
||||
math_def inline T nd4j_re(T val1, T val2) {
|
||||
if (val1 == (T) 0.0f && val2 == (T) 0.0f)
|
||||
return (T) 0.0f;
|
||||
|
@ -731,7 +784,127 @@ namespace nd4j {
|
|||
template<typename T>
|
||||
math_def inline void nd4j_swap(T &val1, T &val2) {
|
||||
T temp = val1; val1=val2; val2=temp;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename X, typename Z>
|
||||
math_def inline Z nd4j_gamma(X a) {
|
||||
// nd4j_lgamma<X,Z>(a);
|
||||
// return (Z)std::tgamma(a);
|
||||
// Split the function domain into three intervals:
|
||||
// (0, 0.001), [0.001, 12), and (12, infinity)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// First interval: (0, 0.001)
|
||||
//
|
||||
// For small a, 1/Gamma(a) has power series a + gamma a^2 - ...
|
||||
// So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of a^3.
|
||||
// The relative error over this interval is less than 6e-7.
|
||||
|
||||
const double eulerGamma = 0.577215664901532860606512090; // Euler's gamma constant
|
||||
|
||||
if (a < X(0.001))
|
||||
return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a)));
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Second interval: [0.001, 12)
|
||||
|
||||
if (a < X(12.0)) {
|
||||
// The algorithm directly approximates gamma over (1,2) and uses
|
||||
// reduction identities to reduce other arguments to this interval.
|
||||
|
||||
double y = (double)a;
|
||||
int n = 0;
|
||||
bool argWasLessThanOne = y < 1.0;
|
||||
|
||||
// Add or subtract integers as necessary to bring y into (1,2)
|
||||
// Will correct for this below
|
||||
if (argWasLessThanOne) {
|
||||
y += 1.0;
|
||||
}
|
||||
else {
|
||||
n = static_cast<int>(floor(y)) - 1; // will use n later
|
||||
y -= n;
|
||||
}
|
||||
|
||||
// numerator coefficients for approximation over the interval (1,2)
|
||||
static const double p[] = {
|
||||
-1.71618513886549492533811E+0,
|
||||
2.47656508055759199108314E+1,
|
||||
-3.79804256470945635097577E+2,
|
||||
6.29331155312818442661052E+2,
|
||||
8.66966202790413211295064E+2,
|
||||
-3.14512729688483675254357E+4,
|
||||
-3.61444134186911729807069E+4,
|
||||
6.64561438202405440627855E+4
|
||||
};
|
||||
|
||||
// denominator coefficients for approximation over the interval (1,2)
|
||||
static const double q[] = {
|
||||
-3.08402300119738975254353E+1,
|
||||
3.15350626979604161529144E+2,
|
||||
-1.01515636749021914166146E+3,
|
||||
-3.10777167157231109440444E+3,
|
||||
2.25381184209801510330112E+4,
|
||||
4.75584627752788110767815E+3,
|
||||
-1.34659959864969306392456E+5,
|
||||
-1.15132259675553483497211E+5
|
||||
};
|
||||
|
||||
double num = 0.0;
|
||||
double den = 1.0;
|
||||
|
||||
|
||||
double z = y - 1;
|
||||
for (auto i = 0; i < 8; i++) {
|
||||
num = (num + p[i]) * z;
|
||||
den = den * z + q[i];
|
||||
}
|
||||
double result = num / den + 1.0;
|
||||
|
||||
// Apply correction if argument was not initially in (1,2)
|
||||
if (argWasLessThanOne) {
|
||||
// Use identity gamma(z) = gamma(z+1)/z
|
||||
// The variable "result" now holds gamma of the original y + 1
|
||||
// Thus we use y-1 to get back the orginal y.
|
||||
result /= (y - 1.0);
|
||||
}
|
||||
else {
|
||||
// Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z)
|
||||
for (auto i = 0; i < n; i++)
|
||||
result *= y++;
|
||||
}
|
||||
|
||||
return Z(result);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Third interval: [12, infinity)
|
||||
|
||||
if (a > 171.624) {
|
||||
// Correct answer too large to display. Force +infinity.
|
||||
return Z(DOUBLE_MAX_VALUE);
|
||||
//DataTypeUtils::infOrMax<Z>();
|
||||
}
|
||||
|
||||
return nd4j::math::nd4j_exp<Z,Z>(nd4j::math::nd4j_lgamma<X,Z>(a));
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
math_def inline Z nd4j_igamma(X a, Y x) {
|
||||
Z aim = nd4j_pow<X, X, Z>(x, a) / (nd4j_exp<X, Z>(x) * nd4j_gamma<Y, Z>(a));
|
||||
auto sum = Z(0.);
|
||||
auto denom = Z(1.);
|
||||
for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) {
|
||||
denom *= (a + i);
|
||||
sum += nd4j_pow<X, int, Z>(x, i) / denom;
|
||||
}
|
||||
return aim * sum;
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
math_def inline Z nd4j_igammac(X a, Y x) {
|
||||
return Z(1.) - nd4j_igamma<X, Y, Z>(a, x);
|
||||
}
|
||||
|
||||
#ifdef __CUDACC__
|
||||
namespace atomics {
|
||||
|
|
|
@ -537,6 +537,50 @@ TEST_F(DeclarableOpsTests10, atan2_test6) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, IGamma_Test1) {
|
||||
|
||||
auto y = NDArrayFactory::create<double>('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1});
|
||||
auto x = NDArrayFactory::create<double>('c', { 4}, {1.2, 2.2, 3.2, 4.2});
|
||||
|
||||
auto exp = NDArrayFactory::create<double>('c', {1,3,4}, {
|
||||
0.659917, 0.61757898, 0.59726304, 0.58478117,
|
||||
0.0066205109, 0.022211598, 0.040677428, 0.059117373,
|
||||
0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735});
|
||||
|
||||
nd4j::ops::igamma op;
|
||||
auto result = op.execute({&y, &x}, {}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
auto z = result->at(0);
|
||||
// z->printBuffer("OUtput");
|
||||
// exp.printBuffer("EXpect");
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, IGamma_Test2) {
|
||||
|
||||
auto y = NDArrayFactory::create<double>('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,
|
||||
7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1});
|
||||
auto x = NDArrayFactory::create<double>('c', { 4}, {1.2, 2.2, 3.2, 4.2});
|
||||
auto exp = NDArrayFactory::create<double>('c', {1,3,4}, {0.340083, 0.382421, 0.402737, 0.415221,
|
||||
0.993379, 0.977788, 0.959323, 0.940883,
|
||||
0.999996, 0.999914, 0.999564, 0.998773});
|
||||
|
||||
nd4j::ops::igammac op;
|
||||
auto result = op.execute({&y, &x}, {}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
auto z = result->at(0);
|
||||
// z->printBuffer("OUtput");
|
||||
// exp.printBuffer("EXpect");
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, range_test10) {
|
||||
|
|
Loading…
Reference in New Issue