From 0d6bb657bc40e60d5d898de556df24e29e29cb92 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 20 Jul 2019 21:33:47 +1000 Subject: [PATCH] FlatBuffers dtype conversion fix (missing bfloat16) (#71) Signed-off-by: AlexDBlack --- .../org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index 743fb527a..7db1ae33c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -87,7 +87,8 @@ public class FlatBuffersMapper { return DataType.UINT32; case UINT64: return DataType.UINT64; - + case BFLOAT16: + return DataType.BFLOAT16; default: throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]"); } @@ -123,6 +124,8 @@ public class FlatBuffersMapper { return org.nd4j.linalg.api.buffer.DataType.UINT32; } else if (val == DataType.UINT64) { return org.nd4j.linalg.api.buffer.DataType.UINT64; + } else if (val == DataType.BFLOAT16){ + return org.nd4j.linalg.api.buffer.DataType.BFLOAT16; } else { throw new RuntimeException("Unknown datatype: " + val); }