Added dtype formulation for poisson and gamma distributions. (#442)

* Added dtype formulation for poisson and gamma distributions.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored gamma distribution generator and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added generator for gamma distribution when alpha (shape) between 0 and 1

Signed-off-by: shugeo <sgazeos@gmail.com>

* Implemented gamma distribution for shape param less than 1 and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Implemented gamma distributed randoms for shape (alpha) parameter greater then 1.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added cuda implementation for gamma distribution.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored cuda and cpu implementation of gamma distribution.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed crash with default beta param with gamma distribution.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed  pow for arm arch.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Gamma test fixed

* Cosmetic changes only.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed random value retrieving

* Eliminated overflow attemptions.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Modified random retrieving.

Signed-off-by: shugeo <sgazeos@gmail.com>

* enlighted density of tests for Gamma distribution.

Signed-off-by: shugeo <sgazeos@gmail.com>

Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
Co-authored-by: raver119 <raver119@gmail.com>
master
shugeo 2020-06-08 13:14:22 +03:00 committed by GitHub
parent c8096197c7
commit 3a3c952e75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 260 additions and 33 deletions

View File

@ -65,7 +65,7 @@ namespace sd {
additionalShape = additionalShapeBroadcasted; additionalShape = additionalShapeBroadcasted;
} }
auto lastDim = shape::sizeAt(alphaShape, 0); auto lastDim = shape::sizeAt(alphaShape, 0);
auto dtype = ArrayOptions::dataType(alphaShape); auto dtype = block.numD() > 0? D_ARG(0): ArrayOptions::dataType(alphaShape);
for (auto i = 0; i < shape::rank(additionalShape); i++) for (auto i = 0; i < shape::rank(additionalShape); i++)
shape.push_back(shape::sizeAt(additionalShape, i)); shape.push_back(shape::sizeAt(additionalShape, i));
auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape); auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape);

View File

