Merge pull request #1 from KonduitAI/shugeo_gamma

Shugeo gamma
master
shugeo 2019-10-16 18:49:33 +03:00 committed by GitHub
commit 3662657d5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 423 additions and 6 deletions

View File

@ -78,7 +78,9 @@
(28, LogicalXor) ,\ (28, LogicalXor) ,\
(29, LogicalNot) ,\ (29, LogicalNot) ,\
(30, LogicalAnd), \ (30, LogicalAnd), \
(31, DivideNoNan) (31, DivideNoNan), \
(32, IGamma), \
(33, IGammac)
// these ops return same data type as input // these ops return same data type as input
#define TRANSFORM_SAME_OPS \ #define TRANSFORM_SAME_OPS \
@ -245,7 +247,9 @@
(43, TruncateMod) ,\ (43, TruncateMod) ,\
(44, SquaredReverseSubtract) ,\ (44, SquaredReverseSubtract) ,\
(45, ReversePow), \ (45, ReversePow), \
(46, DivideNoNan) (46, DivideNoNan), \
(47, IGamma), \
(48, IGammac)
@ -380,7 +384,9 @@
(35, AMinPairwise) ,\ (35, AMinPairwise) ,\
(36, TruncateMod), \ (36, TruncateMod), \
(37, ReplaceNans), \ (37, ReplaceNans), \
(38, DivideNoNan) (38, DivideNoNan), \
(39, IGamma), \
(40, IGammac)

View File

@ -49,6 +49,8 @@ namespace nd4j {
static BroadcastOpsTuple DivideNoNan(); static BroadcastOpsTuple DivideNoNan();
static BroadcastOpsTuple Multiply(); static BroadcastOpsTuple Multiply();
static BroadcastOpsTuple Subtract(); static BroadcastOpsTuple Subtract();
static BroadcastOpsTuple IGamma();
static BroadcastOpsTuple IGammac();
}; };
} }

View File

@ -0,0 +1,59 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_igamma)
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(igamma, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
//REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!");
// auto tZ = BroadcastHelper::broadcastApply({scalar::IGamma, pairwise::IGamma, broadcast::IGamma}, x, y, z);
auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGamma(), x, y, z);
if (tZ == nullptr)
return ND4J_STATUS_KERNEL_FAILURE;
else if (tZ != z) {
OVERWRITE_RESULT(tZ);
}
return Status::OK();
}
DECLARE_TYPES(igamma) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS});
}
}
}
#endif

View File

@ -0,0 +1,58 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_igammac)
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(igammac, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
//REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!");
// auto tZ = BroadcastHelper::broadcastApply({scalar::IGammac, pairwise::IGammac, broadcast::IGammac}, x, y, z);
auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGammac(), x, y, z);
if (tZ == nullptr)
return ND4J_STATUS_KERNEL_FAILURE;
else if (tZ != z) {
OVERWRITE_RESULT(tZ);
}
return Status::OK();
}
DECLARE_TYPES(igammac) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS});
}
}
}
#endif

View File

@ -357,6 +357,28 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_Pow) #if NOT_EXCLUDED(OP_Pow)
DECLARE_BROADCASTABLE_OP(Pow, 0, 0); DECLARE_BROADCASTABLE_OP(Pow, 0, 0);
#endif #endif
/**
* Broadcastable igamma implementation
*
* igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x)
* Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt }
* gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt }
* @tparam T
*/
#if NOT_EXCLUDED(OP_igamma)
DECLARE_BROADCASTABLE_OP(igamma, 0, 0);
#endif
/**
* Broadcastable igammac implementation
* igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x)
* Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt }
* Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt }
* @tparam T
*/
#if NOT_EXCLUDED(OP_igammac)
DECLARE_BROADCASTABLE_OP(igammac, 0, 0);
#endif
} }
} }

View File

@ -48,4 +48,11 @@ namespace nd4j {
BroadcastOpsTuple BroadcastOpsTuple::Subtract() { BroadcastOpsTuple BroadcastOpsTuple::Subtract() {
return custom(nd4j::scalar::Subtract, nd4j::pairwise::Subtract, nd4j::broadcast::Subtract); return custom(nd4j::scalar::Subtract, nd4j::pairwise::Subtract, nd4j::broadcast::Subtract);
} }
BroadcastOpsTuple BroadcastOpsTuple::IGamma() {
return custom(nd4j::scalar::IGamma, nd4j::pairwise::IGamma, nd4j::broadcast::IGamma);
}
BroadcastOpsTuple BroadcastOpsTuple::IGammac() {
return custom(nd4j::scalar::IGammac, nd4j::pairwise::IGammac, nd4j::broadcast::IGammac);
}
} }

