[WIP] Random Uniform (#36)
* args Signed-off-by: raver119@gmail.com <raver119@gmail.com> * T args Signed-off-by: raver119 <raver119@gmail.com>master
parent
24980efde3
commit
51f3a1371d
|
@ -46,11 +46,25 @@ namespace nd4j {
|
|||
|
||||
auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr;
|
||||
auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr;
|
||||
bool disposable = false;
|
||||
|
||||
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
||||
min = NDArrayFactory::create_('c', {}, dtype);
|
||||
max = NDArrayFactory::create_('c', {}, dtype);
|
||||
min->assign(T_ARG(0));
|
||||
max->assign(T_ARG(1));
|
||||
disposable = true;
|
||||
}
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
|
||||
|
||||
helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
|
||||
|
||||
if (disposable) {
|
||||
delete min;
|
||||
delete max;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -150,10 +150,6 @@ namespace helpers {
|
|||
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
|
||||
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue