spikingjelly.activation_based.base module#

spikingjelly.activation_based.base.check_backend_library(backend: str)[源代码]#

API Language: 中文 | English


  • 中文

检查某个后端的python库是否已经安装。若未安装则此函数会报 ImportError

参数:

backend (str) -- 'torch', 'cupy', 'triton''lava'

返回:

None

返回类型:

None

抛出:

ImportError -- 若所请求后端依赖的 Python 库未安装,则抛出 ImportError


  • English

Check whether the python lib for backend is installed. If not, this function will raise an ImportError .

参数:

backend (str) -- 'torch', 'cupy', 'triton' or 'lava'

返回:

None

返回类型:

None

抛出:

ImportError -- Raised when the Python package required by backend is not installed

class spikingjelly.activation_based.base.StepModule[源代码]#

基类:object

API Language: 中文 | English


  • 中文

步进模式接口基类。

实现该接口的模块通过 step_mode 区分单步模式 "s" 与多步模式 "m"


  • English

Base interface for step-mode aware modules.

Modules implementing this interface distinguish single-step mode "s" from multi-step mode "m" through step_mode.

supported_step_mode() Tuple[str][源代码]#

API Language: 中文 | English


  • 中文

返回:

包含支持的步进模式的tuple。"s" 代表单步模式, "m" 代表多步模式。

返回类型:

Tuple[str]


  • English

返回:

a tuple that contains the supported step mode(s). "s" is for single-step mode, and "m" is for multi-step mode.

返回类型:

Tuple[str]

property step_mode: str#

** 中文 | English


  • 中文

返回:

模块当前使用的步进模式

返回类型:

str


  • English

返回:

the current step mode of this module

返回类型:

str

Type:

**API Language

class spikingjelly.activation_based.base.SingleStepModule[源代码]#

基类:StepModule

API Language: 中文 | English


  • 中文

单步模式模块的接口基类。

实现该接口的模块仅支持单步模式 "s"


  • English

Base interface for single-step mode modules.

Modules implementing this interface only support single-step mode "s".

supported_step_mode()[源代码]#

API Language: 中文 | English


  • 中文

返回:

仅包含 "s" 的 tuple

返回类型:

Tuple[str]


  • English

返回:

A tuple containing only "s"

返回类型:

Tuple[str]

class spikingjelly.activation_based.base.MultiStepModule[源代码]#

基类:StepModule

API Language: 中文 | English


  • 中文

多步模式模块的接口基类。

实现该接口的模块仅支持多步模式 "m"


  • English

Base interface for multi-step mode modules.

Modules implementing this interface only support multi-step mode "m".

supported_step_mode()[源代码]#

API Language: 中文 | English


  • 中文

返回:

仅包含 "m" 的 tuple

返回类型:

Tuple[str]


  • English

返回:

A tuple containing only "m"

返回类型:

Tuple[str]

class spikingjelly.activation_based.base.MemoryModule[源代码]#

基类:Module, StepModule

API Language: 中文 | English


  • 中文

SpikingJelly 中所有有状态模块的基类。

MemoryModule 通过 register_memory 注册内部状态变量,并提供 resetdetach、显式 memory 提取与恢复等通用能力。


  • English

Base class of all stateful modules in SpikingJelly.

MemoryModule registers internal state variables via register_memory and provides common utilities such as reset, detach, and explicit memory extraction / restoration.

property supported_backends: Tuple[str]#

** 中文 | English


  • 中文

返回:

支持的后端

返回类型:

Tuple[str]


  • English

返回:

supported backends

返回类型:

Tuple[str]

Type:

**API Language

property backend#

** 中文 | English


  • 中文

返回:

当前后端名称

返回类型:

str


  • English

返回:

the name of the current backend

返回类型:

str

Type:

**API Language

abstractmethod single_step_forward(x: Tensor, *args, **kwargs)[源代码]#

API Language: 中文 | English


  • 中文

本模块的单步的前向传播函数。

参数:

x (torch.Tensor) -- 输入张量,约定 shape = [N, *],其中 N 通常为 batch 维

返回:

单步前向传播的输出

返回类型:

Any


  • English

The single-step forward function for this module.

参数:

x (torch.Tensor) -- Input tensor, conventionally with shape = [N, *] where N is usually the batch dimension

返回:

Output of the single-step forward pass

返回类型:

Any

multi_step_forward(x_seq: Tensor, *args, **kwargs)[源代码]#

API Language: 中文 | English


  • 中文

本模块的多步的前向传播函数,通过调用 Tsingle_step_forward(x[t], *args, **kwargs) 实现

参数:

x_seq (torch.Tensor) -- 输入序列张量,约定 shape = [T, N, *],其中第 0 维为时间维

返回:

按时间堆叠的输出序列

返回类型:

torch.Tensor

抛出:

RuntimeError -- 若某个时间步返回值无法被 torch.stack 堆叠,则底层异常会原样向上传播


  • English

The multi-step forward function for this module, which is implemented by calling single_step_forward(x[t], *args, **kwargs) over T time steps.