@ -47,7 +47,7 @@ namespace sd {
auto in = INPUT_VARIABLE(0); auto in = INPUT_VARIABLE(0);
auto shape = in->template asVectorT<Nd4jLong>(); auto shape = in->template asVectorT<Nd4jLong>();
auto lambdaShape = inputShape->at(1); auto lambdaShape = inputShape->at(1);
auto dtype = ArrayOptions::dataType(lambdaShape); auto dtype = block.numD() > 0? D_ARG(0) : ArrayOptions::dataType(lambdaShape);
for (auto d = 0; d < shape::rank(lambdaShape); ++d ) { for (auto d = 0; d < shape::rank(lambdaShape); ++d ) {
shape.emplace_back(shape::sizeAt(lambdaShape, d)); shape.emplace_back(shape::sizeAt(lambdaShape, d));
} }

View File

@ -31,6 +31,87 @@ namespace sd {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
/**
* gammaLess - compute gamma distributed value for shapes (alpha) from 0 to 1
* @tparam T - any float types are acceptable
* @param rng - random generator for uniformly vals
* @param alpha - shape of distribution
* @param beta - scale of distributed values
* @return gamma distributed value
*/
template <typename T>
T gammaLess(graph::RandomGenerator& rng, T const alpha, T const beta) {
auto d = T(1.0334f) - T(0.0766f) * math::p_exp(T(2.2942f) * alpha);
auto a = math::p_pow(T(2.f), alpha) * math::p_pow(T(1.f) - math::p_exp(-d * T(0.5f)), alpha);
auto b = alpha * math::p_pow(d, alpha - T(1.f)) * exp(-d);
auto c = a + b;
T rawX;
static auto index = 0LL;
const T underAlpha = T(1.f) / alpha;
const T powerAlpha = math::p_pow(T(2.f), alpha - T(1.f));
for (;;) {
auto u = rng.relativeT<T>(index++, T(0.f), T(1.f));
if (u <= a / c) rawX = -T(2.f) * math::p_log(T(1.f) - T(0.5f) * math::p_pow(T(c * u), underAlpha));
else rawX = - math::p_log(c * (T(1.f) - u)/(alpha * math::p_pow(d, alpha - T(1.f))));
T v = rng.relativeT(index++, 0.f, 1.f);
if (rawX <= d) {
auto testVal = (math::p_pow(rawX, alpha - 1.f) * math::p_exp(-T(0.5f) * rawX)) / (powerAlpha * math::p_pow(T(1.f) - math::p_exp(-T(0.5f) * rawX), alpha - T(1.f)));
if (testVal < v) continue;
break;
}
else {
if (v <= math::p_pow(d / rawX, T(1.f) - alpha)) break;
continue;
}
}
return rawX / beta;
}
/**
* gammaGreat - generate gamma distributed value for shape (alpha) greater then 1
* @tparam T - given type (any float type is accepted.)
* @param rng - random generator
* @param alpha - shape of the gamma distribution (alpha)
* @param beta - scale of the gamma distribution (beta)
* @return - gamma distributed value with given params
*/
template <typename T>
T gammaGreat(graph::RandomGenerator& rng, T const alpha, T const beta) {
auto decreasedAlpha = alpha - T(1.f/3.f);
auto c = T(1.)/ math::p_sqrt(T(9.f) * decreasedAlpha);
static auto index = 0LL;
T x;
auto normalDistributed = [](graph::RandomGenerator& rng, Nd4jLong& index) {
auto v1 = rng.relativeT(index++, T(0.f), T(1.f));
auto v2 = rng.relativeT(index++, T(0.f), T(1.f));
return math::p_cos(T(2.f * 3.141592f) * v2) * math::p_sqrt(T(-2.f) * math::p_log(v1));
};
// const T underAlpha = T(1.f) / alpha;
// const T powerAlpha = math::p_pow(T(2.f), alpha - T(1.f));
float normalizedVar;
for(;;) {
do {
x = normalDistributed(rng, index); //printf("X = %f\n", x);
normalizedVar = T(1.f) + c * x;
} while(normalizedVar < T(0.f));
normalizedVar = normalizedVar * normalizedVar * normalizedVar; //v * v * v;
auto u = rng.relativeT<T>(index++, T(0.f), T(1.f)); //printf("UNI = %f\n", u);
if( u < T(1.f) - T(.0331f) * (x * x) * (x * x) )
break; //return (d * v / b);
if( log(u) < 0.5f * x * x + decreasedAlpha * (1. - normalizedVar + math::p_log(normalizedVar)) )
break;
}
return (decreasedAlpha * normalizedVar / beta);
}
template <typename T> template <typename T>
void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
@ -52,24 +133,19 @@ namespace helpers {
copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha));
copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta));
} }
// bool directAlpha = alpha->ews() == 1 && alpha->ordering() == 'c';
bool directOutput = output->ews() == 1 && output->ordering() == 'c'; bool directOutput = output->ews() == 1 && output->ordering() == 'c';
T* outputBuf = output->dataBuffer()->primaryAsT<T>(); T* outputBuf = output->dataBuffer()->primaryAsT<T>();
PRAGMA_OMP_PARALLEL_FOR PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong k = 0; k < shift; k++) { for (Nd4jLong k = 0; k < shift; k++) {
auto pos = k * step; auto pos = k * step;
auto u = rng.relativeT<T>(k, 0., 1.);
for (Nd4jLong e = 0; e < step; e++) for (Nd4jLong e = 0; e < step; e++)
if (directOutput) { if (directOutput) {
outputBuf[pos + e] = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e), outputBuf[pos + e] = copyAlpha->t<T>(e) <= 1? gammaLess(rng, copyAlpha->t<T>(e), beta?copyBeta->t<T>(e):T(1.f)):gammaGreat(rng, copyAlpha->t<T>(e), beta?copyBeta->t<T>(e):T(1.f));
beta != nullptr ? copyBeta->t<T>(e) * u : u);
} }
else { else {
output->r<T>(pos + e) = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e), output->r<T>(pos + e) = copyAlpha->t<T>(e) <= 1? gammaLess(rng, copyAlpha->t<T>(e), beta?copyBeta->t<T>(e):T(1.f)):gammaGreat(rng, copyAlpha->t<T>(e), beta?copyBeta->t<T>(e):T(1.f));
beta != nullptr ? copyBeta->t<T>(e) * u : u);
} }
} }
@ -211,4 +287,4 @@ namespace helpers {
} }
} }
} }

