diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index 743fb527a..7db1ae33c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -87,7 +87,8 @@ public class FlatBuffersMapper { return DataType.UINT32; case UINT64: return DataType.UINT64; - + case BFLOAT16: + return DataType.BFLOAT16; default: throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]"); } @@ -123,6 +124,8 @@ public class FlatBuffersMapper { return org.nd4j.linalg.api.buffer.DataType.UINT32; } else if (val == DataType.UINT64) { return org.nd4j.linalg.api.buffer.DataType.UINT64; + } else if (val == DataType.BFLOAT16){ + return org.nd4j.linalg.api.buffer.DataType.BFLOAT16; } else { throw new RuntimeException("Unknown datatype: " + val); }