- new NDArrayFactory scalar constructor
- minor tweak in randomuniform - one more test Signed-off-by: raver119 <raver119@gmail.com>master
parent
51f3a1371d
commit
929c1dc5c7
|
@ -59,6 +59,7 @@ namespace nd4j {
|
|||
|
||||
template <typename T>
|
||||
static NDArray* create_(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray* create_(nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
|
||||
template <typename T>
|
||||
static NDArray create(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
|
|
|
@ -422,6 +422,11 @@ NDArray NDArrayFactory::create(nd4j::DataType dtype, nd4j::LaunchContext * conte
|
|||
return res;
|
||||
}
|
||||
|
||||
NDArray* NDArrayFactory::create_(nd4j::DataType dtype, nd4j::LaunchContext * context) {
|
||||
auto result = new NDArray();
|
||||
*result = NDArrayFactory::create(dtype, context);
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
|
|
|
@ -49,10 +49,10 @@ namespace nd4j {
|
|||
bool disposable = false;
|
||||
|
||||
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
||||
min = NDArrayFactory::create_('c', {}, dtype);
|
||||
max = NDArrayFactory::create_('c', {}, dtype);
|
||||
min->assign(T_ARG(0));
|
||||
max->assign(T_ARG(1));
|
||||
min = NDArrayFactory::create_(dtype);
|
||||
max = NDArrayFactory::create_(dtype);
|
||||
min->p(0, T_ARG(0));
|
||||
max->p(0, T_ARG(1));
|
||||
disposable = true;
|
||||
}
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ namespace helpers {
|
|||
template <typename T>
|
||||
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||
T minVal = T(0);
|
||||
T maxVal = DataTypeUtils::infOrMax<T>();
|
||||
T maxVal = DataTypeUtils::max<T>();
|
||||
if (min)
|
||||
minVal = min->t<T>(0);
|
||||
if (max)
|
||||
|
|
|
@ -1001,3 +1001,13 @@ TEST_F(RNGTests, test_choice_1) {
|
|||
delete x;
|
||||
delete prob;
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, test_uniform_119) {
|
||||
auto x = NDArrayFactory::create<int>('c', {2}, {1, 5});
|
||||
auto z = NDArrayFactory::create<float>('c', {1, 5});
|
||||
|
||||
|
||||
nd4j::ops::randomuniform op;
|
||||
auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
Loading…
Reference in New Issue