- bits_hamming_distance dtype fix (#8208)

- DataTypeUtils::asString fixe + new dtypes added

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-09-06 08:59:05 +03:00 committed by GitHub
parent 4fbd9b7de0
commit 1de9fb218e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 5 deletions

View File

@ -335,8 +335,6 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
return std::string("INT8"); return std::string("INT8");
case INT16: case INT16:
return std::string("INT16"); return std::string("INT16");
case UINT16:
return std::string("UINT16");
case INT32: case INT32:
return std::string("INT32"); return std::string("INT32");
case INT64: case INT64:
@ -353,10 +351,16 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
return std::string("BOOL"); return std::string("BOOL");
case UINT8: case UINT8:
return std::string("UINT8"); return std::string("UINT8");
case UINT16:
return std::string("UINT16");
case UINT32:
return std::string("UINT32");
case UINT64:
return std::string("UINT64");
case UTF8: case UTF8:
return std::string("UTF8"); return std::string("UTF8");
default: default:
throw new std::runtime_error("Unknown data type used"); throw std::runtime_error("Unknown data type used");
} }
} }

View File

@ -48,8 +48,7 @@ namespace nd4j {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, {ALL_INTS}) ->setAllowedInputTypes(0, {ALL_INTS})
->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(1, {ALL_INTS})
->setAllowedOutputTypes(0, {ALL_INDICES}) ->setAllowedOutputTypes(0, {ALL_INDICES});
->setSameMode(true);
} }
} }
} }

View File

@ -129,3 +129,33 @@ TEST_F(DataTypesValidationTests, cast_1) {
ASSERT_TRUE(1.f == x); ASSERT_TRUE(1.f == x);
ASSERT_TRUE(y == x); ASSERT_TRUE(y == x);
} }
TEST_F(DataTypesValidationTests, test_bits_hamming_distance_1) {
auto x = NDArrayFactory::create<int>('c', {3}, {0b01011000, 0b01011111, 0b01111110});
auto y = NDArrayFactory::create<int>('c', {3}, {0b00010110, 0b01011000, 0b01011000});
auto z = NDArrayFactory::create<uint64_t>(0);
Context ctx(1);
ctx.setInputArray(0, &x);
ctx.setInputArray(1, &y);
ctx.setOutputArray(0, &z);
nd4j::ops::bits_hamming_distance op;
auto status = op.execute(&ctx);
ASSERT_NE(Status::OK(), status);
}
TEST_F(DataTypesValidationTests, test_bits_hamming_distance_2) {
auto x = NDArrayFactory::create<int>('c', {3}, {0b01011000, 0b01011111, 0b01111110});
auto y = NDArrayFactory::create<int>('c', {3}, {0b00010110, 0b01011000, 0b01011000});
auto z = NDArrayFactory::create<Nd4jLong>(0);
Context ctx(1);
ctx.setInputArray(0, &x);
ctx.setInputArray(1, &y);
ctx.setOutputArray(0, &z);
nd4j::ops::bits_hamming_distance op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
}