与 NIR 相互转换#
本教程作者: 黄一凡 (AllenYolk)
English version: Export to and Import from NIR
Neuromorphic intermediate representation (NIR) 是一组计算原语,以图(节点+边)的形式描述了 SNN 的模块和连接,在不同的神经形态框架和技术栈之间通用。目前,NIR 被多个模拟器和硬件平台支持 。SpikingJelly 0.0.0.1.0 引入了 nir_exchange 包,使(满足一定条件的) SpikingJelly 模型能够与 NIR 图互相转换。借助 NIR 这一中间表示,用户可以更加轻松地进行硬件部署和框架迁移。
图片来源: What is the Neuromorphic Intermediate Representation (NIR)?#
SpikingJelly 的 nir_exchange 包提供了两个关键的用户接口:
export_to_nir:将 SpikingJelly 模型导出为 NIR 图;import_from_nir:将 NIR 图导入为 SpikingJelly 模型。
本教程将对这两个函数展开介绍。
备注
nir_exchange 功能依赖于 nir 和 nir_exchange 包。可以用 pip 安装:
pip install nir nir_exchange
从 SpikingJelly 到 NIR#
由于开发者精力有限且 NIR 本身只能表示少数几种模块,故目前 export_to_nir 只支持以下 SpikingJelly / PyTorch 模块的转换:
torch.nn.Linear,layer.Lineartorch.nn.Conv2d,layer.Conv2dtorch.nn.AvgPool2d,layer.AvgPool2dtorch.nn.Flatten,layer.FlattenIFNodeLIFNodeandParametricLIFNode
以下面的 SNN 模型为例:
import torch.nn as nn
from spikingjelly.activation_based import layer, neuron
net = nn.Sequential(
layer.Conv2d(3, 16, 3, 1, 1, step_mode="s"),
neuron.IFNode(),
nn.AvgPool2d((2, 2)),
layer.Flatten(step_mode="s"),
nn.Linear(4096, 10),
neuron.ParametricLIFNode(10., decay_input=False, v_reset=None),
)
为了展示兼容性,这一示例故意混用了原生 PyTorch 的无状态层 nn.AvgPool2d, nn.Linear 和 SpikingJelly 包装后的无状态层 layer.Conv2d, layer.Flatten。此外,本例中还使用了 neuron.IFNode 和 neuron.ParametricLIFNode 两种神经元模型。
调用 export_to_nir ,即可将上述模型转换成 NIR 图并保存为 HDF5 文件:
import torch
from spikingjelly.activation_based import nir_exchange
graph = nir_exchange.export_to_nir(
net,
example_input=torch.rand(8, 3, 32, 32),
save_path="./example.h5",
dt=1e-4
)
print(graph)
export_to_nir 参数的含义为:
net:SpikingJelly 模型;example_input:模型输入的样例,用于确定子模块输入和输出的形状;save_path:HDF5 文件路径,用于保存 NIR 图(若为None,则不保存);dt:NIR 模拟时间步长。建议设置成1e-4以对齐其它支持 NIR 的框架。
运行后,当前目录下出现文件 example.h5 ,内含 NIR 图。终端打印出的结果大致为:
NIRGraph(
nodes={
'input_1': Input(input_type={'input': array([ 3, 32, 32])}, metadata={}),
'_0': Conv2d(input_shape=(32, 32), weight=array(...), stride=(1, 1), padding=(1, 1), dilation=(1, 1), groups=1, bias=array(...), metadata={}),
'_1': IF(r=array(...), v_threshold=array(...), v_reset=array(...), input_type={'input': array([16, 32, 32])}, output_type={'output': array([16, 32, 32])}, metadata={}),
'_2': AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, metadata={}),
'_3': Flatten(input_type={'input': array([16, 16, 16])}, start_dim=0, end_dim=-1, output_type={'output': array([4096])}, metadata={}),
'_4': Affine(weight=array(...), bias=array(...), input_type={'input': array([4096])}, output_type={'output': array([10])}, metadata={}),
'_5': LIF(tau=array(...), r=array(...), v_leak=array(...), v_threshold=array(...), v_reset=array(...), input_type={'input': array([10])}, output_type={'output': array([10])}, metadata={}),
'output': Output(output_type={'output': array([10])}, metadata={})
},
edges=[
('input_1', '_0'), ('_0', '_1'), ('_1', '_2'), ('_2', '_3'),
('_3', '_4'), ('_4', '_5'), ('_5', 'output')
],
input_type={'input_1': array([ 3, 32, 32])},
output_type={'output': array([10])},
metadata={}
)
这里,我们只展示了 NIRGraph 的结构,省略了具体的参数数值。可见,NIR 图由节点 nodes 和边 edges 组成。节点对应 SNN 模块,边指示了节点的输入输出关系。
备注
原模型中的 ParametricLIFNode 被转换成了 nir.LIF 节点。这是合理的,因为一旦膜电位时间常量 tau 固定下来,PLIF 神经元就将变成 LIF 神经元。
备注
不同于 PyTorch 和 SpikingJelly 模型, NIRGraph 中的节点大多蕴含 输入输出形状 信息。例如,上方例子中的 '_3': Flatten(...) 节点指明了输入形状为 [16, 16, 16] ,输出形状为 [4096] ; '_5': LIF(...) 的输入输出形状则都为 [10] 。显然,NIR 图中的形状信息是不包含时间维度 T 和批量维度 B 的;换言之,NIR只 描述单样本、单个时间步上的模型结构 。
PyTorch / SpikingJelly 模型的子模块不含输入输出形状信息,但 NIR 图却需要这些信息。为了获取输入输出形状信息,export_to_nir 要求用户给出 example_input 样例输入。 example_input 可以具有时间或批量维度,具体取决于 PyTorch / SpikingJelly 模型的需求。 export_to_nir 函数内部将调用 PyTorch 的 ShapeProp 功能来获取输入输出形状信息。
从 NIR 到 SpikingJelly#
函数 import_from_nir 可以将已有的 NIR 图转换成 SpikingJelly 模型。以上一节生成的 NIR 图为例:
gm = nir_exchange.import_from_nir(graph="./example.h5", dt=1e-4)
print(gm)
x = torch.rand(9, 3, 32, 32) # [B, C, H, W]
y = gm(x) # forward pass
print("y.shape =", y[0].shape) # y is a tuple; the 2nd element is each layer's state
此处,import_from_nir 参数的含义是:
graph:若给出字符串,则视为保存有 NIR 图的 HDF5 文件路径;函数内部将从该文件中读出NIRGraph。若给出NIRGraph对象,则直接使用该对象。dt:NIR 图的模拟时间步长。与export_to_nir的dt参数一致。
该函数返回一个 torch.fx.GraphModule 对象,可以像 torch.nn.Module 那样直接调用以运行前向传播。前向传播返回一个二元组,第一个元素是模型的输出,第二个元素则是模型内子模块的状态字典(绝大多数情况下无需使用)。上方代码块的终端输出大致为:
GraphModule(
(_0): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), step_mode=s)
(_1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(_2): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, step_mode=s)
(_3): Flatten(start_dim=1, end_dim=-1, step_mode=s)
(_4): Linear(in_features=4096, out_features=10, bias=True)
(_5): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=10.0
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
def forward(self, input, state : typing_Dict[str,typing_Any] = {'_0': None, '_1': None, '_2': None, '_3': None, '_4': None, '_5': None, 'input_1': None, 'output': None}):
ones = torch.ones(1); ones = None
input_1 = input
_0 = self._0(input_1); input_1 = None
_1 = self._1(_0); _0 = None
_2 = self._2(_1); _1 = None
_3 = self._3(_2); _2 = None
_4 = self._4(_3); _3 = None
_5 = self._5(_4); _4 = None
return (_5, state)
# To see more debug info, please use `graph_module.print_readable()`
y.shape = torch.Size([9, 10])
可见,NIR 图被正确地转换成了 SpikingJelly 模型。模型的无状态层均来自 spikingjelly.activation_based.layer ,可以配置步进模式(见 step_mode 属性)。
目前, import_from_nir 仅支持以下 NIR 节点类型:
nir.Linear,nir.Affinenir.Conv2dnir.AvgPool2dnir.Flattennir.IFnir.LIF
备注
import_from_nir 还提供了 dtype , device 和 step_mode 参数,用于控制所返回的 SpikingJelly 模型的数据类型、设备、步进模式。例如,可以通过以下方式得到多步模式的 SpikingJelly 模型:
gm = nir_exchange.import_from_nir(
"./example.h5", dt=1e-4, step_mode="m"
)
print(gm)
x = torch.rand(7, 9, 3, 32, 32) # [T, B, C, H, W]
y = gm(x)
print("y.shape =", y[0].shape)