- new NDArrayFactory scalar constructor

- minor tweak in randomuniform
- one more test

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-08 08:49:41 +03:00
parent 51f3a1371d
commit 929c1dc5c7
5 changed files with 21 additions and 5 deletions

View File

@ -59,6 +59,7 @@ namespace nd4j {
template <typename T>
static NDArray* create_(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* create_(nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
template <typename T>
static NDArray create(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());

View File

@ -422,6 +422,11 @@ NDArray NDArrayFactory::create(nd4j::DataType dtype, nd4j::LaunchContext * conte
return res;
}
NDArray* NDArrayFactory::create_(nd4j::DataType dtype, nd4j::LaunchContext * context) {
auto result = new NDArray();
*result = NDArrayFactory::create(dtype, context);
return result;
}
////////////////////////////////////////////////////////////////////////
template <typename T>

View File

@ -49,10 +49,10 @@ namespace nd4j {
bool disposable = false;
if (min == nullptr && max == nullptr && block.numT() >= 2) {
min = NDArrayFactory::create_('c', {}, dtype);
max = NDArrayFactory::create_('c', {}, dtype);
min->assign(T_ARG(0));
max->assign(T_ARG(1));
min = NDArrayFactory::create_(dtype);
max = NDArrayFactory::create_(dtype);
min->p(0, T_ARG(0));
max->p(0, T_ARG(1));
disposable = true;
}

View File

@ -131,7 +131,7 @@ namespace helpers {
template <typename T>
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
T minVal = T(0);
T maxVal = DataTypeUtils::infOrMax<T>();
T maxVal = DataTypeUtils::max<T>();
if (min)
minVal = min->t<T>(0);
if (max)

View File

@ -1001,3 +1001,13 @@ TEST_F(RNGTests, test_choice_1) {
delete x;
delete prob;
}
TEST_F(RNGTests, test_uniform_119) {
auto x = NDArrayFactory::create<int>('c', {2}, {1, 5});
auto z = NDArrayFactory::create<float>('c', {1, 5});
nd4j::ops::randomuniform op;
auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {});
ASSERT_EQ(Status::OK(), status);
}