onnx入门和更改onnx模型应用示例

 

辨析onnx的相关基础概念,并通过一个更改onnx模型的应用示例来加深理解。

onnx入门

Tensor + Node = Graph

Graph + Metadata = Model

由上面两个公式可以看出,想要创造一个onnx模型,需要两个部分:Tensor和Node。Tensor是数据的载体,Node是对数据的操作。这两个部分组合在一起就构成了一个Graph。而Graph再加上一些元数据就构成了一个Model。

这个过程涉及到了四个make函数

  1. make_tensor_value_info: 给定一个Tensor的名字、数据类型和形状,创建一个Tensor。
  2. make_node: 给定一个Node的名字、操作符的名字、输入和输出,以及可能存在的属性(Attribute),创建一个Node。
  3. make_graph: 给定一个Graph的名字,使用上面两个函数构造的对象,创建一个Graph。
  4. make_model: 给定一个Model的名字、Graph和元数据,创建一个Model。

以Node为单元辨析关键概念

  1. Input: 输入的Tensor。
  2. Output: 输出的Tensor。
  3. Initializer: 初始化的Tensor, 一般用于存储模型的参数。
  4. Attribute: Node的属性,用于存储一些额外的信息,例如卷积算子:kernel_shapedilationspadsstrides等。

当输入永远不会改变时,可以将其设置为Initializer

这些信息都存储在Graph中,这里贴出onnx的Graph的定义(在文件onnx_ml_pb2.pyi可以找到,这个文件需要满足onnx<=1.14.0版本):

class GraphProto(Message):
    name = ... # type: str
    doc_string = ... # type: str
    
    @property
    def node(self) -> RepeatedCompositeFieldContainer[NodeProto]: ...
    
    @property
    def initializer(self) -> RepeatedCompositeFieldContainer[TensorProto]: ...
    
    @property
    def sparse_initializer(self) -> RepeatedCompositeFieldContainer[SparseTensorProto]: ...
    
    @property
    def input(self) -> RepeatedCompositeFieldContainer[ValueInfoProto]: ...
    
    @property
    def output(self) -> RepeatedCompositeFieldContainer[ValueInfoProto]: ...
    
    @property
    def value_info(self) -> RepeatedCompositeFieldContainer[ValueInfoProto]: ...
    
    @property
    def quantization_annotation(self) -> RepeatedCompositeFieldContainer[TensorAnnotation]: ...
    
    def __init__(self,
        node : OptionalType[Iterable[NodeProto]] = None,
        name : OptionalType[str] = None,
        initializer : OptionalType[Iterable[TensorProto]] = None,
        sparse_initializer : OptionalType[Iterable[SparseTensorProto]] = None,
        doc_string : OptionalType[str] = None,
        input : OptionalType[Iterable[ValueInfoProto]] = None,
        output : OptionalType[Iterable[ValueInfoProto]] = None,
        value_info : OptionalType[Iterable[ValueInfoProto]] = None,
        quantization_annotation : OptionalType[Iterable[TensorAnnotation]] = None,
        ) -> None: ...
    @classmethod
    def FromString(cls, s: bytes) -> GraphProto: ...
    def MergeFrom(self, other_msg: Message) -> None: ...
    def CopyFrom(self, other_msg: Message) -> None: ...

ps: sparse_initializerquantization_annotation 不在本文讨论范围内。

可以看到,所有的成员变量都是RepeatedCompositeFieldContainer,要想对它们进行操作,那就来看看RepeatedCompositeFieldContainer部分定义(具体的实现不重要,这里略去):

class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):
  """Simple, list-like container for holding repeated composite fields."""

  # Disallows assignment to other attributes.
  __slots__ = ['_message_descriptor']

  def __init__(self, message_listener: Any, message_descriptor: Any) -> None:
    """
    Note that we pass in a descriptor instead of the generated directly,
    since at the time we construct a _RepeatedCompositeFieldContainer we
    haven't yet necessarily initialized the type that will be contained in the
    container.

    Args:
      message_listener: A MessageListener implementation.
        The RepeatedCompositeFieldContainer will call this object's
        Modified() method when it is modified.
      message_descriptor: A Descriptor instance describing the protocol type
        that should be present in this container.  We'll use the
        _concrete_class field of this descriptor when the client calls add().
    """

  def add(self, **kwargs: Any) -> _T:
    """Adds a new element at the end of the list and returns it. Keyword
    arguments may be used to initialize the element.
    """

  def append(self, value: _T) -> None:
    """Appends one element by copying the message."""

  def insert(self, key: int, value: _T) -> None:
    """Inserts the item at the specified position by copying."""

  def extend(self, elem_seq: Iterable[_T]) -> None:
    """Extends by appending the given sequence of elements of the same type

    as this one, copying each individual message.
    """

  def MergeFrom(
      self,
      other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]],
  ) -> None:
    """Appends the contents of another repeated field of the same type to this
    one, copying each individual message.
    """

  def remove(self, elem: _T) -> None:
    """Removes an item from the list. Similar to list.remove()."""

  def pop(self, key: Optional[int] = -1) -> _T:
    """Removes and returns an item at a given index. Similar to list.pop()."""

