辨析onnx的相关基础概念,并通过一个更改onnx模型的应用示例来加深理解。
onnx入门
Tensor + Node = Graph
Graph + Metadata = Model
由上面两个公式可以看出,想要创造一个onnx模型,需要两个部分:Tensor和Node。Tensor是数据的载体,Node是对数据的操作。这两个部分组合在一起就构成了一个Graph。而Graph再加上一些元数据就构成了一个Model。
这个过程涉及到了四个make函数:
make_tensor_value_info: 给定一个Tensor的名字、数据类型和形状,创建一个Tensor。make_node: 给定一个Node的名字、操作符的名字、输入和输出,以及可能存在的属性(Attribute),创建一个Node。make_graph: 给定一个Graph的名字,使用上面两个函数构造的对象,创建一个Graph。make_model: 给定一个Model的名字、Graph和元数据,创建一个Model。
以Node为单元辨析关键概念
Input: 输入的Tensor。Output: 输出的Tensor。Initializer: 初始化的Tensor, 一般用于存储模型的参数。Attribute: Node的属性,用于存储一些额外的信息,例如卷积算子:kernel_shape、dilations、pads、strides等。
当输入永远不会改变时,可以将其设置为Initializer。
这些信息都存储在Graph中,这里贴出onnx的Graph的定义(在文件onnx_ml_pb2.pyi可以找到,这个文件需要满足onnx<=1.14.0版本):
class GraphProto(Message):
name = ... # type: str
doc_string = ... # type: str
@property
def node(self) -> RepeatedCompositeFieldContainer[NodeProto]: ...
@property
def initializer(self) -> RepeatedCompositeFieldContainer[TensorProto]: ...
@property
def sparse_initializer(self) -> RepeatedCompositeFieldContainer[SparseTensorProto]: ...
@property
def input(self) -> RepeatedCompositeFieldContainer[ValueInfoProto]: ...
@property
def output(self) -> RepeatedCompositeFieldContainer[ValueInfoProto]: ...
@property
def value_info(self) -> RepeatedCompositeFieldContainer[ValueInfoProto]: ...
@property
def quantization_annotation(self) -> RepeatedCompositeFieldContainer[TensorAnnotation]: ...
def __init__(self,
node : OptionalType[Iterable[NodeProto]] = None,
name : OptionalType[str] = None,
initializer : OptionalType[Iterable[TensorProto]] = None,
sparse_initializer : OptionalType[Iterable[SparseTensorProto]] = None,
doc_string : OptionalType[str] = None,
input : OptionalType[Iterable[ValueInfoProto]] = None,
output : OptionalType[Iterable[ValueInfoProto]] = None,
value_info : OptionalType[Iterable[ValueInfoProto]] = None,
quantization_annotation : OptionalType[Iterable[TensorAnnotation]] = None,
) -> None: ...
@classmethod
def FromString(cls, s: bytes) -> GraphProto: ...
def MergeFrom(self, other_msg: Message) -> None: ...
def CopyFrom(self, other_msg: Message) -> None: ...
ps: sparse_initializer 和 quantization_annotation 不在本文讨论范围内。
可以看到,所有的成员变量都是RepeatedCompositeFieldContainer,要想对它们进行操作,那就来看看RepeatedCompositeFieldContainer的部分定义(具体的实现不重要,这里略去):
class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):
"""Simple, list-like container for holding repeated composite fields."""
# Disallows assignment to other attributes.
__slots__ = ['_message_descriptor']
def __init__(self, message_listener: Any, message_descriptor: Any) -> None:
"""
Note that we pass in a descriptor instead of the generated directly,
since at the time we construct a _RepeatedCompositeFieldContainer we
haven't yet necessarily initialized the type that will be contained in the
container.
Args:
message_listener: A MessageListener implementation.
The RepeatedCompositeFieldContainer will call this object's
Modified() method when it is modified.
message_descriptor: A Descriptor instance describing the protocol type
that should be present in this container. We'll use the
_concrete_class field of this descriptor when the client calls add().
"""
def add(self, **kwargs: Any) -> _T:
"""Adds a new element at the end of the list and returns it. Keyword
arguments may be used to initialize the element.
"""
def append(self, value: _T) -> None:
"""Appends one element by copying the message."""
def insert(self, key: int, value: _T) -> None:
"""Inserts the item at the specified position by copying."""
def extend(self, elem_seq: Iterable[_T]) -> None:
"""Extends by appending the given sequence of elements of the same type
as this one, copying each individual message.
"""
def MergeFrom(
self,
other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]],
) -> None:
"""Appends the contents of another repeated field of the same type to this
one, copying each individual message.
"""
def remove(self, elem: _T) -> None:
"""Removes an item from the list. Similar to list.remove()."""
def pop(self, key: Optional[int] = -1) -> _T:
"""Removes and returns an item at a given index. Similar to list.pop()."""
可以看到,RepeatedCompositeFieldContainer 是一个类似于list的容器,可以通过add、append、insert、extend、MergeFrom等方法对其进行操作。
更改onnx模型应用示例
有了上面的基础,我们可以通过一个更改onnx模型的应用示例来加深理解。
应用描述:
目前有一个模型,其中有两个卷积层,输入128通道,输出128通道,卷积核3*3,因为参数量太大,希望使用 DepthwiseConv + PointwiseConv 来替代。
两个卷积层的输入NCHW=(1, 128, 144, 240)
import onnx
import numpy as np
# 加载 ONNX 模型
model_path = 'model.onnx'
model = onnx.load(model_path)
new_group = 128
choose_node_name = ['node1', 'node2']
idxs = []
nodes = []
# 遍历模型图
for idx, node in enumerate(model.graph.node):
if node.op_type == 'Conv':
if node.name in choose_node_name:
# 创建一个中间张量
intermidate_output = node.name + '/inter'
inter_tensor = onnx.helper.make_tensor_value_info(intermidate_output, onnx.TensorProto.FLOAT, [1, 128, 144, 240])
model.graph.value_info.append(inter_tensor)
# 创建分组卷积节点
group_conv_node = onnx.helper.make_node(
'Conv',
inputs=node.input,
outputs=[intermidate_output],
name=node.name,
group=new_group,
kernel_shape=node.attribute[2].ints,
strides=node.attribute[4].ints,
pads=node.attribute[3].ints,
dilations=node.attribute[0].ints
)
# attribute 的 index 对应的具体属性可以通过打印 node.attribute 来查看
# 为 pixel-wise 节点增加两个initializer
pw_input_1_name = node.input[1] + '_pw'
pw_input_2_name = node.input[2] + '_pw'
pw_input_1 = onnx.numpy_helper.from_array(np.random.randn(128, 128, 1, 1).astype(np.float32), pw_input_1_name)
pw_input_2 = onnx.numpy_helper.from_array(np.random.randn(128).astype(np.float32), pw_input_2_name)
model.graph.initializer.append(pw_input_1)
model.graph.initializer.append(pw_input_2)
# 再此节点后面再增加一个pixel-wise的卷积节点
pw_conv_node = onnx.helper.make_node(
'Conv',
inputs=[intermidate_output, pw_input_1_name, pw_input_2_name],
outputs=node.output,
name=node.name + '_pw',
group=1,
kernel_shape=[1, 1],
strides=[1, 1],
pads=[0, 0, 0, 0],
dilations=[1, 1]
)
# 记录需要删除的节点index和新增的节点
idxs.append(idx)
nodes.append([pw_conv_node, group_conv_node])
# 更新原Node的输入的initializer的形状
# 解释:
# 1. 原先的Node是Conv,其weight的shape是[128, 128, 3, 3]
# 2. 分组卷积的weight的shape是[128, 1, 3, 3]
# 后面有提供一个创建仅有一个分组卷积的模型的代码,可以通过netron来查看此模型的weight的shape来进行验证
for initializer in model.graph.initializer:
if initializer.name == node.input[1]: # 对应的是weight
# 更新输入张量的形状
shape = list(initializer.dims)
shape[1] //= new_group # 更新通道数等相应维度
initializer.dims[:] = shape
new_data = np.random.randn(*shape).astype(np.float32)
initializer.CopyFrom(onnx.numpy_helper.from_array(new_data, initializer.name))
# +i 是因为,每次删除一个节点后,增加了两个节点,后续的节点的index都要增加1
for i, (idx, node) in enumerate(zip(idxs, nodes)):
model.graph.node.remove(model.graph.node[idx+i])
model.graph.node.insert(idx+i, node[0])
model.graph.node.insert(idx+i+1, node[1])
# 保存更新后的模型
updated_model_path = 'updated_model.onnx'
onnx.checker.check_model(model)
onnx.save_model(model, updated_model_path)
# 创建一个仅有分组卷积的模型
import onnx
from torch import nn
import torch
class ConvModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(128, 128, 3, 1, 1, groups=128)
def forward(self, x):
return self.conv(x)
model = ConvModel()
model.eval()
# 导出模型
dummy_input = torch.randn(1, 128, 144, 240)
torch.onnx.export(model, dummy_input, "group_conv.onnx", verbose=True)