cavis/libnd4j/include/ops/declarable/helpers/zeta.h

103 lines
2.8 KiB
C++

/* ******************************************************************************
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by Yurii Shyrma on 12.12.2017.
//
#ifndef LIBND4J_ZETA_H
#define LIBND4J_ZETA_H
#include <ops/declarable/helpers/helpers.h>
#include "array/NDArray.h"
namespace sd {
namespace ops {
namespace helpers {
// calculate the Hurwitz zeta function for arrays
void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& output);
// calculate the Hurwitz zeta function for scalars
// fast implementation, it is based on Euler-Maclaurin summation formula
template <typename T>
_CUDA_HD T zetaScalar(const T x, const T q) {
const T machep = 1.11022302462515654042e-16;
// FIXME: @raver119
// expansion coeffZetaicients for Euler-Maclaurin summation formula (2k)! / B2k, where B2k are Bernoulli numbers
const T coeffZeta[] = { 12.0,-720.0,30240.0,-1209600.0,47900160.0,-1.8924375803183791606e9,7.47242496e10,-2.950130727918164224e12, 1.1646782814350067249e14, -4.5979787224074726105e15, 1.8152105401943546773e17, -7.1661652561756670113e18};
// if (x <= (T)1.)
// throw("zeta function: x must be > 1 !");
// if (q <= (T)0.)
// throw("zeta function: q must be > 0 !");
T a, b(0.), k, s, t, w;
s = math::nd4j_pow<T, T, T>(q, -x);
a = q;
int i = 0;
while(i < 9 || a <= (T)9.) {
i += 1;
a += (T)1.0;
b = math::nd4j_pow<T, T, T>(a, -x);
s += b;
if(math::nd4j_abs(b / s) < (T)machep)
return s;
}
w = a;
s += b * (w / (x - (T)1.) - (T)0.5);
a = (T)1.;
k = (T)0.;
for(i = 0; i < 12; ++i) {
a *= x + k;
b /= w;
t = a * b / coeffZeta[i];
s += t;
t = math::nd4j_abs(t / s);
if(t < (T)machep)
return s;
k += (T)1.f;
a *= x + k;
b /= w;
k += (T)1.f;
}
return s;
}
}
}
}
#endif //LIBND4J_ZETA_H