commit
						3662657d5c
					
				| @ -78,7 +78,9 @@ | |||||||
|        (28, LogicalXor) ,\ |        (28, LogicalXor) ,\ | ||||||
|        (29, LogicalNot) ,\ |        (29, LogicalNot) ,\ | ||||||
|        (30, LogicalAnd), \ |        (30, LogicalAnd), \ | ||||||
|        (31, DivideNoNan) |        (31, DivideNoNan), \ | ||||||
|  |        (32, IGamma), \ | ||||||
|  |        (33, IGammac) | ||||||
| 
 | 
 | ||||||
| // these ops return same data type as input
 | // these ops return same data type as input
 | ||||||
| #define TRANSFORM_SAME_OPS \ | #define TRANSFORM_SAME_OPS \ | ||||||
| @ -245,7 +247,9 @@ | |||||||
|         (43, TruncateMod) ,\ |         (43, TruncateMod) ,\ | ||||||
|         (44, SquaredReverseSubtract) ,\ |         (44, SquaredReverseSubtract) ,\ | ||||||
|         (45, ReversePow), \ |         (45, ReversePow), \ | ||||||
|         (46, DivideNoNan) |         (46, DivideNoNan), \ | ||||||
|  |         (47, IGamma), \ | ||||||
|  |         (48, IGammac) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -380,7 +384,9 @@ | |||||||
|         (35, AMinPairwise) ,\ |         (35, AMinPairwise) ,\ | ||||||
|         (36, TruncateMod), \ |         (36, TruncateMod), \ | ||||||
|         (37, ReplaceNans), \ |         (37, ReplaceNans), \ | ||||||
|         (38, DivideNoNan) |         (38, DivideNoNan), \ | ||||||
|  |         (39, IGamma), \ | ||||||
|  |         (40, IGammac) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -49,6 +49,8 @@ namespace nd4j { | |||||||
|         static BroadcastOpsTuple DivideNoNan(); |         static BroadcastOpsTuple DivideNoNan(); | ||||||
|         static BroadcastOpsTuple Multiply(); |         static BroadcastOpsTuple Multiply(); | ||||||
|         static BroadcastOpsTuple Subtract(); |         static BroadcastOpsTuple Subtract(); | ||||||
|  |         static BroadcastOpsTuple IGamma(); | ||||||
|  |         static BroadcastOpsTuple IGammac(); | ||||||
|     }; |     }; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -0,0 +1,59 @@ | |||||||
|  | /*******************************************************************************
 | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | //
 | ||||||
|  | // @author sgazeos@gmail.com
 | ||||||
|  | //
 | ||||||
|  | 
 | ||||||
|  | #include <op_boilerplate.h> | ||||||
|  | #if NOT_EXCLUDED(OP_igamma) | ||||||
|  | 
 | ||||||
|  | #include <ops/declarable/generic/helpers/BroadcastHelper.h> | ||||||
|  | #include <ops/declarable/CustomOperations.h> | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  |     namespace ops { | ||||||
|  |         BROADCASTABLE_OP_IMPL(igamma, 0, 0) { | ||||||
|  |             auto x = INPUT_VARIABLE(0); | ||||||
|  |             auto y = INPUT_VARIABLE(1); | ||||||
|  |             auto z = OUTPUT_VARIABLE(0); | ||||||
|  | 
 | ||||||
|  |             BROADCAST_CHECK_EMPTY(x,y,z); | ||||||
|  | 
 | ||||||
|  |             //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!");
 | ||||||
|  | 
 | ||||||
|  | //            auto tZ = BroadcastHelper::broadcastApply({scalar::IGamma, pairwise::IGamma, broadcast::IGamma}, x, y, z);
 | ||||||
|  |             auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGamma(), x, y, z); | ||||||
|  | 
 | ||||||
|  |             if (tZ == nullptr) | ||||||
|  |                 return ND4J_STATUS_KERNEL_FAILURE; | ||||||
|  |             else if (tZ != z) { | ||||||
|  |                 OVERWRITE_RESULT(tZ); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             return Status::OK(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         DECLARE_TYPES(igamma) { | ||||||
|  |             getOpDescriptor() | ||||||
|  |                 ->setAllowedInputTypes(0, {ALL_FLOATS}) | ||||||
|  |                 ->setAllowedInputTypes(1, {ALL_FLOATS}) | ||||||
|  |                 ->setAllowedOutputTypes(0, {ALL_FLOATS}); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #endif | ||||||
| @ -0,0 +1,58 @@ | |||||||
|  | /*******************************************************************************
 | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | //
 | ||||||
|  | // @author sgazeos@gmail.com
 | ||||||
|  | //
 | ||||||
|  | 
 | ||||||
|  | #include <op_boilerplate.h> | ||||||
|  | #if NOT_EXCLUDED(OP_igammac) | ||||||
|  | 
 | ||||||
|  | #include <ops/declarable/generic/helpers/BroadcastHelper.h> | ||||||
|  | #include <ops/declarable/CustomOperations.h> | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  |     namespace ops { | ||||||
|  |         BROADCASTABLE_OP_IMPL(igammac, 0, 0) { | ||||||
|  |             auto x = INPUT_VARIABLE(0); | ||||||
|  |             auto y = INPUT_VARIABLE(1); | ||||||
|  |             auto z = OUTPUT_VARIABLE(0); | ||||||
|  | 
 | ||||||
|  |             BROADCAST_CHECK_EMPTY(x,y,z); | ||||||
|  | 
 | ||||||
|  |             //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!");
 | ||||||
|  | 
 | ||||||
|  | //            auto tZ = BroadcastHelper::broadcastApply({scalar::IGammac, pairwise::IGammac, broadcast::IGammac}, x, y, z);
 | ||||||
|  |             auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGammac(), x, y, z); | ||||||
|  |             if (tZ == nullptr) | ||||||
|  |                 return ND4J_STATUS_KERNEL_FAILURE; | ||||||
|  |             else if (tZ != z) { | ||||||
|  |                 OVERWRITE_RESULT(tZ); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             return Status::OK(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         DECLARE_TYPES(igammac) { | ||||||
|  |             getOpDescriptor() | ||||||
|  |                 ->setAllowedInputTypes(0, {ALL_FLOATS}) | ||||||
|  |                 ->setAllowedInputTypes(1, {ALL_FLOATS}) | ||||||
|  |                 ->setAllowedOutputTypes(0, {ALL_FLOATS}); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #endif | ||||||
| @ -357,6 +357,28 @@ namespace nd4j { | |||||||
|         #if NOT_EXCLUDED(OP_Pow) |         #if NOT_EXCLUDED(OP_Pow) | ||||||
|         DECLARE_BROADCASTABLE_OP(Pow, 0, 0); |         DECLARE_BROADCASTABLE_OP(Pow, 0, 0); | ||||||
|         #endif |         #endif | ||||||
|  | 
 | ||||||
|  |         /**
 | ||||||
|  |          * Broadcastable igamma implementation | ||||||
|  |          * | ||||||
|  |          * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) | ||||||
|  |          * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } | ||||||
|  |          * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } | ||||||
|  |          * @tparam T | ||||||
|  |          */ | ||||||
|  |         #if NOT_EXCLUDED(OP_igamma) | ||||||
|  |                 DECLARE_BROADCASTABLE_OP(igamma, 0, 0); | ||||||
|  |         #endif | ||||||
|  |         /**
 | ||||||
|  |          * Broadcastable igammac implementation | ||||||
|  |          * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) | ||||||
|  |          * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } | ||||||
|  |          * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } | ||||||
|  |          * @tparam T | ||||||
|  |          */ | ||||||
|  |         #if NOT_EXCLUDED(OP_igammac) | ||||||
|  |                 DECLARE_BROADCASTABLE_OP(igammac, 0, 0); | ||||||
|  |         #endif | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -48,4 +48,11 @@ namespace nd4j { | |||||||
|     BroadcastOpsTuple BroadcastOpsTuple::Subtract() { |     BroadcastOpsTuple BroadcastOpsTuple::Subtract() { | ||||||
|         return custom(nd4j::scalar::Subtract, nd4j::pairwise::Subtract, nd4j::broadcast::Subtract); |         return custom(nd4j::scalar::Subtract, nd4j::pairwise::Subtract, nd4j::broadcast::Subtract); | ||||||
|     } |     } | ||||||
|  |     BroadcastOpsTuple BroadcastOpsTuple::IGamma() { | ||||||
|  |         return custom(nd4j::scalar::IGamma, nd4j::pairwise::IGamma, nd4j::broadcast::IGamma); | ||||||
|  |     } | ||||||
|  |     BroadcastOpsTuple BroadcastOpsTuple::IGammac() { | ||||||
|  |         return custom(nd4j::scalar::IGammac, nd4j::pairwise::IGammac, nd4j::broadcast::IGammac); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
| } | } | ||||||
|  | |||||||
| @ -1482,6 +1482,52 @@ namespace simdOps { | |||||||
| 	}; | 	}; | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | 	template <typename X, typename Y, typename Z> | ||||||
|  | 	class IGamma { | ||||||
|  | 	public: | ||||||
|  | 		no_op_exec_special | ||||||
|  | 		no_op_exec_special_cuda | ||||||
|  | 
 | ||||||
|  | 		op_def static Z op(X d1, Z *params) { | ||||||
|  | 			return nd4j::math::nd4j_igamma<X, X, Z>(d1, params[0]); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		op_def static Z op(X d1, Y d2) { | ||||||
|  | 			return nd4j::math::nd4j_igamma<X, Y, Z>(d1, d2); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		op_def static Z op(X d1, Y d2, Z *params) { | ||||||
|  | 			return nd4j::math::nd4j_igamma<X, Y, Z>(d1, d2); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		op_def static Z op(X d1) { | ||||||
|  | 			return d1; | ||||||
|  | 		} | ||||||
|  | 	}; | ||||||
|  | 
 | ||||||
|  | 	template <typename X, typename Y, typename Z> | ||||||
|  | 	class IGammac { | ||||||
|  | 	public: | ||||||
|  | 		no_op_exec_special | ||||||
|  | 		no_op_exec_special_cuda | ||||||
|  | 
 | ||||||
|  | 		op_def static Z op(X d1, Z *params) { | ||||||
|  | 			return nd4j::math::nd4j_igammac<X, X, Z>(d1, params[0]); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		op_def static Z op(X d1, Y d2) { | ||||||
|  | 			return nd4j::math::nd4j_igammac<X, Y, Z>(d1, d2); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		op_def static Z op(X d1, Y d2, Z *params) { | ||||||
|  | 			return nd4j::math::nd4j_igammac<X, Y, Z>(d1, d2); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		op_def static Z op(X d1) { | ||||||
|  | 			return d1; | ||||||
|  | 		} | ||||||
|  | 	}; | ||||||
|  | 
 | ||||||
| 	template <typename X> | 	template <typename X> | ||||||
| 	class Round { | 	class Round { | ||||||
| 	public: | 	public: | ||||||
|  | |||||||
| @ -223,6 +223,12 @@ namespace nd4j { | |||||||
|             return nd4j_sgn<T, Z>(val); |             return nd4j_sgn<T, Z>(val); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         template<typename X, typename Z> | ||||||
|  |         math_def inline Z nd4j_gamma(X a); | ||||||
|  | 
 | ||||||
|  |         template<typename X, typename Z> | ||||||
|  |         math_def inline Z nd4j_lgamma(X x); | ||||||
|  | 
 | ||||||
| //#ifndef __CUDACC__
 | //#ifndef __CUDACC__
 | ||||||
| /*
 | /*
 | ||||||
|         template<> |         template<> | ||||||
| @ -656,9 +662,56 @@ namespace nd4j { | |||||||
|             return p_pow<Z>(static_cast<Z>(val), static_cast<Z>(val2)); |             return p_pow<Z>(static_cast<Z>(val), static_cast<Z>(val2)); | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  |         /**
 | ||||||
|  |          * LogGamma(a) - float point extension of ln(n!) | ||||||
|  |          **/ | ||||||
|  |         template <typename X, typename Z> | ||||||
|  |         math_def inline Z nd4j_lgamma(X x) { | ||||||
|  | //            if (x <= X(0.0))
 | ||||||
|  | //            {
 | ||||||
|  | //                std::stringstream os;
 | ||||||
|  | //                os << "Logarithm of Gamma has sence only for positive values, but " << x <<  " was given.";
 | ||||||
|  | //                throw std::invalid_argument( os.str() );
 | ||||||
|  | //            }
 | ||||||
|  | 
 | ||||||
|  |             if (x < X(12.0)) { | ||||||
|  |                 return nd4j_log<Z,Z>(nd4j_gamma<X,Z>(x)); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Abramowitz and Stegun 6.1.41
 | ||||||
|  |             // Asymptotic series should be good to at least 11 or 12 figures
 | ||||||
|  |             // For error analysis, see Whittiker and Watson
 | ||||||
|  |             // A Course in Modern Analysis (1927), page 252
 | ||||||
|  | 
 | ||||||
|  |             static const double c[8] = { | ||||||
|  |                     1.0/12.0, | ||||||
|  |                     -1.0/360.0, | ||||||
|  |                     1.0/1260.0, | ||||||
|  |                     -1.0/1680.0, | ||||||
|  |                     1.0/1188.0, | ||||||
|  |                     -691.0/360360.0, | ||||||
|  |                     1.0/156.0, | ||||||
|  |                     -3617.0/122400.0 | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             double z = Z(1.0 / Z(x * x)); | ||||||
|  |             double sum = c[7]; | ||||||
|  | 
 | ||||||
|  |             for (int i = 6; i >= 0; i--) { | ||||||
|  |                 sum *= z; | ||||||
|  |                 sum += c[i]; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             double series = sum / Z(x); | ||||||
|  | 
 | ||||||
|  |             static const double halfLogTwoPi = 0.91893853320467274178032973640562; | ||||||
|  | 
 | ||||||
|  |             return Z((double(x) - 0.5) * nd4j_log<X,double>(x) - double(x) + halfLogTwoPi + series); | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 		template<typename T> | 
 | ||||||
|  |         template<typename T> | ||||||
| 		math_def inline T nd4j_re(T val1, T val2) { | 		math_def inline T nd4j_re(T val1, T val2) { | ||||||
| 			if (val1 == (T) 0.0f && val2 == (T) 0.0f) | 			if (val1 == (T) 0.0f && val2 == (T) 0.0f) | ||||||
| 				return (T) 0.0f; | 				return (T) 0.0f; | ||||||
| @ -731,7 +784,127 @@ namespace nd4j { | |||||||
|         template<typename T> |         template<typename T> | ||||||
|         math_def inline void nd4j_swap(T &val1, T &val2) { |         math_def inline void nd4j_swap(T &val1, T &val2) { | ||||||
|             T temp = val1; val1=val2; val2=temp; |             T temp = val1; val1=val2; val2=temp; | ||||||
| 		}; |         }; | ||||||
|  | 
 | ||||||
|  |         template <typename X, typename Z> | ||||||
|  |         math_def inline Z nd4j_gamma(X a) { | ||||||
|  | //            nd4j_lgamma<X,Z>(a);
 | ||||||
|  | //            return (Z)std::tgamma(a);
 | ||||||
|  |             // Split the function domain into three intervals:
 | ||||||
|  |             // (0, 0.001), [0.001, 12), and (12, infinity)
 | ||||||
|  | 
 | ||||||
|  |             ///////////////////////////////////////////////////////////////////////////
 | ||||||
|  |             // First interval: (0, 0.001)
 | ||||||
|  |             //
 | ||||||
|  |             // For small a, 1/Gamma(a) has power series a + gamma a^2  - ...
 | ||||||
|  |             // So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of a^3.
 | ||||||
|  |             // The relative error over this interval is less than 6e-7.
 | ||||||
|  | 
 | ||||||
|  |             const double eulerGamma = 0.577215664901532860606512090; // Euler's gamma constant
 | ||||||
|  | 
 | ||||||
|  |             if (a < X(0.001)) | ||||||
|  |                 return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a))); | ||||||
|  | 
 | ||||||
|  |             ///////////////////////////////////////////////////////////////////////////
 | ||||||
|  |             // Second interval: [0.001, 12)
 | ||||||
|  | 
 | ||||||
|  |             if (a < X(12.0)) { | ||||||
|  |                 // The algorithm directly approximates gamma over (1,2) and uses
 | ||||||
|  |                 // reduction identities to reduce other arguments to this interval.
 | ||||||
|  | 
 | ||||||
|  |                 double y = (double)a; | ||||||
|  |                 int n = 0; | ||||||
|  |                 bool argWasLessThanOne = y < 1.0; | ||||||
|  | 
 | ||||||
|  |                 // Add or subtract integers as necessary to bring y into (1,2)
 | ||||||
|  |                 // Will correct for this below
 | ||||||
|  |                 if (argWasLessThanOne) { | ||||||
|  |                     y += 1.0; | ||||||
|  |                 } | ||||||
|  |                 else { | ||||||
|  |                     n = static_cast<int>(floor(y)) - 1;  // will use n later
 | ||||||
|  |                     y -= n; | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 // numerator coefficients for approximation over the interval (1,2)
 | ||||||
|  |                 static const double p[] = { | ||||||
|  |                         -1.71618513886549492533811E+0, | ||||||
|  |                         2.47656508055759199108314E+1, | ||||||
|  |                         -3.79804256470945635097577E+2, | ||||||
|  |                         6.29331155312818442661052E+2, | ||||||
|  |                         8.66966202790413211295064E+2, | ||||||
|  |                         -3.14512729688483675254357E+4, | ||||||
|  |                         -3.61444134186911729807069E+4, | ||||||
|  |                         6.64561438202405440627855E+4 | ||||||
|  |                 }; | ||||||
|  | 
 | ||||||
|  |                 // denominator coefficients for approximation over the interval (1,2)
 | ||||||
|  |                 static const double q[] = { | ||||||
|  |                         -3.08402300119738975254353E+1, | ||||||
|  |                         3.15350626979604161529144E+2, | ||||||
|  |                         -1.01515636749021914166146E+3, | ||||||
|  |                         -3.10777167157231109440444E+3, | ||||||
|  |                         2.25381184209801510330112E+4, | ||||||
|  |                         4.75584627752788110767815E+3, | ||||||
|  |                         -1.34659959864969306392456E+5, | ||||||
|  |                         -1.15132259675553483497211E+5 | ||||||
|  |                 }; | ||||||
|  | 
 | ||||||
|  |                 double num = 0.0; | ||||||
|  |                 double den = 1.0; | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |                 double z = y - 1; | ||||||
|  |                 for (auto i = 0; i < 8; i++) { | ||||||
|  |                     num = (num + p[i]) * z; | ||||||
|  |                     den = den * z + q[i]; | ||||||
|  |                 } | ||||||
|  |                 double result = num / den + 1.0; | ||||||
|  | 
 | ||||||
|  |                 // Apply correction if argument was not initially in (1,2)
 | ||||||
|  |                 if (argWasLessThanOne) { | ||||||
|  |                     // Use identity gamma(z) = gamma(z+1)/z
 | ||||||
|  |                     // The variable "result" now holds gamma of the original y + 1
 | ||||||
|  |                     // Thus we use y-1 to get back the orginal y.
 | ||||||
|  |                     result /= (y - 1.0); | ||||||
|  |                 } | ||||||
|  |                 else { | ||||||
|  |                     // Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z)
 | ||||||
|  |                     for (auto i = 0; i < n; i++) | ||||||
|  |                         result *= y++; | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 return Z(result); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             ///////////////////////////////////////////////////////////////////////////
 | ||||||
|  |             // Third interval: [12, infinity)
 | ||||||
|  | 
 | ||||||
|  |             if (a > 171.624) { | ||||||
|  |                 // Correct answer too large to display. Force +infinity.
 | ||||||
|  |                 return Z(DOUBLE_MAX_VALUE); | ||||||
|  |                 //DataTypeUtils::infOrMax<Z>();
 | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             return nd4j::math::nd4j_exp<Z,Z>(nd4j::math::nd4j_lgamma<X,Z>(a)); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         template <typename X, typename Y, typename Z> | ||||||
|  |         math_def inline Z nd4j_igamma(X a, Y x) { | ||||||
|  |             Z aim = nd4j_pow<X, X, Z>(x, a) / (nd4j_exp<X, Z>(x) * nd4j_gamma<Y, Z>(a)); | ||||||
|  |             auto sum = Z(0.); | ||||||
|  |             auto denom = Z(1.); | ||||||
|  |             for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) { | ||||||
|  |                 denom *= (a + i); | ||||||
|  |                 sum += nd4j_pow<X, int, Z>(x, i) / denom; | ||||||
|  |             } | ||||||
|  |             return aim * sum; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         template <typename X, typename Y, typename Z> | ||||||
|  |         math_def inline Z nd4j_igammac(X a, Y x) { | ||||||
|  |             return Z(1.) - nd4j_igamma<X, Y, Z>(a, x); | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
| #ifdef __CUDACC__ | #ifdef __CUDACC__ | ||||||
| 		namespace atomics { | 		namespace atomics { | ||||||
| @ -1473,4 +1646,4 @@ inline __device__ bfloat16 nd4j_atomicDiv<bfloat16>(bfloat16* address, bfloat16 | |||||||
| 
 | 
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #endif /* TEMPLATEMATH_H_ */ | #endif /* TEMPLATEMATH_H_ */ | ||||||
|  | |||||||
| @ -537,6 +537,50 @@ TEST_F(DeclarableOpsTests10, atan2_test6) { | |||||||
|     delete result; |     delete result; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | //////////////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests10, IGamma_Test1) { | ||||||
|  | 
 | ||||||
|  |     auto y = NDArrayFactory::create<double>('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); | ||||||
|  |     auto x = NDArrayFactory::create<double>('c', {      4}, {1.2, 2.2, 3.2, 4.2}); | ||||||
|  | 
 | ||||||
|  |     auto exp = NDArrayFactory::create<double>('c', {1,3,4}, { | ||||||
|  |                0.659917,     0.61757898,  0.59726304,   0.58478117, | ||||||
|  |            0.0066205109,    0.022211598, 0.040677428,  0.059117373, | ||||||
|  |         0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735}); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::igamma op; | ||||||
|  |     auto result = op.execute({&y, &x}, {}, {}, {}); | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, result->status()); | ||||||
|  |     auto z = result->at(0); | ||||||
|  | //    z->printBuffer("OUtput");
 | ||||||
|  | //    exp.printBuffer("EXpect");
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(z)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(z)); | ||||||
|  | 
 | ||||||
|  |     delete result; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //////////////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests10, IGamma_Test2) { | ||||||
|  | 
 | ||||||
|  |     auto y = NDArrayFactory::create<double>('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 , | ||||||
|  |                                                              7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); | ||||||
|  |     auto x = NDArrayFactory::create<double>('c', {      4}, {1.2, 2.2, 3.2, 4.2}); | ||||||
|  |     auto exp = NDArrayFactory::create<double>('c', {1,3,4}, {0.340083, 0.382421, 0.402737, 0.415221, | ||||||
|  |                                                              0.993379, 0.977788, 0.959323, 0.940883, | ||||||
|  |                                                              0.999996, 0.999914, 0.999564, 0.998773}); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::igammac op; | ||||||
|  |     auto result = op.execute({&y, &x}, {}, {}, {}); | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, result->status()); | ||||||
|  |     auto z = result->at(0); | ||||||
|  | //    z->printBuffer("OUtput");
 | ||||||
|  | //    exp.printBuffer("EXpect");
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(z)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(z)); | ||||||
|  | 
 | ||||||
|  |     delete result; | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(DeclarableOpsTests10, range_test10) { | TEST_F(DeclarableOpsTests10, range_test10) { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user