fix for is_increasing/non_decreasing ops for empty input case (#63)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-20 11:12:15 +03:00 committed by GitHub
parent 3f38900c33
commit 7898f3c0cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 2 deletions

View File

@ -27,9 +27,12 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
BOOLEAN_OP_IMPL(is_non_decreasing, 1, true) { BOOLEAN_OP_IMPL(is_non_decreasing, 1, true) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
// in case of empty input there's nothing to do
if (input->isEmpty())
return ND4J_STATUS_TRUE;
bool isNonDecreasing = true; bool isNonDecreasing = true;
nd4j::ops::helpers::compare_elem(block.launchContext(), input, false, isNonDecreasing); nd4j::ops::helpers::compare_elem(block.launchContext(), input, false, isNonDecreasing);

View File

@ -27,9 +27,12 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
BOOLEAN_OP_IMPL(is_strictly_increasing, 1, true) { BOOLEAN_OP_IMPL(is_strictly_increasing, 1, true) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
// in case of empty input there's nothing to do
if (input->isEmpty())
return ND4J_STATUS_TRUE;
bool isStrictlyIncreasing = true; bool isStrictlyIncreasing = true;
nd4j::ops::helpers::compare_elem(block.launchContext(), input, true, isStrictlyIncreasing); nd4j::ops::helpers::compare_elem(block.launchContext(), input, true, isStrictlyIncreasing);

View File

@ -552,3 +552,33 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_3) {
auto temp2 = temp1 * cLast; auto temp2 = temp1 * cLast;
} }
} }
TEST_F(DeclarableOpsTests15, test_empty_increasing_1) {
auto x = NDArrayFactory::create<float>('c', {1, 0, 3});
auto z = NDArrayFactory::create<bool>(false);
Context ctx(1);
ctx.setInputArray(0, &x);
ctx.setOutputArray(0, &z);
nd4j::ops::is_strictly_increasing op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(true, z.e<bool>(0));
}
TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) {
auto x = NDArrayFactory::create<float>('c', {1, 0, 3});
auto z = NDArrayFactory::create<bool>(false);
Context ctx(1);
ctx.setInputArray(0, &x);
ctx.setOutputArray(0, &z);
nd4j::ops::is_non_decreasing op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(true, z.e<bool>(0));
}