Export to and Import from NIR#
Author: Yifan Huang (AllenYolk)
中文版: 与 NIR 相互转换
Neuromorphic intermediate representation (NIR) is a set of computational primitives that describes SNN modules and their connections in the form of graphs (nodes and edges), and is designed to be shared across different neuromorphic frameworks and technology stacks. At present, NIR is supported by multiple simulators and hardware platforms. SpikingJelly 0.0.0.1.0 introduces the nir_exchange package, which enables (under certain conditions) bidirectional conversion between SpikingJelly models and NIR graphs. By leveraging NIR as an intermediate representation, users can easily perform hardware deployment and framework migration.
Image source: What is the Neuromorphic Intermediate Representation (NIR)?#
The nir_exchange package provides two key interfaces:
export_to_nir: export a SpikingJelly model to an NIR graph;import_from_nir: convert an NIR graph to a SpikingJelly model.
This tutorial provides a detailed introduction to these two functions.
备注
nir_exchange package depends on nir and nir_exchange. Install them using pip :
pip install nir nir_exchange
From SpikingJelly to NIR#
Due to limited development resources and the fact that NIR itself can only represent a small number of module types, the current export_to_nir function supports conversion only for the following SpikingJelly / PyTorch modules:
torch.nn.Linear,layer.Lineartorch.nn.Conv2d,layer.Conv2dtorch.nn.AvgPool2d,layer.AvgPool2dtorch.nn.Flatten,layer.FlattenIFNodeLIFNodeandParametricLIFNode
Consider the following SNN model as an example:
import torch.nn as nn
from spikingjelly.activation_based import layer, neuron
net = nn.Sequential(
layer.Conv2d(3, 16, 3, 1, 1, step_mode="s"),
neuron.IFNode(),
nn.AvgPool2d((2, 2)),
layer.Flatten(step_mode="s"),
nn.Linear(4096, 10),
neuron.ParametricLIFNode(10., decay_input=False, v_reset=None),
)
To demonstrate compatibility, this example deliberately mixes native PyTorch stateless layers nn.AvgPool2d, nn.Linear with the SpikingJelly-wrapped stateless layers layer.Conv2d, layer.Flatten. In addition, two neuron models, neuron.IFNode and neuron.ParametricLIFNode, are used in this example.
By calling export_to_nir, the above model can be converted into an NIR graph and saved as an HDF5 file:
import torch
from spikingjelly.activation_based import nir_exchange
graph = nir_exchange.export_to_nir(
net,
example_input=torch.rand(8, 3, 32, 32),
save_path="./example.h5",
dt=1e-4
)
print(graph)
The meanings of the parameters of export_to_nir are as follows:
net: the SpikingJelly model;example_input: an example input to the model, used to determine the input and output shapes of submodules;save_path: the path to the HDF5 file used to save the NIR graph (ifNone, the graph is not saved);dt: the simulation time step used in NIR. It is recommended to set this value to1e-4in order to align with other frameworks that support NIR.
After execution, a file named example.h5 will appear in the current directory, containing the NIR graph. The output printed in the terminal is roughly as follows:
NIRGraph(
nodes={
'input_1': Input(input_type={'input': array([ 3, 32, 32])}, metadata={}),
'_0': Conv2d(input_shape=(32, 32), weight=array(...), stride=(1, 1), padding=(1, 1), dilation=(1, 1), groups=1, bias=array(...), metadata={}),
'_1': IF(r=array(...), v_threshold=array(...), v_reset=array(...), input_type={'input': array([16, 32, 32])}, output_type={'output': array([16, 32, 32])}, metadata={}),
'_2': AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, metadata={}),
'_3': Flatten(input_type={'input': array([16, 16, 16])}, start_dim=0, end_dim=-1, output_type={'output': array([4096])}, metadata={}),
'_4': Affine(weight=array(...), bias=array(...), input_type={'input': array([4096])}, output_type={'output': array([10])}, metadata={}),
'_5': LIF(tau=array(...), r=array(...), v_leak=array(...), v_threshold=array(...), v_reset=array(...), input_type={'input': array([10])}, output_type={'output': array([10])}, metadata={}),
'output': Output(output_type={'output': array([10])}, metadata={})
},
edges=[
('input_1', '_0'), ('_0', '_1'), ('_1', '_2'), ('_2', '_3'),
('_3', '_4'), ('_4', '_5'), ('_5', 'output')
],
input_type={'input_1': array([ 3, 32, 32])},
output_type={'output': array([10])},
metadata={}
)
Here, only the structure of the NIRGraph is shown. As can be seen, an NIR graph consists of nodes (nodes) and edges (edges). Nodes correspond to SNN modules, while edges indicate the input-output relationships between nodes.
备注
The original ParametricLIFNode in the model is converted into a nir.LIF node. This is reasonable because once the membrane time constant tau is fixed, a PLIF neuron degenerates into an LIF neuron.
备注
Unlike PyTorch and SpikingJelly models, most nodes in an NIRGraph explicitly contain input and output shape information. For example, the '_3': Flatten(...) node in the above example specifies an input shape of [16, 16, 16] and an output shape of [4096]; the '_5': LIF(...) node has both input and output shapes equal to [10]. Clearly, the shape information in an NIR graph does not include the time dimension T or the batch dimension B. In other words, NIR describes the model structure for a single sample at a single time step only.
Submodules in PyTorch / SpikingJelly models do not carry input-output shape information, whereas NIR graphs require it. To obtain such shape information, export_to_nir requires the user to provide example_input. example_input may include a time or batch dimension, depending on the requirements of the PyTorch / SpikingJelly model. Internally, export_to_nir invokes PyTorch’s ShapeProp utility to infer input and output shapes.
From NIR to SpikingJelly#
The function import_from_nir converts an existing NIR graph into a SpikingJelly model. Using the NIR graph generated in the previous section as an example
gm = nir_exchange.import_from_nir(graph="./example.h5", dt=1e-4)
print(gm)
x = torch.rand(9, 3, 32, 32) # [B, C, H, W]
y = gm(x) # forward pass
print("y.shape =", y[0].shape) # y is a tuple; the 2nd element is each layer's state
Here, the arguments of import_from_nir means:
graph: If a string is provided, it is interpreted as the path to an HDF5 file that stores an NIR graph; the function will load theNIRGraphfrom this file. If anNIRGraphobject is provided, it will be used directly.dt: The simulation time step of the NIR graph. This parameter should be consistent with thedtargument ofexport_to_nir.
This function returns a torch.fx.GraphModule object, which can be invoked directly like a torch.nn.Module to perform forward propagation. The forward pass returns a tuple: the first element is the model output, and the second element is a dictionary of internal states of the submodules (which is unnecessary in most cases). The terminal output of the above code block is approximately
GraphModule(
(_0): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), step_mode=s)
(_1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(_2): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, step_mode=s)
(_3): Flatten(start_dim=1, end_dim=-1, step_mode=s)
(_4): Linear(in_features=4096, out_features=10, bias=True)
(_5): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=10.0
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
def forward(self, input, state : typing_Dict[str,typing_Any] = {'_0': None, '_1': None, '_2': None, '_3': None, '_4': None, '_5': None, 'input_1': None, 'output': None}):
ones = torch.ones(1); ones = None
input_1 = input
_0 = self._0(input_1); input_1 = None
_1 = self._1(_0); _0 = None
_2 = self._2(_1); _1 = None
_3 = self._3(_2); _2 = None
_4 = self._4(_3); _3 = None
_5 = self._5(_4); _4 = None
return (_5, state)
# To see more debug info, please use `graph_module.print_readable()`
y.shape = torch.Size([9, 10])
As shown above, the NIR graph is correctly converted into a SpikingJelly model. All stateless layers in the model are instantiated from classes in spikingjelly.activation_based.layer and support configurable step modes (see the step_mode attribute).
Currently, import_from_nir supports only the following NIR node types.
nir.Linear,nir.Affinenir.Conv2dnir.AvgPool2dnir.Flattennir.IFnir.LIF
备注
import_from_nir also provides the dtype, device, and step_mode arguments, which control the data type, device, and step mode of the returned SpikingJelly model. For example, a multi-step SpikingJelly model can be obtained as follows.
gm = nir_exchange.import_from_nir(
"./example.h5", dt=1e-4, step_mode="m"
)
print(gm)
x = torch.rand(7, 9, 3, 32, 32) # [T, B, C, H, W]
y = gm(x)
print("y.shape =", y[0].shape)