Rename flatbuffers DataType to DType (#228)
* Rename flatbuffers DataType enum to DType Signed-off-by: Alex Black <blacka101@gmail.com> * Rename flatbuffers DataType enum to DType Signed-off-by: Alex Black <blacka101@gmail.com> * Updates for flatbuffers datatype enum renaming Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
25b01f7850
commit
6cc887bee9
|
@ -583,7 +583,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
||||||
auto fName = builder.CreateString(*(var->getName()));
|
auto fName = builder.CreateString(*(var->getName()));
|
||||||
auto id = CreateIntPair(builder, var->id(), var->index());
|
auto id = CreateIntPair(builder, var->id(), var->index());
|
||||||
|
|
||||||
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
|
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
|
||||||
|
|
||||||
variables_vector.push_back(fv);
|
variables_vector.push_back(fv);
|
||||||
arrays++;
|
arrays++;
|
||||||
|
|
|
@ -38,7 +38,7 @@ namespace nd4j {
|
||||||
public:
|
public:
|
||||||
static int asInt(DataType type);
|
static int asInt(DataType type);
|
||||||
static DataType fromInt(int dtype);
|
static DataType fromInt(int dtype);
|
||||||
static DataType fromFlatDataType(nd4j::graph::DataType dtype);
|
static DataType fromFlatDataType(nd4j::graph::DType dtype);
|
||||||
FORCEINLINE static std::string asString(DataType dataType);
|
FORCEINLINE static std::string asString(DataType dataType);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace nd4j {
|
||||||
return (DataType) val;
|
return (DataType) val;
|
||||||
}
|
}
|
||||||
|
|
||||||
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DataType dtype) {
|
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DType dtype) {
|
||||||
return (DataType) dtype;
|
return (DataType) dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,56 +40,56 @@ inline const char *EnumNameByteOrder(ByteOrder e) {
|
||||||
return EnumNamesByteOrder()[index];
|
return EnumNamesByteOrder()[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
enum DataType {
|
enum DType {
|
||||||
DataType_INHERIT = 0,
|
DType_INHERIT = 0,
|
||||||
DataType_BOOL = 1,
|
DType_BOOL = 1,
|
||||||
DataType_FLOAT8 = 2,
|
DType_FLOAT8 = 2,
|
||||||
DataType_HALF = 3,
|
DType_HALF = 3,
|
||||||
DataType_HALF2 = 4,
|
DType_HALF2 = 4,
|
||||||
DataType_FLOAT = 5,
|
DType_FLOAT = 5,
|
||||||
DataType_DOUBLE = 6,
|
DType_DOUBLE = 6,
|
||||||
DataType_INT8 = 7,
|
DType_INT8 = 7,
|
||||||
DataType_INT16 = 8,
|
DType_INT16 = 8,
|
||||||
DataType_INT32 = 9,
|
DType_INT32 = 9,
|
||||||
DataType_INT64 = 10,
|
DType_INT64 = 10,
|
||||||
DataType_UINT8 = 11,
|
DType_UINT8 = 11,
|
||||||
DataType_UINT16 = 12,
|
DType_UINT16 = 12,
|
||||||
DataType_UINT32 = 13,
|
DType_UINT32 = 13,
|
||||||
DataType_UINT64 = 14,
|
DType_UINT64 = 14,
|
||||||
DataType_QINT8 = 15,
|
DType_QINT8 = 15,
|
||||||
DataType_QINT16 = 16,
|
DType_QINT16 = 16,
|
||||||
DataType_BFLOAT16 = 17,
|
DType_BFLOAT16 = 17,
|
||||||
DataType_UTF8 = 50,
|
DType_UTF8 = 50,
|
||||||
DataType_MIN = DataType_INHERIT,
|
DType_MIN = DType_INHERIT,
|
||||||
DataType_MAX = DataType_UTF8
|
DType_MAX = DType_UTF8
|
||||||
};
|
};
|
||||||
|
|
||||||
inline const DataType (&EnumValuesDataType())[19] {
|
inline const DType (&EnumValuesDType())[19] {
|
||||||
static const DataType values[] = {
|
static const DType values[] = {
|
||||||
DataType_INHERIT,
|
DType_INHERIT,
|
||||||
DataType_BOOL,
|
DType_BOOL,
|
||||||
DataType_FLOAT8,
|
DType_FLOAT8,
|
||||||
DataType_HALF,
|
DType_HALF,
|
||||||
DataType_HALF2,
|
DType_HALF2,
|
||||||
DataType_FLOAT,
|
DType_FLOAT,
|
||||||
DataType_DOUBLE,
|
DType_DOUBLE,
|
||||||
DataType_INT8,
|
DType_INT8,
|
||||||
DataType_INT16,
|
DType_INT16,
|
||||||
DataType_INT32,
|
DType_INT32,
|
||||||
DataType_INT64,
|
DType_INT64,
|
||||||
DataType_UINT8,
|
DType_UINT8,
|
||||||
DataType_UINT16,
|
DType_UINT16,
|
||||||
DataType_UINT32,
|
DType_UINT32,
|
||||||
DataType_UINT64,
|
DType_UINT64,
|
||||||
DataType_QINT8,
|
DType_QINT8,
|
||||||
DataType_QINT16,
|
DType_QINT16,
|
||||||
DataType_BFLOAT16,
|
DType_BFLOAT16,
|
||||||
DataType_UTF8
|
DType_UTF8
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const char * const *EnumNamesDataType() {
|
inline const char * const *EnumNamesDType() {
|
||||||
static const char * const names[] = {
|
static const char * const names[] = {
|
||||||
"INHERIT",
|
"INHERIT",
|
||||||
"BOOL",
|
"BOOL",
|
||||||
|
@ -147,9 +147,9 @@ inline const char * const *EnumNamesDataType() {
|
||||||
return names;
|
return names;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const char *EnumNameDataType(DataType e) {
|
inline const char *EnumNameDType(DType e) {
|
||||||
const size_t index = static_cast<int>(e);
|
const size_t index = static_cast<int>(e);
|
||||||
return EnumNamesDataType()[index];
|
return EnumNamesDType()[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
|
@ -165,8 +165,8 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
const flatbuffers::Vector<int8_t> *buffer() const {
|
const flatbuffers::Vector<int8_t> *buffer() const {
|
||||||
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER);
|
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER);
|
||||||
}
|
}
|
||||||
DataType dtype() const {
|
DType dtype() const {
|
||||||
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
|
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||||
}
|
}
|
||||||
ByteOrder byteOrder() const {
|
ByteOrder byteOrder() const {
|
||||||
return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0));
|
return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0));
|
||||||
|
@ -192,7 +192,7 @@ struct FlatArrayBuilder {
|
||||||
void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) {
|
void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) {
|
||||||
fbb_.AddOffset(FlatArray::VT_BUFFER, buffer);
|
fbb_.AddOffset(FlatArray::VT_BUFFER, buffer);
|
||||||
}
|
}
|
||||||
void add_dtype(DataType dtype) {
|
void add_dtype(DType dtype) {
|
||||||
fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
||||||
}
|
}
|
||||||
void add_byteOrder(ByteOrder byteOrder) {
|
void add_byteOrder(ByteOrder byteOrder) {
|
||||||
|
@ -214,7 +214,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArray(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0,
|
flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0,
|
||||||
DataType dtype = DataType_INHERIT,
|
DType dtype = DType_INHERIT,
|
||||||
ByteOrder byteOrder = ByteOrder_LE) {
|
ByteOrder byteOrder = ByteOrder_LE) {
|
||||||
FlatArrayBuilder builder_(_fbb);
|
FlatArrayBuilder builder_(_fbb);
|
||||||
builder_.add_buffer(buffer);
|
builder_.add_buffer(buffer);
|
||||||
|
@ -228,7 +228,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArrayDirect(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
const std::vector<int64_t> *shape = nullptr,
|
const std::vector<int64_t> *shape = nullptr,
|
||||||
const std::vector<int8_t> *buffer = nullptr,
|
const std::vector<int8_t> *buffer = nullptr,
|
||||||
DataType dtype = DataType_INHERIT,
|
DType dtype = DType_INHERIT,
|
||||||
ByteOrder byteOrder = ByteOrder_LE) {
|
ByteOrder byteOrder = ByteOrder_LE) {
|
||||||
return nd4j::graph::CreateFlatArray(
|
return nd4j::graph::CreateFlatArray(
|
||||||
_fbb,
|
_fbb,
|
||||||
|
|
|
@ -23,7 +23,7 @@ nd4j.graph.ByteOrder = {
|
||||||
/**
|
/**
|
||||||
* @enum
|
* @enum
|
||||||
*/
|
*/
|
||||||
nd4j.graph.DataType = {
|
nd4j.graph.DType = {
|
||||||
INHERIT: 0,
|
INHERIT: 0,
|
||||||
BOOL: 1,
|
BOOL: 1,
|
||||||
FLOAT8: 2,
|
FLOAT8: 2,
|
||||||
|
@ -123,11 +123,11 @@ nd4j.graph.FlatArray.prototype.bufferArray = function() {
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @returns {nd4j.graph.DataType}
|
* @returns {nd4j.graph.DType}
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatArray.prototype.dtype = function() {
|
nd4j.graph.FlatArray.prototype.dtype = function() {
|
||||||
var offset = this.bb.__offset(this.bb_pos, 8);
|
var offset = this.bb.__offset(this.bb_pos, 8);
|
||||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
|
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -205,10 +205,10 @@ nd4j.graph.FlatArray.startBufferVector = function(builder, numElems) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @param {nd4j.graph.DataType} dtype
|
* @param {nd4j.graph.DType} dtype
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatArray.addDtype = function(builder, dtype) {
|
nd4j.graph.FlatArray.addDtype = function(builder, dtype) {
|
||||||
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
|
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
namespace nd4j.graph
|
namespace nd4j.graph
|
||||||
{
|
{
|
||||||
|
|
||||||
public enum DataType : sbyte
|
public enum DType : sbyte
|
||||||
{
|
{
|
||||||
INHERIT = 0,
|
INHERIT = 0,
|
||||||
BOOL = 1,
|
BOOL = 1,
|
|
@ -2,8 +2,8 @@
|
||||||
|
|
||||||
package nd4j.graph;
|
package nd4j.graph;
|
||||||
|
|
||||||
public final class DataType {
|
public final class DType {
|
||||||
private DataType() { }
|
private DType() { }
|
||||||
public static final byte INHERIT = 0;
|
public static final byte INHERIT = 0;
|
||||||
public static final byte BOOL = 1;
|
public static final byte BOOL = 1;
|
||||||
public static final byte FLOAT8 = 2;
|
public static final byte FLOAT8 = 2;
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
# namespace: graph
|
# namespace: graph
|
||||||
|
|
||||||
class DataType(object):
|
class DType(object):
|
||||||
INHERIT = 0
|
INHERIT = 0
|
||||||
BOOL = 1
|
BOOL = 1
|
||||||
FLOAT8 = 2
|
FLOAT8 = 2
|
|
@ -33,13 +33,13 @@ public struct FlatArray : IFlatbufferObject
|
||||||
public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); }
|
public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); }
|
||||||
#endif
|
#endif
|
||||||
public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); }
|
public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); }
|
||||||
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
|
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
|
||||||
public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } }
|
public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } }
|
||||||
|
|
||||||
public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder,
|
public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder,
|
||||||
VectorOffset shapeOffset = default(VectorOffset),
|
VectorOffset shapeOffset = default(VectorOffset),
|
||||||
VectorOffset bufferOffset = default(VectorOffset),
|
VectorOffset bufferOffset = default(VectorOffset),
|
||||||
DataType dtype = DataType.INHERIT,
|
DType dtype = DType.INHERIT,
|
||||||
ByteOrder byteOrder = ByteOrder.LE) {
|
ByteOrder byteOrder = ByteOrder.LE) {
|
||||||
builder.StartObject(4);
|
builder.StartObject(4);
|
||||||
FlatArray.AddBuffer(builder, bufferOffset);
|
FlatArray.AddBuffer(builder, bufferOffset);
|
||||||
|
@ -58,7 +58,7 @@ public struct FlatArray : IFlatbufferObject
|
||||||
public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); }
|
public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); }
|
||||||
public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||||
public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||||
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||||
public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); }
|
public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); }
|
||||||
public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) {
|
public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) {
|
||||||
int o = builder.EndObject();
|
int o = builder.EndObject();
|
||||||
|
|
|
@ -97,14 +97,14 @@ public struct FlatNode : IFlatbufferObject
|
||||||
public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); }
|
public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); }
|
||||||
#endif
|
#endif
|
||||||
public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); }
|
public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); }
|
||||||
public DataType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DataType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DataType)0; }
|
public DType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
|
||||||
public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } }
|
public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
#if ENABLE_SPAN_T
|
#if ENABLE_SPAN_T
|
||||||
public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); }
|
public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); }
|
||||||
#else
|
#else
|
||||||
public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); }
|
public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); }
|
||||||
#endif
|
#endif
|
||||||
public DataType[] GetOutputTypesArray() { return __p.__vector_as_array<DataType>(38); }
|
public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
|
||||||
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
||||||
|
|
||||||
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
||||||
|
@ -196,8 +196,8 @@ public struct FlatNode : IFlatbufferObject
|
||||||
public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||||
public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); }
|
public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); }
|
||||||
public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); }
|
public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); }
|
||||||
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
|
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
|
||||||
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||||
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||||
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
|
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
|
||||||
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
||||||
|
|
|
@ -25,7 +25,7 @@ public struct FlatVariable : IFlatbufferObject
|
||||||
public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); }
|
public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); }
|
||||||
#endif
|
#endif
|
||||||
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
|
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
|
||||||
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
|
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
|
||||||
public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
|
public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
|
||||||
public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } }
|
public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
#if ENABLE_SPAN_T
|
#if ENABLE_SPAN_T
|
||||||
|
@ -41,7 +41,7 @@ public struct FlatVariable : IFlatbufferObject
|
||||||
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
|
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
|
||||||
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
||||||
StringOffset nameOffset = default(StringOffset),
|
StringOffset nameOffset = default(StringOffset),
|
||||||
DataType dtype = DataType.INHERIT,
|
DType dtype = DType.INHERIT,
|
||||||
VectorOffset shapeOffset = default(VectorOffset),
|
VectorOffset shapeOffset = default(VectorOffset),
|
||||||
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
|
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
|
||||||
int device = 0,
|
int device = 0,
|
||||||
|
@ -60,7 +60,7 @@ public struct FlatVariable : IFlatbufferObject
|
||||||
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
|
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
|
||||||
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
||||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||||
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||||
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); }
|
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); }
|
||||||
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
|
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
|
||||||
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }
|
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }
|
||||||
|
|
|
@ -312,11 +312,11 @@ nd4j.graph.FlatNode.prototype.opName = function(optionalEncoding) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {number} index
|
* @param {number} index
|
||||||
* @returns {nd4j.graph.DataType}
|
* @returns {nd4j.graph.DType}
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatNode.prototype.outputTypes = function(index) {
|
nd4j.graph.FlatNode.prototype.outputTypes = function(index) {
|
||||||
var offset = this.bb.__offset(this.bb_pos, 38);
|
var offset = this.bb.__offset(this.bb_pos, 38);
|
||||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DataType} */ (0);
|
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -686,7 +686,7 @@ nd4j.graph.FlatNode.addOutputTypes = function(builder, outputTypesOffset) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @param {Array.<nd4j.graph.DataType>} data
|
* @param {Array.<nd4j.graph.DType>} data
|
||||||
* @returns {flatbuffers.Offset}
|
* @returns {flatbuffers.Offset}
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) {
|
nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) {
|
||||||
|
|
|
@ -65,8 +65,8 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
const flatbuffers::String *name() const {
|
const flatbuffers::String *name() const {
|
||||||
return GetPointer<const flatbuffers::String *>(VT_NAME);
|
return GetPointer<const flatbuffers::String *>(VT_NAME);
|
||||||
}
|
}
|
||||||
DataType dtype() const {
|
DType dtype() const {
|
||||||
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
|
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||||
}
|
}
|
||||||
const flatbuffers::Vector<int64_t> *shape() const {
|
const flatbuffers::Vector<int64_t> *shape() const {
|
||||||
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
|
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
|
||||||
|
@ -106,7 +106,7 @@ struct FlatVariableBuilder {
|
||||||
void add_name(flatbuffers::Offset<flatbuffers::String> name) {
|
void add_name(flatbuffers::Offset<flatbuffers::String> name) {
|
||||||
fbb_.AddOffset(FlatVariable::VT_NAME, name);
|
fbb_.AddOffset(FlatVariable::VT_NAME, name);
|
||||||
}
|
}
|
||||||
void add_dtype(DataType dtype) {
|
void add_dtype(DType dtype) {
|
||||||
fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
||||||
}
|
}
|
||||||
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
|
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
|
||||||
|
@ -137,7 +137,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariable(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
flatbuffers::Offset<IntPair> id = 0,
|
flatbuffers::Offset<IntPair> id = 0,
|
||||||
flatbuffers::Offset<flatbuffers::String> name = 0,
|
flatbuffers::Offset<flatbuffers::String> name = 0,
|
||||||
DataType dtype = DataType_INHERIT,
|
DType dtype = DType_INHERIT,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||||
int32_t device = 0,
|
int32_t device = 0,
|
||||||
|
@ -157,7 +157,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
flatbuffers::Offset<IntPair> id = 0,
|
flatbuffers::Offset<IntPair> id = 0,
|
||||||
const char *name = nullptr,
|
const char *name = nullptr,
|
||||||
DataType dtype = DataType_INHERIT,
|
DType dtype = DType_INHERIT,
|
||||||
const std::vector<int64_t> *shape = nullptr,
|
const std::vector<int64_t> *shape = nullptr,
|
||||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||||
int32_t device = 0,
|
int32_t device = 0,
|
||||||
|
|
|
@ -76,11 +76,11 @@ nd4j.graph.FlatVariable.prototype.name = function(optionalEncoding) {
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @returns {nd4j.graph.DataType}
|
* @returns {nd4j.graph.DType}
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatVariable.prototype.dtype = function() {
|
nd4j.graph.FlatVariable.prototype.dtype = function() {
|
||||||
var offset = this.bb.__offset(this.bb_pos, 8);
|
var offset = this.bb.__offset(this.bb_pos, 8);
|
||||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
|
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -150,10 +150,10 @@ nd4j.graph.FlatVariable.addName = function(builder, nameOffset) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @param {nd4j.graph.DataType} dtype
|
* @param {nd4j.graph.DType} dtype
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatVariable.addDtype = function(builder, dtype) {
|
nd4j.graph.FlatVariable.addDtype = function(builder, dtype) {
|
||||||
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
|
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -111,7 +111,7 @@ namespace nd4j {
|
||||||
|
|
||||||
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
||||||
|
|
||||||
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
|
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DType>(array.dataType()), bo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -219,7 +219,7 @@ namespace nd4j {
|
||||||
throw std::runtime_error("CONSTANT variable must have NDArray bundled");
|
throw std::runtime_error("CONSTANT variable must have NDArray bundled");
|
||||||
|
|
||||||
auto ar = flatVariable->ndarray();
|
auto ar = flatVariable->ndarray();
|
||||||
if (ar->dtype() == DataType_UTF8) {
|
if (ar->dtype() == DType_UTF8) {
|
||||||
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||||||
} else {
|
} else {
|
||||||
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||||||
|
@ -320,7 +320,7 @@ namespace nd4j {
|
||||||
auto fBuffer = builder.CreateVector(array->asByteVector());
|
auto fBuffer = builder.CreateVector(array->asByteVector());
|
||||||
|
|
||||||
// packing array
|
// packing array
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType());
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DType) array->dataType());
|
||||||
|
|
||||||
// packing id/index of this var
|
// packing id/index of this var
|
||||||
auto fVid = CreateIntPair(builder, this->_id, this->_index);
|
auto fVid = CreateIntPair(builder, this->_id, this->_index);
|
||||||
|
@ -331,7 +331,7 @@ namespace nd4j {
|
||||||
stringId = builder.CreateString(this->_name);
|
stringId = builder.CreateString(this->_name);
|
||||||
|
|
||||||
// returning array
|
// returning array
|
||||||
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
|
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
|
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ enum ByteOrder:byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DataType for arrays/buffers
|
// DataType for arrays/buffers
|
||||||
enum DataType:byte {
|
enum DType:byte {
|
||||||
INHERIT,
|
INHERIT,
|
||||||
BOOL,
|
BOOL,
|
||||||
FLOAT8,
|
FLOAT8,
|
||||||
|
@ -49,7 +49,7 @@ enum DataType:byte {
|
||||||
table FlatArray {
|
table FlatArray {
|
||||||
shape:[long]; // shape in Nd4j format
|
shape:[long]; // shape in Nd4j format
|
||||||
buffer:[byte]; // byte buffer with data
|
buffer:[byte]; // byte buffer with data
|
||||||
dtype:DataType; // data type of actual data within buffer
|
dtype:DType; // data type of actual data within buffer
|
||||||
byteOrder:ByteOrder; // byte order of buffer
|
byteOrder:ByteOrder; // byte order of buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ table FlatNode {
|
||||||
opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability
|
opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability
|
||||||
|
|
||||||
// output data types (optional)
|
// output data types (optional)
|
||||||
outputTypes:[DataType];
|
outputTypes:[DType];
|
||||||
|
|
||||||
//Scalar value - used for scalar ops. Should be single value only.
|
//Scalar value - used for scalar ops. Should be single value only.
|
||||||
scalar:FlatArray;
|
scalar:FlatArray;
|
||||||
|
|
|
@ -51,7 +51,7 @@ table UIVariable {
|
||||||
id:IntPair; //Existing IntPair class
|
id:IntPair; //Existing IntPair class
|
||||||
name:string;
|
name:string;
|
||||||
type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER
|
type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER
|
||||||
datatype:DataType;
|
datatype:DType;
|
||||||
shape:[long];
|
shape:[long];
|
||||||
controlDeps:[string]; //Input control dependencies: variable x -> this
|
controlDeps:[string]; //Input control dependencies: variable x -> this
|
||||||
outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
||||||
|
|
|
@ -30,7 +30,7 @@ enum VarType:byte {
|
||||||
table FlatVariable {
|
table FlatVariable {
|
||||||
id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node
|
id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node
|
||||||
name:string; // symbolic ID of the Variable (if defined)
|
name:string; // symbolic ID of the Variable (if defined)
|
||||||
dtype:DataType;
|
dtype:DType;
|
||||||
|
|
||||||
shape:[long]; // shape is absolutely optional. either shape or ndarray might be set
|
shape:[long]; // shape is absolutely optional. either shape or ndarray might be set
|
||||||
ndarray:FlatArray;
|
ndarray:FlatArray;
|
||||||
|
|
|
@ -94,10 +94,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) {
|
||||||
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
||||||
auto fBuffer = builder.CreateVector(array->asByteVector());
|
auto fBuffer = builder.CreateVector(array->asByteVector());
|
||||||
|
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
|
||||||
auto fVid = CreateIntPair(builder, -1);
|
auto fVid = CreateIntPair(builder, -1);
|
||||||
|
|
||||||
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
|
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
|
||||||
|
|
||||||
std::vector<int> outputs1, outputs2, inputs1, inputs2;
|
std::vector<int> outputs1, outputs2, inputs1, inputs2;
|
||||||
outputs1.push_back(2);
|
outputs1.push_back(2);
|
||||||
|
@ -265,7 +265,7 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) {
|
||||||
|
|
||||||
auto name1 = builder.CreateString("wow1");
|
auto name1 = builder.CreateString("wow1");
|
||||||
|
|
||||||
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT);
|
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DType::FLOAT);
|
||||||
|
|
||||||
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
|
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
|
||||||
variables_vector.push_back(fXVar);
|
variables_vector.push_back(fXVar);
|
||||||
|
|
|
@ -73,9 +73,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) {
|
||||||
auto fBuffer = builder.CreateVector(vec);
|
auto fBuffer = builder.CreateVector(vec);
|
||||||
auto fVid = CreateIntPair(builder, 1, 12);
|
auto fVid = CreateIntPair(builder, 1, 12);
|
||||||
|
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
|
||||||
|
|
||||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
|
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
|
||||||
|
|
||||||
builder.Finish(flatVar);
|
builder.Finish(flatVar);
|
||||||
|
|
||||||
|
@ -107,9 +107,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) {
|
||||||
auto fBuffer = builder.CreateVector(vec);
|
auto fBuffer = builder.CreateVector(vec);
|
||||||
auto fVid = CreateIntPair(builder, 1, 12);
|
auto fVid = CreateIntPair(builder, 1, 12);
|
||||||
|
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
|
||||||
|
|
||||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
|
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
|
||||||
|
|
||||||
builder.Finish(flatVar);
|
builder.Finish(flatVar);
|
||||||
|
|
||||||
|
@ -144,9 +144,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) {
|
||||||
auto fBuffer = builder.CreateVector(vec);
|
auto fBuffer = builder.CreateVector(vec);
|
||||||
auto fVid = CreateIntPair(builder, 1, 12);
|
auto fVid = CreateIntPair(builder, 1, 12);
|
||||||
|
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
|
||||||
|
|
||||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
|
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
|
||||||
|
|
||||||
builder.Finish(flatVar);
|
builder.Finish(flatVar);
|
||||||
|
|
||||||
|
@ -180,7 +180,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) {
|
||||||
auto fShape = builder.CreateVector(original.getShapeAsFlatVector());
|
auto fShape = builder.CreateVector(original.getShapeAsFlatVector());
|
||||||
auto fVid = CreateIntPair(builder, 37, 12);
|
auto fVid = CreateIntPair(builder, 37, 12);
|
||||||
|
|
||||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
|
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
|
||||||
|
|
||||||
builder.Finish(flatVar);
|
builder.Finish(flatVar);
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.VariableType;
|
import org.nd4j.autodiff.samediff.VariableType;
|
||||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.graph.DataType;
|
import org.nd4j.graph.DType;
|
||||||
import org.nd4j.graph.FlatArray;
|
import org.nd4j.graph.FlatArray;
|
||||||
import org.nd4j.graph.FlatNode;
|
import org.nd4j.graph.FlatNode;
|
||||||
import org.nd4j.graph.FlatProperties;
|
import org.nd4j.graph.FlatProperties;
|
||||||
|
@ -66,33 +66,33 @@ public class FlatBuffersMapper {
|
||||||
public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
|
public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
return DataType.FLOAT;
|
return DType.FLOAT;
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
return DataType.DOUBLE;
|
return DType.DOUBLE;
|
||||||
case HALF:
|
case HALF:
|
||||||
return DataType.HALF;
|
return DType.HALF;
|
||||||
case INT:
|
case INT:
|
||||||
return DataType.INT32;
|
return DType.INT32;
|
||||||
case LONG:
|
case LONG:
|
||||||
return DataType.INT64;
|
return DType.INT64;
|
||||||
case BOOL:
|
case BOOL:
|
||||||
return DataType.BOOL;
|
return DType.BOOL;
|
||||||
case SHORT:
|
case SHORT:
|
||||||
return DataType.INT16;
|
return DType.INT16;
|
||||||
case BYTE:
|
case BYTE:
|
||||||
return DataType.INT8;
|
return DType.INT8;
|
||||||
case UBYTE:
|
case UBYTE:
|
||||||
return DataType.UINT8;
|
return DType.UINT8;
|
||||||
case UTF8:
|
case UTF8:
|
||||||
return DataType.UTF8;
|
return DType.UTF8;
|
||||||
case UINT16:
|
case UINT16:
|
||||||
return DataType.UINT16;
|
return DType.UINT16;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
return DataType.UINT32;
|
return DType.UINT32;
|
||||||
case UINT64:
|
case UINT64:
|
||||||
return DataType.UINT64;
|
return DType.UINT64;
|
||||||
case BFLOAT16:
|
case BFLOAT16:
|
||||||
return DataType.BFLOAT16;
|
return DType.BFLOAT16;
|
||||||
default:
|
default:
|
||||||
throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
|
throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
|
||||||
}
|
}
|
||||||
|
@ -102,33 +102,33 @@ public class FlatBuffersMapper {
|
||||||
* This method converts enums for DataType
|
* This method converts enums for DataType
|
||||||
*/
|
*/
|
||||||
public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
|
public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
|
||||||
if (val == DataType.FLOAT) {
|
if (val == DType.FLOAT) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
return org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
||||||
} else if (val == DataType.DOUBLE) {
|
} else if (val == DType.DOUBLE) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
|
return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
|
||||||
} else if (val == DataType.HALF) {
|
} else if (val == DType.HALF) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.HALF;
|
return org.nd4j.linalg.api.buffer.DataType.HALF;
|
||||||
} else if (val == DataType.INT32) {
|
} else if (val == DType.INT32) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.INT;
|
return org.nd4j.linalg.api.buffer.DataType.INT;
|
||||||
} else if (val == DataType.INT64) {
|
} else if (val == DType.INT64) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.LONG;
|
return org.nd4j.linalg.api.buffer.DataType.LONG;
|
||||||
} else if (val == DataType.INT8) {
|
} else if (val == DType.INT8) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.BYTE;
|
return org.nd4j.linalg.api.buffer.DataType.BYTE;
|
||||||
} else if (val == DataType.BOOL) {
|
} else if (val == DType.BOOL) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.BOOL;
|
return org.nd4j.linalg.api.buffer.DataType.BOOL;
|
||||||
} else if (val == DataType.UINT8) {
|
} else if (val == DType.UINT8) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.UBYTE;
|
return org.nd4j.linalg.api.buffer.DataType.UBYTE;
|
||||||
} else if (val == DataType.INT16) {
|
} else if (val == DType.INT16) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.SHORT;
|
return org.nd4j.linalg.api.buffer.DataType.SHORT;
|
||||||
} else if (val == DataType.UTF8) {
|
} else if (val == DType.UTF8) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.UTF8;
|
return org.nd4j.linalg.api.buffer.DataType.UTF8;
|
||||||
} else if (val == DataType.UINT16) {
|
} else if (val == DType.UINT16) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.UINT16;
|
return org.nd4j.linalg.api.buffer.DataType.UINT16;
|
||||||
} else if (val == DataType.UINT32) {
|
} else if (val == DType.UINT32) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.UINT32;
|
return org.nd4j.linalg.api.buffer.DataType.UINT32;
|
||||||
} else if (val == DataType.UINT64) {
|
} else if (val == DType.UINT64) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.UINT64;
|
return org.nd4j.linalg.api.buffer.DataType.UINT64;
|
||||||
} else if (val == DataType.BFLOAT16){
|
} else if (val == DType.BFLOAT16){
|
||||||
return org.nd4j.linalg.api.buffer.DataType.BFLOAT16;
|
return org.nd4j.linalg.api.buffer.DataType.BFLOAT16;
|
||||||
} else {
|
} else {
|
||||||
throw new RuntimeException("Unknown datatype: " + val);
|
throw new RuntimeException("Unknown datatype: " + val);
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
|
|
||||||
package org.nd4j.graph;
|
package org.nd4j.graph;
|
||||||
|
|
||||||
public final class DataType {
|
public final class DType {
|
||||||
private DataType() { }
|
private DType() { }
|
||||||
public static final byte INHERIT = 0;
|
public static final byte INHERIT = 0;
|
||||||
public static final byte BOOL = 1;
|
public static final byte BOOL = 1;
|
||||||
public static final byte FLOAT8 = 2;
|
public static final byte FLOAT8 = 2;
|
Loading…
Reference in New Issue