spikingjelly.activation_based.base module#
- spikingjelly.activation_based.base.check_backend_library(backend: str)[源代码]#
-
中文
检查某个后端的python库是否已经安装。若未安装则此函数会报
ImportError。- 参数:
backend (str) --
'torch','cupy','triton'或'lava'- 返回:
None- 返回类型:
None
- 抛出:
ImportError -- 若所请求后端依赖的 Python 库未安装,则抛出
ImportError
English
Check whether the python lib for backend is installed. If not, this function will raise an
ImportError.- 参数:
backend (str) --
'torch','cupy','triton'or'lava'- 返回:
None- 返回类型:
None
- 抛出:
ImportError -- Raised when the Python package required by
backendis not installed
- class spikingjelly.activation_based.base.StepModule[源代码]#
基类:
object
中文
步进模式接口基类。
实现该接口的模块通过
step_mode区分单步模式"s"与多步模式"m"。
English
Base interface for step-mode aware modules.
Modules implementing this interface distinguish single-step mode
"s"from multi-step mode"m"throughstep_mode.
- class spikingjelly.activation_based.base.SingleStepModule[源代码]#
基类:
StepModule
中文
单步模式模块的接口基类。
实现该接口的模块仅支持单步模式
"s"。
English
Base interface for single-step mode modules.
Modules implementing this interface only support single-step mode
"s".
- class spikingjelly.activation_based.base.MultiStepModule[源代码]#
基类:
StepModule
中文
多步模式模块的接口基类。
实现该接口的模块仅支持多步模式
"m"。
English
Base interface for multi-step mode modules.
Modules implementing this interface only support multi-step mode
"m".
- class spikingjelly.activation_based.base.MemoryModule[源代码]#
基类:
Module,StepModule
中文
SpikingJelly 中所有有状态模块的基类。
MemoryModule通过register_memory注册内部状态变量,并提供reset、detach、显式 memory 提取与恢复等通用能力。
English
Base class of all stateful modules in SpikingJelly.
MemoryModuleregisters internal state variables viaregister_memoryand provides common utilities such asreset,detach, and explicit memory extraction / restoration.- property backend#
-
中文
- 返回:
当前后端名称
- 返回类型:
English
- abstractmethod single_step_forward(x: Tensor, *args, **kwargs)[源代码]#
-
中文
本模块的单步的前向传播函数。
- 参数:
x (torch.Tensor) -- 输入张量,约定
shape = [N, *],其中N通常为 batch 维- 返回:
单步前向传播的输出
- 返回类型:
Any
English
The single-step forward function for this module.
- 参数:
x (torch.Tensor) -- Input tensor, conventionally with
shape = [N, *]whereNis usually the batch dimension- 返回:
Output of the single-step forward pass
- 返回类型:
Any
- multi_step_forward(x_seq: Tensor, *args, **kwargs)[源代码]#
-
中文
本模块的多步的前向传播函数,通过调用
T次single_step_forward(x[t], *args, **kwargs)实现- 参数:
x_seq (torch.Tensor) -- 输入序列张量,约定
shape = [T, N, *],其中第 0 维为时间维- 返回:
按时间堆叠的输出序列
- 返回类型:
- 抛出:
RuntimeError -- 若某个时间步返回值无法被
torch.stack堆叠,则底层异常会原样向上传播
English
The multi-step forward function for this module, which is implemented by calling
single_step_forward(x[t], *args, **kwargs)overTtime steps.- 参数:
x_seq (torch.Tensor) -- Input sequence tensor, conventionally with
shape = [T, N, *]and the time axis at dimension 0- 返回:
Output sequence stacked along the time dimension
- 返回类型:
- 抛出:
RuntimeError -- Any stacking failure raised by
torch.stackis propagated unchanged
- forward(*args, **kwargs)[源代码]#
-
中文
若为单步模式
step_mode == "s",则调用self.single_step_forward(...)。 若为多步模式step_mode == "m",则调用self.multi_step_forward(...)。- 返回:
与当前
step_mode对应的前向传播结果- 返回类型:
Any
- 抛出:
ValueError -- 当
self.step_mode既不是"s"也不是"m"时抛出
English
Call
self.single_step_forward(...)ifstep_mode == "s". Callself.multi_step_forward(...)ifstep_mode == "m".- 返回:
Forward result selected according to the current
step_mode- 返回类型:
Any
- 抛出:
ValueError -- Raised when
self.step_modeis neither"s"nor"m"
- register_memory(name: str, value)[源代码]#
-
中文
将变量存入用于保存有状态变量(例如脉冲神经元的膜电位)的字典中。这个变量将被初始化为
value。 每次调用self.reset()函数后,self.name都会被重置为value。警告
若状态变量是个
torch.Tensor,则 不应对其做原地修改操作 。- 参数:
name (str) -- 状态变量的名字
value (Any) -- 状态变量的初始与重制值
- 返回:
None- 返回类型:
None
- 抛出:
AssertionError -- 当
name已经是模块现有成员属性时抛出
English
Register the state variable to memory dict, which saves stateful variables (e.g., the membrane potential of a spiking neuron). The variable will be initialized as
value.self.namewill be set tovalueafter callingself.reset().警告
Do not modify the state variable in-place if it's a
torch.Tensor.- 参数:
name (str) -- state variable's name
value (Any) -- state variable's initial and reset value
- 返回:
None- 返回类型:
None
- 抛出:
AssertionError -- Raised when
namealready exists as an attribute of the module
- reset()[源代码]#
-
中文
重置所有有状态变量为重制值。
若当前状态与重制值均为同形状、同 dtype、同 device 的张量,则优先原地恢复; 否则使用复制或重新赋值恢复。
- 返回:
None- 返回类型:
None
English
Reset all stateful variables to their reset values.
If both the current state and the reset value are tensors with the same shape, dtype, and device, the state is restored in-place whenever possible; otherwise it falls back to copy or reassignment.
- 返回:
None- 返回类型:
None
- set_reset_value(name: str, value)[源代码]#
-
中文
设置状态变量
self.name的重制值。- 参数:
name (str) -- 状态变量名称
value (Any) -- 新的重制值
- 返回:
None- 返回类型:
None
English
Set the reset value of state variable
self.name.- 参数:
name (str) -- Name of the state variable
value (Any) -- New reset value
- 返回:
None- 返回类型:
None
- memories() Generator[源代码]#
-
中文
- 返回:
返回一个所有状态变量的生成器
- 返回类型:
Generator
English
- 返回:
a generator over all stateful variables
- 返回类型:
Generator
- spikingjelly.activation_based.base.named_memories(module: Module, prefix: str = '') Generator[源代码]#
-
中文
递归地生成模块树中的所有状态变量。类似于
named_parameters()方法。- 参数:
module (torch.nn.Module) -- 目标模块
prefix (str) -- 名称前缀
- 返回:
状态变量名称和值的生成器
- 返回类型:
Generator
- 抛出:
RecursionError -- 若模块树存在异常递归结构,Python 递归遍历时会抛出异常
English
Recursively yield all memory variables in a module tree. Similar to
named_parameters().- 参数:
module (torch.nn.Module) -- the target module
prefix (str) -- name prefix
- 返回:
a generator of memory variable names and values
- 返回类型:
Generator
- 抛出:
RecursionError -- Raised if traversing the module tree exceeds Python recursion limits
- spikingjelly.activation_based.base.memories(module: Module) Generator[源代码]#
-
中文
递归地生成模块树中的所有状态变量值。类似于
parameters()方法。- 参数:
module (nn.Module) -- 目标模块
- 返回:
状态变量值的生成器
- 返回类型:
Generator
- 抛出:
RecursionError -- 若模块树存在异常递归结构,Python 递归遍历时会抛出异常
English
Recursively yield all memory variables in a module tree. Similar to
parameters().- 参数:
module (nn.Module) -- the target module
- 返回:
a generator of memory variable values
- 返回类型:
Generator
- 抛出:
RecursionError -- Raised if traversing the module tree exceeds Python recursion limits
- spikingjelly.activation_based.base.extract_memories(module: Module) list[源代码]#
-
中文
提取模块中所有的状态变量值并返回列表。
- 参数:
module (torch.nn.Module) -- 目标模块
- 返回:
状态变量值的列表
- 返回类型:
- 抛出:
RecursionError -- 若模块树存在异常递归结构,Python 递归遍历时会抛出异常
English
Extract all memory variable values from the module and return as a list.
- 参数:
module (torch.nn.Module) -- the target module
- 返回:
a list of memory variable values
- 返回类型:
- 抛出:
RecursionError -- Raised if traversing the module tree exceeds Python recursion limits
- spikingjelly.activation_based.base.load_memories(module: Module, memory_list: list)[源代码]#
-
中文
将状态变量列表加载到模块中。
- 参数:
module (torch.nn.Module) -- 目标模块
memory_list (list) -- 状态变量值列表
- 返回:
None- 返回类型:
None
- 抛出:
ValueError -- 当
memory_list的长度与module当前状态变量数量不一致时抛出
English
Load memory variables from a list into the module.
- 参数:
module (torch.nn.Module) -- the target module
memory_list (list) -- list of memory variable values
- 返回:
None- 返回类型:
None
- 抛出:
ValueError -- Raised when the length of
memory_listdoes not match the number of current memory variables inmodule
- spikingjelly.activation_based.base.to_functional_forward(module: Module, fn: Callable | None = None)[源代码]#
-
中文
给定一个可能包含隐式状态变量(记忆,memory)的模块,获取其显式状态的前向传播函数。
对于包含状态的模块,返回的函数签名为
(*inputs, *states) -> (*outputs, *new_states), 其中:inputs为原始forward所需的常规输入参数;states为当前模块中所有状态变量的值,其顺序与extract_memories(module)一致;outputs为原始forward的输出结果;new_states为执行前向传播后更新得到的状态变量。
若模块中不存在任何状态变量,则直接返回
module.forward本身。备注
该函数通过在调用过程中 临时替换模块内部状态 的方式实现功能转换, 并在执行结束后 恢复原始状态 , 因此对模块本身不产生副作用。
警告
如果某个状态变量为
torch.Tensor,则其不应在module.forward中被原地修改。否则, 会导致输入给前向传播函数的状态变量被修改,导致意想不到的错误。- 参数:
module (torch.nn.Module) -- 目标模块
fn (Optional[Callable]) -- 含隐式状态的前向传播函数。若为
None,则默认使用module.forward。 该参数可用于指定特殊的前向传播函数(如,module的父类的forward)。默认值 为None。
- 返回:
带有显式输入输出状态的前向传播函数
- 返回类型:
Callable
- 抛出:
ValueError -- 若后续调用时提供的显式状态数量与
module当前 memory 布局不一致,则相关 helper 可能抛出异常
English
Given a module that may contain implicit state variables, get the forward function with explicit state variables.
For a stateful module, the returned function has the following signature
(*inputs, *states) -> (*outputs, *new_states)where:inputsare the regular input arguments required by the originalforward;statesare the current memory variable values, in the same order as returned byextract_memories(module);outputsare the outputs of the originalforwardmethod;new_statesare the updated memory variables after the forward pass.
If the module does not contain any memory variables,
module.forwardis returned directly.备注
The conversion is implemented by temporarily loading the provided states into the module, executing the original forward pass, extracting the updated states, and finally restoring the original internal states. Therefore, this operation has no side effects on the module itself.
警告
If a state variable is a
torch.Tensor, it should not be modified in-place inmodule.forward. Otherwise, the provided states will be modified, which may lead to unexpected errors.- 参数:
module (torch.nn.Module) -- the target module
fn (Optional[Callable]) -- the forward function to be used. If
None,module.forwardis used by default. This argument can be used to explicitly specify another forward function (e.g., theforwardmethod ofmodule's parent class). Defaults toNone.
- 返回:
a functional-style forward function with explicit and flattened states
- 返回类型:
Callable
代码示例 | Example
import torch import torch.nn as nn from spikingjelly.activation_based import base class StatefulModule(base.MemoryModule): def __init__(self): super().__init__() self.register_memory("counter", torch.tensor(0.0)) self.linear = nn.Linear(10, 5) def single_step_forward(self, x): self.counter = self.counter + 1.0 return self.linear(x) module = StatefulModule() f_forward = base.to_functional_forward(module) x = torch.randn(3, 10) initial_state = torch.tensor(0.0) output, new_state = f_forward(x, initial_state) assert torch.equal(output, module.linear(x)) assert torch.equal(new_state, initial_state + 1.0)