commit
3662657d5c
|
@ -78,7 +78,9 @@
|
||||||
(28, LogicalXor) ,\
|
(28, LogicalXor) ,\
|
||||||
(29, LogicalNot) ,\
|
(29, LogicalNot) ,\
|
||||||
(30, LogicalAnd), \
|
(30, LogicalAnd), \
|
||||||
(31, DivideNoNan)
|
(31, DivideNoNan), \
|
||||||
|
(32, IGamma), \
|
||||||
|
(33, IGammac)
|
||||||
|
|
||||||
// these ops return same data type as input
|
// these ops return same data type as input
|
||||||
#define TRANSFORM_SAME_OPS \
|
#define TRANSFORM_SAME_OPS \
|
||||||
|
@ -245,7 +247,9 @@
|
||||||
(43, TruncateMod) ,\
|
(43, TruncateMod) ,\
|
||||||
(44, SquaredReverseSubtract) ,\
|
(44, SquaredReverseSubtract) ,\
|
||||||
(45, ReversePow), \
|
(45, ReversePow), \
|
||||||
(46, DivideNoNan)
|
(46, DivideNoNan), \
|
||||||
|
(47, IGamma), \
|
||||||
|
(48, IGammac)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -380,7 +384,9 @@
|
||||||
(35, AMinPairwise) ,\
|
(35, AMinPairwise) ,\
|
||||||
(36, TruncateMod), \
|
(36, TruncateMod), \
|
||||||
(37, ReplaceNans), \
|
(37, ReplaceNans), \
|
||||||
(38, DivideNoNan)
|
(38, DivideNoNan), \
|
||||||
|
(39, IGamma), \
|
||||||
|
(40, IGammac)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,8 @@ namespace nd4j {
|
||||||
static BroadcastOpsTuple DivideNoNan();
|
static BroadcastOpsTuple DivideNoNan();
|
||||||
static BroadcastOpsTuple Multiply();
|
static BroadcastOpsTuple Multiply();
|
||||||
static BroadcastOpsTuple Subtract();
|
static BroadcastOpsTuple Subtract();
|
||||||
|
static BroadcastOpsTuple IGamma();
|
||||||
|
static BroadcastOpsTuple IGammac();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_igamma)
|
||||||
|
|
||||||
|
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
BROADCASTABLE_OP_IMPL(igamma, 0, 0) {
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||||
|
|
||||||
|
//REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!");
|
||||||
|
|
||||||
|
// auto tZ = BroadcastHelper::broadcastApply({scalar::IGamma, pairwise::IGamma, broadcast::IGamma}, x, y, z);
|
||||||
|
auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGamma(), x, y, z);
|
||||||
|
|
||||||
|
if (tZ == nullptr)
|
||||||
|
return ND4J_STATUS_KERNEL_FAILURE;
|
||||||
|
else if (tZ != z) {
|
||||||
|
OVERWRITE_RESULT(tZ);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(igamma) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,58 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_igammac)
|
||||||
|
|
||||||
|
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
BROADCASTABLE_OP_IMPL(igammac, 0, 0) {
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||||
|
|
||||||
|
//REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!");
|
||||||
|
|
||||||
|
// auto tZ = BroadcastHelper::broadcastApply({scalar::IGammac, pairwise::IGammac, broadcast::IGammac}, x, y, z);
|
||||||
|
auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGammac(), x, y, z);
|
||||||
|
if (tZ == nullptr)
|
||||||
|
return ND4J_STATUS_KERNEL_FAILURE;
|
||||||
|
else if (tZ != z) {
|
||||||
|
OVERWRITE_RESULT(tZ);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(igammac) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -357,6 +357,28 @@ namespace nd4j {
|
||||||
#if NOT_EXCLUDED(OP_Pow)
|
#if NOT_EXCLUDED(OP_Pow)
|
||||||
DECLARE_BROADCASTABLE_OP(Pow, 0, 0);
|
DECLARE_BROADCASTABLE_OP(Pow, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Broadcastable igamma implementation
|
||||||
|
*
|
||||||
|
* igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x)
|
||||||
|
* Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt }
|
||||||
|
* gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt }
|
||||||
|
* @tparam T
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_igamma)
|
||||||
|
DECLARE_BROADCASTABLE_OP(igamma, 0, 0);
|
||||||
|
#endif
|
||||||
|
/**
|
||||||
|
* Broadcastable igammac implementation
|
||||||
|
* igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x)
|
||||||
|
* Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt }
|
||||||
|
* Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt }
|
||||||
|
* @tparam T
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_igammac)
|
||||||
|
DECLARE_BROADCASTABLE_OP(igammac, 0, 0);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,4 +48,11 @@ namespace nd4j {
|
||||||
BroadcastOpsTuple BroadcastOpsTuple::Subtract() {
|
BroadcastOpsTuple BroadcastOpsTuple::Subtract() {
|
||||||
return custom(nd4j::scalar::Subtract, nd4j::pairwise::Subtract, nd4j::broadcast::Subtract);
|
return custom(nd4j::scalar::Subtract, nd4j::pairwise::Subtract, nd4j::broadcast::Subtract);
|
||||||
}
|
}
|
||||||
|
BroadcastOpsTuple BroadcastOpsTuple::IGamma() {
|
||||||
|
return custom(nd4j::scalar::IGamma, nd4j::pairwise::IGamma, nd4j::broadcast::IGamma);
|
||||||
|
}
|
||||||
|
BroadcastOpsTuple BroadcastOpsTuple::IGammac() {
|
||||||
|
return custom(nd4j::scalar::IGammac, nd4j::pairwise::IGammac, nd4j::broadcast::IGammac);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1482,6 +1482,52 @@ namespace simdOps {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <typename X, typename Y, typename Z>
|
||||||
|
class IGamma {
|
||||||
|
public:
|
||||||
|
no_op_exec_special
|
||||||
|
no_op_exec_special_cuda
|
||||||
|
|
||||||
|
op_def static Z op(X d1, Z *params) {
|
||||||
|
return nd4j::math::nd4j_igamma<X, X, Z>(d1, params[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z op(X d1, Y d2) {
|
||||||
|
return nd4j::math::nd4j_igamma<X, Y, Z>(d1, d2);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z op(X d1, Y d2, Z *params) {
|
||||||
|
return nd4j::math::nd4j_igamma<X, Y, Z>(d1, d2);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z op(X d1) {
|
||||||
|
return d1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename X, typename Y, typename Z>
|
||||||
|
class IGammac {
|
||||||
|
public:
|
||||||
|
no_op_exec_special
|
||||||
|
no_op_exec_special_cuda
|
||||||
|
|
||||||
|
op_def static Z op(X d1, Z *params) {
|
||||||
|
return nd4j::math::nd4j_igammac<X, X, Z>(d1, params[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z op(X d1, Y d2) {
|
||||||
|
return nd4j::math::nd4j_igammac<X, Y, Z>(d1, d2);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z op(X d1, Y d2, Z *params) {
|
||||||
|
return nd4j::math::nd4j_igammac<X, Y, Z>(d1, d2);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z op(X d1) {
|
||||||
|
return d1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename X>
|
template <typename X>
|
||||||
class Round {
|
class Round {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -223,6 +223,12 @@ namespace nd4j {
|
||||||
return nd4j_sgn<T, Z>(val);
|
return nd4j_sgn<T, Z>(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
math_def inline Z nd4j_gamma(X a);
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
math_def inline Z nd4j_lgamma(X x);
|
||||||
|
|
||||||
//#ifndef __CUDACC__
|
//#ifndef __CUDACC__
|
||||||
/*
|
/*
|
||||||
template<>
|
template<>
|
||||||
|
@ -656,6 +662,53 @@ namespace nd4j {
|
||||||
return p_pow<Z>(static_cast<Z>(val), static_cast<Z>(val2));
|
return p_pow<Z>(static_cast<Z>(val), static_cast<Z>(val2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LogGamma(a) - float point extension of ln(n!)
|
||||||
|
**/
|
||||||
|
template <typename X, typename Z>
|
||||||
|
math_def inline Z nd4j_lgamma(X x) {
|
||||||
|
// if (x <= X(0.0))
|
||||||
|
// {
|
||||||
|
// std::stringstream os;
|
||||||
|
// os << "Logarithm of Gamma has sence only for positive values, but " << x << " was given.";
|
||||||
|
// throw std::invalid_argument( os.str() );
|
||||||
|
// }
|
||||||
|
|
||||||
|
if (x < X(12.0)) {
|
||||||
|
return nd4j_log<Z,Z>(nd4j_gamma<X,Z>(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abramowitz and Stegun 6.1.41
|
||||||
|
// Asymptotic series should be good to at least 11 or 12 figures
|
||||||
|
// For error analysis, see Whittiker and Watson
|
||||||
|
// A Course in Modern Analysis (1927), page 252
|
||||||
|
|
||||||
|
static const double c[8] = {
|
||||||
|
1.0/12.0,
|
||||||
|
-1.0/360.0,
|
||||||
|
1.0/1260.0,
|
||||||
|
-1.0/1680.0,
|
||||||
|
1.0/1188.0,
|
||||||
|
-691.0/360360.0,
|
||||||
|
1.0/156.0,
|
||||||
|
-3617.0/122400.0
|
||||||
|
};
|
||||||
|
|
||||||
|
double z = Z(1.0 / Z(x * x));
|
||||||
|
double sum = c[7];
|
||||||
|
|
||||||
|
for (int i = 6; i >= 0; i--) {
|
||||||
|
sum *= z;
|
||||||
|
sum += c[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
double series = sum / Z(x);
|
||||||
|
|
||||||
|
static const double halfLogTwoPi = 0.91893853320467274178032973640562;
|
||||||
|
|
||||||
|
return Z((double(x) - 0.5) * nd4j_log<X,double>(x) - double(x) + halfLogTwoPi + series);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
@ -733,6 +786,126 @@ namespace nd4j {
|
||||||
T temp = val1; val1=val2; val2=temp;
|
T temp = val1; val1=val2; val2=temp;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename X, typename Z>
|
||||||
|
math_def inline Z nd4j_gamma(X a) {
|
||||||
|
// nd4j_lgamma<X,Z>(a);
|
||||||
|
// return (Z)std::tgamma(a);
|
||||||
|
// Split the function domain into three intervals:
|
||||||
|
// (0, 0.001), [0.001, 12), and (12, infinity)
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
// First interval: (0, 0.001)
|
||||||
|
//
|
||||||
|
// For small a, 1/Gamma(a) has power series a + gamma a^2 - ...
|
||||||
|
// So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of a^3.
|
||||||
|
// The relative error over this interval is less than 6e-7.
|
||||||
|
|
||||||
|
const double eulerGamma = 0.577215664901532860606512090; // Euler's gamma constant
|
||||||
|
|
||||||
|
if (a < X(0.001))
|
||||||
|
return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a)));
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
// Second interval: [0.001, 12)
|
||||||
|
|
||||||
|
if (a < X(12.0)) {
|
||||||
|
// The algorithm directly approximates gamma over (1,2) and uses
|
||||||
|
// reduction identities to reduce other arguments to this interval.
|
||||||
|
|
||||||
|
double y = (double)a;
|
||||||
|
int n = 0;
|
||||||
|
bool argWasLessThanOne = y < 1.0;
|
||||||
|
|
||||||
|
// Add or subtract integers as necessary to bring y into (1,2)
|
||||||
|
// Will correct for this below
|
||||||
|
if (argWasLessThanOne) {
|
||||||
|
y += 1.0;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
n = static_cast<int>(floor(y)) - 1; // will use n later
|
||||||
|
y -= n;
|
||||||
|
}
|
||||||
|
|
||||||
|
// numerator coefficients for approximation over the interval (1,2)
|
||||||
|
static const double p[] = {
|
||||||
|
-1.71618513886549492533811E+0,
|
||||||
|
2.47656508055759199108314E+1,
|
||||||
|
-3.79804256470945635097577E+2,
|
||||||
|
6.29331155312818442661052E+2,
|
||||||
|
8.66966202790413211295064E+2,
|
||||||
|
-3.14512729688483675254357E+4,
|
||||||
|
-3.61444134186911729807069E+4,
|
||||||
|
6.64561438202405440627855E+4
|
||||||
|
};
|
||||||
|
|
||||||
|
// denominator coefficients for approximation over the interval (1,2)
|
||||||
|
static const double q[] = {
|
||||||
|
-3.08402300119738975254353E+1,
|
||||||
|
3.15350626979604161529144E+2,
|
||||||
|
-1.01515636749021914166146E+3,
|
||||||
|
-3.10777167157231109440444E+3,
|
||||||
|
2.25381184209801510330112E+4,
|
||||||
|
4.75584627752788110767815E+3,
|
||||||
|
-1.34659959864969306392456E+5,
|
||||||
|
-1.15132259675553483497211E+5
|
||||||
|
};
|
||||||
|
|
||||||
|
double num = 0.0;
|
||||||
|
double den = 1.0;
|
||||||
|
|
||||||
|
|
||||||
|
double z = y - 1;
|
||||||
|
for (auto i = 0; i < 8; i++) {
|
||||||
|
num = (num + p[i]) * z;
|
||||||
|
den = den * z + q[i];
|
||||||
|
}
|
||||||
|
double result = num / den + 1.0;
|
||||||
|
|
||||||
|
// Apply correction if argument was not initially in (1,2)
|
||||||
|
if (argWasLessThanOne) {
|
||||||
|
// Use identity gamma(z) = gamma(z+1)/z
|
||||||
|
// The variable "result" now holds gamma of the original y + 1
|
||||||
|
// Thus we use y-1 to get back the orginal y.
|
||||||
|
result /= (y - 1.0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z)
|
||||||
|
for (auto i = 0; i < n; i++)
|
||||||
|
result *= y++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Z(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
// Third interval: [12, infinity)
|
||||||
|
|
||||||
|
if (a > 171.624) {
|
||||||
|
// Correct answer too large to display. Force +infinity.
|
||||||
|
return Z(DOUBLE_MAX_VALUE);
|
||||||
|
//DataTypeUtils::infOrMax<Z>();
|
||||||
|
}
|
||||||
|
|
||||||
|
return nd4j::math::nd4j_exp<Z,Z>(nd4j::math::nd4j_lgamma<X,Z>(a));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename X, typename Y, typename Z>
|
||||||
|
math_def inline Z nd4j_igamma(X a, Y x) {
|
||||||
|
Z aim = nd4j_pow<X, X, Z>(x, a) / (nd4j_exp<X, Z>(x) * nd4j_gamma<Y, Z>(a));
|
||||||
|
auto sum = Z(0.);
|
||||||
|
auto denom = Z(1.);
|
||||||
|
for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) {
|
||||||
|
denom *= (a + i);
|
||||||
|
sum += nd4j_pow<X, int, Z>(x, i) / denom;
|
||||||
|
}
|
||||||
|
return aim * sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename X, typename Y, typename Z>
|
||||||
|
math_def inline Z nd4j_igammac(X a, Y x) {
|
||||||
|
return Z(1.) - nd4j_igamma<X, Y, Z>(a, x);
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef __CUDACC__
|
#ifdef __CUDACC__
|
||||||
namespace atomics {
|
namespace atomics {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -537,6 +537,50 @@ TEST_F(DeclarableOpsTests10, atan2_test6) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, IGamma_Test1) {
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1});
|
||||||
|
auto x = NDArrayFactory::create<double>('c', { 4}, {1.2, 2.2, 3.2, 4.2});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {1,3,4}, {
|
||||||
|
0.659917, 0.61757898, 0.59726304, 0.58478117,
|
||||||
|
0.0066205109, 0.022211598, 0.040677428, 0.059117373,
|
||||||
|
0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735});
|
||||||
|
|
||||||
|
nd4j::ops::igamma op;
|
||||||
|
auto result = op.execute({&y, &x}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printBuffer("OUtput");
|
||||||
|
// exp.printBuffer("EXpect");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, IGamma_Test2) {
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,
|
||||||
|
7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1});
|
||||||
|
auto x = NDArrayFactory::create<double>('c', { 4}, {1.2, 2.2, 3.2, 4.2});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {1,3,4}, {0.340083, 0.382421, 0.402737, 0.415221,
|
||||||
|
0.993379, 0.977788, 0.959323, 0.940883,
|
||||||
|
0.999996, 0.999914, 0.999564, 0.998773});
|
||||||
|
|
||||||
|
nd4j::ops::igammac op;
|
||||||
|
auto result = op.execute({&y, &x}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printBuffer("OUtput");
|
||||||
|
// exp.printBuffer("EXpect");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, range_test10) {
|
TEST_F(DeclarableOpsTests10, range_test10) {
|
||||||
|
|
Loading…
Reference in New Issue