cavis/.old/contrib/codegen-tools/onnx-def-gen/onnx_def_gen.py

134 lines
4.7 KiB
Python

# /* ******************************************************************************
# *
# *
# * This program and the accompanying materials are made available under the
# * terms of the Apache License, Version 2.0 which is available at
# * https://www.apache.org/licenses/LICENSE-2.0.
# *
# * See the NOTICE file distributed with this work for additional
# * information regarding copyright ownership.
# * Unless required by applicable law or agreed to in writing, software
# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# * License for the specific language governing permissions and limitations
# * under the License.
# *
# * SPDX-License-Identifier: Apache-2.0
# ******************************************************************************/
from onnx.defs import get_all_schemas
from onnx import NodeProto,GraphProto
from google.protobuf import text_format
import onnx.helper
nodes = []
schemas = get_all_schemas()
def load_node(input_str):
"""
Return a node
:param input_str:
:return:
"""
node_proto = NodeProto()
text_format.Parse(input_str,node_proto)
return node_proto
# default values for each type for serialization
def convert_attr_type_to_enum(attr_value):
"""
Pass in an attribute from OpDescriptor and
get back out the equivalent enum value
for conversion to an attribute proto.
:param attr_value: the attribute value
:return:
"""
if str(attr_value.type) == 'AttrType.INTS':
return 7
elif str(attr_value.type) == 'AttrType.UNDEFINED':
return 0
elif str(attr_value.type) == 'AttrType.FLOATS':
return 6
elif str(attr_value.type) == 'AttrType.GRAPH':
return 5
elif str(attr_value.type) == 'AttrType.GRAPHS':
return 10
elif str(attr_value.type) == 'AttrType.INT':
return 2
elif str(attr_value.type) == 'AttrType.STRING':
return 3
elif str(attr_value.type) == 'AttrType.TENSOR':
return 4
elif str(attr_value.type) == 'AttrType.TENSORS':
return 9
elif str(attr_value.type) == 'AttrType.SPARSE_TENSOR':
return 11
elif str(attr_value.type) == 'AttrType.SPARSE_TENSORS':
return 12
elif str(attr_value.type) == 'AttrType.FLOAT':
return 1
elif str(attr_value.type) == 'AttrType.STRINGS':
return 8
else:
raise Exception('Invalid type passed in')
def create_node_from_schema(schema):
"""
Convert an OpSchema to a NodeProto
:param schema: the input OpSchema
:return: the equivalent NodeProto
"""
node_proto = NodeProto()
for attribute in schema.attributes:
attr_value = schema.attributes[attribute]
if attr_value.default_value.name == '':
attr_value_new = onnx.helper.make_attribute(attr_value.name,'')
attr_value_new.type = convert_attr_type_to_enum(attr_value)
node_proto.attribute.append(attr_value_new)
else:
node_proto.attribute.append(attr_value.default_value)
node_proto.op_type = schema.name
node_proto.doc_string = schema.doc
node_proto.name = schema.name
for input_arr in schema.inputs:
input_types = input_arr.types
type_attr = onnx.helper.make_attribute(input_arr.name + '-types', [str(data_type).replace('tensor(', '').replace(')', '') for data_type in input_types])
node_proto.attribute.append(type_attr)
if node_proto.input is None:
node_proto.input = []
node_proto.input.append(input_arr.name)
for output_arr in schema.outputs:
if node_proto.output is None:
node_proto.output = []
output_types = output_arr.types
type_attr = onnx.helper.make_attribute(output_arr.name + '-types',
[str(data_type).replace('tensor(', '').replace(')', '') for data_type
in output_types])
node_proto.attribute.append(type_attr)
node_proto.output.append(output_arr.name)
return node_proto
nodes = [create_node_from_schema(schema) for schema
in sorted(schemas, key=lambda s: s.name)]
with open('onnx-op-defs.pb', 'wb') as f:
graph_proto = GraphProto()
graph_proto.node.extend(nodes)
f.write(graph_proto.SerializeToString())
# for node in nodes:
# message_to_string = text_format.MessageToString(node, as_utf8=True)
# node_2 = load_node(message_to_string)
# f.write(message_to_string + '----f\n')
# with open('onnx.pbtxt','r') as f:
# nodes = [load_node(node_str) for node_str in f.read().split('----f\n')]
# print(nodes)