Shugeo random expo fix2 (#295)
* Refactored exponential distribution implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored exponential distribution and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored test to new result sets. Signed-off-by: shugeo <sgazeos@gmail.com>master
parent
2497290cb0
commit
5dae4069cf
|
@ -119,13 +119,15 @@ namespace randomOps {
|
|||
|
||||
random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) {
|
||||
T lambda = extraParams[0];
|
||||
T x = helper->relativeT(idx, -sd::DataTypeUtils::template max<T>() / 10 , sd::DataTypeUtils::template max<T>() / 10);
|
||||
return x <= (T)0.f ? (T)0.f : (T)1.f - sd::math::nd4j_pow<T, T, T>((T) M_E, -(lambda * x));
|
||||
T x = helper->relativeT<T>(idx); //, T(0.f) , max);
|
||||
T xVal = -sd::math::nd4j_log<T,T>(T(1.f) - x);
|
||||
|
||||
return xVal <= (T)0.f ? (T)0.f : xVal / lambda; //pow<T, T, T>((T) M_E, -(lambda * x));
|
||||
}
|
||||
|
||||
random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) {
|
||||
T lambda = extraParams[0];
|
||||
return valueX <= (T)0.f ? (T)0.f : (T)1.f - sd::math::nd4j_pow<T, T, T>((T) M_E, -(lambda * valueX));
|
||||
return valueX <= (T)0.f ? (T)0.f : (T)(valueX/lambda); //1.f - sd::math::nd4j_exp<T,T>(-lambda * valueX); //pow<T, T, T>((T) M_E, -(lambda * valueX));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -93,6 +93,21 @@ TEST_F(RNGTests, TestSeeds_2) {
|
|||
ASSERT_EQ(456, generator.nodeState());
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, TestGenerator_SGA_1) {
|
||||
RandomGenerator generator(12, 13);
|
||||
auto array= NDArrayFactory::create<float>('c',{10000000});
|
||||
generator.setStates(123L, 456L);
|
||||
for (auto idx = 0; idx < array.lengthOf(); idx++) {
|
||||
float x = generator.relativeT(idx, -sd::DataTypeUtils::template max<float>() / 10,
|
||||
sd::DataTypeUtils::template max<float>() / 10);
|
||||
array.t<float>(idx) = x;
|
||||
}
|
||||
auto minimum = array.reduceNumber(reduce::AMin);
|
||||
minimum.printBuffer("Randomly float min on 1M array");
|
||||
ASSERT_EQ(123, generator.rootState());
|
||||
ASSERT_EQ(456, generator.nodeState());
|
||||
}
|
||||
|
||||
|
||||
TEST_F(RNGTests, Test_Dropout_1) {
|
||||
auto x0 = NDArrayFactory::create<float>('c', {10, 10});
|
||||
|
@ -573,6 +588,15 @@ TEST_F(RNGTests, Test_Uniform_2) {
|
|||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_Uniform_SGA_3) {
|
||||
//auto input = NDArrayFactory::create<Nd4jLong>('c', {1, 2}, {10, 10});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {10, 10});
|
||||
|
||||
RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, -sd::DataTypeUtils::template max<float>(), sd::DataTypeUtils::template max<float>());
|
||||
auto minimumU = x1.reduceNumber(reduce::AMin);
|
||||
minimumU.printBuffer("\nMinimum");
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_Gaussian_2) {
|
||||
auto input = NDArrayFactory::create<Nd4jLong>('c', {1, 2}, {10, 10});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {10, 10});
|
||||
|
@ -728,8 +752,37 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
|
|||
auto z = result.at(0);
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
//
|
||||
z->printBuffer("\nExponential1");
|
||||
auto mean = z->reduceNumber(reduce::Mean);
|
||||
auto variance = z->varianceNumber(variance::SummaryStatsVariance, false);
|
||||
mean.printBuffer("Mean for exponential with param 0.25 (4 exp) is");
|
||||
variance.printBuffer("Variance for exponential with param 0.25 (16 exp) is");
|
||||
ASSERT_FALSE(nexp0->equalsTo(z));
|
||||
ASSERT_FALSE(nexp1->equalsTo(z));
|
||||
ASSERT_FALSE(nexp2->equalsTo(z));
|
||||
|
||||
// delete result;
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) {
|
||||
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {10, 10});
|
||||
auto exp0 = NDArrayFactory::create<float>('c', {10, 10});
|
||||
|
||||
|
||||
sd::ops::random_exponential op;
|
||||
auto result = op.evaluate({&x}, {1.f}, {0});
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result.at(0);
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
//
|
||||
z->printBuffer("\nExponential2");
|
||||
auto mean = z->reduceNumber(reduce::Mean);
|
||||
auto variance = z->varianceNumber(variance::SummaryStatsVariance, false);
|
||||
mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is");
|
||||
variance.printBuffer("Variance for exponential with param 1. (1 exp) is");
|
||||
ASSERT_FALSE(nexp0->equalsTo(z));
|
||||
ASSERT_FALSE(nexp1->equalsTo(z));
|
||||
ASSERT_FALSE(nexp2->equalsTo(z));
|
||||
|
|
Loading…
Reference in New Issue