[WIP] Random Uniform (#36)

* args

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* T args

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-07 17:09:47 +03:00 committed by GitHub
parent 24980efde3
commit 51f3a1371d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 12 deletions

View File

@ -46,11 +46,25 @@ namespace nd4j {
auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr; auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr;
auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr; auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr;
bool disposable = false;
if (min == nullptr && max == nullptr && block.numT() >= 2) {
min = NDArrayFactory::create_('c', {}, dtype);
max = NDArrayFactory::create_('c', {}, dtype);
min->assign(T_ARG(0));
max->assign(T_ARG(1));
disposable = true;
}
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given."); REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
helpers::fillRandomUniform(block.launchContext(), rng, min, max, output); helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
if (disposable) {
delete min;
delete max;
}
return Status::OK(); return Status::OK();
} }

View File

@ -150,10 +150,6 @@ namespace helpers {
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES); BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
} }
} }
} }