FlatBuffers dtype conversion fix (missing bfloat16) (#71)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-07-20 21:33:47 +10:00 committed by AlexDBlack
parent 763a225c6a
commit 0d6bb657bc
1 changed files with 4 additions and 1 deletions

View File

@ -87,7 +87,8 @@ public class FlatBuffersMapper {
return DataType.UINT32; return DataType.UINT32;
case UINT64: case UINT64:
return DataType.UINT64; return DataType.UINT64;
case BFLOAT16:
return DataType.BFLOAT16;
default: default:
throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]"); throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
} }
@ -123,6 +124,8 @@ public class FlatBuffersMapper {
return org.nd4j.linalg.api.buffer.DataType.UINT32; return org.nd4j.linalg.api.buffer.DataType.UINT32;
} else if (val == DataType.UINT64) { } else if (val == DataType.UINT64) {
return org.nd4j.linalg.api.buffer.DataType.UINT64; return org.nd4j.linalg.api.buffer.DataType.UINT64;
} else if (val == DataType.BFLOAT16){
return org.nd4j.linalg.api.buffer.DataType.BFLOAT16;
} else { } else {
throw new RuntimeException("Unknown datatype: " + val); throw new RuntimeException("Unknown datatype: " + val);
} }