From fc3f5d4ffb04c829cf7ba581c1b80c5f64aae304 Mon Sep 17 00:00:00 2001 From: shugeo Date: Wed, 22 Apr 2020 12:12:00 +0300 Subject: [PATCH] Shugeo exponential distribution infinities fix (#403) * Fixed bound problem with Exponential distribution implementation. Signed-off-by: shugeo * Added test for Exponential distribution to avoid infinities. Signed-off-by: shugeo * Added a test for exponential distribution with 1M elements. Signed-off-by: shugeo * Cosmetical changes only and tests. Signed-off-by: shugeo * Modified test and implementation for exponential_distribution op. Signed-off-by: shugeo Co-authored-by: raver119 --- .../declarable/generic/random/exponential.cpp | 23 +------ libnd4j/include/ops/random_ops.h | 6 +- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 69 +++++++++++++++++++ 3 files changed, 73 insertions(+), 25 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/random/exponential.cpp b/libnd4j/include/ops/declarable/generic/random/exponential.cpp index 8605ffafe..cac3d1a88 100644 --- a/libnd4j/include/ops/declarable/generic/random/exponential.cpp +++ b/libnd4j/include/ops/declarable/generic/random/exponential.cpp @@ -27,29 +27,8 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(random_exponential, 1, 1, true, 1, 0) { - // uniform distribution + // random generator for distribution auto rng = block.randomGenerator(); - - // FIXME: to be implemented - /* - if (rng == nullptr) - return Status::THROW("RNG is null, aborting..."); - - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - if (block.width() == 1) - functions::random::RandomFunction::template execTransform>(block.getRNG(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data()); - else { - auto y = INPUT_VARIABLE(1); - REQUIRE_TRUE(y->isSameShape(z), 0, "ExponentialDistribution: Y shape should be equal to Z shape"); - - functions::random::RandomFunction::template execTransform>(block.getRNG(), y->getBuffer(), y->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data()); - } - - STORE_RESULT(*z); -*/ - auto z = OUTPUT_VARIABLE(0); auto lambda = T_ARG(0); diff --git a/libnd4j/include/ops/random_ops.h b/libnd4j/include/ops/random_ops.h index d16b4f68a..939ffa975 100644 --- a/libnd4j/include/ops/random_ops.h +++ b/libnd4j/include/ops/random_ops.h @@ -119,8 +119,8 @@ namespace randomOps { random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { T lambda = extraParams[0]; - T x = helper->relativeT(idx); //, T(0.f) , max); - T xVal = -sd::math::nd4j_log(T(1.f) - x); + T x = helper->relativeT(idx, sd::DataTypeUtils::min(), T(1.f) - sd::DataTypeUtils::template min()); // x from (0, 1) without bounds + T xVal = -sd::math::nd4j_log(x); return xVal <= (T)0.f ? (T)0.f : xVal / lambda; //pow((T) M_E, -(lambda * x)); } @@ -270,7 +270,7 @@ namespace randomOps { random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { T lambda = extraParams[0]; - T x = helper->relativeT(idx, sd::DataTypeUtils::template min(), (T)1.f); + T x = helper->relativeT(idx, sd::DataTypeUtils::template min(), (T)1.f - sd::DataTypeUtils::template min()); return -sd::math::nd4j_log((T)1.f - x) / lambda; } diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 889e194a6..56ca6b95e 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -790,6 +790,75 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) { } +TEST_F(RNGTests, Test_ExponentialDistribution_2_SGA) { + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + RandomGenerator oc(2716049175077475646L, -6182841917129177862L); + + sd::ops::random_exponential op; + RandomLauncher::fillExponential(x.getContext(), oc, &exp0, 2.f); + auto result = op.evaluate({&x}, {1.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + // +// z->printBuffer("\nExponential2+"); + auto mean = z->reduceNumber(reduce::Mean); + auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); + variance.printBuffer("Variance for exponential with param 2. (1/2 exp) is"); +} + +TEST_F(RNGTests, Test_ExponentialDistribution_3_SGA) { + auto x = NDArrayFactory::create('c', {2}, {1000, 1000}); + auto exp0 = NDArrayFactory::create('c', {1000, 1000}); + RandomGenerator oc(2716049175077475646L, -6182841917129177862L); + auto expMean = NDArrayFactory::create(0.5f); + auto expVar = NDArrayFactory::create(0.25f); + sd::ops::random_exponential op; + RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 2.f); + + auto result = op.evaluate({&x}, {1.}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + //ASSERT_TRUE(exp0.isSameShape(z)); + //ASSERT_FALSE(exp0.equalsTo(z)); + // +// z->printBuffer("\nExponential2+"); + auto mean = z->reduceNumber(reduce::Mean); + auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean"); + variance.printBuffer("Variance"); + ASSERT_NEAR(mean.e(0), 1.f, 1.e-2f); + ASSERT_NEAR(variance.e(0), 1.f, 1.e-2f); +// mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); +// variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); +// ASSERT_FALSE(nexp0->equalsTo(z)); +// ASSERT_FALSE(nexp1->equalsTo(z)); +// ASSERT_FALSE(nexp2->equalsTo(z)); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); + variance.printBuffer("Variance for exponential with param 2. (1/4 exp) is"); + ASSERT_TRUE(mean.equalsTo(expMean, 1.e-3)); + ASSERT_TRUE(variance.equalsTo(expVar, 1.e-3)); + RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 1.f); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1.0 (1 exp) is"); +} + TEST_F(RNGTests, Test_ExponentialDistribution_2) { auto x = NDArrayFactory::create('c', {2}, {10, 10}); auto y = NDArrayFactory::create('c', {10, 10});