参数:

x_seq (torch.Tensor) -- Input sequence tensor, conventionally with shape = [T, N, *] and the time axis at dimension 0

返回:

Output sequence stacked along the time dimension

返回类型:

torch.Tensor

抛出:

RuntimeError -- Any stacking failure raised by torch.stack is propagated unchanged

forward(*args, **kwargs)[源代码]#

API Language: 中文 | English


  • 中文

若为单步模式 step_mode == "s",则调用 self.single_step_forward(...) 。 若为多步模式 step_mode == "m",则调用 self.multi_step_forward(...)

返回:

与当前 step_mode 对应的前向传播结果

返回类型:

Any

抛出:

ValueError -- 当 self.step_mode 既不是 "s" 也不是 "m" 时抛出


  • English

Call self.single_step_forward(...) if step_mode == "s". Call self.multi_step_forward(...) if step_mode == "m".

返回:

Forward result selected according to the current step_mode

返回类型:

Any

抛出:

ValueError -- Raised when self.step_mode is neither "s" nor "m"

register_memory(name: str, value)[源代码]#

API Language: 中文 | English


  • 中文

将变量存入用于保存有状态变量(例如脉冲神经元的膜电位)的字典中。这个变量将被初始化为 value 。 每次调用 self.reset() 函数后, self.name 都会被重置为 value

警告

若状态变量是个 torch.Tensor ,则 不应对其做原地修改操作

参数:
  • name (str) -- 状态变量的名字

  • value (Any) -- 状态变量的初始与重制值

返回:

None

返回类型:

None

抛出:

AssertionError -- 当 name 已经是模块现有成员属性时抛出


  • English

Register the state variable to memory dict, which saves stateful variables (e.g., the membrane potential of a spiking neuron). The variable will be initialized as value . self.name will be set to value after calling self.reset() .

警告

Do not modify the state variable in-place if it's a torch.Tensor .

参数:
  • name (str) -- state variable's name

  • value (Any) -- state variable's initial and reset value

返回:

None

返回类型:

None

抛出:

AssertionError -- Raised when name already exists as an attribute of the module

reset()[源代码]#

API Language: 中文 | English


  • 中文

重置所有有状态变量为重制值。

若当前状态与重制值均为同形状、同 dtype、同 device 的张量,则优先原地恢复; 否则使用复制或重新赋值恢复。

返回:

None

返回类型:

None


  • English

Reset all stateful variables to their reset values.

If both the current state and the reset value are tensors with the same shape, dtype, and device, the state is restored in-place whenever possible; otherwise it falls back to copy or reassignment.

返回:

None

返回类型:

None

set_reset_value(name: str, value)[源代码]#

API Language: 中文 | English


  • 中文

设置状态变量 self.name 的重制值。

参数:
  • name (str) -- 状态变量名称

  • value (Any) -- 新的重制值

返回:

None

返回类型:

None


  • English

Set the reset value of state variable self.name .

参数:
  • name (str) -- Name of the state variable

  • value (Any) -- New reset value

返回:

None

返回类型:

None

memories() Generator[源代码]#

API Language: 中文 | English


  • 中文

返回:

返回一个所有状态变量的生成器

返回类型:

Generator


  • English

返回:

a generator over all stateful variables

返回类型:

Generator

named_memories() Generator[源代码]#

API Language: 中文 | English


  • 中文

返回:

返回一个所有状态变量名称及其值的生成器

返回类型:

Generator


  • English

返回:

a generator over all stateful variables' names and values

返回类型:

Generator

detach()[源代码]#

API Language: 中文 | English


  • 中文

从计算图中分离所有有状态变量。

小技巧

可以使用这个函数实现TBPTT (Truncated Back Propagation Through Time)。

返回:

None

返回类型:

None


  • English

Detach all stateful variables.

Tip

We can use this function to implement TBPTT (Truncated Back Propagation Through Time).

返回:

None

返回类型:

None

spikingjelly.activation_based.base.named_memories(module: Module, prefix: str = '') Generator[源代码]#

API Language: 中文 | English


  • 中文

递归地生成模块树中的所有状态变量。类似于 named_parameters() 方法。

参数:
返回:

状态变量名称和值的生成器

返回类型:

Generator

抛出:

RecursionError -- 若模块树存在异常递归结构,Python 递归遍历时会抛出异常


  • English

Recursively yield all memory variables in a module tree. Similar to named_parameters() .

参数:
返回:

a generator of memory variable names and values

返回类型:

Generator

抛出:

RecursionError -- Raised if traversing the module tree exceeds Python recursion limits

spikingjelly.activation_based.base.memories(module: Module) Generator[源代码]#

API Language: 中文 | English


  • 中文

递归地生成模块树中的所有状态变量值。类似于 parameters() 方法。

参数:

module (nn.Module) -- 目标模块

返回:

状态变量值的生成器

返回类型:

Generator

抛出:

RecursionError -- 若模块树存在异常递归结构,Python 递归遍历时会抛出异常


  • English