可以看到,RepeatedCompositeFieldContainer 是一个类似于list的容器,可以通过addappendinsertextendMergeFrom等方法对其进行操作。

更改onnx模型应用示例

有了上面的基础,我们可以通过一个更改onnx模型的应用示例来加深理解。

应用描述:

目前有一个模型,其中有两个卷积层,输入128通道,输出128通道,卷积核3*3,因为参数量太大,希望使用 DepthwiseConv + PointwiseConv 来替代。

两个卷积层的输入NCHW=(1, 128, 144, 240)

import onnx
import numpy as np

# 加载 ONNX 模型
model_path = 'model.onnx'
model = onnx.load(model_path)

new_group = 128

choose_node_name = ['node1', 'node2']

idxs = []
nodes = []

# 遍历模型图
for idx, node in enumerate(model.graph.node):
    if node.op_type == 'Conv':
        if node.name in choose_node_name:
            # 创建一个中间张量
            intermidate_output = node.name + '/inter'
            inter_tensor = onnx.helper.make_tensor_value_info(intermidate_output, onnx.TensorProto.FLOAT, [1, 128, 144, 240])
            model.graph.value_info.append(inter_tensor)

            # 创建分组卷积节点
            group_conv_node = onnx.helper.make_node(
                'Conv',
                inputs=node.input,
                outputs=[intermidate_output],
                name=node.name,
                group=new_group,
                kernel_shape=node.attribute[2].ints,
                strides=node.attribute[4].ints,
                pads=node.attribute[3].ints,
                dilations=node.attribute[0].ints
            )
            # attribute 的 index 对应的具体属性可以通过打印 node.attribute 来查看

            # 为 pixel-wise 节点增加两个initializer
            pw_input_1_name = node.input[1] + '_pw'
            pw_input_2_name = node.input[2] + '_pw'
            pw_input_1 = onnx.numpy_helper.from_array(np.random.randn(128, 128, 1, 1).astype(np.float32), pw_input_1_name)
            pw_input_2 = onnx.numpy_helper.from_array(np.random.randn(128).astype(np.float32), pw_input_2_name)
            model.graph.initializer.append(pw_input_1)
            model.graph.initializer.append(pw_input_2)

            # 再此节点后面再增加一个pixel-wise的卷积节点
            pw_conv_node = onnx.helper.make_node(
                'Conv',
                inputs=[intermidate_output, pw_input_1_name, pw_input_2_name],
                outputs=node.output,
                name=node.name + '_pw',
                group=1,
                kernel_shape=[1, 1],
                strides=[1, 1],
                pads=[0, 0, 0, 0],
                dilations=[1, 1]
            )

            # 记录需要删除的节点index和新增的节点
            idxs.append(idx)
            nodes.append([pw_conv_node, group_conv_node])


            # 更新原Node的输入的initializer的形状
            # 解释:
            # 1. 原先的Node是Conv,其weight的shape是[128, 128, 3, 3]
            # 2. 分组卷积的weight的shape是[128, 1, 3, 3]
            # 后面有提供一个创建仅有一个分组卷积的模型的代码,可以通过netron来查看此模型的weight的shape来进行验证
            for initializer in model.graph.initializer:
                if initializer.name == node.input[1]: # 对应的是weight
                    # 更新输入张量的形状
                    shape = list(initializer.dims)
                    shape[1] //= new_group  # 更新通道数等相应维度
                    initializer.dims[:] = shape

                    new_data = np.random.randn(*shape).astype(np.float32)

                    initializer.CopyFrom(onnx.numpy_helper.from_array(new_data, initializer.name))

# +i 是因为,每次删除一个节点后,增加了两个节点,后续的节点的index都要增加1
for i, (idx, node) in enumerate(zip(idxs, nodes)):
    model.graph.node.remove(model.graph.node[idx+i])
    model.graph.node.insert(idx+i, node[0])
    model.graph.node.insert(idx+i+1, node[1])

# 保存更新后的模型
updated_model_path = 'updated_model.onnx'
onnx.checker.check_model(model)
onnx.save_model(model, updated_model_path)


# 创建一个仅有分组卷积的模型

import onnx
from torch import nn
import torch

class ConvModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(128, 128, 3, 1, 1, groups=128)

    def forward(self, x):
        return self.conv(x)
    
model = ConvModel()
model.eval()

# 导出模型
dummy_input = torch.randn(1, 128, 144, 240)
torch.onnx.export(model, dummy_input, "group_conv.onnx", verbose=True)

参考资料

  1. onnx 官方文档