diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index b803bdb8d..1fbd06f2b 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -78,7 +78,9 @@ (28, LogicalXor) ,\ (29, LogicalNot) ,\ (30, LogicalAnd), \ - (31, DivideNoNan) + (31, DivideNoNan), \ + (32, IGamma), \ + (33, IGammac) // these ops return same data type as input #define TRANSFORM_SAME_OPS \ @@ -245,7 +247,9 @@ (43, TruncateMod) ,\ (44, SquaredReverseSubtract) ,\ (45, ReversePow), \ - (46, DivideNoNan) + (46, DivideNoNan), \ + (47, IGamma), \ + (48, IGammac) @@ -380,7 +384,9 @@ (35, AMinPairwise) ,\ (36, TruncateMod), \ (37, ReplaceNans), \ - (38, DivideNoNan) + (38, DivideNoNan), \ + (39, IGamma), \ + (40, IGammac) diff --git a/libnd4j/include/ops/BroadcastOpsTuple.h b/libnd4j/include/ops/BroadcastOpsTuple.h index c665a0abc..0450e50ab 100644 --- a/libnd4j/include/ops/BroadcastOpsTuple.h +++ b/libnd4j/include/ops/BroadcastOpsTuple.h @@ -49,6 +49,8 @@ namespace nd4j { static BroadcastOpsTuple DivideNoNan(); static BroadcastOpsTuple Multiply(); static BroadcastOpsTuple Subtract(); + static BroadcastOpsTuple IGamma(); + static BroadcastOpsTuple IGammac(); }; } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp new file mode 100644 index 000000000..6bd1c88ed --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_igamma) + +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(igamma, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + +// auto tZ = BroadcastHelper::broadcastApply({scalar::IGamma, pairwise::IGamma, broadcast::IGamma}, x, y, z); + auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGamma(), x, y, z); + + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); + } + + DECLARE_TYPES(igamma) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp new file mode 100644 index 000000000..89494dc4b --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_igammac) + +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(igammac, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + +// auto tZ = BroadcastHelper::broadcastApply({scalar::IGammac, pairwise::IGammac, broadcast::IGammac}, x, y, z); + auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGammac(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); + } + + DECLARE_TYPES(igammac) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/broadcastable.h b/libnd4j/include/ops/declarable/headers/broadcastable.h index b3b2463cd..7ee53b52a 100644 --- a/libnd4j/include/ops/declarable/headers/broadcastable.h +++ b/libnd4j/include/ops/declarable/headers/broadcastable.h @@ -357,6 +357,28 @@ namespace nd4j { #if NOT_EXCLUDED(OP_Pow) DECLARE_BROADCASTABLE_OP(Pow, 0, 0); #endif + + /** + * Broadcastable igamma implementation + * + * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } + * @tparam T + */ + #if NOT_EXCLUDED(OP_igamma) + DECLARE_BROADCASTABLE_OP(igamma, 0, 0); + #endif + /** + * Broadcastable igammac implementation + * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } + * @tparam T + */ + #if NOT_EXCLUDED(OP_igammac) + DECLARE_BROADCASTABLE_OP(igammac, 0, 0); + #endif } } diff --git a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp index ca408e8dc..0e9c99636 100644 --- a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp @@ -48,4 +48,11 @@ namespace nd4j { BroadcastOpsTuple BroadcastOpsTuple::Subtract() { return custom(nd4j::scalar::Subtract, nd4j::pairwise::Subtract, nd4j::broadcast::Subtract); } + BroadcastOpsTuple BroadcastOpsTuple::IGamma() { + return custom(nd4j::scalar::IGamma, nd4j::pairwise::IGamma, nd4j::broadcast::IGamma); + } + BroadcastOpsTuple BroadcastOpsTuple::IGammac() { + return custom(nd4j::scalar::IGammac, nd4j::pairwise::IGammac, nd4j::broadcast::IGammac); + } + } diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index a738f0bdc..132b58033 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -1482,6 +1482,52 @@ namespace simdOps { }; + template + class IGamma { + public: + no_op_exec_special + no_op_exec_special_cuda + + op_def static Z op(X d1, Z *params) { + return nd4j::math::nd4j_igamma(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { + return nd4j::math::nd4j_igamma(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return nd4j::math::nd4j_igamma(d1, d2); + } + + op_def static Z op(X d1) { + return d1; + } + }; + + template + class IGammac { + public: + no_op_exec_special + no_op_exec_special_cuda + + op_def static Z op(X d1, Z *params) { + return nd4j::math::nd4j_igammac(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { + return nd4j::math::nd4j_igammac(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return nd4j::math::nd4j_igammac(d1, d2); + } + + op_def static Z op(X d1) { + return d1; + } + }; + template class Round { public: diff --git a/libnd4j/include/templatemath.h b/libnd4j/include/templatemath.h index 908323369..d0af6c8ed 100644 --- a/libnd4j/include/templatemath.h +++ b/libnd4j/include/templatemath.h @@ -223,6 +223,12 @@ namespace nd4j { return nd4j_sgn(val); } + template + math_def inline Z nd4j_gamma(X a); + + template + math_def inline Z nd4j_lgamma(X x); + //#ifndef __CUDACC__ /* template<> @@ -656,9 +662,56 @@ namespace nd4j { return p_pow(static_cast(val), static_cast(val2)); } + /** + * LogGamma(a) - float point extension of ln(n!) + **/ + template + math_def inline Z nd4j_lgamma(X x) { +// if (x <= X(0.0)) +// { +// std::stringstream os; +// os << "Logarithm of Gamma has sence only for positive values, but " << x << " was given."; +// throw std::invalid_argument( os.str() ); +// } + + if (x < X(12.0)) { + return nd4j_log(nd4j_gamma(x)); + } + + // Abramowitz and Stegun 6.1.41 + // Asymptotic series should be good to at least 11 or 12 figures + // For error analysis, see Whittiker and Watson + // A Course in Modern Analysis (1927), page 252 + + static const double c[8] = { + 1.0/12.0, + -1.0/360.0, + 1.0/1260.0, + -1.0/1680.0, + 1.0/1188.0, + -691.0/360360.0, + 1.0/156.0, + -3617.0/122400.0 + }; + + double z = Z(1.0 / Z(x * x)); + double sum = c[7]; + + for (int i = 6; i >= 0; i--) { + sum *= z; + sum += c[i]; + } + + double series = sum / Z(x); + + static const double halfLogTwoPi = 0.91893853320467274178032973640562; + + return Z((double(x) - 0.5) * nd4j_log(x) - double(x) + halfLogTwoPi + series); + } - template + + template math_def inline T nd4j_re(T val1, T val2) { if (val1 == (T) 0.0f && val2 == (T) 0.0f) return (T) 0.0f; @@ -731,7 +784,127 @@ namespace nd4j { template math_def inline void nd4j_swap(T &val1, T &val2) { T temp = val1; val1=val2; val2=temp; - }; + }; + + template + math_def inline Z nd4j_gamma(X a) { +// nd4j_lgamma(a); +// return (Z)std::tgamma(a); + // Split the function domain into three intervals: + // (0, 0.001), [0.001, 12), and (12, infinity) + + /////////////////////////////////////////////////////////////////////////// + // First interval: (0, 0.001) + // + // For small a, 1/Gamma(a) has power series a + gamma a^2 - ... + // So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of a^3. + // The relative error over this interval is less than 6e-7. + + const double eulerGamma = 0.577215664901532860606512090; // Euler's gamma constant + + if (a < X(0.001)) + return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a))); + + /////////////////////////////////////////////////////////////////////////// + // Second interval: [0.001, 12) + + if (a < X(12.0)) { + // The algorithm directly approximates gamma over (1,2) and uses + // reduction identities to reduce other arguments to this interval. + + double y = (double)a; + int n = 0; + bool argWasLessThanOne = y < 1.0; + + // Add or subtract integers as necessary to bring y into (1,2) + // Will correct for this below + if (argWasLessThanOne) { + y += 1.0; + } + else { + n = static_cast(floor(y)) - 1; // will use n later + y -= n; + } + + // numerator coefficients for approximation over the interval (1,2) + static const double p[] = { + -1.71618513886549492533811E+0, + 2.47656508055759199108314E+1, + -3.79804256470945635097577E+2, + 6.29331155312818442661052E+2, + 8.66966202790413211295064E+2, + -3.14512729688483675254357E+4, + -3.61444134186911729807069E+4, + 6.64561438202405440627855E+4 + }; + + // denominator coefficients for approximation over the interval (1,2) + static const double q[] = { + -3.08402300119738975254353E+1, + 3.15350626979604161529144E+2, + -1.01515636749021914166146E+3, + -3.10777167157231109440444E+3, + 2.25381184209801510330112E+4, + 4.75584627752788110767815E+3, + -1.34659959864969306392456E+5, + -1.15132259675553483497211E+5 + }; + + double num = 0.0; + double den = 1.0; + + + double z = y - 1; + for (auto i = 0; i < 8; i++) { + num = (num + p[i]) * z; + den = den * z + q[i]; + } + double result = num / den + 1.0; + + // Apply correction if argument was not initially in (1,2) + if (argWasLessThanOne) { + // Use identity gamma(z) = gamma(z+1)/z + // The variable "result" now holds gamma of the original y + 1 + // Thus we use y-1 to get back the orginal y. + result /= (y - 1.0); + } + else { + // Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z) + for (auto i = 0; i < n; i++) + result *= y++; + } + + return Z(result); + } + + /////////////////////////////////////////////////////////////////////////// + // Third interval: [12, infinity) + + if (a > 171.624) { + // Correct answer too large to display. Force +infinity. + return Z(DOUBLE_MAX_VALUE); + //DataTypeUtils::infOrMax(); + } + + return nd4j::math::nd4j_exp(nd4j::math::nd4j_lgamma(a)); + } + + template + math_def inline Z nd4j_igamma(X a, Y x) { + Z aim = nd4j_pow(x, a) / (nd4j_exp(x) * nd4j_gamma(a)); + auto sum = Z(0.); + auto denom = Z(1.); + for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) { + denom *= (a + i); + sum += nd4j_pow(x, i) / denom; + } + return aim * sum; + } + + template + math_def inline Z nd4j_igammac(X a, Y x) { + return Z(1.) - nd4j_igamma(a, x); + } #ifdef __CUDACC__ namespace atomics { @@ -1473,4 +1646,4 @@ inline __device__ bfloat16 nd4j_atomicDiv(bfloat16* address, bfloat16 } -#endif /* TEMPLATEMATH_H_ */ \ No newline at end of file +#endif /* TEMPLATEMATH_H_ */ diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 0652a398e..25fe3429a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -537,6 +537,50 @@ TEST_F(DeclarableOpsTests10, atan2_test6) { delete result; } +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, IGamma_Test1) { + + auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); + auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); + + auto exp = NDArrayFactory::create('c', {1,3,4}, { + 0.659917, 0.61757898, 0.59726304, 0.58478117, + 0.0066205109, 0.022211598, 0.040677428, 0.059117373, + 0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735}); + + nd4j::ops::igamma op; + auto result = op.execute({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto z = result->at(0); +// z->printBuffer("OUtput"); +// exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, IGamma_Test2) { + + auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 , + 7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); + auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); + auto exp = NDArrayFactory::create('c', {1,3,4}, {0.340083, 0.382421, 0.402737, 0.415221, + 0.993379, 0.977788, 0.959323, 0.940883, + 0.999996, 0.999914, 0.999564, 0.998773}); + + nd4j::ops::igammac op; + auto result = op.execute({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto z = result->at(0); +// z->printBuffer("OUtput"); +// exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test10) {