Recursively yield all memory variables in a module tree. Similar to parameters() .

参数:

module (nn.Module) -- the target module

返回:

a generator of memory variable values

返回类型:

Generator

抛出:

RecursionError -- Raised if traversing the module tree exceeds Python recursion limits

spikingjelly.activation_based.base.extract_memories(module: Module) list[源代码]#

API Language: 中文 | English


  • 中文

提取模块中所有的状态变量值并返回列表。

参数:

module (torch.nn.Module) -- 目标模块

返回:

状态变量值的列表

返回类型:

list

抛出:

RecursionError -- 若模块树存在异常递归结构,Python 递归遍历时会抛出异常


  • English

Extract all memory variable values from the module and return as a list.

参数:

module (torch.nn.Module) -- the target module

返回:

a list of memory variable values

返回类型:

list

抛出:

RecursionError -- Raised if traversing the module tree exceeds Python recursion limits

spikingjelly.activation_based.base.load_memories(module: Module, memory_list: list)[源代码]#

API Language: 中文 | English


  • 中文

将状态变量列表加载到模块中。

参数:
返回:

None

返回类型:

None

抛出:

ValueError -- 当 memory_list 的长度与 module 当前状态变量数量不一致时抛出


  • English

Load memory variables from a list into the module.

参数:
  • module (torch.nn.Module) -- the target module

  • memory_list (list) -- list of memory variable values

返回:

None

返回类型:

None

抛出:

ValueError -- Raised when the length of memory_list does not match the number of current memory variables in module

spikingjelly.activation_based.base.to_functional_forward(module: Module, fn: Callable | None = None)[源代码]#

API Language: 中文 | English


  • 中文

给定一个可能包含隐式状态变量(记忆,memory)的模块,获取其显式状态的前向传播函数。

对于包含状态的模块,返回的函数签名为 (*inputs, *states) -> (*outputs, *new_states) , 其中:

  • inputs 为原始 forward 所需的常规输入参数;

  • states 为当前模块中所有状态变量的值,其顺序与 extract_memories(module) 一致;

  • outputs 为原始 forward 的输出结果;

  • new_states 为执行前向传播后更新得到的状态变量。

若模块中不存在任何状态变量,则直接返回 module.forward 本身。

备注

该函数通过在调用过程中 临时替换模块内部状态 的方式实现功能转换, 并在执行结束后 恢复原始状态 , 因此对模块本身不产生副作用。

警告

如果某个状态变量为 torch.Tensor ,则其不应在 module.forward 中被原地修改。否则, 会导致输入给前向传播函数的状态变量被修改,导致意想不到的错误。

参数:
  • module (torch.nn.Module) -- 目标模块

  • fn (Optional[Callable]) -- 含隐式状态的前向传播函数。若为 None ,则默认使用 module.forward 。 该参数可用于指定特殊的前向传播函数(如, module 的父类的 forward )。默认值 为 None

返回:

带有显式输入输出状态的前向传播函数

返回类型:

Callable

抛出:

ValueError -- 若后续调用时提供的显式状态数量与 module 当前 memory 布局不一致,则相关 helper 可能抛出异常


  • English

Given a module that may contain implicit state variables, get the forward function with explicit state variables.

For a stateful module, the returned function has the following signature (*inputs, *states) -> (*outputs, *new_states) where:

  • inputs are the regular input arguments required by the original forward;

  • states are the current memory variable values, in the same order as returned by extract_memories(module);

  • outputs are the outputs of the original forward method;

  • new_states are the updated memory variables after the forward pass.

If the module does not contain any memory variables, module.forward is returned directly.

备注

The conversion is implemented by temporarily loading the provided states into the module, executing the original forward pass, extracting the updated states, and finally restoring the original internal states. Therefore, this operation has no side effects on the module itself.

警告

If a state variable is a torch.Tensor, it should not be modified in-place in module.forward. Otherwise, the provided states will be modified, which may lead to unexpected errors.

参数:
  • module (torch.nn.Module) -- the target module

  • fn (Optional[Callable]) -- the forward function to be used. If None, module.forward is used by default. This argument can be used to explicitly specify another forward function (e.g., the forward method of module's parent class). Defaults to None.

返回:

a functional-style forward function with explicit and flattened states

返回类型:

Callable


  • 代码示例 | Example

import torch
import torch.nn as nn
from spikingjelly.activation_based import base

class StatefulModule(base.MemoryModule):
    def __init__(self):
        super().__init__()
        self.register_memory("counter", torch.tensor(0.0))
        self.linear = nn.Linear(10, 5)

    def single_step_forward(self, x):
        self.counter = self.counter + 1.0
        return self.linear(x)

module = StatefulModule()
f_forward = base.to_functional_forward(module)
x = torch.randn(3, 10)
initial_state = torch.tensor(0.0)
output, new_state = f_forward(x, initial_state)

assert torch.equal(output, module.linear(x))
assert torch.equal(new_state, initial_state + 1.0)