Rename flatbuffers DataType to DType (#228)
* Rename flatbuffers DataType enum to DType Signed-off-by: Alex Black <blacka101@gmail.com> * Rename flatbuffers DataType enum to DType Signed-off-by: Alex Black <blacka101@gmail.com> * Updates for flatbuffers datatype enum renaming Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
25b01f7850
commit
6cc887bee9
|
@ -583,7 +583,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
auto fName = builder.CreateString(*(var->getName()));
|
||||
auto id = CreateIntPair(builder, var->id(), var->index());
|
||||
|
||||
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
|
||||
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
|
||||
|
||||
variables_vector.push_back(fv);
|
||||
arrays++;
|
||||
|
|
|
@ -38,7 +38,7 @@ namespace nd4j {
|
|||
public:
|
||||
static int asInt(DataType type);
|
||||
static DataType fromInt(int dtype);
|
||||
static DataType fromFlatDataType(nd4j::graph::DataType dtype);
|
||||
static DataType fromFlatDataType(nd4j::graph::DType dtype);
|
||||
FORCEINLINE static std::string asString(DataType dataType);
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace nd4j {
|
|||
return (DataType) val;
|
||||
}
|
||||
|
||||
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DataType dtype) {
|
||||
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DType dtype) {
|
||||
return (DataType) dtype;
|
||||
}
|
||||
|
||||
|
|
|
@ -40,56 +40,56 @@ inline const char *EnumNameByteOrder(ByteOrder e) {
|
|||
return EnumNamesByteOrder()[index];
|
||||
}
|
||||
|
||||
enum DataType {
|
||||
DataType_INHERIT = 0,
|
||||
DataType_BOOL = 1,
|
||||
DataType_FLOAT8 = 2,
|
||||
DataType_HALF = 3,
|
||||
DataType_HALF2 = 4,
|
||||
DataType_FLOAT = 5,
|
||||
DataType_DOUBLE = 6,
|
||||
DataType_INT8 = 7,
|
||||
DataType_INT16 = 8,
|
||||
DataType_INT32 = 9,
|
||||
DataType_INT64 = 10,
|
||||
DataType_UINT8 = 11,
|
||||
DataType_UINT16 = 12,
|
||||
DataType_UINT32 = 13,
|
||||
DataType_UINT64 = 14,
|
||||
DataType_QINT8 = 15,
|
||||
DataType_QINT16 = 16,
|
||||
DataType_BFLOAT16 = 17,
|
||||
DataType_UTF8 = 50,
|
||||
DataType_MIN = DataType_INHERIT,
|
||||
DataType_MAX = DataType_UTF8
|
||||
enum DType {
|
||||
DType_INHERIT = 0,
|
||||
DType_BOOL = 1,
|
||||
DType_FLOAT8 = 2,
|
||||
DType_HALF = 3,
|
||||
DType_HALF2 = 4,
|
||||
DType_FLOAT = 5,
|
||||
DType_DOUBLE = 6,
|
||||
DType_INT8 = 7,
|
||||
DType_INT16 = 8,
|
||||
DType_INT32 = 9,
|
||||
DType_INT64 = 10,
|
||||
DType_UINT8 = 11,
|
||||
DType_UINT16 = 12,
|
||||
DType_UINT32 = 13,
|
||||
DType_UINT64 = 14,
|
||||
DType_QINT8 = 15,
|
||||
DType_QINT16 = 16,
|
||||
DType_BFLOAT16 = 17,
|
||||
DType_UTF8 = 50,
|
||||
DType_MIN = DType_INHERIT,
|
||||
DType_MAX = DType_UTF8
|
||||
};
|
||||
|
||||
inline const DataType (&EnumValuesDataType())[19] {
|
||||
static const DataType values[] = {
|
||||
DataType_INHERIT,
|
||||
DataType_BOOL,
|
||||
DataType_FLOAT8,
|
||||
DataType_HALF,
|
||||
DataType_HALF2,
|
||||
DataType_FLOAT,
|
||||
DataType_DOUBLE,
|
||||
DataType_INT8,
|
||||
DataType_INT16,
|
||||
DataType_INT32,
|
||||
DataType_INT64,
|
||||
DataType_UINT8,
|
||||
DataType_UINT16,
|
||||
DataType_UINT32,
|
||||
DataType_UINT64,
|
||||
DataType_QINT8,
|
||||
DataType_QINT16,
|
||||
DataType_BFLOAT16,
|
||||
DataType_UTF8
|
||||
inline const DType (&EnumValuesDType())[19] {
|
||||
static const DType values[] = {
|
||||
DType_INHERIT,
|
||||
DType_BOOL,
|
||||
DType_FLOAT8,
|
||||
DType_HALF,
|
||||
DType_HALF2,
|
||||
DType_FLOAT,
|
||||
DType_DOUBLE,
|
||||
DType_INT8,
|
||||
DType_INT16,
|
||||
DType_INT32,
|
||||
DType_INT64,
|
||||
DType_UINT8,
|
||||
DType_UINT16,
|
||||
DType_UINT32,
|
||||
DType_UINT64,
|
||||
DType_QINT8,
|
||||
DType_QINT16,
|
||||
DType_BFLOAT16,
|
||||
DType_UTF8
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesDataType() {
|
||||
inline const char * const *EnumNamesDType() {
|
||||
static const char * const names[] = {
|
||||
"INHERIT",
|
||||
"BOOL",
|
||||
|
@ -147,9 +147,9 @@ inline const char * const *EnumNamesDataType() {
|
|||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameDataType(DataType e) {
|
||||
inline const char *EnumNameDType(DType e) {
|
||||
const size_t index = static_cast<int>(e);
|
||||
return EnumNamesDataType()[index];
|
||||
return EnumNamesDType()[index];
|
||||
}
|
||||
|
||||
struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
|
@ -165,8 +165,8 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const flatbuffers::Vector<int8_t> *buffer() const {
|
||||
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER);
|
||||
}
|
||||
DataType dtype() const {
|
||||
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||
DType dtype() const {
|
||||
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||
}
|
||||
ByteOrder byteOrder() const {
|
||||
return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0));
|
||||
|
@ -192,7 +192,7 @@ struct FlatArrayBuilder {
|
|||
void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) {
|
||||
fbb_.AddOffset(FlatArray::VT_BUFFER, buffer);
|
||||
}
|
||||
void add_dtype(DataType dtype) {
|
||||
void add_dtype(DType dtype) {
|
||||
fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
||||
}
|
||||
void add_byteOrder(ByteOrder byteOrder) {
|
||||
|
@ -214,7 +214,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArray(
|
|||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0,
|
||||
DataType dtype = DataType_INHERIT,
|
||||
DType dtype = DType_INHERIT,
|
||||
ByteOrder byteOrder = ByteOrder_LE) {
|
||||
FlatArrayBuilder builder_(_fbb);
|
||||
builder_.add_buffer(buffer);
|
||||
|
@ -228,7 +228,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArrayDirect(
|
|||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
const std::vector<int64_t> *shape = nullptr,
|
||||
const std::vector<int8_t> *buffer = nullptr,
|
||||
DataType dtype = DataType_INHERIT,
|
||||
DType dtype = DType_INHERIT,
|
||||
ByteOrder byteOrder = ByteOrder_LE) {
|
||||
return nd4j::graph::CreateFlatArray(
|
||||
_fbb,
|
||||
|
|
|
@ -23,7 +23,7 @@ nd4j.graph.ByteOrder = {
|
|||
/**
|
||||
* @enum
|
||||
*/
|
||||
nd4j.graph.DataType = {
|
||||
nd4j.graph.DType = {
|
||||
INHERIT: 0,
|
||||
BOOL: 1,
|
||||
FLOAT8: 2,
|
||||
|
@ -123,11 +123,11 @@ nd4j.graph.FlatArray.prototype.bufferArray = function() {
|
|||
};
|
||||
|
||||
/**
|
||||
* @returns {nd4j.graph.DataType}
|
||||
* @returns {nd4j.graph.DType}
|
||||
*/
|
||||
nd4j.graph.FlatArray.prototype.dtype = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 8);
|
||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
|
||||
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -205,10 +205,10 @@ nd4j.graph.FlatArray.startBufferVector = function(builder, numElems) {
|
|||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {nd4j.graph.DataType} dtype
|
||||
* @param {nd4j.graph.DType} dtype
|
||||
*/
|
||||
nd4j.graph.FlatArray.addDtype = function(builder, dtype) {
|
||||
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
|
||||
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
namespace nd4j.graph
|
||||
{
|
||||
|
||||
public enum DataType : sbyte
|
||||
public enum DType : sbyte
|
||||
{
|
||||
INHERIT = 0,
|
||||
BOOL = 1,
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
package nd4j.graph;
|
||||
|
||||
public final class DataType {
|
||||
private DataType() { }
|
||||
public final class DType {
|
||||
private DType() { }
|
||||
public static final byte INHERIT = 0;
|
||||
public static final byte BOOL = 1;
|
||||
public static final byte FLOAT8 = 2;
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
# namespace: graph
|
||||
|
||||
class DataType(object):
|
||||
class DType(object):
|
||||
INHERIT = 0
|
||||
BOOL = 1
|
||||
FLOAT8 = 2
|
|
@ -33,13 +33,13 @@ public struct FlatArray : IFlatbufferObject
|
|||
public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); }
|
||||
#endif
|
||||
public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); }
|
||||
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
|
||||
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
|
||||
public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } }
|
||||
|
||||
public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder,
|
||||
VectorOffset shapeOffset = default(VectorOffset),
|
||||
VectorOffset bufferOffset = default(VectorOffset),
|
||||
DataType dtype = DataType.INHERIT,
|
||||
DType dtype = DType.INHERIT,
|
||||
ByteOrder byteOrder = ByteOrder.LE) {
|
||||
builder.StartObject(4);
|
||||
FlatArray.AddBuffer(builder, bufferOffset);
|
||||
|
@ -58,7 +58,7 @@ public struct FlatArray : IFlatbufferObject
|
|||
public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); }
|
||||
public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||
public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); }
|
||||
public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) {
|
||||
int o = builder.EndObject();
|
||||
|
|
|
@ -97,14 +97,14 @@ public struct FlatNode : IFlatbufferObject
|
|||
public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); }
|
||||
#endif
|
||||
public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); }
|
||||
public DataType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DataType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DataType)0; }
|
||||
public DType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
|
||||
public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
#if ENABLE_SPAN_T
|
||||
public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); }
|
||||
#else
|
||||
public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); }
|
||||
#endif
|
||||
public DataType[] GetOutputTypesArray() { return __p.__vector_as_array<DataType>(38); }
|
||||
public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
|
||||
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
||||
|
||||
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
||||
|
@ -196,8 +196,8 @@ public struct FlatNode : IFlatbufferObject
|
|||
public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||
public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); }
|
||||
public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); }
|
||||
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
|
||||
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
|
||||
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
|
||||
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
||||
|
|
|
@ -25,7 +25,7 @@ public struct FlatVariable : IFlatbufferObject
|
|||
public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); }
|
||||
#endif
|
||||
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
|
||||
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
|
||||
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
|
||||
public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
|
||||
public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
#if ENABLE_SPAN_T
|
||||
|
@ -41,7 +41,7 @@ public struct FlatVariable : IFlatbufferObject
|
|||
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
|
||||
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
||||
StringOffset nameOffset = default(StringOffset),
|
||||
DataType dtype = DataType.INHERIT,
|
||||
DType dtype = DType.INHERIT,
|
||||
VectorOffset shapeOffset = default(VectorOffset),
|
||||
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
|
||||
int device = 0,
|
||||
|
@ -60,7 +60,7 @@ public struct FlatVariable : IFlatbufferObject
|
|||
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
|
||||
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); }
|
||||
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
|
||||
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }
|
||||
|
|
|
@ -312,11 +312,11 @@ nd4j.graph.FlatNode.prototype.opName = function(optionalEncoding) {
|
|||
|
||||
/**
|
||||
* @param {number} index
|
||||
* @returns {nd4j.graph.DataType}
|
||||
* @returns {nd4j.graph.DType}
|
||||
*/
|
||||
nd4j.graph.FlatNode.prototype.outputTypes = function(index) {
|
||||
var offset = this.bb.__offset(this.bb_pos, 38);
|
||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DataType} */ (0);
|
||||
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -686,7 +686,7 @@ nd4j.graph.FlatNode.addOutputTypes = function(builder, outputTypesOffset) {
|
|||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {Array.<nd4j.graph.DataType>} data
|
||||
* @param {Array.<nd4j.graph.DType>} data
|
||||
* @returns {flatbuffers.Offset}
|
||||
*/
|
||||
nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) {
|
||||
|
|
|
@ -65,8 +65,8 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const flatbuffers::String *name() const {
|
||||
return GetPointer<const flatbuffers::String *>(VT_NAME);
|
||||
}
|
||||
DataType dtype() const {
|
||||
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||
DType dtype() const {
|
||||
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||
}
|
||||
const flatbuffers::Vector<int64_t> *shape() const {
|
||||
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
|
||||
|
@ -106,7 +106,7 @@ struct FlatVariableBuilder {
|
|||
void add_name(flatbuffers::Offset<flatbuffers::String> name) {
|
||||
fbb_.AddOffset(FlatVariable::VT_NAME, name);
|
||||
}
|
||||
void add_dtype(DataType dtype) {
|
||||
void add_dtype(DType dtype) {
|
||||
fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
||||
}
|
||||
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
|
||||
|
@ -137,7 +137,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariable(
|
|||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
flatbuffers::Offset<IntPair> id = 0,
|
||||
flatbuffers::Offset<flatbuffers::String> name = 0,
|
||||
DataType dtype = DataType_INHERIT,
|
||||
DType dtype = DType_INHERIT,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||
int32_t device = 0,
|
||||
|
@ -157,7 +157,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
|
|||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
flatbuffers::Offset<IntPair> id = 0,
|
||||
const char *name = nullptr,
|
||||
DataType dtype = DataType_INHERIT,
|
||||
DType dtype = DType_INHERIT,
|
||||
const std::vector<int64_t> *shape = nullptr,
|
||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||
int32_t device = 0,
|
||||
|
|
|
@ -76,11 +76,11 @@ nd4j.graph.FlatVariable.prototype.name = function(optionalEncoding) {
|
|||
};
|
||||
|
||||
/**
|
||||
* @returns {nd4j.graph.DataType}
|
||||
* @returns {nd4j.graph.DType}
|
||||
*/
|
||||
nd4j.graph.FlatVariable.prototype.dtype = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 8);
|
||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
|
||||
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -150,10 +150,10 @@ nd4j.graph.FlatVariable.addName = function(builder, nameOffset) {
|
|||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {nd4j.graph.DataType} dtype
|
||||
* @param {nd4j.graph.DType} dtype
|
||||
*/
|
||||
nd4j.graph.FlatVariable.addDtype = function(builder, dtype) {
|
||||
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
|
||||
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -111,7 +111,7 @@ namespace nd4j {
|
|||
|
||||
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
||||
|
||||
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
|
||||
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DType>(array.dataType()), bo);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -219,7 +219,7 @@ namespace nd4j {
|
|||
throw std::runtime_error("CONSTANT variable must have NDArray bundled");
|
||||
|
||||
auto ar = flatVariable->ndarray();
|
||||
if (ar->dtype() == DataType_UTF8) {
|
||||
if (ar->dtype() == DType_UTF8) {
|
||||
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||||
} else {
|
||||
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||||
|
@ -320,7 +320,7 @@ namespace nd4j {
|
|||
auto fBuffer = builder.CreateVector(array->asByteVector());
|
||||
|
||||
// packing array
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType());
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DType) array->dataType());
|
||||
|
||||
// packing id/index of this var
|
||||
auto fVid = CreateIntPair(builder, this->_id, this->_index);
|
||||
|
@ -331,7 +331,7 @@ namespace nd4j {
|
|||
stringId = builder.CreateString(this->_name);
|
||||
|
||||
// returning array
|
||||
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
|
||||
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
|
||||
} else {
|
||||
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ enum ByteOrder:byte {
|
|||
}
|
||||
|
||||
// DataType for arrays/buffers
|
||||
enum DataType:byte {
|
||||
enum DType:byte {
|
||||
INHERIT,
|
||||
BOOL,
|
||||
FLOAT8,
|
||||
|
@ -49,7 +49,7 @@ enum DataType:byte {
|
|||
table FlatArray {
|
||||
shape:[long]; // shape in Nd4j format
|
||||
buffer:[byte]; // byte buffer with data
|
||||
dtype:DataType; // data type of actual data within buffer
|
||||
dtype:DType; // data type of actual data within buffer
|
||||
byteOrder:ByteOrder; // byte order of buffer
|
||||
}
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ table FlatNode {
|
|||
opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability
|
||||
|
||||
// output data types (optional)
|
||||
outputTypes:[DataType];
|
||||
outputTypes:[DType];
|
||||
|
||||
//Scalar value - used for scalar ops. Should be single value only.
|
||||
scalar:FlatArray;
|
||||
|
|
|
@ -51,7 +51,7 @@ table UIVariable {
|
|||
id:IntPair; //Existing IntPair class
|
||||
name:string;
|
||||
type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER
|
||||
datatype:DataType;
|
||||
datatype:DType;
|
||||
shape:[long];
|
||||
controlDeps:[string]; //Input control dependencies: variable x -> this
|
||||
outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
||||
|
|
|
@ -30,7 +30,7 @@ enum VarType:byte {
|
|||
table FlatVariable {
|
||||
id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node
|
||||
name:string; // symbolic ID of the Variable (if defined)
|
||||
dtype:DataType;
|
||||
dtype:DType;
|
||||
|
||||
shape:[long]; // shape is absolutely optional. either shape or ndarray might be set
|
||||
ndarray:FlatArray;
|
||||
|
|
|
@ -94,10 +94,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) {
|
|||
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
||||
auto fBuffer = builder.CreateVector(array->asByteVector());
|
||||
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
|
||||
auto fVid = CreateIntPair(builder, -1);
|
||||
|
||||
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
|
||||
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
|
||||
|
||||
std::vector<int> outputs1, outputs2, inputs1, inputs2;
|
||||
outputs1.push_back(2);
|
||||
|
@ -265,7 +265,7 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) {
|
|||
|
||||
auto name1 = builder.CreateString("wow1");
|
||||
|
||||
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT);
|
||||
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DType::FLOAT);
|
||||
|
||||
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
|
||||
variables_vector.push_back(fXVar);
|
||||
|
|
|
@ -73,9 +73,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) {
|
|||
auto fBuffer = builder.CreateVector(vec);
|
||||
auto fVid = CreateIntPair(builder, 1, 12);
|
||||
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
|
||||
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
|
||||
|
||||
builder.Finish(flatVar);
|
||||
|
||||
|
@ -107,9 +107,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) {
|
|||
auto fBuffer = builder.CreateVector(vec);
|
||||
auto fVid = CreateIntPair(builder, 1, 12);
|
||||
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
|
||||
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
|
||||
|
||||
builder.Finish(flatVar);
|
||||
|
||||
|
@ -144,9 +144,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) {
|
|||
auto fBuffer = builder.CreateVector(vec);
|
||||
auto fVid = CreateIntPair(builder, 1, 12);
|
||||
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
|
||||
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
|
||||
|
||||
builder.Finish(flatVar);
|
||||
|
||||
|
@ -180,7 +180,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) {
|
|||
auto fShape = builder.CreateVector(original.getShapeAsFlatVector());
|
||||
auto fVid = CreateIntPair(builder, 37, 12);
|
||||
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
|
||||
|
||||
builder.Finish(flatVar);
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
|||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.graph.DataType;
|
||||
import org.nd4j.graph.DType;
|
||||
import org.nd4j.graph.FlatArray;
|
||||
import org.nd4j.graph.FlatNode;
|
||||
import org.nd4j.graph.FlatProperties;
|
||||
|
@ -66,33 +66,33 @@ public class FlatBuffersMapper {
|
|||
public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
|
||||
switch (type) {
|
||||
case FLOAT:
|
||||
return DataType.FLOAT;
|
||||
return DType.FLOAT;
|
||||
case DOUBLE:
|
||||
return DataType.DOUBLE;
|
||||
return DType.DOUBLE;
|
||||
case HALF:
|
||||
return DataType.HALF;
|
||||
return DType.HALF;
|
||||
case INT:
|
||||
return DataType.INT32;
|
||||
return DType.INT32;
|
||||
case LONG:
|
||||
return DataType.INT64;
|
||||
return DType.INT64;
|
||||
case BOOL:
|
||||
return DataType.BOOL;
|
||||
return DType.BOOL;
|
||||
case SHORT:
|
||||
return DataType.INT16;
|
||||
return DType.INT16;
|
||||
case BYTE:
|
||||
return DataType.INT8;
|
||||
return DType.INT8;
|
||||
case UBYTE:
|
||||
return DataType.UINT8;
|
||||
return DType.UINT8;
|
||||
case UTF8:
|
||||
return DataType.UTF8;
|
||||
return DType.UTF8;
|
||||
case UINT16:
|
||||
return DataType.UINT16;
|
||||
return DType.UINT16;
|
||||
case UINT32:
|
||||
return DataType.UINT32;
|
||||
return DType.UINT32;
|
||||
case UINT64:
|
||||
return DataType.UINT64;
|
||||
return DType.UINT64;
|
||||
case BFLOAT16:
|
||||
return DataType.BFLOAT16;
|
||||
return DType.BFLOAT16;
|
||||
default:
|
||||
throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
|
||||
}
|
||||
|
@ -102,33 +102,33 @@ public class FlatBuffersMapper {
|
|||
* This method converts enums for DataType
|
||||
*/
|
||||
public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
|
||||
if (val == DataType.FLOAT) {
|
||||
if (val == DType.FLOAT) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
||||
} else if (val == DataType.DOUBLE) {
|
||||
} else if (val == DType.DOUBLE) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
|
||||
} else if (val == DataType.HALF) {
|
||||
} else if (val == DType.HALF) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.HALF;
|
||||
} else if (val == DataType.INT32) {
|
||||
} else if (val == DType.INT32) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.INT;
|
||||
} else if (val == DataType.INT64) {
|
||||
} else if (val == DType.INT64) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.LONG;
|
||||
} else if (val == DataType.INT8) {
|
||||
} else if (val == DType.INT8) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.BYTE;
|
||||
} else if (val == DataType.BOOL) {
|
||||
} else if (val == DType.BOOL) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.BOOL;
|
||||
} else if (val == DataType.UINT8) {
|
||||
} else if (val == DType.UINT8) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.UBYTE;
|
||||
} else if (val == DataType.INT16) {
|
||||
} else if (val == DType.INT16) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.SHORT;
|
||||
} else if (val == DataType.UTF8) {
|
||||
} else if (val == DType.UTF8) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.UTF8;
|
||||
} else if (val == DataType.UINT16) {
|
||||
} else if (val == DType.UINT16) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.UINT16;
|
||||
} else if (val == DataType.UINT32) {
|
||||
} else if (val == DType.UINT32) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.UINT32;
|
||||
} else if (val == DataType.UINT64) {
|
||||
} else if (val == DType.UINT64) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.UINT64;
|
||||
} else if (val == DataType.BFLOAT16){
|
||||
} else if (val == DType.BFLOAT16){
|
||||
return org.nd4j.linalg.api.buffer.DataType.BFLOAT16;
|
||||
} else {
|
||||
throw new RuntimeException("Unknown datatype: " + val);
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
package org.nd4j.graph;
|
||||
|
||||
public final class DataType {
|
||||
private DataType() { }
|
||||
public final class DType {
|
||||
private DType() { }
|
||||
public static final byte INHERIT = 0;
|
||||
public static final byte BOOL = 1;
|
||||
public static final byte FLOAT8 = 2;
|
Loading…
Reference in New Issue