Merge remote-tracking branch 'origin/master'

master
raver119 2020-04-22 12:38:37 +03:00
commit 1b8077f66a
3 changed files with 73 additions and 25 deletions

View File

@ -27,29 +27,8 @@
namespace sd {
namespace ops {
CUSTOM_OP_IMPL(random_exponential, 1, 1, true, 1, 0) {
// uniform distribution
// random generator for distribution
auto rng = block.randomGenerator();
// FIXME: to be implemented
/*
if (rng == nullptr)
return Status::THROW("RNG is null, aborting...");
auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
if (block.width() == 1)
functions::random::RandomFunction<T>::template execTransform<randomOps::ExponentialDistribution<T>>(block.getRNG(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data());
else {
auto y = INPUT_VARIABLE(1);
REQUIRE_TRUE(y->isSameShape(z), 0, "ExponentialDistribution: Y shape should be equal to Z shape");
functions::random::RandomFunction<T>::template execTransform<randomOps::ExponentialDistribution<T>>(block.getRNG(), y->getBuffer(), y->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data());
}
STORE_RESULT(*z);
*/
auto z = OUTPUT_VARIABLE(0);
auto lambda = T_ARG(0);

View File

@ -119,8 +119,8 @@ namespace randomOps {
random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) {
T lambda = extraParams[0];
T x = helper->relativeT<T>(idx); //, T(0.f) , max);
T xVal = -sd::math::nd4j_log<T,T>(T(1.f) - x);
T x = helper->relativeT<T>(idx, sd::DataTypeUtils::min<T>(), T(1.f) - sd::DataTypeUtils::template min<T>()); // x from (0, 1) without bounds
T xVal = -sd::math::nd4j_log<T,T>(x);
return xVal <= (T)0.f ? (T)0.f : xVal / lambda; //pow<T, T, T>((T) M_E, -(lambda * x));
}
@ -270,7 +270,7 @@ namespace randomOps {
random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) {
T lambda = extraParams[0];
T x = helper->relativeT(idx, sd::DataTypeUtils::template min<T>(), (T)1.f);
T x = helper->relativeT(idx, sd::DataTypeUtils::template min<T>(), (T)1.f - sd::DataTypeUtils::template min<T>());
return -sd::math::nd4j_log<T, T>((T)1.f - x) / lambda;
}

View File

@ -790,6 +790,75 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) {
}
TEST_F(RNGTests, Test_ExponentialDistribution_2_SGA) {
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {10, 10});
auto exp0 = NDArrayFactory::create<float>('c', {10, 10});
RandomGenerator oc(2716049175077475646L, -6182841917129177862L);
sd::ops::random_exponential op;
RandomLauncher::fillExponential(x.getContext(), oc, &exp0, 2.f);
auto result = op.evaluate({&x}, {1.f}, {0});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
//
// z->printBuffer("\nExponential2+");
auto mean = z->reduceNumber(reduce::Mean);
auto variance = z->varianceNumber(variance::SummaryStatsVariance, false);
mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is");
variance.printBuffer("Variance for exponential with param 1. (1 exp) is");
ASSERT_FALSE(nexp0->equalsTo(z));
ASSERT_FALSE(nexp1->equalsTo(z));
ASSERT_FALSE(nexp2->equalsTo(z));
mean = exp0.reduceNumber(reduce::Mean);
variance = exp0.varianceNumber(variance::SummaryStatsVariance, false);
mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is");
variance.printBuffer("Variance for exponential with param 2. (1/2 exp) is");
}
TEST_F(RNGTests, Test_ExponentialDistribution_3_SGA) {
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {1000, 1000});
auto exp0 = NDArrayFactory::create<double>('c', {1000, 1000});
RandomGenerator oc(2716049175077475646L, -6182841917129177862L);
auto expMean = NDArrayFactory::create<double>(0.5f);
auto expVar = NDArrayFactory::create<double>(0.25f);
sd::ops::random_exponential op;
RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 2.f);
auto result = op.evaluate({&x}, {1.});
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
//ASSERT_TRUE(exp0.isSameShape(z));
//ASSERT_FALSE(exp0.equalsTo(z));
//
// z->printBuffer("\nExponential2+");
auto mean = z->reduceNumber(reduce::Mean);
auto variance = z->varianceNumber(variance::SummaryStatsVariance, false);
mean.printBuffer("Mean");
variance.printBuffer("Variance");
ASSERT_NEAR(mean.e<double>(0), 1.f, 1.e-2f);
ASSERT_NEAR(variance.e<double>(0), 1.f, 1.e-2f);
// mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is");
// variance.printBuffer("Variance for exponential with param 1. (1 exp) is");
// ASSERT_FALSE(nexp0->equalsTo(z));
// ASSERT_FALSE(nexp1->equalsTo(z));
// ASSERT_FALSE(nexp2->equalsTo(z));
mean = exp0.reduceNumber(reduce::Mean);
variance = exp0.varianceNumber(variance::SummaryStatsVariance, false);
mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is");
variance.printBuffer("Variance for exponential with param 2. (1/4 exp) is");
ASSERT_TRUE(mean.equalsTo(expMean, 1.e-3));
ASSERT_TRUE(variance.equalsTo(expVar, 1.e-3));
RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 1.f);
mean = exp0.reduceNumber(reduce::Mean);
variance = exp0.varianceNumber(variance::SummaryStatsVariance, false);
mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is");
variance.printBuffer("Variance for exponential with param 1.0 (1 exp) is");
}
TEST_F(RNGTests, Test_ExponentialDistribution_2) {
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {10, 10});
auto y = NDArrayFactory::create<float>('c', {10, 10});