2022-09-20 15:40:53 +02:00
|
|
|
# /* ******************************************************************************
|
|
|
|
# *
|
|
|
|
# *
|
|
|
|
# * This program and the accompanying materials are made available under the
|
|
|
|
# * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
# * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
# *
|
|
|
|
# * See the NOTICE file distributed with this work for additional
|
|
|
|
# * information regarding copyright ownership.
|
|
|
|
# * Unless required by applicable law or agreed to in writing, software
|
|
|
|
# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
# * License for the specific language governing permissions and limitations
|
|
|
|
# * under the License.
|
|
|
|
# *
|
|
|
|
# * SPDX-License-Identifier: Apache-2.0
|
|
|
|
# ******************************************************************************/
|
|
|
|
|
|
|
|
from onnx.defs import get_all_schemas
|
|
|
|
from onnx import NodeProto,GraphProto
|
|
|
|
from google.protobuf import text_format
|
|
|
|
import onnx.helper
|
|
|
|
|
|
|
|
|
|
|
|
nodes = []
|
|
|
|
schemas = get_all_schemas()
|
|
|
|
|
|
|
|
|
|
|
|
def load_node(input_str):
|
|
|
|
"""
|
|
|
|
Return a node
|
|
|
|
:param input_str:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
node_proto = NodeProto()
|
|
|
|
text_format.Parse(input_str,node_proto)
|
|
|
|
return node_proto
|
|
|
|
|
|
|
|
# default values for each type for serialization
|
|
|
|
|
|
|
|
|
|
|
|
def convert_attr_type_to_enum(attr_value):
|
|
|
|
"""
|
|
|
|
Pass in an attribute from OpDescriptor and
|
|
|
|
get back out the equivalent enum value
|
|
|
|
for conversion to an attribute proto.
|
|
|
|
:param attr_value: the attribute value
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
if str(attr_value.type) == 'AttrType.INTS':
|
|
|
|
return 7
|
|
|
|
elif str(attr_value.type) == 'AttrType.UNDEFINED':
|
|
|
|
return 0
|
|
|
|
elif str(attr_value.type) == 'AttrType.FLOATS':
|
|
|
|
return 6
|
|
|
|
elif str(attr_value.type) == 'AttrType.GRAPH':
|
|
|
|
return 5
|
|
|
|
elif str(attr_value.type) == 'AttrType.GRAPHS':
|
|
|
|
return 10
|
|
|
|
elif str(attr_value.type) == 'AttrType.INT':
|
|
|
|
return 2
|
|
|
|
elif str(attr_value.type) == 'AttrType.STRING':
|
|
|
|
return 3
|
|
|
|
elif str(attr_value.type) == 'AttrType.TENSOR':
|
|
|
|
return 4
|
|
|
|
elif str(attr_value.type) == 'AttrType.TENSORS':
|
|
|
|
return 9
|
|
|
|
elif str(attr_value.type) == 'AttrType.SPARSE_TENSOR':
|
|
|
|
return 11
|
|
|
|
elif str(attr_value.type) == 'AttrType.SPARSE_TENSORS':
|
|
|
|
return 12
|
|
|
|
elif str(attr_value.type) == 'AttrType.FLOAT':
|
|
|
|
return 1
|
|
|
|
elif str(attr_value.type) == 'AttrType.STRINGS':
|
|
|
|
return 8
|
|
|
|
else:
|
|
|
|
raise Exception('Invalid type passed in')
|
|
|
|
|
|
|
|
def create_node_from_schema(schema):
|
|
|
|
|
|
|
|
"""
|
|
|
|
Convert an OpSchema to a NodeProto
|
|
|
|
:param schema: the input OpSchema
|
|
|
|
:return: the equivalent NodeProto
|
|
|
|
"""
|
|
|
|
|
|
|
|
node_proto = NodeProto()
|
|
|
|
for attribute in schema.attributes:
|
|
|
|
attr_value = schema.attributes[attribute]
|
|
|
|
if attr_value.default_value.name == '':
|
|
|
|
attr_value_new = onnx.helper.make_attribute(attr_value.name,'')
|
|
|
|
attr_value_new.type = convert_attr_type_to_enum(attr_value)
|
|
|
|
node_proto.attribute.append(attr_value_new)
|
|
|
|
else:
|
|
|
|
node_proto.attribute.append(attr_value.default_value)
|
|
|
|
node_proto.op_type = schema.name
|
|
|
|
node_proto.doc_string = schema.doc
|
|
|
|
node_proto.name = schema.name
|
|
|
|
for input_arr in schema.inputs:
|
|
|
|
input_types = input_arr.types
|
|
|
|
type_attr = onnx.helper.make_attribute(input_arr.name + '-types', [str(data_type).replace('tensor(', '').replace(')', '') for data_type in input_types])
|
|
|
|
node_proto.attribute.append(type_attr)
|
|
|
|
|
|
|
|
if node_proto.input is None:
|
|
|
|
node_proto.input = []
|
|
|
|
node_proto.input.append(input_arr.name)
|
|
|
|
for output_arr in schema.outputs:
|
|
|
|
if node_proto.output is None:
|
|
|
|
node_proto.output = []
|
|
|
|
output_types = output_arr.types
|
|
|
|
type_attr = onnx.helper.make_attribute(output_arr.name + '-types',
|
|
|
|
[str(data_type).replace('tensor(', '').replace(')', '') for data_type
|
|
|
|
in output_types])
|
|
|
|
node_proto.attribute.append(type_attr)
|
|
|
|
node_proto.output.append(output_arr.name)
|
|
|
|
return node_proto
|
|
|
|
|
|
|
|
|
|
|
|
nodes = [create_node_from_schema(schema) for schema
|
|
|
|
in sorted(schemas, key=lambda s: s.name)]
|
|
|
|
|
|
|
|
with open('onnx-op-defs.pb', 'wb') as f:
|
|
|
|
graph_proto = GraphProto()
|
|
|
|
graph_proto.node.extend(nodes)
|
|
|
|
f.write(graph_proto.SerializeToString())
|
|
|
|
# for node in nodes:
|
|
|
|
# message_to_string = text_format.MessageToString(node, as_utf8=True)
|
|
|
|
# node_2 = load_node(message_to_string)
|
|
|
|
# f.write(message_to_string + '----f\n')
|
|
|
|
|
|
|
|
# with open('onnx.pbtxt','r') as f:
|
|
|
|
# nodes = [load_node(node_str) for node_str in f.read().split('----f\n')]
|
|
|
|
# print(nodes)
|