Rename flatbuffers DataType to DType (#228)

* Rename flatbuffers DataType enum to DType

Signed-off-by: Alex Black <blacka101@gmail.com>

* Rename flatbuffers DataType enum to DType

Signed-off-by: Alex Black <blacka101@gmail.com>

* Updates for flatbuffers datatype enum renaming

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-09-04 16:36:11 +10:00 committed by GitHub
parent 25b01f7850
commit 6cc887bee9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 135 additions and 135 deletions

View File

@ -583,7 +583,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
auto fName = builder.CreateString(*(var->getName())); auto fName = builder.CreateString(*(var->getName()));
auto id = CreateIntPair(builder, var->id(), var->index()); auto id = CreateIntPair(builder, var->id(), var->index());
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray); auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
variables_vector.push_back(fv); variables_vector.push_back(fv);
arrays++; arrays++;

View File

@ -38,7 +38,7 @@ namespace nd4j {
public: public:
static int asInt(DataType type); static int asInt(DataType type);
static DataType fromInt(int dtype); static DataType fromInt(int dtype);
static DataType fromFlatDataType(nd4j::graph::DataType dtype); static DataType fromFlatDataType(nd4j::graph::DType dtype);
FORCEINLINE static std::string asString(DataType dataType); FORCEINLINE static std::string asString(DataType dataType);
template <typename T> template <typename T>

View File

@ -27,7 +27,7 @@ namespace nd4j {
return (DataType) val; return (DataType) val;
} }
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DataType dtype) { DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DType dtype) {
return (DataType) dtype; return (DataType) dtype;
} }

View File

