与 NIR 相互转换#

本教程作者: 黄一凡 (AllenYolk)

English version: Export to and Import from NIR

Neuromorphic intermediate representation (NIR) 是一组计算原语,以图(节点+边)的形式描述了 SNN 的模块和连接,在不同的神经形态框架和技术栈之间通用。目前,NIR 被多个模拟器和硬件平台支持 。SpikingJelly 0.0.0.1.0 引入了 nir_exchange 包,使(满足一定条件的) SpikingJelly 模型能够与 NIR 图互相转换。借助 NIR 这一中间表示,用户可以更加轻松地进行硬件部署和框架迁移。

../../_images/nir-schema.png

图片来源: What is the Neuromorphic Intermediate Representation (NIR)?#

SpikingJelly 的 nir_exchange 包提供了两个关键的用户接口:

本教程将对这两个函数展开介绍。

备注

nir_exchange 功能依赖于 nirnir_exchange 包。可以用 pip 安装:

pip install nir nir_exchange

从 SpikingJelly 到 NIR#

由于开发者精力有限且 NIR 本身只能表示少数几种模块,故目前 export_to_nir 只支持以下 SpikingJelly / PyTorch 模块的转换:

  • torch.nn.Linear, layer.Linear

  • torch.nn.Conv2d, layer.Conv2d

  • torch.nn.AvgPool2d, layer.AvgPool2d

  • torch.nn.Flatten, layer.Flatten

  • IFNode

  • LIFNode and ParametricLIFNode

以下面的 SNN 模型为例:

import torch.nn as nn
from spikingjelly.activation_based import layer, neuron

net = nn.Sequential(
    layer.Conv2d(3, 16, 3, 1, 1, step_mode="s"),
    neuron.IFNode(),
    nn.AvgPool2d((2, 2)),
    layer.Flatten(step_mode="s"),
    nn.Linear(4096, 10),
    neuron.ParametricLIFNode(10., decay_input=False, v_reset=None),
)

为了展示兼容性,这一示例故意混用了原生 PyTorch 的无状态层 nn.AvgPool2d, nn.Linear 和 SpikingJelly 包装后的无状态层 layer.Conv2d, layer.Flatten。此外,本例中还使用了 neuron.IFNodeneuron.ParametricLIFNode 两种神经元模型。

调用 export_to_nir ,即可将上述模型转换成 NIR 图并保存为 HDF5 文件:

import torch
from spikingjelly.activation_based import nir_exchange

graph = nir_exchange.export_to_nir(
    net,
    example_input=torch.rand(8, 3, 32, 32),
    save_path="./example.h5",
    dt=1e-4
)
print(graph)

export_to_nir 参数的含义为:

  • net :SpikingJelly 模型;

  • example_input :模型输入的样例,用于确定子模块输入和输出的形状;

  • save_path :HDF5 文件路径,用于保存 NIR 图(若为 None ,则不保存);

  • dt :NIR 模拟时间步长。建议设置成 1e-4 以对齐其它支持 NIR 的框架。

运行后,当前目录下出现文件 example.h5 ,内含 NIR 图。终端打印出的结果大致为:

NIRGraph(
    nodes={
        'input_1': Input(input_type={'input': array([ 3, 32, 32])}, metadata={}),

        '_0': Conv2d(input_shape=(32, 32), weight=array(...), stride=(1, 1), padding=(1, 1), dilation=(1, 1), groups=1, bias=array(...), metadata={}),

        '_1': IF(r=array(...), v_threshold=array(...), v_reset=array(...), input_type={'input': array([16, 32, 32])}, output_type={'output': array([16, 32, 32])}, metadata={}),

        '_2': AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, metadata={}),

        '_3': Flatten(input_type={'input': array([16, 16, 16])}, start_dim=0, end_dim=-1, output_type={'output': array([4096])}, metadata={}),

        '_4': Affine(weight=array(...), bias=array(...), input_type={'input': array([4096])}, output_type={'output': array([10])}, metadata={}),

        '_5': LIF(tau=array(...), r=array(...), v_leak=array(...), v_threshold=array(...), v_reset=array(...), input_type={'input': array([10])}, output_type={'output': array([10])}, metadata={}),

        'output': Output(output_type={'output': array([10])}, metadata={})
    },

    edges=[
        ('input_1', '_0'), ('_0', '_1'), ('_1', '_2'), ('_2', '_3'),
        ('_3', '_4'), ('_4', '_5'), ('_5', 'output')
    ],

    input_type={'input_1': array([ 3, 32, 32])},
    output_type={'output': array([10])},
    metadata={}
)

