Added dtype formulation for poisson and gamma distributions. (#442)
* Added dtype formulation for poisson and gamma distributions. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored gamma distribution generator and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Added generator for gamma distribution when alpha (shape) between 0 and 1 Signed-off-by: shugeo <sgazeos@gmail.com> * Implemented gamma distribution for shape param less than 1 and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Implemented gamma distributed randoms for shape (alpha) parameter greater then 1. Signed-off-by: shugeo <sgazeos@gmail.com> * Added cuda implementation for gamma distribution. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored cuda and cpu implementation of gamma distribution. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed crash with default beta param with gamma distribution. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed pow for arm arch. Signed-off-by: shugeo <sgazeos@gmail.com> * Gamma test fixed * Cosmetic changes only. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed random value retrieving * Eliminated overflow attemptions. Signed-off-by: shugeo <sgazeos@gmail.com> * Modified random retrieving. Signed-off-by: shugeo <sgazeos@gmail.com> * enlighted density of tests for Gamma distribution. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
c8096197c7
commit
3a3c952e75
|
@ -65,7 +65,7 @@ namespace sd {
|
|||
additionalShape = additionalShapeBroadcasted;
|
||||
}
|
||||
auto lastDim = shape::sizeAt(alphaShape, 0);
|
||||
auto dtype = ArrayOptions::dataType(alphaShape);
|
||||
auto dtype = block.numD() > 0? D_ARG(0): ArrayOptions::dataType(alphaShape);
|
||||
for (auto i = 0; i < shape::rank(additionalShape); i++)
|
||||
shape.push_back(shape::sizeAt(additionalShape, i));
|
||||
auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape);
|
||||
|
|
|
@ -47,7 +47,7 @@ namespace sd {
|
|||
auto in = INPUT_VARIABLE(0);
|
||||
auto shape = in->template asVectorT<Nd4jLong>();
|
||||
auto lambdaShape = inputShape->at(1);
|
||||
auto dtype = ArrayOptions::dataType(lambdaShape);
|
||||
auto dtype = block.numD() > 0? D_ARG(0) : ArrayOptions::dataType(lambdaShape);
|
||||
for (auto d = 0; d < shape::rank(lambdaShape); ++d ) {
|
||||
shape.emplace_back(shape::sizeAt(lambdaShape, d));
|
||||
}
|
||||
|
|
|
@ -31,6 +31,87 @@ namespace sd {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
/**
|
||||
* gammaLess - compute gamma distributed value for shapes (alpha) from 0 to 1
|
||||
* @tparam T - any float types are acceptable
|
||||
* @param rng - random generator for uniformly vals
|
||||
* @param alpha - shape of distribution
|
||||
* @param beta - scale of distributed values
|
||||
* @return gamma distributed value
|
||||
*/
|
||||
template <typename T>
|
||||
T gammaLess(graph::RandomGenerator& rng, T const alpha, T const beta) {
|
||||
auto d = T(1.0334f) - T(0.0766f) * math::p_exp(T(2.2942f) * alpha);
|
||||
auto a = math::p_pow(T(2.f), alpha) * math::p_pow(T(1.f) - math::p_exp(-d * T(0.5f)), alpha);
|
||||
auto b = alpha * math::p_pow(d, alpha - T(1.f)) * exp(-d);
|
||||
auto c = a + b;
|
||||
T rawX;
|
||||
static auto index = 0LL;
|
||||
const T underAlpha = T(1.f) / alpha;
|
||||
const T powerAlpha = math::p_pow(T(2.f), alpha - T(1.f));
|
||||
|
||||
for (;;) {
|
||||
auto u = rng.relativeT<T>(index++, T(0.f), T(1.f));
|
||||
|
||||
if (u <= a / c) rawX = -T(2.f) * math::p_log(T(1.f) - T(0.5f) * math::p_pow(T(c * u), underAlpha));
|
||||
else rawX = - math::p_log(c * (T(1.f) - u)/(alpha * math::p_pow(d, alpha - T(1.f))));
|
||||
|
||||
T v = rng.relativeT(index++, 0.f, 1.f);
|
||||
if (rawX <= d) {
|
||||
auto testVal = (math::p_pow(rawX, alpha - 1.f) * math::p_exp(-T(0.5f) * rawX)) / (powerAlpha * math::p_pow(T(1.f) - math::p_exp(-T(0.5f) * rawX), alpha - T(1.f)));
|
||||
if (testVal < v) continue;
|
||||
break;
|
||||
}
|
||||
else {
|
||||
if (v <= math::p_pow(d / rawX, T(1.f) - alpha)) break;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return rawX / beta;
|
||||
}
|
||||
|
||||
/**
|
||||
* gammaGreat - generate gamma distributed value for shape (alpha) greater then 1
|
||||
* @tparam T - given type (any float type is accepted.)
|
||||
* @param rng - random generator
|
||||
* @param alpha - shape of the gamma distribution (alpha)
|
||||
* @param beta - scale of the gamma distribution (beta)
|
||||
* @return - gamma distributed value with given params
|
||||
*/
|
||||
template <typename T>
|
||||
T gammaGreat(graph::RandomGenerator& rng, T const alpha, T const beta) {
|
||||
auto decreasedAlpha = alpha - T(1.f/3.f);
|
||||
auto c = T(1.)/ math::p_sqrt(T(9.f) * decreasedAlpha);
|
||||
static auto index = 0LL;
|
||||
T x;
|
||||
auto normalDistributed = [](graph::RandomGenerator& rng, Nd4jLong& index) {
|
||||
auto v1 = rng.relativeT(index++, T(0.f), T(1.f));
|
||||
auto v2 = rng.relativeT(index++, T(0.f), T(1.f));
|
||||
|
||||
return math::p_cos(T(2.f * 3.141592f) * v2) * math::p_sqrt(T(-2.f) * math::p_log(v1));
|
||||
};
|
||||
|
||||
// const T underAlpha = T(1.f) / alpha;
|
||||
// const T powerAlpha = math::p_pow(T(2.f), alpha - T(1.f));
|
||||
|
||||
float normalizedVar;
|
||||
for(;;) {
|
||||
do {
|
||||
x = normalDistributed(rng, index); //printf("X = %f\n", x);
|
||||
normalizedVar = T(1.f) + c * x;
|
||||
} while(normalizedVar < T(0.f));
|
||||
normalizedVar = normalizedVar * normalizedVar * normalizedVar; //v * v * v;
|
||||
|
||||
auto u = rng.relativeT<T>(index++, T(0.f), T(1.f)); //printf("UNI = %f\n", u);
|
||||
if( u < T(1.f) - T(.0331f) * (x * x) * (x * x) )
|
||||
break; //return (d * v / b);
|
||||
if( log(u) < 0.5f * x * x + decreasedAlpha * (1. - normalizedVar + math::p_log(normalizedVar)) )
|
||||
break;
|
||||
}
|
||||
return (decreasedAlpha * normalizedVar / beta);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
|
||||
|
||||
|
@ -52,24 +133,19 @@ namespace helpers {
|
|||
|
||||
copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha));
|
||||
copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta));
|
||||
|
||||
}
|
||||
// bool directAlpha = alpha->ews() == 1 && alpha->ordering() == 'c';
|
||||
bool directOutput = output->ews() == 1 && output->ordering() == 'c';
|
||||
T* outputBuf = output->dataBuffer()->primaryAsT<T>();
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (Nd4jLong k = 0; k < shift; k++) {
|
||||
auto pos = k * step;
|
||||
auto u = rng.relativeT<T>(k, 0., 1.);
|
||||
for (Nd4jLong e = 0; e < step; e++)
|
||||
if (directOutput) {
|
||||
outputBuf[pos + e] = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e),
|
||||
beta != nullptr ? copyBeta->t<T>(e) * u : u);
|
||||
outputBuf[pos + e] = copyAlpha->t<T>(e) <= 1? gammaLess(rng, copyAlpha->t<T>(e), beta?copyBeta->t<T>(e):T(1.f)):gammaGreat(rng, copyAlpha->t<T>(e), beta?copyBeta->t<T>(e):T(1.f));
|
||||
}
|
||||
else {
|
||||
output->r<T>(pos + e) = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e),
|
||||
beta != nullptr ? copyBeta->t<T>(e) * u : u);
|
||||
output->r<T>(pos + e) = copyAlpha->t<T>(e) <= 1? gammaLess(rng, copyAlpha->t<T>(e), beta?copyBeta->t<T>(e):T(1.f)):gammaGreat(rng, copyAlpha->t<T>(e), beta?copyBeta->t<T>(e):T(1.f));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -33,6 +33,94 @@
|
|||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
/**
|
||||
* gammaLess - compute gamma distributed value for shapes (alpha) from 0 to 1
|
||||
* @tparam T - any float types are acceptable
|
||||
* @param U - uniform random generated vals
|
||||
* @param alpha - shape of distribution
|
||||
* @param beta - scale of distributed values
|
||||
* @return gamma distributed value
|
||||
*/
|
||||
template <typename T>
|
||||
T __device__ gammaLess(T const* U, Nd4jLong index, Nd4jLong maxLength, T const alpha, T const beta) {
|
||||
auto d = T(1.0334f) - T(0.0766f) * math::p_exp(T(2.2942f) * alpha);
|
||||
auto a = math::p_pow(T(2.f), alpha) * math::p_pow(T(1.f) - math::p_exp(-d * T(0.5f)), alpha);
|
||||
auto b = alpha * math::p_pow(d, alpha - T(1.f)) * exp(-d);
|
||||
auto c = a + b;
|
||||
T rawX;
|
||||
auto indexV = index;
|
||||
auto underAlpha = T(1.f) / alpha;
|
||||
auto powerAlpha = math::p_pow(T(2.f), alpha - T(1.f));
|
||||
|
||||
for (;;) {
|
||||
auto u = (indexV < maxLength)?U[indexV++]:U[0];
|
||||
if (indexV >= maxLength) indexV = 0LL;
|
||||
// math::atomics::nd4j_atomicAdd(index, 1LL);
|
||||
if (u <= a / c) rawX = -T(2.f) * math::p_log(T(1.f) - T(0.5f) * math::p_pow(c * u, underAlpha));
|
||||
else rawX = - math::p_log(c * (T(1.f) - u)/(alpha * math::p_pow(d, alpha - T(1.f))));
|
||||
|
||||
T v = indexV < maxLength?U[indexV++]:U[0];
|
||||
if (indexV >= maxLength) indexV = 0LL;
|
||||
// math::atomics::nd4j_atomicAdd(index, 1LL);
|
||||
|
||||
if (rawX <= d) {
|
||||
auto testVal = (math::p_pow(rawX, alpha - 1.f) * math::p_exp(-T(0.5f) * rawX)) / (powerAlpha * math::p_pow(T(1.f) - math::p_exp(-T(0.5f) * rawX), alpha - T(1.f)));
|
||||
if (testVal < v) continue;
|
||||
break;
|
||||
}
|
||||
else {
|
||||
if (v <= math::p_pow(d / rawX, T(1.f) - alpha)) break;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return rawX / beta;
|
||||
}
|
||||
|
||||
/**
|
||||
* gammaGreat - generate gamma distributed value for shape (alpha) greater then 1
|
||||
* @tparam T - given type (any float type is accepted.)
|
||||
* @param rng - random generator
|
||||
* @param alpha - shape of the gamma distribution (alpha)
|
||||
* @param beta - scale of the gamma distribution (beta)
|
||||
* @return - gamma distributed value with given params
|
||||
*/
|
||||
template <typename T>
|
||||
T __device__ gammaGreat(T const* U, Nd4jLong index, Nd4jLong maxLength, T const alpha, T const beta) {
|
||||
auto decreasedAlpha = alpha - T(1.f/3.f);
|
||||
auto c = T(1.)/ math::p_sqrt(T(9.f) * decreasedAlpha);
|
||||
// static auto index = 0LL;
|
||||
auto indexV = index;
|
||||
T x;
|
||||
auto normalDistributed = [U, maxLength](Nd4jLong& index) {
|
||||
auto v1 = index < maxLength?U[index++]:U[0];
|
||||
if (index >= maxLength) index = 0LL;
|
||||
// math::atomics::nd4j_atomicAdd(index, 1LL);
|
||||
auto v2 = index < maxLength?U[index++]:U[0];
|
||||
if (index >= maxLength) index = 0LL;
|
||||
// math::atomics::nd4j_atomicAdd(index, 1LL);
|
||||
|
||||
return math::p_cos(T(2.f * 3.141592f) * v2) * math::p_sqrt(T(-2.f) * math::p_log(v1));
|
||||
};
|
||||
|
||||
float normalizedVar;
|
||||
for(;;) {
|
||||
do {
|
||||
x = normalDistributed(indexV); //printf("X = %f\n", x);
|
||||
normalizedVar = T(1.f) + c * x;
|
||||
} while(normalizedVar < T(0.f));
|
||||
normalizedVar = normalizedVar * normalizedVar * normalizedVar; //v * v * v;
|
||||
|
||||
auto u = U[indexV++];
|
||||
if (indexV >= maxLength) indexV = 0LL;
|
||||
// math::atomics::nd4j_atomicAdd(index, 1LL);
|
||||
|
||||
if( u < T(1.f) - T(.0331f) * (x * x) * (x * x) )
|
||||
break; //return (d * v / b);
|
||||
if( log(u) < 0.5f * x * x + decreasedAlpha * (1. - normalizedVar + math::p_log(normalizedVar)) )
|
||||
break;
|
||||
}
|
||||
return (decreasedAlpha * normalizedVar / beta);
|
||||
}
|
||||
|
||||
/*
|
||||
* fillGammaKernel - fill up output with gamma distributed values
|
||||
|
@ -44,25 +132,28 @@ namespace helpers {
|
|||
* output - distributed output.
|
||||
* */
|
||||
template <typename T>
|
||||
static __global__ void fillGammaKernel(T* uList, Nd4jLong uLength, T* alpha, const Nd4jLong* alphaShape,
|
||||
T* beta, const Nd4jLong* betaShape, T* output, const Nd4jLong* outputShape) {
|
||||
static __global__ void fillGammaKernel(T const* uList, Nd4jLong uLength, T const* alpha, const Nd4jLong* alphaShape,
|
||||
T const* beta, const Nd4jLong* betaShape, T* output, const Nd4jLong* outputShape) {
|
||||
// fill up
|
||||
__shared__ Nd4jLong aLength;
|
||||
__shared__ Nd4jLong outLength;
|
||||
if (threadIdx.x == 0) {
|
||||
aLength = shape::length(alphaShape);
|
||||
outLength = shape::length(outputShape) / aLength;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) {
|
||||
for (auto k = blockIdx.x; k < (int)outLength; k += gridDim.x) {
|
||||
auto pos = k * aLength;
|
||||
auto u = uList[k]; // this is a vector
|
||||
// auto u = uList[k]; // this is a vector
|
||||
//Nd4jLong index = k;
|
||||
for (auto e = threadIdx.x; e < (int)aLength; e += blockDim.x) {
|
||||
auto aIndex = shape::getIndexOffset(e, alphaShape);
|
||||
auto bIndex = betaShape?shape::getIndexOffset(e, betaShape):-1LL;
|
||||
auto betaV = T(beta != nullptr ? beta[bIndex] * u : u);
|
||||
auto betaV = T(beta != nullptr ? beta[bIndex] : T(1.f));
|
||||
auto zIndex = shape::getIndexOffset(e + pos, outputShape);
|
||||
|
||||
output[zIndex] = math::nd4j_igamma<T, T, T>(alpha[aIndex], betaV);
|
||||
output[zIndex] = alpha[aIndex] > T(1.f)?gammaGreat(uList, pos, uLength, alpha[aIndex], betaV):gammaLess(uList, pos, uLength, alpha[aIndex], betaV);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -76,7 +167,7 @@ namespace helpers {
|
|||
else
|
||||
broadcasted = alpha->shapeInfo();
|
||||
auto step = shape::length(broadcasted);
|
||||
auto shift = output->lengthOf() / step;
|
||||
auto shift = output->lengthOf() * 4LL; // 2-wise greater case for uniform vals
|
||||
|
||||
auto copyAlpha = alpha;
|
||||
auto copyBeta = beta;
|
||||
|
@ -86,19 +177,21 @@ namespace helpers {
|
|||
|
||||
copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha));
|
||||
copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta));
|
||||
copyAlpha->tickWriteDevice(); copyBeta->tickWriteDevice();
|
||||
// if (!copyAlpha->isActualOnDevice()) copyAlpha->syncToDevice();
|
||||
// if (!copyBeta->isActualOnDevice()) copyBeta->syncToDevice();
|
||||
}
|
||||
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray uniform = NDArrayFactory::create<T>('c', {shift}, context);
|
||||
uniform.syncToDevice();
|
||||
// fill up uniform with given length
|
||||
RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.);
|
||||
|
||||
RandomLauncher::fillUniform(context, rng, &uniform, 0.0000000001, 0.9999999999);
|
||||
uniform.syncToDevice();
|
||||
// uniform.printIndexedBuffer("Uniform");
|
||||
fillGammaKernel<T><<<128, 128, 256, *stream>>>(uniform.dataBuffer()->specialAsT<T>(), shift,
|
||||
copyAlpha->dataBuffer()->specialAsT<T>(), copyAlpha->specialShapeInfo(),
|
||||
beta?copyBeta->dataBuffer()->specialAsT<T>():(T*)nullptr,
|
||||
beta?copyBeta->specialShapeInfo():(Nd4jLong*)nullptr,
|
||||
beta?copyBeta->dataBuffer()->specialAsT<T>():(T const*)nullptr,
|
||||
beta?copyBeta->specialShapeInfo():(Nd4jLong const*)nullptr,
|
||||
output->dataBuffer()->specialAsT<T>(), output->specialShapeInfo());
|
||||
|
||||
if (beta != nullptr) {
|
||||
|
|
|
@ -1015,8 +1015,6 @@ TEST_F(RNGTests, Test_GammaDistribution_2) {
|
|||
// z->printIndexedBuffer("Gamma distribution");
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_GammaDistribution_3) {
|
||||
|
@ -1040,6 +1038,61 @@ TEST_F(RNGTests, Test_GammaDistribution_3) {
|
|||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_GammaDistribution_4) {
|
||||
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {1000, 1000});
|
||||
auto al = NDArrayFactory::create<float>(2.f);
|
||||
auto be = NDArrayFactory::create<float>(2.f);
|
||||
auto exp0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||
|
||||
// al.linspace(1.0);
|
||||
// be.assign(2.0);
|
||||
|
||||
sd::ops::random_gamma op;
|
||||
auto result = op.evaluate({&x, &al, &be}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result.at(0);
|
||||
// z->printIndexedBuffer("Gamma distribution");
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
sd::ops::reduce_mean testOps1;
|
||||
sd::ops::reduce_variance testOps2;
|
||||
auto testRes1 = testOps1.evaluate({z});
|
||||
auto testRes2 = testOps2.evaluate({z});
|
||||
// testRes1[0]->printBuffer("Mean (expected 1.0)");
|
||||
// testRes2[0]->printBuffer("Variance (expected 0.5)");
|
||||
ASSERT_NEAR(testRes1[0]->t<float>(0), 1.0f, 0.01);
|
||||
ASSERT_NEAR(testRes2[0]->t<float>(0), 0.5f, 0.02);
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_GammaDistribution_5) {
|
||||
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {100, 100});
|
||||
auto al = NDArrayFactory::create<float>(0.2f);
|
||||
auto be = NDArrayFactory::create<float>(2.f);
|
||||
auto exp0 = NDArrayFactory::create<float>('c', {100, 100});
|
||||
|
||||
// al.linspace(1.0);
|
||||
// be.assign(2.0);
|
||||
|
||||
sd::ops::random_gamma op;
|
||||
auto result = op.evaluate({&x, &al, &be}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result.at(0);
|
||||
// z->printIndexedBuffer("Gamma distribution");
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
// z->printIndexedBuffer("Gamma distributed");
|
||||
sd::ops::reduce_mean testOps1;
|
||||
sd::ops::reduce_variance testOps2;
|
||||
auto testRes1 = testOps1.evaluate({z});
|
||||
auto testRes2 = testOps2.evaluate({z});
|
||||
// testRes1[0]->printBuffer("Mean (expected 0.1)");
|
||||
// testRes2[0]->printBuffer("Variance (expected 0.05)");
|
||||
ASSERT_NEAR(testRes1[0]->t<float>(0), 0.1f, 0.02);
|
||||
ASSERT_NEAR(testRes2[0]->t<float>(0), 0.05f, 0.02);
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_UniformDistribution_04) {
|
||||
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
|
||||
auto al = NDArrayFactory::create<int>(1);
|
||||
|
@ -1055,7 +1108,6 @@ TEST_F(RNGTests, Test_UniformDistribution_04) {
|
|||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_UniformDistribution_05) {
|
||||
|
@ -1237,7 +1289,6 @@ TEST_F(RNGTests, test_multinomial_1) {
|
|||
ASSERT_EQ(Status::OK(), result.status());
|
||||
ASSERT_TRUE(expectedZ.isSameShape(outputZ));
|
||||
ASSERT_TRUE(expectedZ.equalsTo(outputZ));
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, test_multinomial_2) {
|
||||
|
@ -1314,7 +1365,6 @@ TEST_F(RNGTests, test_multinomial_5) {
|
|||
RandomGenerator rng(1234, 1234);
|
||||
|
||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false));
|
||||
|
||||
auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||
auto mean = output.meanNumber();
|
||||
// printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
|
||||
|
@ -1386,7 +1436,6 @@ TEST_F(RNGTests, test_multinomial_6) {
|
|||
ASSERT_NEAR(2.906, mean.e<double>(0), 45e-3); // 1000000 35e-3);
|
||||
|
||||
|
||||
|
||||
RandomGenerator rng(1234, 1234);
|
||||
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);
|
||||
NDArray output('c', { batchValue, Samples }, sd::DataType::INT64);
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.nd4j.linalg.BaseNd4jTest;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||
import org.nd4j.linalg.api.ops.random.custom.*;
|
||||
import org.nd4j.linalg.api.ops.random.impl.*;
|
||||
|
@ -1479,14 +1480,22 @@ public class RandomTests extends BaseNd4jTest {
|
|||
@Test
|
||||
public void testGamma(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray shape = Nd4j.createFromArray(new int[] {1,3});
|
||||
INDArray alpha = Nd4j.rand(1,3);
|
||||
val randomGamma = new RandomGamma(shape, alpha, null);
|
||||
INDArray shape = Nd4j.createFromArray(new int[] {1000,1000});
|
||||
INDArray alpha = Nd4j.createFromArray(new float[]{2.f});
|
||||
INDArray beta = Nd4j.createFromArray(new float[]{2.f});
|
||||
val randomGamma = new RandomGamma(shape, alpha, beta);
|
||||
INDArray[] res = Nd4j.exec(randomGamma);
|
||||
|
||||
val randomGamma1 = new RandomGamma(shape, alpha, null);
|
||||
val randomGamma1 = new RandomGamma(shape, alpha, beta);
|
||||
INDArray[] res1 = Nd4j.exec(randomGamma1);
|
||||
assertEquals(res[0], res1[0]);
|
||||
|
||||
val meanOp0 = new Mean(res[0]);
|
||||
val meanOp1 = new Mean(res1[0]);
|
||||
|
||||
INDArray mean0 = Nd4j.exec(meanOp0);
|
||||
INDArray mean1 = Nd4j.exec(meanOp1);
|
||||
|
||||
assertArrayEquals(mean0.toFloatVector(), mean1.toFloatVector(), 1e-2f);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue