[WIP] Random Uniform (#36)
* args Signed-off-by: raver119@gmail.com <raver119@gmail.com> * T args Signed-off-by: raver119 <raver119@gmail.com>master
parent
24980efde3
commit
51f3a1371d
|
@ -44,13 +44,27 @@ namespace nd4j {
|
|||
if (block.getIArguments()->size())
|
||||
dtype = (DataType)INT_ARG(0);
|
||||
|
||||
auto min = block.width() > 1?INPUT_VARIABLE(1):(NDArray*)nullptr;
|
||||
auto max = block.width() > 2?INPUT_VARIABLE(2):(NDArray*)nullptr;
|
||||
auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr;
|
||||
auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr;
|
||||
bool disposable = false;
|
||||
|
||||
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
||||
min = NDArrayFactory::create_('c', {}, dtype);
|
||||
max = NDArrayFactory::create_('c', {}, dtype);
|
||||
min->assign(T_ARG(0));
|
||||
max->assign(T_ARG(1));
|
||||
disposable = true;
|
||||
}
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
|
||||
|
||||
helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
|
||||
|
||||
if (disposable) {
|
||||
delete min;
|
||||
delete max;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -130,12 +130,12 @@ namespace helpers {
|
|||
|
||||
template <typename T>
|
||||
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||
T minVal = T(0);
|
||||
T maxVal = DataTypeUtils::infOrMax<T>();
|
||||
if (min)
|
||||
minVal = min->t<T>(0);
|
||||
if (max)
|
||||
maxVal = max->t<T>(0);
|
||||
T minVal = T(0);
|
||||
T maxVal = DataTypeUtils::infOrMax<T>();
|
||||
if (min)
|
||||
minVal = min->t<T>(0);
|
||||
if (max)
|
||||
maxVal = max->t<T>(0);
|
||||
|
||||
if (output->isR())
|
||||
RandomLauncher::fillUniform(context, rng, output, minVal, maxVal);
|
||||
|
@ -150,10 +150,6 @@ namespace helpers {
|
|||
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
|
||||
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue