[WIP] Random Uniform (#36)
* args Signed-off-by: raver119@gmail.com <raver119@gmail.com> * T args Signed-off-by: raver119 <raver119@gmail.com>master
parent
24980efde3
commit
51f3a1371d
|
@ -44,13 +44,27 @@ namespace nd4j {
|
||||||
if (block.getIArguments()->size())
|
if (block.getIArguments()->size())
|
||||||
dtype = (DataType)INT_ARG(0);
|
dtype = (DataType)INT_ARG(0);
|
||||||
|
|
||||||
auto min = block.width() > 1?INPUT_VARIABLE(1):(NDArray*)nullptr;
|
auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr;
|
||||||
auto max = block.width() > 2?INPUT_VARIABLE(2):(NDArray*)nullptr;
|
auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr;
|
||||||
|
bool disposable = false;
|
||||||
|
|
||||||
|
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
||||||
|
min = NDArrayFactory::create_('c', {}, dtype);
|
||||||
|
max = NDArrayFactory::create_('c', {}, dtype);
|
||||||
|
min->assign(T_ARG(0));
|
||||||
|
max->assign(T_ARG(1));
|
||||||
|
disposable = true;
|
||||||
|
}
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
|
REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
|
||||||
|
|
||||||
helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
|
helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
|
||||||
|
|
||||||
|
if (disposable) {
|
||||||
|
delete min;
|
||||||
|
delete max;
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -130,12 +130,12 @@ namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||||
T minVal = T(0);
|
T minVal = T(0);
|
||||||
T maxVal = DataTypeUtils::infOrMax<T>();
|
T maxVal = DataTypeUtils::infOrMax<T>();
|
||||||
if (min)
|
if (min)
|
||||||
minVal = min->t<T>(0);
|
minVal = min->t<T>(0);
|
||||||
if (max)
|
if (max)
|
||||||
maxVal = max->t<T>(0);
|
maxVal = max->t<T>(0);
|
||||||
|
|
||||||
if (output->isR())
|
if (output->isR())
|
||||||
RandomLauncher::fillUniform(context, rng, output, minVal, maxVal);
|
RandomLauncher::fillUniform(context, rng, output, minVal, maxVal);
|
||||||
|
@ -150,10 +150,6 @@ namespace helpers {
|
||||||
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
|
|
||||||
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue