[WIP] Random Uniform (#36)
* args Signed-off-by: raver119@gmail.com <raver119@gmail.com> * T args Signed-off-by: raver119 <raver119@gmail.com>master
parent
24980efde3
commit
51f3a1371d
|
@ -46,11 +46,25 @@ namespace nd4j {
|
||||||
|
|
||||||
auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr;
|
auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr;
|
||||||
auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr;
|
auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr;
|
||||||
|
bool disposable = false;
|
||||||
|
|
||||||
|
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
||||||
|
min = NDArrayFactory::create_('c', {}, dtype);
|
||||||
|
max = NDArrayFactory::create_('c', {}, dtype);
|
||||||
|
min->assign(T_ARG(0));
|
||||||
|
max->assign(T_ARG(1));
|
||||||
|
disposable = true;
|
||||||
|
}
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
|
REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
|
||||||
|
|
||||||
helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
|
helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
|
||||||
|
|
||||||
|
if (disposable) {
|
||||||
|
delete min;
|
||||||
|
delete max;
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -150,10 +150,6 @@ namespace helpers {
|
||||||
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
|
|
||||||
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue