- new NDArrayFactory scalar constructor
- minor tweak in randomuniform - one more test Signed-off-by: raver119 <raver119@gmail.com>master
parent
51f3a1371d
commit
929c1dc5c7
|
@ -59,6 +59,7 @@ namespace nd4j {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static NDArray* create_(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
static NDArray* create_(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* create_(nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static NDArray create(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
static NDArray create(const T value, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
|
|
@ -422,6 +422,11 @@ NDArray NDArrayFactory::create(nd4j::DataType dtype, nd4j::LaunchContext * conte
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NDArray* NDArrayFactory::create_(nd4j::DataType dtype, nd4j::LaunchContext * context) {
|
||||||
|
auto result = new NDArray();
|
||||||
|
*result = NDArrayFactory::create(dtype, context);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -49,10 +49,10 @@ namespace nd4j {
|
||||||
bool disposable = false;
|
bool disposable = false;
|
||||||
|
|
||||||
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
||||||
min = NDArrayFactory::create_('c', {}, dtype);
|
min = NDArrayFactory::create_(dtype);
|
||||||
max = NDArrayFactory::create_('c', {}, dtype);
|
max = NDArrayFactory::create_(dtype);
|
||||||
min->assign(T_ARG(0));
|
min->p(0, T_ARG(0));
|
||||||
max->assign(T_ARG(1));
|
max->p(0, T_ARG(1));
|
||||||
disposable = true;
|
disposable = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -131,7 +131,7 @@ namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||||
T minVal = T(0);
|
T minVal = T(0);
|
||||||
T maxVal = DataTypeUtils::infOrMax<T>();
|
T maxVal = DataTypeUtils::max<T>();
|
||||||
if (min)
|
if (min)
|
||||||
minVal = min->t<T>(0);
|
minVal = min->t<T>(0);
|
||||||
if (max)
|
if (max)
|
||||||
|
|
|
@ -1001,3 +1001,13 @@ TEST_F(RNGTests, test_choice_1) {
|
||||||
delete x;
|
delete x;
|
||||||
delete prob;
|
delete prob;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, test_uniform_119) {
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {2}, {1, 5});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {1, 5});
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::ops::randomuniform op;
|
||||||
|
auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
Loading…
Reference in New Issue