spikingjelly.activation_based.nir_exchange package#

Quote

Neuromorphic intermediate representation (NIR) 是一组计算原语,在不同的神经形态框架和技术栈之间通用。目前,NIR 被多个模拟器和硬件平台支持,使用户能够在这些平台之间无缝迁移。

Neuromorphic intermediate representation (NIR) is a set of computational primitives, shared across different neuromorphic frameworks and technology stacks. NIR is currently supported by multiple simulators and hardware platforms, allowing users to seamlessly move between any of these platforms.

备注

本页面的所有函数都可通过 spikingjelly.activation_based.nir_exchange 命名空间直接访问。

The functions are available in the spikingjelly.activation_based.nir_exchange namespace.

Supported Modules#

Supported SpikingJelly / PyTorch Modules:

  • torch.nn.Linear, layer.Linear

  • torch.nn.Conv2d, layer.Conv2d

  • torch.nn.AvgPool2d, layer.AvgPool2d

  • torch.nn.Flatten, layer.Flatten

  • IFNode

  • LIFNode and ParametricLIFNode

Supported NIR Nodes:

  • nir.Linear, nir.Affine

  • nir.Conv2d

  • nir.AvgPool2d

  • nir.Flatten

  • nir.IF

  • nir.LIF

备注

我们将在后续更新中逐渐完善对其他模块的支持。

We will add support for more modules in future updates.

SpikingJelly to NIR#

spikingjelly.activation_based.nir_exchange.to_nir.export_to_nir(net: Module, example_input: Tensor, save_path: str | Path | None = None, dt: float = 0.0001)[源代码]#

API Language: 中文 | English


  • 中文

将 SpikingJelly 的模型转换为 NIR(Neuromorphic Intermediate Representation) 图, 以供后续转换到其它框架或部署到神经形态芯片上。本函数会自动通过示例输入 example_input 推导每个模块的输入输出形状,将 SpikingJelly 或 PyTorch 模块转换为对应的 NIR 节点。

参数:
  • net (torch.nn.Module) -- 需要转换的 SpikingJelly / PyTorch 模型

  • example_input (torch.Tensor) -- 用于推导 net 中各个子模块输入输出形状的示例输入张量

  • save_path (Optional[Union[str, Path]]) -- 转换后的 NIR 图保存路径。如果不为 None,函数会将 NIR 图写入指定的 HDF5 文件。默认为 None ,即不保存 NIR 图

  • dt (float) -- 网络时间步长,单位为秒,用于计算 NIR 神经元节点的时间常量等超参数。默认值为 1e-4, 与大多数兼容 NIR 的框架一致

返回:

转换得到的 NIRGraph 对象

返回类型:

nir.NIRGraph


  • English

Convert a SpikingJelly model to a NIR (Neuromorphic Intermediate Representation) graph for conversion to other frameworks or deployment on neuromorphic hardware. This function automatically infers the input and output shapes of each submodule using example_input, and converts SpikingJelly or PyTorch modules to the corresponding NIR nodes.

参数:
  • net (torch.nn.Module) -- the SpikingJelly / PyTorch model to convert

  • example_input (torch.Tensor) -- an example input tensor used to infer the input and output shapes of each submodule in net

  • save_path (Optional[Union[str, Path]]) -- the path to save the converted NIR graph. If not None, the NIR graph will be written to the specified HDF5 file. Defaults to None, which means the NIR graph will not be saved

  • dt (float) -- simulation time step in seconds, used to compute time constants and other hyperparameters for NIR neuron nodes. The default value is 1e-4, consistent with other frameworks that support NIR

返回:

the converted NIRGraph object

返回类型:

nir.NIRGraph

NIR to SpikingJelly#

spikingjelly.activation_based.nir_exchange.from_nir.import_from_nir(graph: nir.NIRGraph | str, dt: float = 0.0001, device: str = 'cpu', dtype: dtype = torch.float32, step_mode: str = 's') GraphModule[源代码]#

API Language: 中文 | English


  • 中文

NIR(Neuromorphic Intermediate Representation) 图 转换为 SpikingJelly 神经网络模型。函数会根据 NIR 节点类型自动映射为对应的 SpikingJelly 模块(如 Linear、Conv2d、IF/LIF 神经元等),并返回可直接运行的 fx.GraphModule 对象。

参数:
  • graph (Union[nir.NIRGraph, str]) -- NIR 图,或存储 NIR 图的 HDF5 文件路径

  • dt (float) -- 网络时间步长,单位为秒,用于重构 IF/LIF 节点的时间常量等超参数。默认值为 1e-4,与大多数兼容 NIR 的框架一致

  • device (str) -- 模型运行设备,如 'cpu''cuda'

  • dtype (torch.dtype) -- 模型张量数据类型,通常为 torch.float32torch.float64

  • step_mode (str) -- 步进模式,可选 's' (单步) 或 'm' (多步)。NIR 图将首先转换到单步模式的 SpikingJelly 模型, 随后统一改变模型中所有子模块的步进模式

返回:

转换得到的 fx.GraphModule 对象

返回类型:

torch.fx.GraphModule


  • English

Convert a NIR(Neuromorphic Intermediate Representation) graph to a SpikingJelly model. The function automatically maps NIR nodes to corresponding SpikingJelly modules (e.g., Linear, Conv2d, IF/LIF neurons) and returns an runnable fx.GraphModule object.

参数:
  • graph (Union[nir.NIRGraph, str]) -- NIR graph, or the path to the HDF5 file storing the NIR graph

  • dt (float) -- simulation time step in seconds, used to reconstruct time constant and other neuronal hyperparameters. Default is 1e-4, which is consistent with most frameworks that support NIR

  • device (str) -- device on which the model will run, e.g., 'cpu' or 'cuda'

  • dtype (torch.dtype) -- data type of model tensors, usually torch.float32 or torch.float64

  • step_mode (str) -- step mode, either 's' (single-step) or 'm' (multi-step). NIR graph will first be converted to a single-step SpikingJelly model. Then, all the submodules will be set to the specified step mode.

返回:

the converted SpikingJelly fx.GraphModule object

返回类型:

torch.fx.GraphModule