diff --git a/libnd4j/include/ops/declarable/generic/random/uniform.cpp b/libnd4j/include/ops/declarable/generic/random/uniform.cpp index 47f7c4a0c..82203bbcc 100644 --- a/libnd4j/include/ops/declarable/generic/random/uniform.cpp +++ b/libnd4j/include/ops/declarable/generic/random/uniform.cpp @@ -44,13 +44,27 @@ namespace nd4j { if (block.getIArguments()->size()) dtype = (DataType)INT_ARG(0); - auto min = block.width() > 1?INPUT_VARIABLE(1):(NDArray*)nullptr; - auto max = block.width() > 2?INPUT_VARIABLE(2):(NDArray*)nullptr; + auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr; + auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr; + bool disposable = false; + + if (min == nullptr && max == nullptr && block.numT() >= 2) { + min = NDArrayFactory::create_('c', {}, dtype); + max = NDArrayFactory::create_('c', {}, dtype); + min->assign(T_ARG(0)); + max->assign(T_ARG(1)); + disposable = true; + } auto output = OUTPUT_VARIABLE(0); REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given."); helpers::fillRandomUniform(block.launchContext(), rng, min, max, output); + + if (disposable) { + delete min; + delete max; + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp index 124474f87..a8f108c00 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp @@ -130,12 +130,12 @@ namespace helpers { template void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { - T minVal = T(0); - T maxVal = DataTypeUtils::infOrMax(); - if (min) - minVal = min->t(0); - if (max) - maxVal = max->t(0); + T minVal = T(0); + T maxVal = DataTypeUtils::infOrMax(); + if (min) + minVal = min->t(0); + if (max) + maxVal = max->t(0); if (output->isR()) RandomLauncher::fillUniform(context, rng, output, minVal, maxVal); @@ -150,10 +150,6 @@ namespace helpers { void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES); } - - BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context, - graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES); - } } } \ No newline at end of file