cavis/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu

195 lines
6.2 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/*******************************************************************************
* Copyright (t2) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include<cmath>
#include <array/DataTypeUtils.h>
#include<ops/declarable/helpers/betaInc.h>
#include <helpers/PointersManager.h>
namespace sd {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
// modified Lentzs algorithm for continued fractions,
// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions,”
template <typename T>
__device__ T continuedFractionCuda(const T a, const T b, const T x) {
extern __shared__ unsigned char shmem[];
T* coeffs = reinterpret_cast<T*>(shmem);
const T min = DataTypeUtils::min<T>() / DataTypeUtils::eps<T>();
const T aPlusb = a + b;
T val, aPlus2i;
T t2 = coeffs[1];
T t1 = coeffs[0];
if(math::nd4j_abs<T>(t1) < min)
t1 = min;
t1 = static_cast<T>(1) / t1;
T result = t1;
for(uint i = 1; i <= maxIter; ++i) {
const uint i2 = 2*i;
aPlus2i = a + static_cast<T>(i2);
// t1
t1 = static_cast<T>(1) + coeffs[i2] * t1;
if(math::nd4j_abs<T>(t1) < min)
t1 = min;
t1 = static_cast<T>(1) / t1;
// t2
t2 = static_cast<T>(1) + coeffs[i2] / t2;
if(math::nd4j_abs<T>(t2) < min)
t2 = min;
// result
result *= t2 * t1;
// t1
t1 = static_cast<T>(1) + coeffs[i2 + 1] * t1;
if(math::nd4j_abs<T>(t1) < min)
t1 = min;
t1 = static_cast<T>(1) / t1;
// t2
t2 = static_cast<T>(1) + coeffs[i2 + 1] / t2;
if(math::nd4j_abs<T>(t2) < min)
t2 = min;
// result
val = t2 * t1;
result *= val;
// condition to stop loop
if(math::nd4j_abs<T>(val - static_cast<T>(1)) <= DataTypeUtils::eps<T>())
return result;
}
return DataTypeUtils::infOrMax<T>(); // no convergence, more iterations is required, return infinity
}
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
const void* vb, const Nd4jLong* bShapeInfo,
const void* vx, const Nd4jLong* xShapeInfo,
void* vz, const Nd4jLong* zShapeInfo) {
extern __shared__ unsigned char shmem[];
T* sharedMem = reinterpret_cast<T*>(shmem);
const Nd4jLong j = blockIdx.x; // one block per each element
T& z = *(reinterpret_cast<T*>(vz) + shape::getIndexOffset(j, zShapeInfo));
__shared__ T a, b, x;
__shared__ bool symmCond;
if (threadIdx.x == 0) {
a = *(reinterpret_cast<const T*>(va) + shape::getIndexOffset(j, aShapeInfo));
b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo));
x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo));
symmCond = x > (a + static_cast<T>(1)) / (a + b + static_cast<T>(2));
if(symmCond) { // swap a and b, x = 1 - x
T temp = a;
a = b;
b = temp;
x = static_cast<T>(1) - x;
}
}
__syncthreads();
// t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5
if(a == b && x == static_cast<T>(0.5)) {
z = static_cast<T>(0.5);
return;
}
if (x == static_cast<T>(0) || x == static_cast<T>(1)) {
z = symmCond ? static_cast<T>(1) - x : x;
return;
}
// calculate two coefficients per thread
if(threadIdx.x != 0) {
const int i = threadIdx.x;
const T aPlus2i = a + 2*i;
sharedMem[2*i] = i * (b - i) * x / ((aPlus2i - static_cast<T>(1)) * aPlus2i);
sharedMem[2*i + 1] = -(a + i) * (a + b + i) * x / ((aPlus2i + static_cast<T>(1)) * aPlus2i);
}
__syncthreads();
if(threadIdx.x == 0) {
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
sharedMem[0] = static_cast<T>(1) - (a + b) * x / (a + static_cast<T>(1));
sharedMem[1] = static_cast<T>(1);
z = front * continuedFractionCuda(a, b, x) / a;
if(symmCond) // symmetry relation
z = static_cast<T>(1) - z;
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
static void betaIncForArrayCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
const void* va, const Nd4jLong* aShapeInfo,
const void* vb, const Nd4jLong* bShapeInfo,
const void* vx, const Nd4jLong* xShapeInfo,
void* vz, const Nd4jLong* zShapeInfo) {
betaIncForArrayCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(va, aShapeInfo, vb, bShapeInfo, vx, xShapeInfo, vz, zShapeInfo);
}
///////////////////////////////////////////////////////////////////
// overload betaInc for arrays, shapes of a, b and x must be the same !!!
void betaInc(sd::LaunchContext* context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output) {
const int threadsPerBlock = maxIter;
const int blocksPerGrid = output.lengthOf();
const int sharedMem = 2 * output.sizeOfT() * threadsPerBlock + 128;
const auto xType = x.dataType();
PointersManager manager(context, "betaInc");
NDArray::prepareSpecialUse({&output}, {&a, &b, &x});
BUILD_SINGLE_SELECTOR(xType, betaIncForArrayCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), a.specialBuffer(), a.specialShapeInfo(), b.specialBuffer(), b.specialShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), FLOAT_TYPES);
NDArray::registerSpecialUse({&output}, {&a, &b, &x});
manager.synchronize();
}
}
}
}