- bits_hamming_distance dtype fix (#8208)
- DataTypeUtils::asString fixe + new dtypes added Signed-off-by: raver119 <raver119@gmail.com>master
parent
4fbd9b7de0
commit
1de9fb218e
|
@ -335,8 +335,6 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
|||
return std::string("INT8");
|
||||
case INT16:
|
||||
return std::string("INT16");
|
||||
case UINT16:
|
||||
return std::string("UINT16");
|
||||
case INT32:
|
||||
return std::string("INT32");
|
||||
case INT64:
|
||||
|
@ -353,10 +351,16 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
|||
return std::string("BOOL");
|
||||
case UINT8:
|
||||
return std::string("UINT8");
|
||||
case UINT16:
|
||||
return std::string("UINT16");
|
||||
case UINT32:
|
||||
return std::string("UINT32");
|
||||
case UINT64:
|
||||
return std::string("UINT64");
|
||||
case UTF8:
|
||||
return std::string("UTF8");
|
||||
default:
|
||||
throw new std::runtime_error("Unknown data type used");
|
||||
throw std::runtime_error("Unknown data type used");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -48,8 +48,7 @@ namespace nd4j {
|
|||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_INTS})
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setAllowedOutputTypes(0, {ALL_INDICES})
|
||||
->setSameMode(true);
|
||||
->setAllowedOutputTypes(0, {ALL_INDICES});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -129,3 +129,33 @@ TEST_F(DataTypesValidationTests, cast_1) {
|
|||
ASSERT_TRUE(1.f == x);
|
||||
ASSERT_TRUE(y == x);
|
||||
}
|
||||
|
||||
TEST_F(DataTypesValidationTests, test_bits_hamming_distance_1) {
|
||||
auto x = NDArrayFactory::create<int>('c', {3}, {0b01011000, 0b01011111, 0b01111110});
|
||||
auto y = NDArrayFactory::create<int>('c', {3}, {0b00010110, 0b01011000, 0b01011000});
|
||||
auto z = NDArrayFactory::create<uint64_t>(0);
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &x);
|
||||
ctx.setInputArray(1, &y);
|
||||
ctx.setOutputArray(0, &z);
|
||||
|
||||
nd4j::ops::bits_hamming_distance op;
|
||||
auto status = op.execute(&ctx);
|
||||
ASSERT_NE(Status::OK(), status);
|
||||
}
|
||||
|
||||
TEST_F(DataTypesValidationTests, test_bits_hamming_distance_2) {
|
||||
auto x = NDArrayFactory::create<int>('c', {3}, {0b01011000, 0b01011111, 0b01111110});
|
||||
auto y = NDArrayFactory::create<int>('c', {3}, {0b00010110, 0b01011000, 0b01011000});
|
||||
auto z = NDArrayFactory::create<Nd4jLong>(0);
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &x);
|
||||
ctx.setInputArray(1, &y);
|
||||
ctx.setOutputArray(0, &z);
|
||||
|
||||
nd4j::ops::bits_hamming_distance op;
|
||||
auto status = op.execute(&ctx);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
Loading…
Reference in New Issue