Merge remote-tracking branch 'origin/master'
commit
1b8077f66a
|
@ -27,29 +27,8 @@
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(random_exponential, 1, 1, true, 1, 0) {
|
CUSTOM_OP_IMPL(random_exponential, 1, 1, true, 1, 0) {
|
||||||
// uniform distribution
|
// random generator for distribution
|
||||||
auto rng = block.randomGenerator();
|
auto rng = block.randomGenerator();
|
||||||
|
|
||||||
// FIXME: to be implemented
|
|
||||||
/*
|
|
||||||
if (rng == nullptr)
|
|
||||||
return Status::THROW("RNG is null, aborting...");
|
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
if (block.width() == 1)
|
|
||||||
functions::random::RandomFunction<T>::template execTransform<randomOps::ExponentialDistribution<T>>(block.getRNG(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data());
|
|
||||||
else {
|
|
||||||
auto y = INPUT_VARIABLE(1);
|
|
||||||
REQUIRE_TRUE(y->isSameShape(z), 0, "ExponentialDistribution: Y shape should be equal to Z shape");
|
|
||||||
|
|
||||||
functions::random::RandomFunction<T>::template execTransform<randomOps::ExponentialDistribution<T>>(block.getRNG(), y->getBuffer(), y->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data());
|
|
||||||
}
|
|
||||||
|
|
||||||
STORE_RESULT(*z);
|
|
||||||
*/
|
|
||||||
|
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
auto lambda = T_ARG(0);
|
auto lambda = T_ARG(0);
|
||||||
|
|
||||||
|
|
|
@ -119,8 +119,8 @@ namespace randomOps {
|
||||||
|
|
||||||
random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) {
|
random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) {
|
||||||
T lambda = extraParams[0];
|
T lambda = extraParams[0];
|
||||||
T x = helper->relativeT<T>(idx); //, T(0.f) , max);
|
T x = helper->relativeT<T>(idx, sd::DataTypeUtils::min<T>(), T(1.f) - sd::DataTypeUtils::template min<T>()); // x from (0, 1) without bounds
|
||||||
T xVal = -sd::math::nd4j_log<T,T>(T(1.f) - x);
|
T xVal = -sd::math::nd4j_log<T,T>(x);
|
||||||
|
|
||||||
return xVal <= (T)0.f ? (T)0.f : xVal / lambda; //pow<T, T, T>((T) M_E, -(lambda * x));
|
return xVal <= (T)0.f ? (T)0.f : xVal / lambda; //pow<T, T, T>((T) M_E, -(lambda * x));
|
||||||
}
|
}
|
||||||
|
@ -270,7 +270,7 @@ namespace randomOps {
|
||||||
|
|
||||||
random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) {
|
random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) {
|
||||||
T lambda = extraParams[0];
|
T lambda = extraParams[0];
|
||||||
T x = helper->relativeT(idx, sd::DataTypeUtils::template min<T>(), (T)1.f);
|
T x = helper->relativeT(idx, sd::DataTypeUtils::template min<T>(), (T)1.f - sd::DataTypeUtils::template min<T>());
|
||||||
return -sd::math::nd4j_log<T, T>((T)1.f - x) / lambda;
|
return -sd::math::nd4j_log<T, T>((T)1.f - x) / lambda;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -790,6 +790,75 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, Test_ExponentialDistribution_2_SGA) {
|
||||||
|
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {10, 10});
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
RandomGenerator oc(2716049175077475646L, -6182841917129177862L);
|
||||||
|
|
||||||
|
sd::ops::random_exponential op;
|
||||||
|
RandomLauncher::fillExponential(x.getContext(), oc, &exp0, 2.f);
|
||||||
|
auto result = op.evaluate({&x}, {1.f}, {0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
//
|
||||||
|
// z->printBuffer("\nExponential2+");
|
||||||
|
auto mean = z->reduceNumber(reduce::Mean);
|
||||||
|
auto variance = z->varianceNumber(variance::SummaryStatsVariance, false);
|
||||||
|
mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is");
|
||||||
|
variance.printBuffer("Variance for exponential with param 1. (1 exp) is");
|
||||||
|
ASSERT_FALSE(nexp0->equalsTo(z));
|
||||||
|
ASSERT_FALSE(nexp1->equalsTo(z));
|
||||||
|
ASSERT_FALSE(nexp2->equalsTo(z));
|
||||||
|
mean = exp0.reduceNumber(reduce::Mean);
|
||||||
|
variance = exp0.varianceNumber(variance::SummaryStatsVariance, false);
|
||||||
|
mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is");
|
||||||
|
variance.printBuffer("Variance for exponential with param 2. (1/2 exp) is");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, Test_ExponentialDistribution_3_SGA) {
|
||||||
|
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {1000, 1000});
|
||||||
|
auto exp0 = NDArrayFactory::create<double>('c', {1000, 1000});
|
||||||
|
RandomGenerator oc(2716049175077475646L, -6182841917129177862L);
|
||||||
|
auto expMean = NDArrayFactory::create<double>(0.5f);
|
||||||
|
auto expVar = NDArrayFactory::create<double>(0.25f);
|
||||||
|
sd::ops::random_exponential op;
|
||||||
|
RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 2.f);
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x}, {1.});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
//ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
|
//ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
//
|
||||||
|
// z->printBuffer("\nExponential2+");
|
||||||
|
auto mean = z->reduceNumber(reduce::Mean);
|
||||||
|
auto variance = z->varianceNumber(variance::SummaryStatsVariance, false);
|
||||||
|
mean.printBuffer("Mean");
|
||||||
|
variance.printBuffer("Variance");
|
||||||
|
ASSERT_NEAR(mean.e<double>(0), 1.f, 1.e-2f);
|
||||||
|
ASSERT_NEAR(variance.e<double>(0), 1.f, 1.e-2f);
|
||||||
|
// mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is");
|
||||||
|
// variance.printBuffer("Variance for exponential with param 1. (1 exp) is");
|
||||||
|
// ASSERT_FALSE(nexp0->equalsTo(z));
|
||||||
|
// ASSERT_FALSE(nexp1->equalsTo(z));
|
||||||
|
// ASSERT_FALSE(nexp2->equalsTo(z));
|
||||||
|
mean = exp0.reduceNumber(reduce::Mean);
|
||||||
|
variance = exp0.varianceNumber(variance::SummaryStatsVariance, false);
|
||||||
|
mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is");
|
||||||
|
variance.printBuffer("Variance for exponential with param 2. (1/4 exp) is");
|
||||||
|
ASSERT_TRUE(mean.equalsTo(expMean, 1.e-3));
|
||||||
|
ASSERT_TRUE(variance.equalsTo(expVar, 1.e-3));
|
||||||
|
RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 1.f);
|
||||||
|
mean = exp0.reduceNumber(reduce::Mean);
|
||||||
|
variance = exp0.varianceNumber(variance::SummaryStatsVariance, false);
|
||||||
|
mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is");
|
||||||
|
variance.printBuffer("Variance for exponential with param 1.0 (1 exp) is");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_ExponentialDistribution_2) {
|
TEST_F(RNGTests, Test_ExponentialDistribution_2) {
|
||||||
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {10, 10});
|
auto x = NDArrayFactory::create<Nd4jLong>('c', {2}, {10, 10});
|
||||||
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
|
Loading…
Reference in New Issue