View File

@ -33,6 +33,94 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
/**
* gammaLess - compute gamma distributed value for shapes (alpha) from 0 to 1
* @tparam T - any float types are acceptable
* @param U - uniform random generated vals
* @param alpha - shape of distribution
* @param beta - scale of distributed values
* @return gamma distributed value
*/
template <typename T>
T __device__ gammaLess(T const* U, Nd4jLong index, Nd4jLong maxLength, T const alpha, T const beta) {
auto d = T(1.0334f) - T(0.0766f) * math::p_exp(T(2.2942f) * alpha);
auto a = math::p_pow(T(2.f), alpha) * math::p_pow(T(1.f) - math::p_exp(-d * T(0.5f)), alpha);
auto b = alpha * math::p_pow(d, alpha - T(1.f)) * exp(-d);
auto c = a + b;
T rawX;
auto indexV = index;
auto underAlpha = T(1.f) / alpha;
auto powerAlpha = math::p_pow(T(2.f), alpha - T(1.f));
for (;;) {
auto u = (indexV < maxLength)?U[indexV++]:U[0];
if (indexV >= maxLength) indexV = 0LL;
// math::atomics::nd4j_atomicAdd(index, 1LL);
if (u <= a / c) rawX = -T(2.f) * math::p_log(T(1.f) - T(0.5f) * math::p_pow(c * u, underAlpha));
else rawX = - math::p_log(c * (T(1.f) - u)/(alpha * math::p_pow(d, alpha - T(1.f))));
T v = indexV < maxLength?U[indexV++]:U[0];
if (indexV >= maxLength) indexV = 0LL;
// math::atomics::nd4j_atomicAdd(index, 1LL);
if (rawX <= d) {
auto testVal = (math::p_pow(rawX, alpha - 1.f) * math::p_exp(-T(0.5f) * rawX)) / (powerAlpha * math::p_pow(T(1.f) - math::p_exp(-T(0.5f) * rawX), alpha - T(1.f)));
if (testVal < v) continue;
break;
}
else {
if (v <= math::p_pow(d / rawX, T(1.f) - alpha)) break;
continue;
}
}
return rawX / beta;
}
/**
* gammaGreat - generate gamma distributed value for shape (alpha) greater then 1
* @tparam T - given type (any float type is accepted.)
* @param rng - random generator
* @param alpha - shape of the gamma distribution (alpha)
* @param beta - scale of the gamma distribution (beta)
* @return - gamma distributed value with given params
*/
template <typename T>
T __device__ gammaGreat(T const* U, Nd4jLong index, Nd4jLong maxLength, T const alpha, T const beta) {
auto decreasedAlpha = alpha - T(1.f/3.f);
auto c = T(1.)/ math::p_sqrt(T(9.f) * decreasedAlpha);
// static auto index = 0LL;
auto indexV = index;
T x;
auto normalDistributed = [U, maxLength](Nd4jLong& index) {
auto v1 = index < maxLength?U[index++]:U[0];
if (index >= maxLength) index = 0LL;
// math::atomics::nd4j_atomicAdd(index, 1LL);
auto v2 = index < maxLength?U[index++]:U[0];
if (index >= maxLength) index = 0LL;
// math::atomics::nd4j_atomicAdd(index, 1LL);
return math::p_cos(T(2.f * 3.141592f) * v2) * math::p_sqrt(T(-2.f) * math::p_log(v1));
};
float normalizedVar;
for(;;) {
do {
x = normalDistributed(indexV); //printf("X = %f\n", x);
normalizedVar = T(1.f) + c * x;
} while(normalizedVar < T(0.f));
normalizedVar = normalizedVar * normalizedVar * normalizedVar; //v * v * v;
auto u = U[indexV++];
if (indexV >= maxLength) indexV = 0LL;
// math::atomics::nd4j_atomicAdd(index, 1LL);
if( u < T(1.f) - T(.0331f) * (x * x) * (x * x) )
break; //return (d * v / b);
if( log(u) < 0.5f * x * x + decreasedAlpha * (1. - normalizedVar + math::p_log(normalizedVar)) )
break;
}
return (decreasedAlpha * normalizedVar / beta);
}
/* /*
* fillGammaKernel - fill up output with gamma distributed values * fillGammaKernel - fill up output with gamma distributed values
@ -44,25 +132,28 @@ namespace helpers {
* output - distributed output. * output - distributed output.
* */ * */
template <typename T> template <typename T>
static __global__ void fillGammaKernel(T* uList, Nd4jLong uLength, T* alpha, const Nd4jLong* alphaShape, static __global__ void fillGammaKernel(T const* uList, Nd4jLong uLength, T const* alpha, const Nd4jLong* alphaShape,
T* beta, const Nd4jLong* betaShape, T* output, const Nd4jLong* outputShape) { T const* beta, const Nd4jLong* betaShape, T* output, const Nd4jLong* outputShape) {
// fill up // fill up
__shared__ Nd4jLong aLength; __shared__ Nd4jLong aLength;
__shared__ Nd4jLong outLength;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
aLength = shape::length(alphaShape); aLength = shape::length(alphaShape);
outLength = shape::length(outputShape) / aLength;
} }
__syncthreads(); __syncthreads();
for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) { for (auto k = blockIdx.x; k < (int)outLength; k += gridDim.x) {
auto pos = k * aLength; auto pos = k * aLength;
auto u = uList[k]; // this is a vector // auto u = uList[k]; // this is a vector
//Nd4jLong index = k;
for (auto e = threadIdx.x; e < (int)aLength; e += blockDim.x) { for (auto e = threadIdx.x; e < (int)aLength; e += blockDim.x) {
auto aIndex = shape::getIndexOffset(e, alphaShape); auto aIndex = shape::getIndexOffset(e, alphaShape);
auto bIndex = betaShape?shape::getIndexOffset(e, betaShape):-1LL; auto bIndex = betaShape?shape::getIndexOffset(e, betaShape):-1LL;
auto betaV = T(beta != nullptr ? beta[bIndex] * u : u); auto betaV = T(beta != nullptr ? beta[bIndex] : T(1.f));
auto zIndex = shape::getIndexOffset(e + pos, outputShape); auto zIndex = shape::getIndexOffset(e + pos, outputShape);
output[zIndex] = math::nd4j_igamma<T, T, T>(alpha[aIndex], betaV); output[zIndex] = alpha[aIndex] > T(1.f)?gammaGreat(uList, pos, uLength, alpha[aIndex], betaV):gammaLess(uList, pos, uLength, alpha[aIndex], betaV);
} }
} }
} }
@ -76,7 +167,7 @@ namespace helpers {
else else
broadcasted = alpha->shapeInfo(); broadcasted = alpha->shapeInfo();
auto step = shape::length(broadcasted); auto step = shape::length(broadcasted);
auto shift = output->lengthOf() / step; auto shift = output->lengthOf() * 4LL; // 2-wise greater case for uniform vals
auto copyAlpha = alpha; auto copyAlpha = alpha;
auto copyBeta = beta; auto copyBeta = beta;
@ -86,19 +177,21 @@ namespace helpers {
copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha));
copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta));
copyAlpha->tickWriteDevice(); copyBeta->tickWriteDevice(); // if (!copyAlpha->isActualOnDevice()) copyAlpha->syncToDevice();
// if (!copyBeta->isActualOnDevice()) copyBeta->syncToDevice();
} }
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray uniform = NDArrayFactory::create<T>('c', {shift}, context); NDArray uniform = NDArrayFactory::create<T>('c', {shift}, context);
uniform.syncToDevice(); uniform.syncToDevice();
// fill up uniform with given length // fill up uniform with given length
RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.); RandomLauncher::fillUniform(context, rng, &uniform, 0.0000000001, 0.9999999999);
uniform.syncToDevice();
// uniform.printIndexedBuffer("Uniform");
fillGammaKernel<T><<<128, 128, 256, *stream>>>(uniform.dataBuffer()->specialAsT<T>(), shift, fillGammaKernel<T><<<128, 128, 256, *stream>>>(uniform.dataBuffer()->specialAsT<T>(), shift,
copyAlpha->dataBuffer()->specialAsT<T>(), copyAlpha->specialShapeInfo(), copyAlpha->dataBuffer()->specialAsT<T>(), copyAlpha->specialShapeInfo(),
beta?copyBeta->dataBuffer()->specialAsT<T>():(T*)nullptr, beta?copyBeta->dataBuffer()->specialAsT<T>():(T const*)nullptr,
beta?copyBeta->specialShapeInfo():(Nd4jLong*)nullptr, beta?copyBeta->specialShapeInfo():(Nd4jLong const*)nullptr,
output->dataBuffer()->specialAsT<T>(), output->specialShapeInfo()); output->dataBuffer()->specialAsT<T>(), output->specialShapeInfo());
if (beta != nullptr) { if (beta != nullptr) {

View File

@ -1015,8 +1015,6 @@ TEST_F(RNGTests, Test_GammaDistribution_2) {
// z->printIndexedBuffer("Gamma distribution"); // z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
} }
TEST_F(RNGTests, Test_GammaDistribution_3) { TEST_F(RNGTests, Test_GammaDistribution_3) {
@ -1037,7 +1035,62 @@ TEST_F(RNGTests, Test_GammaDistribution_3) {
ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
}
TEST_F(RNGTests, Test_GammaDistribution_4) {
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {1000, 1000});
auto al = NDArrayFactory::create<float>(2.f);
auto be = NDArrayFactory::create<float>(2.f);
auto exp0 = NDArrayFactory::create<float>('c', {1000, 1000});
// al.linspace(1.0);
// be.assign(2.0);
sd::ops::random_gamma op;
auto result = op.evaluate({&x, &al, &be}, {}, {});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
// z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
sd::ops::reduce_mean testOps1;
sd::ops::reduce_variance testOps2;
auto testRes1 = testOps1.evaluate({z});
auto testRes2 = testOps2.evaluate({z});
// testRes1[0]->printBuffer("Mean (expected 1.0)");
// testRes2[0]->printBuffer("Variance (expected 0.5)");
ASSERT_NEAR(testRes1[0]->t<float>(0), 1.0f, 0.01);
ASSERT_NEAR(testRes2[0]->t<float>(0), 0.5f, 0.02);
}
TEST_F(RNGTests, Test_GammaDistribution_5) {
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {100, 100});
auto al = NDArrayFactory::create<float>(0.2f);
auto be = NDArrayFactory::create<float>(2.f);
auto exp0 = NDArrayFactory::create<float>('c', {100, 100});
// al.linspace(1.0);
// be.assign(2.0);
sd::ops::random_gamma op;
auto result = op.evaluate({&x, &al, &be}, {}, {});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
// z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
// z->printIndexedBuffer("Gamma distributed");
sd::ops::reduce_mean testOps1;
sd::ops::reduce_variance testOps2;
auto testRes1 = testOps1.evaluate({z});
auto testRes2 = testOps2.evaluate({z});
// testRes1[0]->printBuffer("Mean (expected 0.1)");
// testRes2[0]->printBuffer("Variance (expected 0.05)");
ASSERT_NEAR(testRes1[0]->t<float>(0), 0.1f, 0.02);
ASSERT_NEAR(testRes2[0]->t<float>(0), 0.05f, 0.02);
} }
TEST_F(RNGTests, Test_UniformDistribution_04) { TEST_F(RNGTests, Test_UniformDistribution_04) {
@ -1055,7 +1108,6 @@ TEST_F(RNGTests, Test_UniformDistribution_04) {
ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
} }
TEST_F(RNGTests, Test_UniformDistribution_05) { TEST_F(RNGTests, Test_UniformDistribution_05) {
@ -1237,7 +1289,6 @@ TEST_F(RNGTests, test_multinomial_1) {
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
ASSERT_TRUE(expectedZ.isSameShape(outputZ)); ASSERT_TRUE(expectedZ.isSameShape(outputZ));
ASSERT_TRUE(expectedZ.equalsTo(outputZ)); ASSERT_TRUE(expectedZ.equalsTo(outputZ));
} }
TEST_F(RNGTests, test_multinomial_2) { TEST_F(RNGTests, test_multinomial_2) {
@ -1314,7 +1365,6 @@ TEST_F(RNGTests, test_multinomial_5) {
RandomGenerator rng(1234, 1234); RandomGenerator rng(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false)); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false));
auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false);
auto mean = output.meanNumber(); auto mean = output.meanNumber();
// printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0)); // printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
@ -1386,7 +1436,6 @@ TEST_F(RNGTests, test_multinomial_6) {
ASSERT_NEAR(2.906, mean.e<double>(0), 45e-3); // 1000000 35e-3); ASSERT_NEAR(2.906, mean.e<double>(0), 45e-3); // 1000000 35e-3);
RandomGenerator rng(1234, 1234); RandomGenerator rng(1234, 1234);
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32); NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);
NDArray output('c', { batchValue, Samples }, sd::DataType::INT64); NDArray output('c', { batchValue, Samples }, sd::DataType::INT64);

