- bits_hamming_distance dtype fix (#8208)
- DataTypeUtils::asString fixe + new dtypes added Signed-off-by: raver119 <raver119@gmail.com>master
parent
4fbd9b7de0
commit
1de9fb218e
|
@ -335,8 +335,6 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
||||||
return std::string("INT8");
|
return std::string("INT8");
|
||||||
case INT16:
|
case INT16:
|
||||||
return std::string("INT16");
|
return std::string("INT16");
|
||||||
case UINT16:
|
|
||||||
return std::string("UINT16");
|
|
||||||
case INT32:
|
case INT32:
|
||||||
return std::string("INT32");
|
return std::string("INT32");
|
||||||
case INT64:
|
case INT64:
|
||||||
|
@ -353,10 +351,16 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
||||||
return std::string("BOOL");
|
return std::string("BOOL");
|
||||||
case UINT8:
|
case UINT8:
|
||||||
return std::string("UINT8");
|
return std::string("UINT8");
|
||||||
|
case UINT16:
|
||||||
|
return std::string("UINT16");
|
||||||
|
case UINT32:
|
||||||
|
return std::string("UINT32");
|
||||||
|
case UINT64:
|
||||||
|
return std::string("UINT64");
|
||||||
case UTF8:
|
case UTF8:
|
||||||
return std::string("UTF8");
|
return std::string("UTF8");
|
||||||
default:
|
default:
|
||||||
throw new std::runtime_error("Unknown data type used");
|
throw std::runtime_error("Unknown data type used");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,8 +48,7 @@ namespace nd4j {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, {ALL_INTS})
|
->setAllowedInputTypes(0, {ALL_INTS})
|
||||||
->setAllowedInputTypes(1, {ALL_INTS})
|
->setAllowedInputTypes(1, {ALL_INTS})
|
||||||
->setAllowedOutputTypes(0, {ALL_INDICES})
|
->setAllowedOutputTypes(0, {ALL_INDICES});
|
||||||
->setSameMode(true);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -129,3 +129,33 @@ TEST_F(DataTypesValidationTests, cast_1) {
|
||||||
ASSERT_TRUE(1.f == x);
|
ASSERT_TRUE(1.f == x);
|
||||||
ASSERT_TRUE(y == x);
|
ASSERT_TRUE(y == x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, test_bits_hamming_distance_1) {
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {3}, {0b01011000, 0b01011111, 0b01111110});
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {3}, {0b00010110, 0b01011000, 0b01011000});
|
||||||
|
auto z = NDArrayFactory::create<uint64_t>(0);
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &x);
|
||||||
|
ctx.setInputArray(1, &y);
|
||||||
|
ctx.setOutputArray(0, &z);
|
||||||
|
|
||||||
|
nd4j::ops::bits_hamming_distance op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_NE(Status::OK(), status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, test_bits_hamming_distance_2) {
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {3}, {0b01011000, 0b01011111, 0b01111110});
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {3}, {0b00010110, 0b01011000, 0b01011000});
|
||||||
|
auto z = NDArrayFactory::create<Nd4jLong>(0);
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &x);
|
||||||
|
ctx.setInputArray(1, &y);
|
||||||
|
ctx.setOutputArray(0, &z);
|
||||||
|
|
||||||
|
nd4j::ops::bits_hamming_distance op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
Loading…
Reference in New Issue