From 929c1dc5c7dbfccf8691d1c100cb0ec3a1b6e114 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 8 Nov 2019 08:49:41 +0300 Subject: [PATCH] - new NDArrayFactory scalar constructor - minor tweak in randomuniform - one more test Signed-off-by: raver119 --- libnd4j/blas/NDArrayFactory.h | 1 + libnd4j/blas/cpu/NDArrayFactory.cpp | 5 +++++ .../include/ops/declarable/generic/random/uniform.cpp | 8 ++++---- libnd4j/include/ops/declarable/helpers/cpu/random.cpp | 2 +- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 10 ++++++++++ 5 files changed, 21 insertions(+), 5 deletions(-) diff --git a/libnd4j/blas/NDArrayFactory.h b/libnd4j/blas/NDArrayFactory.h index 0b1c1fbd7..cdd8d9f9f 100644 --- a/libnd4j/blas/NDArrayFactory.h +++ b/libnd4j/blas/NDArrayFactory.h @@ -59,6 +59,7 @@ namespace nd4j { template static NDArray* create_(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + static NDArray* create_(nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); template static NDArray create(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); diff --git a/libnd4j/blas/cpu/NDArrayFactory.cpp b/libnd4j/blas/cpu/NDArrayFactory.cpp index d8b686b12..b091f13b7 100644 --- a/libnd4j/blas/cpu/NDArrayFactory.cpp +++ b/libnd4j/blas/cpu/NDArrayFactory.cpp @@ -422,6 +422,11 @@ NDArray NDArrayFactory::create(nd4j::DataType dtype, nd4j::LaunchContext * conte return res; } +NDArray* NDArrayFactory::create_(nd4j::DataType dtype, nd4j::LaunchContext * context) { + auto result = new NDArray(); + *result = NDArrayFactory::create(dtype, context); + return result; +} //////////////////////////////////////////////////////////////////////// template diff --git a/libnd4j/include/ops/declarable/generic/random/uniform.cpp b/libnd4j/include/ops/declarable/generic/random/uniform.cpp index 82203bbcc..fd65f842a 100644 --- a/libnd4j/include/ops/declarable/generic/random/uniform.cpp +++ b/libnd4j/include/ops/declarable/generic/random/uniform.cpp @@ -49,10 +49,10 @@ namespace nd4j { bool disposable = false; if (min == nullptr && max == nullptr && block.numT() >= 2) { - min = NDArrayFactory::create_('c', {}, dtype); - max = NDArrayFactory::create_('c', {}, dtype); - min->assign(T_ARG(0)); - max->assign(T_ARG(1)); + min = NDArrayFactory::create_(dtype); + max = NDArrayFactory::create_(dtype); + min->p(0, T_ARG(0)); + max->p(0, T_ARG(1)); disposable = true; } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp index a8f108c00..3f9788330 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp @@ -131,7 +131,7 @@ namespace helpers { template void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { T minVal = T(0); - T maxVal = DataTypeUtils::infOrMax(); + T maxVal = DataTypeUtils::max(); if (min) minVal = min->t(0); if (max) diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index e81ea7964..bc4db6e63 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -1001,3 +1001,13 @@ TEST_F(RNGTests, test_choice_1) { delete x; delete prob; } + +TEST_F(RNGTests, test_uniform_119) { + auto x = NDArrayFactory::create('c', {2}, {1, 5}); + auto z = NDArrayFactory::create('c', {1, 5}); + + + nd4j::ops::randomuniform op; + auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {}); + ASSERT_EQ(Status::OK(), status); +} \ No newline at end of file