From 3a3c952e755744457fa1b02c4c3772a60297006f Mon Sep 17 00:00:00 2001 From: shugeo Date: Mon, 8 Jun 2020 13:14:22 +0300 Subject: [PATCH] Added dtype formulation for poisson and gamma distributions. (#442) * Added dtype formulation for poisson and gamma distributions. Signed-off-by: shugeo * Refactored gamma distribution generator and tests. Signed-off-by: shugeo * Added generator for gamma distribution when alpha (shape) between 0 and 1 Signed-off-by: shugeo * Implemented gamma distribution for shape param less than 1 and tests. Signed-off-by: shugeo * Implemented gamma distributed randoms for shape (alpha) parameter greater then 1. Signed-off-by: shugeo * Added cuda implementation for gamma distribution. Signed-off-by: shugeo * Refactored cuda and cpu implementation of gamma distribution. Signed-off-by: shugeo * Fixed crash with default beta param with gamma distribution. Signed-off-by: shugeo * Fixed pow for arm arch. Signed-off-by: shugeo * Gamma test fixed * Cosmetic changes only. Signed-off-by: shugeo * Fixed random value retrieving * Eliminated overflow attemptions. Signed-off-by: shugeo * Modified random retrieving. Signed-off-by: shugeo * enlighted density of tests for Gamma distribution. Signed-off-by: shugeo Co-authored-by: Alexander Stoyakin Co-authored-by: raver119 --- .../ops/declarable/generic/random/gamma.cpp | 2 +- .../ops/declarable/generic/random/poisson.cpp | 2 +- .../ops/declarable/helpers/cpu/random.cpp | 92 ++++++++++++-- .../ops/declarable/helpers/cuda/random.cu | 117 ++++++++++++++++-- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 61 ++++++++- .../java/org/nd4j/linalg/rng/RandomTests.java | 19 ++- 6 files changed, 260 insertions(+), 33 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/random/gamma.cpp b/libnd4j/include/ops/declarable/generic/random/gamma.cpp index a00ce2b7e..b7dfc9f06 100644 --- a/libnd4j/include/ops/declarable/generic/random/gamma.cpp +++ b/libnd4j/include/ops/declarable/generic/random/gamma.cpp @@ -65,7 +65,7 @@ namespace sd { additionalShape = additionalShapeBroadcasted; } auto lastDim = shape::sizeAt(alphaShape, 0); - auto dtype = ArrayOptions::dataType(alphaShape); + auto dtype = block.numD() > 0? D_ARG(0): ArrayOptions::dataType(alphaShape); for (auto i = 0; i < shape::rank(additionalShape); i++) shape.push_back(shape::sizeAt(additionalShape, i)); auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', shape); diff --git a/libnd4j/include/ops/declarable/generic/random/poisson.cpp b/libnd4j/include/ops/declarable/generic/random/poisson.cpp index eedfbbe1f..2eb601bc9 100644 --- a/libnd4j/include/ops/declarable/generic/random/poisson.cpp +++ b/libnd4j/include/ops/declarable/generic/random/poisson.cpp @@ -47,7 +47,7 @@ namespace sd { auto in = INPUT_VARIABLE(0); auto shape = in->template asVectorT(); auto lambdaShape = inputShape->at(1); - auto dtype = ArrayOptions::dataType(lambdaShape); + auto dtype = block.numD() > 0? D_ARG(0) : ArrayOptions::dataType(lambdaShape); for (auto d = 0; d < shape::rank(lambdaShape); ++d ) { shape.emplace_back(shape::sizeAt(lambdaShape, d)); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp index b0e1553e4..d96b30175 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp @@ -31,6 +31,87 @@ namespace sd { namespace ops { namespace helpers { + /** + * gammaLess - compute gamma distributed value for shapes (alpha) from 0 to 1 + * @tparam T - any float types are acceptable + * @param rng - random generator for uniformly vals + * @param alpha - shape of distribution + * @param beta - scale of distributed values + * @return gamma distributed value + */ + template + T gammaLess(graph::RandomGenerator& rng, T const alpha, T const beta) { + auto d = T(1.0334f) - T(0.0766f) * math::p_exp(T(2.2942f) * alpha); + auto a = math::p_pow(T(2.f), alpha) * math::p_pow(T(1.f) - math::p_exp(-d * T(0.5f)), alpha); + auto b = alpha * math::p_pow(d, alpha - T(1.f)) * exp(-d); + auto c = a + b; + T rawX; + static auto index = 0LL; + const T underAlpha = T(1.f) / alpha; + const T powerAlpha = math::p_pow(T(2.f), alpha - T(1.f)); + + for (;;) { + auto u = rng.relativeT(index++, T(0.f), T(1.f)); + + if (u <= a / c) rawX = -T(2.f) * math::p_log(T(1.f) - T(0.5f) * math::p_pow(T(c * u), underAlpha)); + else rawX = - math::p_log(c * (T(1.f) - u)/(alpha * math::p_pow(d, alpha - T(1.f)))); + + T v = rng.relativeT(index++, 0.f, 1.f); + if (rawX <= d) { + auto testVal = (math::p_pow(rawX, alpha - 1.f) * math::p_exp(-T(0.5f) * rawX)) / (powerAlpha * math::p_pow(T(1.f) - math::p_exp(-T(0.5f) * rawX), alpha - T(1.f))); + if (testVal < v) continue; + break; + } + else { + if (v <= math::p_pow(d / rawX, T(1.f) - alpha)) break; + continue; + } + } + + return rawX / beta; + } + + /** + * gammaGreat - generate gamma distributed value for shape (alpha) greater then 1 + * @tparam T - given type (any float type is accepted.) + * @param rng - random generator + * @param alpha - shape of the gamma distribution (alpha) + * @param beta - scale of the gamma distribution (beta) + * @return - gamma distributed value with given params + */ + template + T gammaGreat(graph::RandomGenerator& rng, T const alpha, T const beta) { + auto decreasedAlpha = alpha - T(1.f/3.f); + auto c = T(1.)/ math::p_sqrt(T(9.f) * decreasedAlpha); + static auto index = 0LL; + T x; + auto normalDistributed = [](graph::RandomGenerator& rng, Nd4jLong& index) { + auto v1 = rng.relativeT(index++, T(0.f), T(1.f)); + auto v2 = rng.relativeT(index++, T(0.f), T(1.f)); + + return math::p_cos(T(2.f * 3.141592f) * v2) * math::p_sqrt(T(-2.f) * math::p_log(v1)); + }; + +// const T underAlpha = T(1.f) / alpha; +// const T powerAlpha = math::p_pow(T(2.f), alpha - T(1.f)); + + float normalizedVar; + for(;;) { + do { + x = normalDistributed(rng, index); //printf("X = %f\n", x); + normalizedVar = T(1.f) + c * x; + } while(normalizedVar < T(0.f)); + normalizedVar = normalizedVar * normalizedVar * normalizedVar; //v * v * v; + + auto u = rng.relativeT(index++, T(0.f), T(1.f)); //printf("UNI = %f\n", u); + if( u < T(1.f) - T(.0331f) * (x * x) * (x * x) ) + break; //return (d * v / b); + if( log(u) < 0.5f * x * x + decreasedAlpha * (1. - normalizedVar + math::p_log(normalizedVar)) ) + break; + } + return (decreasedAlpha * normalizedVar / beta); + } + template void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { @@ -52,24 +133,19 @@ namespace helpers { copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); - } -// bool directAlpha = alpha->ews() == 1 && alpha->ordering() == 'c'; bool directOutput = output->ews() == 1 && output->ordering() == 'c'; T* outputBuf = output->dataBuffer()->primaryAsT(); PRAGMA_OMP_PARALLEL_FOR for (Nd4jLong k = 0; k < shift; k++) { auto pos = k * step; - auto u = rng.relativeT(k, 0., 1.); for (Nd4jLong e = 0; e < step; e++) if (directOutput) { - outputBuf[pos + e] = math::nd4j_igamma(copyAlpha->t(e), - beta != nullptr ? copyBeta->t(e) * u : u); + outputBuf[pos + e] = copyAlpha->t(e) <= 1? gammaLess(rng, copyAlpha->t(e), beta?copyBeta->t(e):T(1.f)):gammaGreat(rng, copyAlpha->t(e), beta?copyBeta->t(e):T(1.f)); } else { - output->r(pos + e) = math::nd4j_igamma(copyAlpha->t(e), - beta != nullptr ? copyBeta->t(e) * u : u); + output->r(pos + e) = copyAlpha->t(e) <= 1? gammaLess(rng, copyAlpha->t(e), beta?copyBeta->t(e):T(1.f)):gammaGreat(rng, copyAlpha->t(e), beta?copyBeta->t(e):T(1.f)); } } @@ -211,4 +287,4 @@ namespace helpers { } } -} \ No newline at end of file +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/random.cu b/libnd4j/include/ops/declarable/helpers/cuda/random.cu index fe692a0df..e13883515 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/random.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/random.cu @@ -33,6 +33,94 @@ namespace sd { namespace ops { namespace helpers { + /** + * gammaLess - compute gamma distributed value for shapes (alpha) from 0 to 1 + * @tparam T - any float types are acceptable + * @param U - uniform random generated vals + * @param alpha - shape of distribution + * @param beta - scale of distributed values + * @return gamma distributed value + */ + template + T __device__ gammaLess(T const* U, Nd4jLong index, Nd4jLong maxLength, T const alpha, T const beta) { + auto d = T(1.0334f) - T(0.0766f) * math::p_exp(T(2.2942f) * alpha); + auto a = math::p_pow(T(2.f), alpha) * math::p_pow(T(1.f) - math::p_exp(-d * T(0.5f)), alpha); + auto b = alpha * math::p_pow(d, alpha - T(1.f)) * exp(-d); + auto c = a + b; + T rawX; + auto indexV = index; + auto underAlpha = T(1.f) / alpha; + auto powerAlpha = math::p_pow(T(2.f), alpha - T(1.f)); + + for (;;) { + auto u = (indexV < maxLength)?U[indexV++]:U[0]; + if (indexV >= maxLength) indexV = 0LL; +// math::atomics::nd4j_atomicAdd(index, 1LL); + if (u <= a / c) rawX = -T(2.f) * math::p_log(T(1.f) - T(0.5f) * math::p_pow(c * u, underAlpha)); + else rawX = - math::p_log(c * (T(1.f) - u)/(alpha * math::p_pow(d, alpha - T(1.f)))); + + T v = indexV < maxLength?U[indexV++]:U[0]; + if (indexV >= maxLength) indexV = 0LL; +// math::atomics::nd4j_atomicAdd(index, 1LL); + + if (rawX <= d) { + auto testVal = (math::p_pow(rawX, alpha - 1.f) * math::p_exp(-T(0.5f) * rawX)) / (powerAlpha * math::p_pow(T(1.f) - math::p_exp(-T(0.5f) * rawX), alpha - T(1.f))); + if (testVal < v) continue; + break; + } + else { + if (v <= math::p_pow(d / rawX, T(1.f) - alpha)) break; + continue; + } + } + return rawX / beta; + } + + /** + * gammaGreat - generate gamma distributed value for shape (alpha) greater then 1 + * @tparam T - given type (any float type is accepted.) + * @param rng - random generator + * @param alpha - shape of the gamma distribution (alpha) + * @param beta - scale of the gamma distribution (beta) + * @return - gamma distributed value with given params + */ + template + T __device__ gammaGreat(T const* U, Nd4jLong index, Nd4jLong maxLength, T const alpha, T const beta) { + auto decreasedAlpha = alpha - T(1.f/3.f); + auto c = T(1.)/ math::p_sqrt(T(9.f) * decreasedAlpha); +// static auto index = 0LL; + auto indexV = index; + T x; + auto normalDistributed = [U, maxLength](Nd4jLong& index) { + auto v1 = index < maxLength?U[index++]:U[0]; + if (index >= maxLength) index = 0LL; +// math::atomics::nd4j_atomicAdd(index, 1LL); + auto v2 = index < maxLength?U[index++]:U[0]; + if (index >= maxLength) index = 0LL; +// math::atomics::nd4j_atomicAdd(index, 1LL); + + return math::p_cos(T(2.f * 3.141592f) * v2) * math::p_sqrt(T(-2.f) * math::p_log(v1)); + }; + + float normalizedVar; + for(;;) { + do { + x = normalDistributed(indexV); //printf("X = %f\n", x); + normalizedVar = T(1.f) + c * x; + } while(normalizedVar < T(0.f)); + normalizedVar = normalizedVar * normalizedVar * normalizedVar; //v * v * v; + + auto u = U[indexV++]; + if (indexV >= maxLength) indexV = 0LL; +// math::atomics::nd4j_atomicAdd(index, 1LL); + + if( u < T(1.f) - T(.0331f) * (x * x) * (x * x) ) + break; //return (d * v / b); + if( log(u) < 0.5f * x * x + decreasedAlpha * (1. - normalizedVar + math::p_log(normalizedVar)) ) + break; + } + return (decreasedAlpha * normalizedVar / beta); + } /* * fillGammaKernel - fill up output with gamma distributed values @@ -44,25 +132,28 @@ namespace helpers { * output - distributed output. * */ template - static __global__ void fillGammaKernel(T* uList, Nd4jLong uLength, T* alpha, const Nd4jLong* alphaShape, - T* beta, const Nd4jLong* betaShape, T* output, const Nd4jLong* outputShape) { + static __global__ void fillGammaKernel(T const* uList, Nd4jLong uLength, T const* alpha, const Nd4jLong* alphaShape, + T const* beta, const Nd4jLong* betaShape, T* output, const Nd4jLong* outputShape) { // fill up __shared__ Nd4jLong aLength; + __shared__ Nd4jLong outLength; if (threadIdx.x == 0) { aLength = shape::length(alphaShape); + outLength = shape::length(outputShape) / aLength; } __syncthreads(); - for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) { + for (auto k = blockIdx.x; k < (int)outLength; k += gridDim.x) { auto pos = k * aLength; - auto u = uList[k]; // this is a vector +// auto u = uList[k]; // this is a vector + //Nd4jLong index = k; for (auto e = threadIdx.x; e < (int)aLength; e += blockDim.x) { auto aIndex = shape::getIndexOffset(e, alphaShape); auto bIndex = betaShape?shape::getIndexOffset(e, betaShape):-1LL; - auto betaV = T(beta != nullptr ? beta[bIndex] * u : u); + auto betaV = T(beta != nullptr ? beta[bIndex] : T(1.f)); auto zIndex = shape::getIndexOffset(e + pos, outputShape); - output[zIndex] = math::nd4j_igamma(alpha[aIndex], betaV); + output[zIndex] = alpha[aIndex] > T(1.f)?gammaGreat(uList, pos, uLength, alpha[aIndex], betaV):gammaLess(uList, pos, uLength, alpha[aIndex], betaV); } } } @@ -76,7 +167,7 @@ namespace helpers { else broadcasted = alpha->shapeInfo(); auto step = shape::length(broadcasted); - auto shift = output->lengthOf() / step; + auto shift = output->lengthOf() * 4LL; // 2-wise greater case for uniform vals auto copyAlpha = alpha; auto copyBeta = beta; @@ -86,19 +177,21 @@ namespace helpers { copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); - copyAlpha->tickWriteDevice(); copyBeta->tickWriteDevice(); +// if (!copyAlpha->isActualOnDevice()) copyAlpha->syncToDevice(); +// if (!copyBeta->isActualOnDevice()) copyBeta->syncToDevice(); } auto stream = context->getCudaStream(); NDArray uniform = NDArrayFactory::create('c', {shift}, context); uniform.syncToDevice(); // fill up uniform with given length - RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.); - + RandomLauncher::fillUniform(context, rng, &uniform, 0.0000000001, 0.9999999999); + uniform.syncToDevice(); +// uniform.printIndexedBuffer("Uniform"); fillGammaKernel<<<128, 128, 256, *stream>>>(uniform.dataBuffer()->specialAsT(), shift, copyAlpha->dataBuffer()->specialAsT(), copyAlpha->specialShapeInfo(), - beta?copyBeta->dataBuffer()->specialAsT():(T*)nullptr, - beta?copyBeta->specialShapeInfo():(Nd4jLong*)nullptr, + beta?copyBeta->dataBuffer()->specialAsT():(T const*)nullptr, + beta?copyBeta->specialShapeInfo():(Nd4jLong const*)nullptr, output->dataBuffer()->specialAsT(), output->specialShapeInfo()); if (beta != nullptr) { diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 469cc77be..a2c33374a 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -1015,8 +1015,6 @@ TEST_F(RNGTests, Test_GammaDistribution_2) { // z->printIndexedBuffer("Gamma distribution"); ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_FALSE(exp0.equalsTo(z)); - - } TEST_F(RNGTests, Test_GammaDistribution_3) { @@ -1037,7 +1035,62 @@ TEST_F(RNGTests, Test_GammaDistribution_3) { ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_FALSE(exp0.equalsTo(z)); + +} +TEST_F(RNGTests, Test_GammaDistribution_4) { + auto x = NDArrayFactory::create('c', {2}, {1000, 1000}); + auto al = NDArrayFactory::create(2.f); + auto be = NDArrayFactory::create(2.f); + auto exp0 = NDArrayFactory::create('c', {1000, 1000}); + +// al.linspace(1.0); +// be.assign(2.0); + + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + sd::ops::reduce_mean testOps1; + sd::ops::reduce_variance testOps2; + auto testRes1 = testOps1.evaluate({z}); + auto testRes2 = testOps2.evaluate({z}); +// testRes1[0]->printBuffer("Mean (expected 1.0)"); +// testRes2[0]->printBuffer("Variance (expected 0.5)"); + ASSERT_NEAR(testRes1[0]->t(0), 1.0f, 0.01); + ASSERT_NEAR(testRes2[0]->t(0), 0.5f, 0.02); +} + +TEST_F(RNGTests, Test_GammaDistribution_5) { + auto x = NDArrayFactory::create('c', {2}, {100, 100}); + auto al = NDArrayFactory::create(0.2f); + auto be = NDArrayFactory::create(2.f); + auto exp0 = NDArrayFactory::create('c', {100, 100}); + +// al.linspace(1.0); +// be.assign(2.0); + + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); +// z->printIndexedBuffer("Gamma distributed"); + sd::ops::reduce_mean testOps1; + sd::ops::reduce_variance testOps2; + auto testRes1 = testOps1.evaluate({z}); + auto testRes2 = testOps2.evaluate({z}); +// testRes1[0]->printBuffer("Mean (expected 0.1)"); +// testRes2[0]->printBuffer("Variance (expected 0.05)"); + ASSERT_NEAR(testRes1[0]->t(0), 0.1f, 0.02); + ASSERT_NEAR(testRes2[0]->t(0), 0.05f, 0.02); } TEST_F(RNGTests, Test_UniformDistribution_04) { @@ -1055,7 +1108,6 @@ TEST_F(RNGTests, Test_UniformDistribution_04) { ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_FALSE(exp0.equalsTo(z)); - } TEST_F(RNGTests, Test_UniformDistribution_05) { @@ -1237,7 +1289,6 @@ TEST_F(RNGTests, test_multinomial_1) { ASSERT_EQ(Status::OK(), result.status()); ASSERT_TRUE(expectedZ.isSameShape(outputZ)); ASSERT_TRUE(expectedZ.equalsTo(outputZ)); - } TEST_F(RNGTests, test_multinomial_2) { @@ -1314,7 +1365,6 @@ TEST_F(RNGTests, test_multinomial_5) { RandomGenerator rng(1234, 1234); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false)); - auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); auto mean = output.meanNumber(); // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); @@ -1386,7 +1436,6 @@ TEST_F(RNGTests, test_multinomial_6) { ASSERT_NEAR(2.906, mean.e(0), 45e-3); // 1000000 35e-3); - RandomGenerator rng(1234, 1234); NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32); NDArray output('c', { batchValue, Samples }, sd::DataType::INT64); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index 4e885db96..c6bb6743c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -30,6 +30,7 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.random.custom.*; import org.nd4j.linalg.api.ops.random.impl.*; @@ -1479,14 +1480,22 @@ public class RandomTests extends BaseNd4jTest { @Test public void testGamma(){ Nd4j.getRandom().setSeed(12345); - INDArray shape = Nd4j.createFromArray(new int[] {1,3}); - INDArray alpha = Nd4j.rand(1,3); - val randomGamma = new RandomGamma(shape, alpha, null); + INDArray shape = Nd4j.createFromArray(new int[] {1000,1000}); + INDArray alpha = Nd4j.createFromArray(new float[]{2.f}); + INDArray beta = Nd4j.createFromArray(new float[]{2.f}); + val randomGamma = new RandomGamma(shape, alpha, beta); INDArray[] res = Nd4j.exec(randomGamma); - val randomGamma1 = new RandomGamma(shape, alpha, null); + val randomGamma1 = new RandomGamma(shape, alpha, beta); INDArray[] res1 = Nd4j.exec(randomGamma1); - assertEquals(res[0], res1[0]); + + val meanOp0 = new Mean(res[0]); + val meanOp1 = new Mean(res1[0]); + + INDArray mean0 = Nd4j.exec(meanOp0); + INDArray mean1 = Nd4j.exec(meanOp1); + + assertArrayEquals(mean0.toFloatVector(), mean1.toFloatVector(), 1e-2f); } @Test