View File

@ -30,6 +30,7 @@ import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.random.custom.*; import org.nd4j.linalg.api.ops.random.custom.*;
import org.nd4j.linalg.api.ops.random.impl.*; import org.nd4j.linalg.api.ops.random.impl.*;
@ -1479,14 +1480,22 @@ public class RandomTests extends BaseNd4jTest {
@Test @Test
public void testGamma(){ public void testGamma(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray shape = Nd4j.createFromArray(new int[] {1,3}); INDArray shape = Nd4j.createFromArray(new int[] {1000,1000});
INDArray alpha = Nd4j.rand(1,3); INDArray alpha = Nd4j.createFromArray(new float[]{2.f});
val randomGamma = new RandomGamma(shape, alpha, null); INDArray beta = Nd4j.createFromArray(new float[]{2.f});
val randomGamma = new RandomGamma(shape, alpha, beta);
INDArray[] res = Nd4j.exec(randomGamma); INDArray[] res = Nd4j.exec(randomGamma);
val randomGamma1 = new RandomGamma(shape, alpha, null); val randomGamma1 = new RandomGamma(shape, alpha, beta);
INDArray[] res1 = Nd4j.exec(randomGamma1); INDArray[] res1 = Nd4j.exec(randomGamma1);
assertEquals(res[0], res1[0]);
val meanOp0 = new Mean(res[0]);
val meanOp1 = new Mean(res1[0]);
INDArray mean0 = Nd4j.exec(meanOp0);
INDArray mean1 = Nd4j.exec(meanOp1);
assertArrayEquals(mean0.toFloatVector(), mean1.toFloatVector(), 1e-2f);
} }
@Test @Test