Shugeo random expo fix2 (#295)

* Refactored exponential distribution implementation.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored exponential distribution and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored test to new result sets.

Signed-off-by: shugeo <sgazeos@gmail.com>
master
shugeo 2020-03-20 10:33:20 +02:00 committed by GitHub
parent 2497290cb0
commit 5dae4069cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 3 deletions

View File

@ -119,13 +119,15 @@ namespace randomOps {
random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) {
T lambda = extraParams[0];
T x = helper->relativeT(idx, -sd::DataTypeUtils::template max<T>() / 10 , sd::DataTypeUtils::template max<T>() / 10);
return x <= (T)0.f ? (T)0.f : (T)1.f - sd::math::nd4j_pow<T, T, T>((T) M_E, -(lambda * x));
T x = helper->relativeT<T>(idx); //, T(0.f) , max);
T xVal = -sd::math::nd4j_log<T,T>(T(1.f) - x);
return xVal <= (T)0.f ? (T)0.f : xVal / lambda; //pow<T, T, T>((T) M_E, -(lambda * x));
}
random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) {
T lambda = extraParams[0];
return valueX <= (T)0.f ? (T)0.f : (T)1.f - sd::math::nd4j_pow<T, T, T>((T) M_E, -(lambda * valueX));
return valueX <= (T)0.f ? (T)0.f : (T)(valueX/lambda); //1.f - sd::math::nd4j_exp<T,T>(-lambda * valueX); //pow<T, T, T>((T) M_E, -(lambda * valueX));
}
};

View File

@ -93,6 +93,21 @@ TEST_F(RNGTests, TestSeeds_2) {
ASSERT_EQ(456, generator.nodeState());
}
TEST_F(RNGTests, TestGenerator_SGA_1) {
RandomGenerator generator(12, 13);
auto array= NDArrayFactory::create<float>('c',{10000000});
generator.setStates(123L, 456L);
for (auto idx = 0; idx < array.lengthOf(); idx++) {
float x = generator.relativeT(idx, -sd::DataTypeUtils::template max<float>() / 10,
sd::DataTypeUtils::template max<float>() / 10);
array.t<float>(idx) = x;
}
auto minimum = array.reduceNumber(reduce::AMin);
minimum.printBuffer("Randomly float min on 1M array");
ASSERT_EQ(123, generator.rootState());
ASSERT_EQ(456, generator.nodeState());
}
TEST_F(RNGTests, Test_Dropout_1) {
auto x0 = NDArrayFactory::create<float>('c', {10, 10});
@ -573,6 +588,15 @@ TEST_F(RNGTests, Test_Uniform_2) {
}
TEST_F(RNGTests, Test_Uniform_SGA_3) {
//auto input = NDArrayFactory::create<Nd4jLong>('c', {1, 2}, {10, 10});
auto x1 = NDArrayFactory::create<float>('c', {10, 10});
RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, -sd::DataTypeUtils::template max<float>(), sd::DataTypeUtils::template max<float>());
auto minimumU = x1.reduceNumber(reduce::AMin);
minimumU.printBuffer("\nMinimum");
}
TEST_F(RNGTests, Test_Gaussian_2) {
auto input = NDArrayFactory::create<Nd4jLong>('c', {1, 2}, {10, 10});
auto x1 = NDArrayFactory::create<float>('c', {10, 10});
@ -728,8 +752,37 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
auto z = result.at(0);
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
//
z->printBuffer("\nExponential1");
auto mean = z->reduceNumber(reduce::Mean);
auto variance = z->varianceNumber(variance::SummaryStatsVariance, false);
mean.printBuffer("Mean for exponential with param 0.25 (4 exp) is");
variance.printBuffer("Variance for exponential with param 0.25 (16 exp) is");
ASSERT_FALSE(nexp0->equalsTo(z));
ASSERT_FALSE(nexp1->equalsTo(z));
ASSERT_FALSE(nexp2->equalsTo(z));
// delete result;
}
TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) {
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {10, 10});
auto exp0 = NDArrayFactory::create<float>('c', {10, 10});
sd::ops::random_exponential op;
auto result = op.evaluate({&x}, {1.f}, {0});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
//
z->printBuffer("\nExponential2");
auto mean = z->reduceNumber(reduce::Mean);
auto variance = z->varianceNumber(variance::SummaryStatsVariance, false);
mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is");
variance.printBuffer("Variance for exponential with param 1. (1 exp) is");
ASSERT_FALSE(nexp0->equalsTo(z));
ASSERT_FALSE(nexp1->equalsTo(z));
ASSERT_FALSE(nexp2->equalsTo(z));