From 5dae4069cf788a9ea1c34bc675548d2aed5ad517 Mon Sep 17 00:00:00 2001 From: shugeo Date: Fri, 20 Mar 2020 10:33:20 +0200 Subject: [PATCH] Shugeo random expo fix2 (#295) * Refactored exponential distribution implementation. Signed-off-by: shugeo * Refactored exponential distribution and tests. Signed-off-by: shugeo * Refactored test to new result sets. Signed-off-by: shugeo --- libnd4j/include/ops/random_ops.h | 8 ++-- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 53 +++++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/libnd4j/include/ops/random_ops.h b/libnd4j/include/ops/random_ops.h index 844f88ed3..d16b4f68a 100644 --- a/libnd4j/include/ops/random_ops.h +++ b/libnd4j/include/ops/random_ops.h @@ -119,13 +119,15 @@ namespace randomOps { random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { T lambda = extraParams[0]; - T x = helper->relativeT(idx, -sd::DataTypeUtils::template max() / 10 , sd::DataTypeUtils::template max() / 10); - return x <= (T)0.f ? (T)0.f : (T)1.f - sd::math::nd4j_pow((T) M_E, -(lambda * x)); + T x = helper->relativeT(idx); //, T(0.f) , max); + T xVal = -sd::math::nd4j_log(T(1.f) - x); + + return xVal <= (T)0.f ? (T)0.f : xVal / lambda; //pow((T) M_E, -(lambda * x)); } random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { T lambda = extraParams[0]; - return valueX <= (T)0.f ? (T)0.f : (T)1.f - sd::math::nd4j_pow((T) M_E, -(lambda * valueX)); + return valueX <= (T)0.f ? (T)0.f : (T)(valueX/lambda); //1.f - sd::math::nd4j_exp(-lambda * valueX); //pow((T) M_E, -(lambda * valueX)); } }; diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 64ab1781d..889e194a6 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -93,6 +93,21 @@ TEST_F(RNGTests, TestSeeds_2) { ASSERT_EQ(456, generator.nodeState()); } +TEST_F(RNGTests, TestGenerator_SGA_1) { + RandomGenerator generator(12, 13); + auto array= NDArrayFactory::create('c',{10000000}); + generator.setStates(123L, 456L); + for (auto idx = 0; idx < array.lengthOf(); idx++) { + float x = generator.relativeT(idx, -sd::DataTypeUtils::template max() / 10, + sd::DataTypeUtils::template max() / 10); + array.t(idx) = x; + } + auto minimum = array.reduceNumber(reduce::AMin); + minimum.printBuffer("Randomly float min on 1M array"); + ASSERT_EQ(123, generator.rootState()); + ASSERT_EQ(456, generator.nodeState()); +} + TEST_F(RNGTests, Test_Dropout_1) { auto x0 = NDArrayFactory::create('c', {10, 10}); @@ -573,6 +588,15 @@ TEST_F(RNGTests, Test_Uniform_2) { } +TEST_F(RNGTests, Test_Uniform_SGA_3) { + //auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, -sd::DataTypeUtils::template max(), sd::DataTypeUtils::template max()); + auto minimumU = x1.reduceNumber(reduce::AMin); + minimumU.printBuffer("\nMinimum"); +} + TEST_F(RNGTests, Test_Gaussian_2) { auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); auto x1 = NDArrayFactory::create('c', {10, 10}); @@ -728,8 +752,37 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) { auto z = result.at(0); ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_FALSE(exp0.equalsTo(z)); + // + z->printBuffer("\nExponential1"); + auto mean = z->reduceNumber(reduce::Mean); + auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 0.25 (4 exp) is"); + variance.printBuffer("Variance for exponential with param 0.25 (16 exp) is"); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + +// delete result; +} + +TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) { + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + sd::ops::random_exponential op; + auto result = op.evaluate({&x}, {1.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + // + z->printBuffer("\nExponential2"); + auto mean = z->reduceNumber(reduce::Mean); + auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); ASSERT_FALSE(nexp0->equalsTo(z)); ASSERT_FALSE(nexp1->equalsTo(z)); ASSERT_FALSE(nexp2->equalsTo(z));