From 46f8c58502a6b90c33b6c904ac2a74a570fa7679 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 6 Sep 2019 08:57:53 +0300 Subject: [PATCH] - bits_hamming_distance dtype fix - DataTypeUtils::asString fixe + new dtypes added Signed-off-by: raver119 --- libnd4j/include/array/DataTypeUtils.h | 10 +++++-- .../generic/bitwise/bits_hamming_distance.cpp | 3 +- .../layers_tests/DataTypesValidationTests.cpp | 30 +++++++++++++++++++ 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index 2a52ba6f5..8b3176c2b 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -335,8 +335,6 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) { return std::string("INT8"); case INT16: return std::string("INT16"); - case UINT16: - return std::string("UINT16"); case INT32: return std::string("INT32"); case INT64: @@ -353,10 +351,16 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) { return std::string("BOOL"); case UINT8: return std::string("UINT8"); + case UINT16: + return std::string("UINT16"); + case UINT32: + return std::string("UINT32"); + case UINT64: + return std::string("UINT64"); case UTF8: return std::string("UTF8"); default: - throw new std::runtime_error("Unknown data type used"); + throw std::runtime_error("Unknown data type used"); } } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp index ff72ff4b9..f2a39b270 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp @@ -48,8 +48,7 @@ namespace nd4j { getOpDescriptor() ->setAllowedInputTypes(0, {ALL_INTS}) ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes(0, {ALL_INDICES}) - ->setSameMode(true); + ->setAllowedOutputTypes(0, {ALL_INDICES}); } } } diff --git a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp index 9de87b584..c018e58d0 100644 --- a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -129,3 +129,33 @@ TEST_F(DataTypesValidationTests, cast_1) { ASSERT_TRUE(1.f == x); ASSERT_TRUE(y == x); } + +TEST_F(DataTypesValidationTests, test_bits_hamming_distance_1) { + auto x = NDArrayFactory::create('c', {3}, {0b01011000, 0b01011111, 0b01111110}); + auto y = NDArrayFactory::create('c', {3}, {0b00010110, 0b01011000, 0b01011000}); + auto z = NDArrayFactory::create(0); + + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + nd4j::ops::bits_hamming_distance op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); +} + +TEST_F(DataTypesValidationTests, test_bits_hamming_distance_2) { + auto x = NDArrayFactory::create('c', {3}, {0b01011000, 0b01011111, 0b01111110}); + auto y = NDArrayFactory::create('c', {3}, {0b00010110, 0b01011000, 0b01011000}); + auto z = NDArrayFactory::create(0); + + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + nd4j::ops::bits_hamming_distance op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); +} \ No newline at end of file