- new NDArrayFactory scalar constructor

- minor tweak in randomuniform
- one more test

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-08 08:49:41 +03:00
parent 51f3a1371d
commit 929c1dc5c7
5 changed files with 21 additions and 5 deletions

View File

@ -59,6 +59,7 @@ namespace nd4j {
template <typename T> template <typename T>
static NDArray* create_(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); static NDArray* create_(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* create_(nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
template <typename T> template <typename T>
static NDArray create(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); static NDArray create(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());

View File

@ -422,6 +422,11 @@ NDArray NDArrayFactory::create(nd4j::DataType dtype, nd4j::LaunchContext * conte
return res; return res;
} }
NDArray* NDArrayFactory::create_(nd4j::DataType dtype, nd4j::LaunchContext * context) {
auto result = new NDArray();
*result = NDArrayFactory::create(dtype, context);
return result;
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>

View File

@ -49,10 +49,10 @@ namespace nd4j {
bool disposable = false; bool disposable = false;
if (min == nullptr && max == nullptr && block.numT() >= 2) { if (min == nullptr && max == nullptr && block.numT() >= 2) {
min = NDArrayFactory::create_('c', {}, dtype); min = NDArrayFactory::create_(dtype);
max = NDArrayFactory::create_('c', {}, dtype); max = NDArrayFactory::create_(dtype);
min->assign(T_ARG(0)); min->p(0, T_ARG(0));
max->assign(T_ARG(1)); max->p(0, T_ARG(1));
disposable = true; disposable = true;
} }

View File

@ -131,7 +131,7 @@ namespace helpers {
template <typename T> template <typename T>
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
T minVal = T(0); T minVal = T(0);
T maxVal = DataTypeUtils::infOrMax<T>(); T maxVal = DataTypeUtils::max<T>();
if (min) if (min)
minVal = min->t<T>(0); minVal = min->t<T>(0);
if (max) if (max)

View File

@ -1001,3 +1001,13 @@ TEST_F(RNGTests, test_choice_1) {
delete x; delete x;
delete prob; delete prob;
} }
TEST_F(RNGTests, test_uniform_119) {
auto x = NDArrayFactory::create<int>('c', {2}, {1, 5});
auto z = NDArrayFactory::create<float>('c', {1, 5});
nd4j::ops::randomuniform op;
auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {});
ASSERT_EQ(Status::OK(), status);
}