[WIP] Random Uniform (#36)

* args

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* T args

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-07 17:09:47 +03:00 committed by GitHub
parent 24980efde3
commit 51f3a1371d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 12 deletions

View File

@ -44,13 +44,27 @@ namespace nd4j {
if (block.getIArguments()->size()) if (block.getIArguments()->size())
dtype = (DataType)INT_ARG(0); dtype = (DataType)INT_ARG(0);
auto min = block.width() > 1?INPUT_VARIABLE(1):(NDArray*)nullptr; auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr;
auto max = block.width() > 2?INPUT_VARIABLE(2):(NDArray*)nullptr; auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr;
bool disposable = false;
if (min == nullptr && max == nullptr && block.numT() >= 2) {
min = NDArrayFactory::create_('c', {}, dtype);
max = NDArrayFactory::create_('c', {}, dtype);
min->assign(T_ARG(0));
max->assign(T_ARG(1));
disposable = true;
}
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given."); REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
helpers::fillRandomUniform(block.launchContext(), rng, min, max, output); helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
if (disposable) {
delete min;
delete max;
}
return Status::OK(); return Status::OK();
} }

View File

@ -130,12 +130,12 @@ namespace helpers {
template <typename T> template <typename T>
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
T minVal = T(0); T minVal = T(0);
T maxVal = DataTypeUtils::infOrMax<T>(); T maxVal = DataTypeUtils::infOrMax<T>();
if (min) if (min)
minVal = min->t<T>(0); minVal = min->t<T>(0);
if (max) if (max)
maxVal = max->t<T>(0); maxVal = max->t<T>(0);
if (output->isR()) if (output->isR())
RandomLauncher::fillUniform(context, rng, output, minVal, maxVal); RandomLauncher::fillUniform(context, rng, output, minVal, maxVal);
@ -150,10 +150,6 @@ namespace helpers {
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES); BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
} }
} }
} }