Rename flatbuffers DataType to DType (#228)

* Rename flatbuffers DataType enum to DType

Signed-off-by: Alex Black <blacka101@gmail.com>

* Rename flatbuffers DataType enum to DType

Signed-off-by: Alex Black <blacka101@gmail.com>

* Updates for flatbuffers datatype enum renaming

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-09-04 16:36:11 +10:00 committed by GitHub
parent 25b01f7850
commit 6cc887bee9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 135 additions and 135 deletions

View File

@ -583,7 +583,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
auto fName = builder.CreateString(*(var->getName()));
auto id = CreateIntPair(builder, var->id(), var->index());
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
variables_vector.push_back(fv);
arrays++;

View File

@ -38,7 +38,7 @@ namespace nd4j {
public:
static int asInt(DataType type);
static DataType fromInt(int dtype);
static DataType fromFlatDataType(nd4j::graph::DataType dtype);
static DataType fromFlatDataType(nd4j::graph::DType dtype);
FORCEINLINE static std::string asString(DataType dataType);
template <typename T>

View File

@ -27,7 +27,7 @@ namespace nd4j {
return (DataType) val;
}
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DataType dtype) {
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DType dtype) {
return (DataType) dtype;
}

View File

@ -40,56 +40,56 @@ inline const char *EnumNameByteOrder(ByteOrder e) {
return EnumNamesByteOrder()[index];
}
enum DataType {
DataType_INHERIT = 0,
DataType_BOOL = 1,
DataType_FLOAT8 = 2,
DataType_HALF = 3,
DataType_HALF2 = 4,
DataType_FLOAT = 5,
DataType_DOUBLE = 6,
DataType_INT8 = 7,
DataType_INT16 = 8,
DataType_INT32 = 9,
DataType_INT64 = 10,
DataType_UINT8 = 11,
DataType_UINT16 = 12,
DataType_UINT32 = 13,
DataType_UINT64 = 14,
DataType_QINT8 = 15,
DataType_QINT16 = 16,
DataType_BFLOAT16 = 17,
DataType_UTF8 = 50,
DataType_MIN = DataType_INHERIT,
DataType_MAX = DataType_UTF8
enum DType {
DType_INHERIT = 0,
DType_BOOL = 1,
DType_FLOAT8 = 2,
DType_HALF = 3,
DType_HALF2 = 4,
DType_FLOAT = 5,
DType_DOUBLE = 6,
DType_INT8 = 7,
DType_INT16 = 8,
DType_INT32 = 9,
DType_INT64 = 10,
DType_UINT8 = 11,
DType_UINT16 = 12,
DType_UINT32 = 13,
DType_UINT64 = 14,
DType_QINT8 = 15,
DType_QINT16 = 16,
DType_BFLOAT16 = 17,
DType_UTF8 = 50,
DType_MIN = DType_INHERIT,
DType_MAX = DType_UTF8
};
inline const DataType (&EnumValuesDataType())[19] {
static const DataType values[] = {
DataType_INHERIT,
DataType_BOOL,
DataType_FLOAT8,
DataType_HALF,
DataType_HALF2,
DataType_FLOAT,
DataType_DOUBLE,
DataType_INT8,
DataType_INT16,
DataType_INT32,
DataType_INT64,
DataType_UINT8,
DataType_UINT16,
DataType_UINT32,
DataType_UINT64,
DataType_QINT8,
DataType_QINT16,
DataType_BFLOAT16,
DataType_UTF8
inline const DType (&EnumValuesDType())[19] {
static const DType values[] = {
DType_INHERIT,
DType_BOOL,
DType_FLOAT8,
DType_HALF,
DType_HALF2,
DType_FLOAT,
DType_DOUBLE,
DType_INT8,
DType_INT16,
DType_INT32,
DType_INT64,
DType_UINT8,
DType_UINT16,
DType_UINT32,
DType_UINT64,
DType_QINT8,
DType_QINT16,
DType_BFLOAT16,
DType_UTF8
};
return values;
}
inline const char * const *EnumNamesDataType() {
inline const char * const *EnumNamesDType() {
static const char * const names[] = {
"INHERIT",
"BOOL",
@ -147,9 +147,9 @@ inline const char * const *EnumNamesDataType() {
return names;
}
inline const char *EnumNameDataType(DataType e) {
inline const char *EnumNameDType(DType e) {
const size_t index = static_cast<int>(e);
return EnumNamesDataType()[index];
return EnumNamesDType()[index];
}
struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
@ -165,8 +165,8 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::Vector<int8_t> *buffer() const {
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER);
}
DataType dtype() const {
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
DType dtype() const {
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
}
ByteOrder byteOrder() const {
return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0));
@ -192,7 +192,7 @@ struct FlatArrayBuilder {
void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) {
fbb_.AddOffset(FlatArray::VT_BUFFER, buffer);
}
void add_dtype(DataType dtype) {
void add_dtype(DType dtype) {
fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0);
}
void add_byteOrder(ByteOrder byteOrder) {
@ -214,7 +214,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArray(
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0,
DataType dtype = DataType_INHERIT,
DType dtype = DType_INHERIT,
ByteOrder byteOrder = ByteOrder_LE) {
FlatArrayBuilder builder_(_fbb);
builder_.add_buffer(buffer);
@ -228,7 +228,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArrayDirect(
flatbuffers::FlatBufferBuilder &_fbb,
const std::vector<int64_t> *shape = nullptr,
const std::vector<int8_t> *buffer = nullptr,
DataType dtype = DataType_INHERIT,
DType dtype = DType_INHERIT,
ByteOrder byteOrder = ByteOrder_LE) {
return nd4j::graph::CreateFlatArray(
_fbb,

View File

@ -23,7 +23,7 @@ nd4j.graph.ByteOrder = {
/**
* @enum
*/
nd4j.graph.DataType = {
nd4j.graph.DType = {
INHERIT: 0,
BOOL: 1,
FLOAT8: 2,
@ -123,11 +123,11 @@ nd4j.graph.FlatArray.prototype.bufferArray = function() {
};
/**
* @returns {nd4j.graph.DataType}
* @returns {nd4j.graph.DType}
*/
nd4j.graph.FlatArray.prototype.dtype = function() {
var offset = this.bb.__offset(this.bb_pos, 8);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
};
/**
@ -205,10 +205,10 @@ nd4j.graph.FlatArray.startBufferVector = function(builder, numElems) {
/**
* @param {flatbuffers.Builder} builder
* @param {nd4j.graph.DataType} dtype
* @param {nd4j.graph.DType} dtype
*/
nd4j.graph.FlatArray.addDtype = function(builder, dtype) {
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
};
/**

View File

@ -5,7 +5,7 @@
namespace nd4j.graph
{
public enum DataType : sbyte
public enum DType : sbyte
{
INHERIT = 0,
BOOL = 1,

View File

@ -2,8 +2,8 @@
package nd4j.graph;
public final class DataType {
private DataType() { }
public final class DType {
private DType() { }
public static final byte INHERIT = 0;
public static final byte BOOL = 1;
public static final byte FLOAT8 = 2;

View File

@ -2,7 +2,7 @@
# namespace: graph
class DataType(object):
class DType(object):
INHERIT = 0
BOOL = 1
FLOAT8 = 2

View File

@ -33,13 +33,13 @@ public struct FlatArray : IFlatbufferObject
public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); }
#endif
public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); }
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } }
public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder,
VectorOffset shapeOffset = default(VectorOffset),
VectorOffset bufferOffset = default(VectorOffset),
DataType dtype = DataType.INHERIT,
DType dtype = DType.INHERIT,
ByteOrder byteOrder = ByteOrder.LE) {
builder.StartObject(4);
FlatArray.AddBuffer(builder, bufferOffset);
@ -58,7 +58,7 @@ public struct FlatArray : IFlatbufferObject
public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); }
public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); }
public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) {
int o = builder.EndObject();

View File

@ -97,14 +97,14 @@ public struct FlatNode : IFlatbufferObject
public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); }
#endif
public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); }
public DataType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DataType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DataType)0; }
public DType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T
public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); }
#else
public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); }
#endif
public DataType[] GetOutputTypesArray() { return __p.__vector_as_array<DataType>(38); }
public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
@ -196,8 +196,8 @@ public struct FlatNode : IFlatbufferObject
public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); }
public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); }
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {

View File

@ -25,7 +25,7 @@ public struct FlatVariable : IFlatbufferObject
public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); }
#endif
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T
@ -41,7 +41,7 @@ public struct FlatVariable : IFlatbufferObject
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
Offset<IntPair> idOffset = default(Offset<IntPair>),
StringOffset nameOffset = default(StringOffset),
DataType dtype = DataType.INHERIT,
DType dtype = DType.INHERIT,
VectorOffset shapeOffset = default(VectorOffset),
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
int device = 0,
@ -60,7 +60,7 @@ public struct FlatVariable : IFlatbufferObject
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); }
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }

View File

@ -312,11 +312,11 @@ nd4j.graph.FlatNode.prototype.opName = function(optionalEncoding) {
/**
* @param {number} index
* @returns {nd4j.graph.DataType}
* @returns {nd4j.graph.DType}
*/
nd4j.graph.FlatNode.prototype.outputTypes = function(index) {
var offset = this.bb.__offset(this.bb_pos, 38);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DataType} */ (0);
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
};
/**
@ -686,7 +686,7 @@ nd4j.graph.FlatNode.addOutputTypes = function(builder, outputTypesOffset) {
/**
* @param {flatbuffers.Builder} builder
* @param {Array.<nd4j.graph.DataType>} data
* @param {Array.<nd4j.graph.DType>} data
* @returns {flatbuffers.Offset}
*/
nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) {

View File

@ -65,8 +65,8 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::String *name() const {
return GetPointer<const flatbuffers::String *>(VT_NAME);
}
DataType dtype() const {
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
DType dtype() const {
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
}
const flatbuffers::Vector<int64_t> *shape() const {
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
@ -106,7 +106,7 @@ struct FlatVariableBuilder {
void add_name(flatbuffers::Offset<flatbuffers::String> name) {
fbb_.AddOffset(FlatVariable::VT_NAME, name);
}
void add_dtype(DataType dtype) {
void add_dtype(DType dtype) {
fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0);
}
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
@ -137,7 +137,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariable(
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<IntPair> id = 0,
flatbuffers::Offset<flatbuffers::String> name = 0,
DataType dtype = DataType_INHERIT,
DType dtype = DType_INHERIT,
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
flatbuffers::Offset<FlatArray> ndarray = 0,
int32_t device = 0,
@ -157,7 +157,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<IntPair> id = 0,
const char *name = nullptr,
DataType dtype = DataType_INHERIT,
DType dtype = DType_INHERIT,
const std::vector<int64_t> *shape = nullptr,
flatbuffers::Offset<FlatArray> ndarray = 0,
int32_t device = 0,

View File

@ -76,11 +76,11 @@ nd4j.graph.FlatVariable.prototype.name = function(optionalEncoding) {
};
/**
* @returns {nd4j.graph.DataType}
* @returns {nd4j.graph.DType}
*/
nd4j.graph.FlatVariable.prototype.dtype = function() {
var offset = this.bb.__offset(this.bb_pos, 8);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
};
/**
@ -150,10 +150,10 @@ nd4j.graph.FlatVariable.addName = function(builder, nameOffset) {
/**
* @param {flatbuffers.Builder} builder
* @param {nd4j.graph.DataType} dtype
* @param {nd4j.graph.DType} dtype
*/
nd4j.graph.FlatVariable.addDtype = function(builder, dtype) {
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
};
/**

View File

@ -111,7 +111,7 @@ namespace nd4j {
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DType>(array.dataType()), bo);
}
}
}

View File

@ -219,7 +219,7 @@ namespace nd4j {
throw std::runtime_error("CONSTANT variable must have NDArray bundled");
auto ar = flatVariable->ndarray();
if (ar->dtype() == DataType_UTF8) {
if (ar->dtype() == DType_UTF8) {
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
} else {
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
@ -320,7 +320,7 @@ namespace nd4j {
auto fBuffer = builder.CreateVector(array->asByteVector());
// packing array
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType());
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DType) array->dataType());
// packing id/index of this var
auto fVid = CreateIntPair(builder, this->_id, this->_index);
@ -331,7 +331,7 @@ namespace nd4j {
stringId = builder.CreateString(this->_name);
// returning array
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
} else {
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
}

View File

@ -23,7 +23,7 @@ enum ByteOrder:byte {
}
// DataType for arrays/buffers
enum DataType:byte {
enum DType:byte {
INHERIT,
BOOL,
FLOAT8,
@ -49,7 +49,7 @@ enum DataType:byte {
table FlatArray {
shape:[long]; // shape in Nd4j format
buffer:[byte]; // byte buffer with data
dtype:DataType; // data type of actual data within buffer
dtype:DType; // data type of actual data within buffer
byteOrder:ByteOrder; // byte order of buffer
}

View File

@ -48,7 +48,7 @@ table FlatNode {
opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability
// output data types (optional)
outputTypes:[DataType];
outputTypes:[DType];
//Scalar value - used for scalar ops. Should be single value only.
scalar:FlatArray;

View File

@ -51,7 +51,7 @@ table UIVariable {
id:IntPair; //Existing IntPair class
name:string;
type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER
datatype:DataType;
datatype:DType;
shape:[long];
controlDeps:[string]; //Input control dependencies: variable x -> this
outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of

View File

@ -30,7 +30,7 @@ enum VarType:byte {
table FlatVariable {
id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node
name:string; // symbolic ID of the Variable (if defined)
dtype:DataType;
dtype:DType;
shape:[long]; // shape is absolutely optional. either shape or ndarray might be set
ndarray:FlatArray;

View File

@ -94,10 +94,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) {
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
auto fBuffer = builder.CreateVector(array->asByteVector());
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
auto fVid = CreateIntPair(builder, -1);
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
std::vector<int> outputs1, outputs2, inputs1, inputs2;
outputs1.push_back(2);
@ -265,7 +265,7 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) {
auto name1 = builder.CreateString("wow1");
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT);
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DType::FLOAT);
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
variables_vector.push_back(fXVar);

View File

@ -73,9 +73,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) {
auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
builder.Finish(flatVar);
@ -107,9 +107,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) {
auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
builder.Finish(flatVar);
@ -144,9 +144,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) {
auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
builder.Finish(flatVar);
@ -180,7 +180,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) {
auto fShape = builder.CreateVector(original.getShapeAsFlatVector());
auto fVid = CreateIntPair(builder, 37, 12);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
builder.Finish(flatVar);

View File

@ -31,7 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.graph.DataType;
import org.nd4j.graph.DType;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatProperties;
@ -66,33 +66,33 @@ public class FlatBuffersMapper {
public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
switch (type) {
case FLOAT:
return DataType.FLOAT;
return DType.FLOAT;
case DOUBLE:
return DataType.DOUBLE;
return DType.DOUBLE;
case HALF:
return DataType.HALF;
return DType.HALF;
case INT:
return DataType.INT32;
return DType.INT32;
case LONG:
return DataType.INT64;
return DType.INT64;
case BOOL:
return DataType.BOOL;
return DType.BOOL;
case SHORT:
return DataType.INT16;
return DType.INT16;
case BYTE:
return DataType.INT8;
return DType.INT8;
case UBYTE:
return DataType.UINT8;
return DType.UINT8;
case UTF8:
return DataType.UTF8;
return DType.UTF8;
case UINT16:
return DataType.UINT16;
return DType.UINT16;
case UINT32:
return DataType.UINT32;
return DType.UINT32;
case UINT64:
return DataType.UINT64;
return DType.UINT64;
case BFLOAT16:
return DataType.BFLOAT16;
return DType.BFLOAT16;
default:
throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
}
@ -102,33 +102,33 @@ public class FlatBuffersMapper {
* This method converts enums for DataType
*/
public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
if (val == DataType.FLOAT) {
if (val == DType.FLOAT) {
return org.nd4j.linalg.api.buffer.DataType.FLOAT;
} else if (val == DataType.DOUBLE) {
} else if (val == DType.DOUBLE) {
return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
} else if (val == DataType.HALF) {
} else if (val == DType.HALF) {
return org.nd4j.linalg.api.buffer.DataType.HALF;
} else if (val == DataType.INT32) {
} else if (val == DType.INT32) {
return org.nd4j.linalg.api.buffer.DataType.INT;
} else if (val == DataType.INT64) {
} else if (val == DType.INT64) {
return org.nd4j.linalg.api.buffer.DataType.LONG;
} else if (val == DataType.INT8) {
} else if (val == DType.INT8) {
return org.nd4j.linalg.api.buffer.DataType.BYTE;
} else if (val == DataType.BOOL) {
} else if (val == DType.BOOL) {
return org.nd4j.linalg.api.buffer.DataType.BOOL;
} else if (val == DataType.UINT8) {
} else if (val == DType.UINT8) {
return org.nd4j.linalg.api.buffer.DataType.UBYTE;
} else if (val == DataType.INT16) {
} else if (val == DType.INT16) {
return org.nd4j.linalg.api.buffer.DataType.SHORT;
} else if (val == DataType.UTF8) {
} else if (val == DType.UTF8) {
return org.nd4j.linalg.api.buffer.DataType.UTF8;
} else if (val == DataType.UINT16) {
} else if (val == DType.UINT16) {
return org.nd4j.linalg.api.buffer.DataType.UINT16;
} else if (val == DataType.UINT32) {
} else if (val == DType.UINT32) {
return org.nd4j.linalg.api.buffer.DataType.UINT32;
} else if (val == DataType.UINT64) {
} else if (val == DType.UINT64) {
return org.nd4j.linalg.api.buffer.DataType.UINT64;
} else if (val == DataType.BFLOAT16){
} else if (val == DType.BFLOAT16){
return org.nd4j.linalg.api.buffer.DataType.BFLOAT16;
} else {
throw new RuntimeException("Unknown datatype: " + val);

View File

@ -2,8 +2,8 @@
package org.nd4j.graph;
public final class DataType {
private DataType() { }
public final class DType {
private DType() { }
public static final byte INHERIT = 0;
public static final byte BOOL = 1;
public static final byte FLOAT8 = 2;