@ -40,56 +40,56 @@ inline const char *EnumNameByteOrder(ByteOrder e) {
return EnumNamesByteOrder()[index]; return EnumNamesByteOrder()[index];
} }
enum DataType { enum DType {
DataType_INHERIT = 0, DType_INHERIT = 0,
DataType_BOOL = 1, DType_BOOL = 1,
DataType_FLOAT8 = 2, DType_FLOAT8 = 2,
DataType_HALF = 3, DType_HALF = 3,
DataType_HALF2 = 4, DType_HALF2 = 4,
DataType_FLOAT = 5, DType_FLOAT = 5,
DataType_DOUBLE = 6, DType_DOUBLE = 6,
DataType_INT8 = 7, DType_INT8 = 7,
DataType_INT16 = 8, DType_INT16 = 8,
DataType_INT32 = 9, DType_INT32 = 9,
DataType_INT64 = 10, DType_INT64 = 10,
DataType_UINT8 = 11, DType_UINT8 = 11,
DataType_UINT16 = 12, DType_UINT16 = 12,
DataType_UINT32 = 13, DType_UINT32 = 13,
DataType_UINT64 = 14, DType_UINT64 = 14,
DataType_QINT8 = 15, DType_QINT8 = 15,
DataType_QINT16 = 16, DType_QINT16 = 16,
DataType_BFLOAT16 = 17, DType_BFLOAT16 = 17,
DataType_UTF8 = 50, DType_UTF8 = 50,
DataType_MIN = DataType_INHERIT, DType_MIN = DType_INHERIT,
DataType_MAX = DataType_UTF8 DType_MAX = DType_UTF8
}; };
inline const DataType (&EnumValuesDataType())[19] { inline const DType (&EnumValuesDType())[19] {
static const DataType values[] = { static const DType values[] = {
DataType_INHERIT, DType_INHERIT,
DataType_BOOL, DType_BOOL,
DataType_FLOAT8, DType_FLOAT8,
DataType_HALF, DType_HALF,
DataType_HALF2, DType_HALF2,
DataType_FLOAT, DType_FLOAT,
DataType_DOUBLE, DType_DOUBLE,
DataType_INT8, DType_INT8,
DataType_INT16, DType_INT16,
DataType_INT32, DType_INT32,
DataType_INT64, DType_INT64,
DataType_UINT8, DType_UINT8,
DataType_UINT16, DType_UINT16,
DataType_UINT32, DType_UINT32,
DataType_UINT64, DType_UINT64,
DataType_QINT8, DType_QINT8,
DataType_QINT16, DType_QINT16,
DataType_BFLOAT16, DType_BFLOAT16,
DataType_UTF8 DType_UTF8
}; };
return values; return values;
} }
inline const char * const *EnumNamesDataType() { inline const char * const *EnumNamesDType() {
static const char * const names[] = { static const char * const names[] = {
"INHERIT", "INHERIT",
"BOOL", "BOOL",
@ -147,9 +147,9 @@ inline const char * const *EnumNamesDataType() {
return names; return names;
} }
inline const char *EnumNameDataType(DataType e) { inline const char *EnumNameDType(DType e) {
const size_t index = static_cast<int>(e); const size_t index = static_cast<int>(e);
return EnumNamesDataType()[index]; return EnumNamesDType()[index];
} }
struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
@ -165,8 +165,8 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::Vector<int8_t> *buffer() const { const flatbuffers::Vector<int8_t> *buffer() const {
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER); return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER);
} }
DataType dtype() const { DType dtype() const {
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0)); return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
} }
ByteOrder byteOrder() const { ByteOrder byteOrder() const {
return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0)); return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0));
@ -192,7 +192,7 @@ struct FlatArrayBuilder {
void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) { void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) {
fbb_.AddOffset(FlatArray::VT_BUFFER, buffer); fbb_.AddOffset(FlatArray::VT_BUFFER, buffer);
} }
void add_dtype(DataType dtype) { void add_dtype(DType dtype) {
fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0); fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0);
} }
void add_byteOrder(ByteOrder byteOrder) { void add_byteOrder(ByteOrder byteOrder) {
@ -214,7 +214,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArray(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0, flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0, flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0,
DataType dtype = DataType_INHERIT, DType dtype = DType_INHERIT,
ByteOrder byteOrder = ByteOrder_LE) { ByteOrder byteOrder = ByteOrder_LE) {
FlatArrayBuilder builder_(_fbb); FlatArrayBuilder builder_(_fbb);
builder_.add_buffer(buffer); builder_.add_buffer(buffer);
@ -228,7 +228,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArrayDirect(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
const std::vector<int64_t> *shape = nullptr, const std::vector<int64_t> *shape = nullptr,
const std::vector<int8_t> *buffer = nullptr, const std::vector<int8_t> *buffer = nullptr,
DataType dtype = DataType_INHERIT, DType dtype = DType_INHERIT,
ByteOrder byteOrder = ByteOrder_LE) { ByteOrder byteOrder = ByteOrder_LE) {
return nd4j::graph::CreateFlatArray( return nd4j::graph::CreateFlatArray(
_fbb, _fbb,

View File

@ -23,7 +23,7 @@ nd4j.graph.ByteOrder = {
/** /**
* @enum * @enum
*/ */
nd4j.graph.DataType = { nd4j.graph.DType = {
INHERIT: 0, INHERIT: 0,
BOOL: 1, BOOL: 1,
FLOAT8: 2, FLOAT8: 2,
@ -123,11 +123,11 @@ nd4j.graph.FlatArray.prototype.bufferArray = function() {
}; };
/** /**
* @returns {nd4j.graph.DataType} * @returns {nd4j.graph.DType}
*/ */
nd4j.graph.FlatArray.prototype.dtype = function() { nd4j.graph.FlatArray.prototype.dtype = function() {
var offset = this.bb.__offset(this.bb_pos, 8); var offset = this.bb.__offset(this.bb_pos, 8);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
}; };
/** /**
@ -205,10 +205,10 @@ nd4j.graph.FlatArray.startBufferVector = function(builder, numElems) {
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
* @param {nd4j.graph.DataType} dtype * @param {nd4j.graph.DType} dtype
*/ */
nd4j.graph.FlatArray.addDtype = function(builder, dtype) { nd4j.graph.FlatArray.addDtype = function(builder, dtype) {
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT); builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
}; };
/** /**

View File

@ -5,7 +5,7 @@
namespace nd4j.graph namespace nd4j.graph
{ {
public enum DataType : sbyte public enum DType : sbyte
{ {
INHERIT = 0, INHERIT = 0,
BOOL = 1, BOOL = 1,

View File

@ -2,8 +2,8 @@
package nd4j.graph; package nd4j.graph;
public final class DataType { public final class DType {
private DataType() { } private DType() { }
public static final byte INHERIT = 0; public static final byte INHERIT = 0;
public static final byte BOOL = 1; public static final byte BOOL = 1;
public static final byte FLOAT8 = 2; public static final byte FLOAT8 = 2;

View File

@ -2,7 +2,7 @@
# namespace: graph # namespace: graph
class DataType(object): class DType(object):
INHERIT = 0 INHERIT = 0
BOOL = 1 BOOL = 1
FLOAT8 = 2 FLOAT8 = 2

View File

@ -33,13 +33,13 @@ public struct FlatArray : IFlatbufferObject
public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); } public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); }
#endif #endif
public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); } public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); }
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } } public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } }
public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder, public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder,
VectorOffset shapeOffset = default(VectorOffset), VectorOffset shapeOffset = default(VectorOffset),
VectorOffset bufferOffset = default(VectorOffset), VectorOffset bufferOffset = default(VectorOffset),
DataType dtype = DataType.INHERIT, DType dtype = DType.INHERIT,
ByteOrder byteOrder = ByteOrder.LE) { ByteOrder byteOrder = ByteOrder.LE) {
builder.StartObject(4); builder.StartObject(4);
FlatArray.AddBuffer(builder, bufferOffset); FlatArray.AddBuffer(builder, bufferOffset);
@ -58,7 +58,7 @@ public struct FlatArray : IFlatbufferObject
public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); } public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); }
public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); } public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); }
public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) { public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) {
int o = builder.EndObject(); int o = builder.EndObject();

View File

@ -97,14 +97,14 @@ public struct FlatNode : IFlatbufferObject
public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); } public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); }
#endif #endif
public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); } public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); }
public DataType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DataType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DataType)0; } public DType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } } public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T #if ENABLE_SPAN_T
public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); } public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); }
#else #else
public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); } public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); }
#endif #endif
public DataType[] GetOutputTypesArray() { return __p.__vector_as_array<DataType>(38); } public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } } public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder, public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
@ -196,8 +196,8 @@ public struct FlatNode : IFlatbufferObject
public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); } public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); }
public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); } public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); }
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); } public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); } public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) { public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {

View File

@ -25,7 +25,7 @@ public struct FlatVariable : IFlatbufferObject
public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); } public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); }
#endif #endif
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); } public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; } public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } } public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T #if ENABLE_SPAN_T
@ -41,7 +41,7 @@ public struct FlatVariable : IFlatbufferObject
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder, public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
Offset<IntPair> idOffset = default(Offset<IntPair>), Offset<IntPair> idOffset = default(Offset<IntPair>),
StringOffset nameOffset = default(StringOffset), StringOffset nameOffset = default(StringOffset),
DataType dtype = DataType.INHERIT, DType dtype = DType.INHERIT,
VectorOffset shapeOffset = default(VectorOffset), VectorOffset shapeOffset = default(VectorOffset),
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>), Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
int device = 0, int device = 0,
@ -60,7 +60,7 @@ public struct FlatVariable : IFlatbufferObject
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); } public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); } public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); } public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); }
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); } public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); } public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }

View File

@ -312,11 +312,11 @@ nd4j.graph.FlatNode.prototype.opName = function(optionalEncoding) {
/** /**
* @param {number} index * @param {number} index
* @returns {nd4j.graph.DataType} * @returns {nd4j.graph.DType}
*/ */
nd4j.graph.FlatNode.prototype.outputTypes = function(index) { nd4j.graph.FlatNode.prototype.outputTypes = function(index) {
var offset = this.bb.__offset(this.bb_pos, 38); var offset = this.bb.__offset(this.bb_pos, 38);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DataType} */ (0); return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
}; };
/** /**
@ -686,7 +686,7 @@ nd4j.graph.FlatNode.addOutputTypes = function(builder, outputTypesOffset) {
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
* @param {Array.<nd4j.graph.DataType>} data * @param {Array.<nd4j.graph.DType>} data
* @returns {flatbuffers.Offset} * @returns {flatbuffers.Offset}
*/ */
nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) { nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) {

View File

@ -65,8 +65,8 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::String *name() const { const flatbuffers::String *name() const {
return GetPointer<const flatbuffers::String *>(VT_NAME); return GetPointer<const flatbuffers::String *>(VT_NAME);
} }
DataType dtype() const { DType dtype() const {
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0)); return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
} }
const flatbuffers::Vector<int64_t> *shape() const { const flatbuffers::Vector<int64_t> *shape() const {
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE); return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
@ -106,7 +106,7 @@ struct FlatVariableBuilder {
void add_name(flatbuffers::Offset<flatbuffers::String> name) { void add_name(flatbuffers::Offset<flatbuffers::String> name) {
fbb_.AddOffset(FlatVariable::VT_NAME, name); fbb_.AddOffset(FlatVariable::VT_NAME, name);
} }
void add_dtype(DataType dtype) { void add_dtype(DType dtype) {
fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0); fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0);
} }
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) { void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
@ -137,7 +137,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariable(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<IntPair> id = 0, flatbuffers::Offset<IntPair> id = 0,
flatbuffers::Offset<flatbuffers::String> name = 0, flatbuffers::Offset<flatbuffers::String> name = 0,
DataType dtype = DataType_INHERIT, DType dtype = DType_INHERIT,
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0, flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
flatbuffers::Offset<FlatArray> ndarray = 0, flatbuffers::Offset<FlatArray> ndarray = 0,
int32_t device = 0, int32_t device = 0,
@ -157,7 +157,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<IntPair> id = 0, flatbuffers::Offset<IntPair> id = 0,
const char *name = nullptr, const char *name = nullptr,
DataType dtype = DataType_INHERIT, DType dtype = DType_INHERIT,
const std::vector<int64_t> *shape = nullptr, const std::vector<int64_t> *shape = nullptr,
flatbuffers::Offset<FlatArray> ndarray = 0, flatbuffers::Offset<FlatArray> ndarray = 0,
int32_t device = 0, int32_t device = 0,

View File

@ -76,11 +76,11 @@ nd4j.graph.FlatVariable.prototype.name = function(optionalEncoding) {
}; };
/** /**
* @returns {nd4j.graph.DataType} * @returns {nd4j.graph.DType}
*/ */
nd4j.graph.FlatVariable.prototype.dtype = function() { nd4j.graph.FlatVariable.prototype.dtype = function() {
var offset = this.bb.__offset(this.bb_pos, 8); var offset = this.bb.__offset(this.bb_pos, 8);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
}; };
/** /**
@ -150,10 +150,10 @@ nd4j.graph.FlatVariable.addName = function(builder, nameOffset) {
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
* @param {nd4j.graph.DataType} dtype * @param {nd4j.graph.DType} dtype
*/ */
nd4j.graph.FlatVariable.addDtype = function(builder, dtype) { nd4j.graph.FlatVariable.addDtype = function(builder, dtype) {
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT); builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
}; };
/** /**

View File

@ -111,7 +111,7 @@ namespace nd4j {
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder()); auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo); return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DType>(array.dataType()), bo);
} }
} }
} }

View File

@ -219,7 +219,7 @@ namespace nd4j {
throw std::runtime_error("CONSTANT variable must have NDArray bundled"); throw std::runtime_error("CONSTANT variable must have NDArray bundled");
auto ar = flatVariable->ndarray(); auto ar = flatVariable->ndarray();
if (ar->dtype() == DataType_UTF8) { if (ar->dtype() == DType_UTF8) {
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar); _ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
} else { } else {
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar); _ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
@ -320,7 +320,7 @@ namespace nd4j {
auto fBuffer = builder.CreateVector(array->asByteVector()); auto fBuffer = builder.CreateVector(array->asByteVector());
// packing array // packing array
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType()); auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DType) array->dataType());
// packing id/index of this var // packing id/index of this var
auto fVid = CreateIntPair(builder, this->_id, this->_index); auto fVid = CreateIntPair(builder, this->_id, this->_index);
@ -331,7 +331,7 @@ namespace nd4j {
stringId = builder.CreateString(this->_name); stringId = builder.CreateString(this->_name);
// returning array // returning array
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray); return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
} else { } else {
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList"); throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
} }

View File

@ -23,7 +23,7 @@ enum ByteOrder:byte {
} }
// DataType for arrays/buffers // DataType for arrays/buffers
enum DataType:byte { enum DType:byte {
INHERIT, INHERIT,
BOOL, BOOL,
FLOAT8, FLOAT8,
@ -49,7 +49,7 @@ enum DataType:byte {
table FlatArray { table FlatArray {
shape:[long]; // shape in Nd4j format shape:[long]; // shape in Nd4j format
buffer:[byte]; // byte buffer with data buffer:[byte]; // byte buffer with data
dtype:DataType; // data type of actual data within buffer dtype:DType; // data type of actual data within buffer
byteOrder:ByteOrder; // byte order of buffer byteOrder:ByteOrder; // byte order of buffer
} }

View File

@ -48,7 +48,7 @@ table FlatNode {
opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability
// output data types (optional) // output data types (optional)
outputTypes:[DataType]; outputTypes:[DType];
//Scalar value - used for scalar ops. Should be single value only. //Scalar value - used for scalar ops. Should be single value only.
scalar:FlatArray; scalar:FlatArray;

View File

@ -51,7 +51,7 @@ table UIVariable {
id:IntPair; //Existing IntPair class id:IntPair; //Existing IntPair class
name:string; name:string;
type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER
datatype:DataType; datatype:DType;
shape:[long]; shape:[long];
controlDeps:[string]; //Input control dependencies: variable x -> this controlDeps:[string]; //Input control dependencies: variable x -> this
outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of

View File

@ -30,7 +30,7 @@ enum VarType:byte {
table FlatVariable { table FlatVariable {
id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node
name:string; // symbolic ID of the Variable (if defined) name:string; // symbolic ID of the Variable (if defined)
dtype:DataType; dtype:DType;
shape:[long]; // shape is absolutely optional. either shape or ndarray might be set shape:[long]; // shape is absolutely optional. either shape or ndarray might be set
ndarray:FlatArray; ndarray:FlatArray;

View File

@ -94,10 +94,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) {
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
auto fBuffer = builder.CreateVector(array->asByteVector()); auto fBuffer = builder.CreateVector(array->asByteVector());
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT); auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
auto fVid = CreateIntPair(builder, -1); auto fVid = CreateIntPair(builder, -1);
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray); auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
std::vector<int> outputs1, outputs2, inputs1, inputs2; std::vector<int> outputs1, outputs2, inputs1, inputs2;
outputs1.push_back(2); outputs1.push_back(2);
@ -265,7 +265,7 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) {
auto name1 = builder.CreateString("wow1"); auto name1 = builder.CreateString("wow1");
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT); auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DType::FLOAT);
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector; std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
variables_vector.push_back(fXVar); variables_vector.push_back(fXVar);

View File

@ -73,9 +73,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) {
auto fBuffer = builder.CreateVector(vec); auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12); auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT); auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray); auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
builder.Finish(flatVar); builder.Finish(flatVar);
@ -107,9 +107,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) {
auto fBuffer = builder.CreateVector(vec); auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12); auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE); auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray); auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
builder.Finish(flatVar); builder.Finish(flatVar);
@ -144,9 +144,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) {
auto fBuffer = builder.CreateVector(vec); auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12); auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE); auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray); auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
builder.Finish(flatVar); builder.Finish(flatVar);
@ -180,7 +180,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) {
auto fShape = builder.CreateVector(original.getShapeAsFlatVector()); auto fShape = builder.CreateVector(original.getShapeAsFlatVector());
auto fVid = CreateIntPair(builder, 37, 12); auto fVid = CreateIntPair(builder, 37, 12);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
builder.Finish(flatVar); builder.Finish(flatVar);

View File

@ -31,7 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.graph.DataType; import org.nd4j.graph.DType;
import org.nd4j.graph.FlatArray; import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatNode; import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatProperties; import org.nd4j.graph.FlatProperties;
@ -66,33 +66,33 @@ public class FlatBuffersMapper {
public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) { public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
switch (type) { switch (type) {
case FLOAT: case FLOAT:
return DataType.FLOAT; return DType.FLOAT;
case DOUBLE: case DOUBLE:
return DataType.DOUBLE; return DType.DOUBLE;
case HALF: case HALF:
return DataType.HALF; return DType.HALF;
case INT: case INT:
return DataType.INT32; return DType.INT32;
case LONG: case LONG:
return DataType.INT64; return DType.INT64;
case BOOL: case BOOL:
return DataType.BOOL; return DType.BOOL;
case SHORT: case SHORT:
return DataType.INT16; return DType.INT16;
case BYTE: case BYTE:
return DataType.INT8; return DType.INT8;
case UBYTE: case UBYTE:
return DataType.UINT8; return DType.UINT8;
case UTF8: case UTF8:
return DataType.UTF8; return DType.UTF8;
case UINT16: case UINT16:
return DataType.UINT16; return DType.UINT16;
case UINT32: case UINT32:
return DataType.UINT32; return DType.UINT32;
case UINT64: case UINT64:
return DataType.UINT64; return DType.UINT64;
case BFLOAT16: case BFLOAT16:
return DataType.BFLOAT16; return DType.BFLOAT16;
default: default:
throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]"); throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
} }
@ -102,33 +102,33 @@ public class FlatBuffersMapper {
* This method converts enums for DataType * This method converts enums for DataType
*/ */
public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) { public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
if (val == DataType.FLOAT) { if (val == DType.FLOAT) {
return org.nd4j.linalg.api.buffer.DataType.FLOAT; return org.nd4j.linalg.api.buffer.DataType.FLOAT;
} else if (val == DataType.DOUBLE) { } else if (val == DType.DOUBLE) {
return org.nd4j.linalg.api.buffer.DataType.DOUBLE; return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
} else if (val == DataType.HALF) { } else if (val == DType.HALF) {
return org.nd4j.linalg.api.buffer.DataType.HALF; return org.nd4j.linalg.api.buffer.DataType.HALF;
} else if (val == DataType.INT32) { } else if (val == DType.INT32) {
return org.nd4j.linalg.api.buffer.DataType.INT; return org.nd4j.linalg.api.buffer.DataType.INT;
} else if (val == DataType.INT64) { } else if (val == DType.INT64) {
return org.nd4j.linalg.api.buffer.DataType.LONG; return org.nd4j.linalg.api.buffer.DataType.LONG;
} else if (val == DataType.INT8) { } else if (val == DType.INT8) {
return org.nd4j.linalg.api.buffer.DataType.BYTE; return org.nd4j.linalg.api.buffer.DataType.BYTE;
} else if (val == DataType.BOOL) { } else if (val == DType.BOOL) {
return org.nd4j.linalg.api.buffer.DataType.BOOL; return org.nd4j.linalg.api.buffer.DataType.BOOL;
} else if (val == DataType.UINT8) { } else if (val == DType.UINT8) {
return org.nd4j.linalg.api.buffer.DataType.UBYTE; return org.nd4j.linalg.api.buffer.DataType.UBYTE;
} else if (val == DataType.INT16) { } else if (val == DType.INT16) {
return org.nd4j.linalg.api.buffer.DataType.SHORT; return org.nd4j.linalg.api.buffer.DataType.SHORT;
} else if (val == DataType.UTF8) { } else if (val == DType.UTF8) {
return org.nd4j.linalg.api.buffer.DataType.UTF8; return org.nd4j.linalg.api.buffer.DataType.UTF8;
} else if (val == DataType.UINT16) { } else if (val == DType.UINT16) {
return org.nd4j.linalg.api.buffer.DataType.UINT16; return org.nd4j.linalg.api.buffer.DataType.UINT16;
} else if (val == DataType.UINT32) { } else if (val == DType.UINT32) {
return org.nd4j.linalg.api.buffer.DataType.UINT32; return org.nd4j.linalg.api.buffer.DataType.UINT32;
} else if (val == DataType.UINT64) { } else if (val == DType.UINT64) {
return org.nd4j.linalg.api.buffer.DataType.UINT64; return org.nd4j.linalg.api.buffer.DataType.UINT64;
} else if (val == DataType.BFLOAT16){ } else if (val == DType.BFLOAT16){
return org.nd4j.linalg.api.buffer.DataType.BFLOAT16; return org.nd4j.linalg.api.buffer.DataType.BFLOAT16;
} else { } else {
throw new RuntimeException("Unknown datatype: " + val); throw new RuntimeException("Unknown datatype: " + val);

View File

@ -2,8 +2,8 @@
package org.nd4j.graph; package org.nd4j.graph;
public final class DataType { public final class DType {
private DataType() { } private DType() { }
public static final byte INHERIT = 0; public static final byte INHERIT = 0;
public static final byte BOOL = 1; public static final byte BOOL = 1;
public static final byte FLOAT8 = 2; public static final byte FLOAT8 = 2;