View File

@ -1482,6 +1482,52 @@ namespace simdOps {
}; };
template <typename X, typename Y, typename Z>
class IGamma {
public:
no_op_exec_special
no_op_exec_special_cuda
op_def static Z op(X d1, Z *params) {
return nd4j::math::nd4j_igamma<X, X, Z>(d1, params[0]);
}
op_def static Z op(X d1, Y d2) {
return nd4j::math::nd4j_igamma<X, Y, Z>(d1, d2);
}
op_def static Z op(X d1, Y d2, Z *params) {
return nd4j::math::nd4j_igamma<X, Y, Z>(d1, d2);
}
op_def static Z op(X d1) {
return d1;
}
};
template <typename X, typename Y, typename Z>
class IGammac {
public:
no_op_exec_special
no_op_exec_special_cuda
op_def static Z op(X d1, Z *params) {
return nd4j::math::nd4j_igammac<X, X, Z>(d1, params[0]);
}
op_def static Z op(X d1, Y d2) {
return nd4j::math::nd4j_igammac<X, Y, Z>(d1, d2);
}
op_def static Z op(X d1, Y d2, Z *params) {
return nd4j::math::nd4j_igammac<X, Y, Z>(d1, d2);
}
op_def static Z op(X d1) {
return d1;
}
};
template <typename X> template <typename X>
class Round { class Round {
public: public:

View File

@ -223,6 +223,12 @@ namespace nd4j {
return nd4j_sgn<T, Z>(val); return nd4j_sgn<T, Z>(val);
} }
template<typename X, typename Z>
math_def inline Z nd4j_gamma(X a);
template<typename X, typename Z>
math_def inline Z nd4j_lgamma(X x);
//#ifndef __CUDACC__ //#ifndef __CUDACC__
/* /*
template<> template<>
@ -656,9 +662,56 @@ namespace nd4j {
return p_pow<Z>(static_cast<Z>(val), static_cast<Z>(val2)); return p_pow<Z>(static_cast<Z>(val), static_cast<Z>(val2));
} }
/**
* LogGamma(a) - float point extension of ln(n!)
**/
template <typename X, typename Z>
math_def inline Z nd4j_lgamma(X x) {
// if (x <= X(0.0))
// {
// std::stringstream os;
// os << "Logarithm of Gamma has sence only for positive values, but " << x << " was given.";
// throw std::invalid_argument( os.str() );
// }
if (x < X(12.0)) {
return nd4j_log<Z,Z>(nd4j_gamma<X,Z>(x));
}
// Abramowitz and Stegun 6.1.41
// Asymptotic series should be good to at least 11 or 12 figures
// For error analysis, see Whittiker and Watson
// A Course in Modern Analysis (1927), page 252
static const double c[8] = {
1.0/12.0,
-1.0/360.0,
1.0/1260.0,
-1.0/1680.0,
1.0/1188.0,
-691.0/360360.0,
1.0/156.0,
-3617.0/122400.0
};
double z = Z(1.0 / Z(x * x));
double sum = c[7];
for (int i = 6; i >= 0; i--) {
sum *= z;
sum += c[i];
}
double series = sum / Z(x);
static const double halfLogTwoPi = 0.91893853320467274178032973640562;
return Z((double(x) - 0.5) * nd4j_log<X,double>(x) - double(x) + halfLogTwoPi + series);
}
template<typename T>
template<typename T>
math_def inline T nd4j_re(T val1, T val2) { math_def inline T nd4j_re(T val1, T val2) {
if (val1 == (T) 0.0f && val2 == (T) 0.0f) if (val1 == (T) 0.0f && val2 == (T) 0.0f)
return (T) 0.0f; return (T) 0.0f;
@ -731,7 +784,127 @@ namespace nd4j {
template<typename T> template<typename T>
math_def inline void nd4j_swap(T &val1, T &val2) { math_def inline void nd4j_swap(T &val1, T &val2) {
T temp = val1; val1=val2; val2=temp; T temp = val1; val1=val2; val2=temp;
}; };
template <typename X, typename Z>
math_def inline Z nd4j_gamma(X a) {
// nd4j_lgamma<X,Z>(a);
// return (Z)std::tgamma(a);
// Split the function domain into three intervals:
// (0, 0.001), [0.001, 12), and (12, infinity)
///////////////////////////////////////////////////////////////////////////
// First interval: (0, 0.001)
//
// For small a, 1/Gamma(a) has power series a + gamma a^2 - ...
// So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of a^3.
// The relative error over this interval is less than 6e-7.
const double eulerGamma = 0.577215664901532860606512090; // Euler's gamma constant
if (a < X(0.001))
return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a)));
///////////////////////////////////////////////////////////////////////////
// Second interval: [0.001, 12)
if (a < X(12.0)) {
// The algorithm directly approximates gamma over (1,2) and uses
// reduction identities to reduce other arguments to this interval.
double y = (double)a;
int n = 0;
bool argWasLessThanOne = y < 1.0;
// Add or subtract integers as necessary to bring y into (1,2)
// Will correct for this below
if (argWasLessThanOne) {
y += 1.0;
}
else {
n = static_cast<int>(floor(y)) - 1; // will use n later
y -= n;
}
// numerator coefficients for approximation over the interval (1,2)
static const double p[] = {
-1.71618513886549492533811E+0,
2.47656508055759199108314E+1,
-3.79804256470945635097577E+2,
6.29331155312818442661052E+2,
8.66966202790413211295064E+2,
-3.14512729688483675254357E+4,
-3.61444134186911729807069E+4,
6.64561438202405440627855E+4
};
// denominator coefficients for approximation over the interval (1,2)
static const double q[] = {
-3.08402300119738975254353E+1,
3.15350626979604161529144E+2,
-1.01515636749021914166146E+3,
-3.10777167157231109440444E+3,
2.25381184209801510330112E+4,
4.75584627752788110767815E+3,
-1.34659959864969306392456E+5,
-1.15132259675553483497211E+5
};
double num = 0.0;
double den = 1.0;
double z = y - 1;
for (auto i = 0; i < 8; i++) {
num = (num + p[i]) * z;
den = den * z + q[i];
}
double result = num / den + 1.0;
// Apply correction if argument was not initially in (1,2)
if (argWasLessThanOne) {
// Use identity gamma(z) = gamma(z+1)/z
// The variable "result" now holds gamma of the original y + 1
// Thus we use y-1 to get back the orginal y.
result /= (y - 1.0);
}
else {
// Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z)
for (auto i = 0; i < n; i++)
result *= y++;
}
return Z(result);
}
///////////////////////////////////////////////////////////////////////////
// Third interval: [12, infinity)
if (a > 171.624) {
// Correct answer too large to display. Force +infinity.
return Z(DOUBLE_MAX_VALUE);
//DataTypeUtils::infOrMax<Z>();
}
return nd4j::math::nd4j_exp<Z,Z>(nd4j::math::nd4j_lgamma<X,Z>(a));
}
template <typename X, typename Y, typename Z>
math_def inline Z nd4j_igamma(X a, Y x) {
Z aim = nd4j_pow<X, X, Z>(x, a) / (nd4j_exp<X, Z>(x) * nd4j_gamma<Y, Z>(a));
auto sum = Z(0.);
auto denom = Z(1.);
for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) {
denom *= (a + i);
sum += nd4j_pow<X, int, Z>(x, i) / denom;
}
return aim * sum;
}
template <typename X, typename Y, typename Z>
math_def inline Z nd4j_igammac(X a, Y x) {
return Z(1.) - nd4j_igamma<X, Y, Z>(a, x);
}
#ifdef __CUDACC__ #ifdef __CUDACC__
namespace atomics { namespace atomics {

View File

@ -537,6 +537,50 @@ TEST_F(DeclarableOpsTests10, atan2_test6) {
delete result; delete result;
} }
//////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, IGamma_Test1) {
auto y = NDArrayFactory::create<double>('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1});
auto x = NDArrayFactory::create<double>('c', { 4}, {1.2, 2.2, 3.2, 4.2});
auto exp = NDArrayFactory::create<double>('c', {1,3,4}, {
0.659917, 0.61757898, 0.59726304, 0.58478117,
0.0066205109, 0.022211598, 0.040677428, 0.059117373,
0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735});
nd4j::ops::igamma op;
auto result = op.execute({&y, &x}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printBuffer("OUtput");
// exp.printBuffer("EXpect");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, IGamma_Test2) {
auto y = NDArrayFactory::create<double>('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,
7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1});
auto x = NDArrayFactory::create<double>('c', { 4}, {1.2, 2.2, 3.2, 4.2});
auto exp = NDArrayFactory::create<double>('c', {1,3,4}, {0.340083, 0.382421, 0.402737, 0.415221,
0.993379, 0.977788, 0.959323, 0.940883,
0.999996, 0.999914, 0.999564, 0.998773});
nd4j::ops::igammac op;
auto result = op.execute({&y, &x}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printBuffer("OUtput");
// exp.printBuffer("EXpect");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, range_test10) { TEST_F(DeclarableOpsTests10, range_test10) {