fix for is_increasing/non_decreasing ops for empty input case (#63)
Signed-off-by: raver119 <raver119@gmail.com>master
parent
3f38900c33
commit
7898f3c0cc
|
@ -27,9 +27,12 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
BOOLEAN_OP_IMPL(is_non_decreasing, 1, true) {
|
BOOLEAN_OP_IMPL(is_non_decreasing, 1, true) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
// in case of empty input there's nothing to do
|
||||||
|
if (input->isEmpty())
|
||||||
|
return ND4J_STATUS_TRUE;
|
||||||
|
|
||||||
bool isNonDecreasing = true;
|
bool isNonDecreasing = true;
|
||||||
|
|
||||||
nd4j::ops::helpers::compare_elem(block.launchContext(), input, false, isNonDecreasing);
|
nd4j::ops::helpers::compare_elem(block.launchContext(), input, false, isNonDecreasing);
|
||||||
|
|
|
@ -27,9 +27,12 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
BOOLEAN_OP_IMPL(is_strictly_increasing, 1, true) {
|
BOOLEAN_OP_IMPL(is_strictly_increasing, 1, true) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
// in case of empty input there's nothing to do
|
||||||
|
if (input->isEmpty())
|
||||||
|
return ND4J_STATUS_TRUE;
|
||||||
|
|
||||||
bool isStrictlyIncreasing = true;
|
bool isStrictlyIncreasing = true;
|
||||||
|
|
||||||
nd4j::ops::helpers::compare_elem(block.launchContext(), input, true, isStrictlyIncreasing);
|
nd4j::ops::helpers::compare_elem(block.launchContext(), input, true, isStrictlyIncreasing);
|
||||||
|
|
|
@ -552,3 +552,33 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_3) {
|
||||||
auto temp2 = temp1 * cLast;
|
auto temp2 = temp1 * cLast;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests15, test_empty_increasing_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 3});
|
||||||
|
auto z = NDArrayFactory::create<bool>(false);
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &x);
|
||||||
|
ctx.setOutputArray(0, &z);
|
||||||
|
|
||||||
|
nd4j::ops::is_strictly_increasing op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(true, z.e<bool>(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 3});
|
||||||
|
auto z = NDArrayFactory::create<bool>(false);
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &x);
|
||||||
|
ctx.setOutputArray(0, &z);
|
||||||
|
|
||||||
|
nd4j::ops::is_non_decreasing op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(true, z.e<bool>(0));
|
||||||
|
}
|
Loading…
Reference in New Issue