这里,我们只展示了 NIRGraph 的结构,省略了具体的参数数值。可见,NIR 图由节点 nodes 和边 edges 组成。节点对应 SNN 模块,边指示了节点的输入输出关系。

备注

原模型中的 ParametricLIFNode 被转换成了 nir.LIF 节点。这是合理的,因为一旦膜电位时间常量 tau 固定下来,PLIF 神经元就将变成 LIF 神经元。

备注

不同于 PyTorch 和 SpikingJelly 模型, NIRGraph 中的节点大多蕴含 输入输出形状 信息。例如,上方例子中的 '_3': Flatten(...) 节点指明了输入形状为 [16, 16, 16] ,输出形状为 [4096]'_5': LIF(...) 的输入输出形状则都为 [10] 。显然,NIR 图中的形状信息是不包含时间维度 T 和批量维度 B 的;换言之,NIR只 描述单样本、单个时间步上的模型结构

PyTorch / SpikingJelly 模型的子模块不含输入输出形状信息,但 NIR 图却需要这些信息。为了获取输入输出形状信息,export_to_nir 要求用户给出 example_input 样例输入。 example_input 可以具有时间或批量维度,具体取决于 PyTorch / SpikingJelly 模型的需求。 export_to_nir 函数内部将调用 PyTorch 的 ShapeProp 功能来获取输入输出形状信息。

从 NIR 到 SpikingJelly#

函数 import_from_nir 可以将已有的 NIR 图转换成 SpikingJelly 模型。以上一节生成的 NIR 图为例:

gm = nir_exchange.import_from_nir(graph="./example.h5", dt=1e-4)
print(gm)
x = torch.rand(9, 3, 32, 32) # [B, C, H, W]
y = gm(x) # forward pass
print("y.shape =", y[0].shape) # y is a tuple; the 2nd element is each layer's state

此处,import_from_nir 参数的含义是:

  • graph :若给出字符串,则视为保存有 NIR 图的 HDF5 文件路径;函数内部将从该文件中读出 NIRGraph 。若给出 NIRGraph 对象,则直接使用该对象。

  • dt :NIR 图的模拟时间步长。与 export_to_nirdt 参数一致。

该函数返回一个 torch.fx.GraphModule 对象,可以像 torch.nn.Module 那样直接调用以运行前向传播。前向传播返回一个二元组,第一个元素是模型的输出,第二个元素则是模型内子模块的状态字典(绝大多数情况下无需使用)。上方代码块的终端输出大致为:

GraphModule(
  (_0): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), step_mode=s)
  (_1): IFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch
    (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
  )
  (_2): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, step_mode=s)
  (_3): Flatten(start_dim=1, end_dim=-1, step_mode=s)
  (_4): Linear(in_features=4096, out_features=10, bias=True)
  (_5): LIFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=10.0
    (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
  )
)

def forward(self, input, state : typing_Dict[str,typing_Any] = {'_0': None, '_1': None, '_2': None, '_3': None, '_4': None, '_5': None, 'input_1': None, 'output': None}):
    ones = torch.ones(1);  ones = None
    input_1 = input
    _0 = self._0(input_1);  input_1 = None
    _1 = self._1(_0);  _0 = None
    _2 = self._2(_1);  _1 = None
    _3 = self._3(_2);  _2 = None
    _4 = self._4(_3);  _3 = None
    _5 = self._5(_4);  _4 = None
    return (_5, state)

# To see more debug info, please use `graph_module.print_readable()`

y.shape = torch.Size([9, 10])

可见,NIR 图被正确地转换成了 SpikingJelly 模型。模型的无状态层均来自 spikingjelly.activation_based.layer ,可以配置步进模式(见 step_mode 属性)。

目前, import_from_nir 仅支持以下 NIR 节点类型:

  • nir.Linear, nir.Affine

  • nir.Conv2d

  • nir.AvgPool2d

  • nir.Flatten

  • nir.IF

  • nir.LIF

备注

import_from_nir 还提供了 dtypedevicestep_mode 参数,用于控制所返回的 SpikingJelly 模型的数据类型、设备、步进模式。例如,可以通过以下方式得到多步模式的 SpikingJelly 模型:

gm = nir_exchange.import_from_nir(
    "./example.h5", dt=1e-4, step_mode="m"
)
print(gm)
x = torch.rand(7, 9, 3, 32, 32) # [T, B, C, H, W]
y = gm(x)
print("y.shape =", y[0].shape)