diff --git a/libnd4j/include/helpers/impl/RandomLauncher.cpp b/libnd4j/include/helpers/impl/RandomLauncher.cpp index 8114c2ec4..f7cdd0f3a 100644 --- a/libnd4j/include/helpers/impl/RandomLauncher.cpp +++ b/libnd4j/include/helpers/impl/RandomLauncher.cpp @@ -26,8 +26,6 @@ #include namespace sd { - // FIXME: implement this - void RandomLauncher::applyDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { if (z == nullptr) z = array; @@ -35,8 +33,12 @@ namespace sd { ExtraArguments arguments({retainProb}); PointersManager pm(context, "applyDropOut"); + NDArray::prepareSpecialUse({z}, {array}); + NativeOpExecutioner::execRandom(context, random::DropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({z}, {array}); } void RandomLauncher::applyInvertedDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { @@ -46,8 +48,12 @@ namespace sd { ExtraArguments arguments({retainProb}); PointersManager pm(context, "applyInvertedDropOut"); + NDArray::prepareSpecialUse({z}, {array}); + NativeOpExecutioner::execRandom(context, random::DropOutInverted, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({z}, {array}); } void RandomLauncher::applyAlphaDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z) { @@ -57,63 +63,95 @@ namespace sd { ExtraArguments arguments({retainProb, alpha, beta, alphaPrime}); PointersManager pm(context, "applyAlphaDropOut"); + NDArray::prepareSpecialUse({z}, {array}); + NativeOpExecutioner::execRandom(context, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({z}, {array}); } void RandomLauncher::fillBernoulli(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double prob) { ExtraArguments arguments({prob}); PointersManager pm(context, "fillBernoulli"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::BernoulliDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillUniform(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double from, double to) { ExtraArguments arguments({from, to}); PointersManager pm(context, "fillUniform"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::UniformDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillGaussian(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); PointersManager pm(context, "fillGaussian"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillExponential(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double lambda) { ExtraArguments arguments({lambda}); PointersManager pm(context, "fillExponential"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::ExponentialDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillLogNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); PointersManager pm(context, "fillLogNormal"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillTruncatedNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); PointersManager pm(context, "fillTruncatedNormal"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::TruncatedNormalDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillBinomial(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, int trials, double prob) { ExtraArguments arguments({(double) trials, prob}); PointersManager pm(context, "fillBinomial"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::BinomialDistributionEx, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } }