Forward Functions#
SpikingJelly 的 前向传播函数 实现了 SNN 的多步前向传播逻辑。
SpikingJelly's forward functions provide multi-step forward propagation logic for SNNs.
- spikingjelly.activation_based.functional.forward.multi_step_forward(x_seq: Tensor, single_step_module: Module | list[Module] | tuple[Module] | Sequential | Callable)[源代码]#
-
中文
在单步模块
single_step_module上使用多步前向传播。函数内部将执行一个for循环, 执行T次单步前向传播。若single_step_module为多个模块,则每个时间步都会按顺序依次执行这些模块。- 参数:
x_seq (torch.Tensor) --
shape=[T, batch_size, ...]的输入tensorsingle_step_module (Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]) -- 一个或多个单步模块
- 返回:
shape=[T, batch_size, ...]的输出tensor- 返回类型:
- 抛出:
Exception -- 任何底层模块在某个时间步前向传播时抛出的异常都会原样向上传播
English
Applies multi-step forward on
single_step_module. The function runs a for loop to execute single-step forward forTtimes. Ifsingle_step_modulecontains multiple modules, they are applied sequentially at each time-step.- 参数:
x_seq (torch.Tensor) -- the input tensor with
shape=[T, batch_size, ...]single_step_module (Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]) -- one or many single-step modules
- 返回:
the output tensor with
shape=[T, batch_size, ...]- 返回类型:
- 抛出:
Exception -- Any exception raised by an underlying module at any time step is propagated unchanged
- spikingjelly.activation_based.functional.forward.t_last_multi_step_forward(x_seq: Tensor, single_step_module: Module | list[Module] | tuple[Module] | Sequential | Callable)[源代码]#
-
中文
在单步模块
single_step_module上使用多步前向传播。此函数适用于时间维位于最后一维的序列张量,即
shape=[batch_size, ..., T]。 它会沿最后一维逐个时间步取出切片,并在每个时间步顺序执行单步模块。- 参数:
- 返回:
shape=[batch_size, ..., T]的输出tensor- 返回类型:
- 抛出:
Exception -- 任何底层模块在某个时间步前向传播时抛出的异常都会原样向上传播
English
Apply multi-step forward on
single_step_module.This helper is intended for sequence tensors whose time axis is the last dimension, i.e.
shape=[batch_size, ..., T]. It slices along the last dimension and applies the single-step module(s) at each time step.- 参数:
x_seq (torch.Tensor) -- the input tensor with
shape=[batch_size, ..., T]single_step_module (Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]) -- one or many single-step modules
- 返回:
the output tensor with
shape=[batch_size, ..., T]- 返回类型:
- 抛出:
Exception -- Any exception raised by an underlying module at any time step is propagated unchanged
- spikingjelly.activation_based.functional.forward.chunk_multi_step_forward(split_size: int, x_seq: Tensor, multi_step_module: Module)[源代码]#
-
中文
将
shape = [T, *]的输入x_seq拆分成多个shape = [split_size, *]的小tensor(若T % split_size != 0,最后一个tensor的shape[0]会小于split_size),然后逐个输入到multi_step_module中,再沿着dim=0将输出重新拼接,因此输出的首维长度仍为T。chunk_multi_step_forward可以在使用很大的T进行不带梯度的推理(例如ANN2SNN)时使用,能够减少内存消耗量。- 参数:
split_size (int) -- 分割的尺寸
x_seq (torch.Tensor) -- 输入
multi_step_module (torch.nn.Module) -- 一个使用多步传播模式的网络
- 返回:
输出
- 返回类型:
- 抛出:
Exception -- 任何
multi_step_module在某个分块上的前向传播异常都会原样向上传播
English
Splits the input
x_seqwithshape = [T, *]to many tensor chunks withshape = [split_size, *](ifT % split_size != 0,shape[0]of the last tensor chunk will be smaller thansplit_size), and sends chunks tomulti_step_module, then concatenates the outputs back alongdim=0, so the output keeps the original leading lengthT.chunk_multi_step_forwardcan be used for inference with a largeT(e.g., ANN2SNN) to reduce the memory consumption.- 参数:
split_size (int) -- the split size
x_seq (Tensor) -- the input tensor
multi_step_module (nn.Module) -- a network in multi-step mode
- 返回:
the output tensor
- 返回类型:
Tensor
- 抛出:
Exception -- Any exception raised by
multi_step_moduleon a chunk is propagated unchanged
代码示例 | Example
import torch import torch.nn as nn from spikingjelly.activation_based import neuron, layer, functional net = nn.Sequential( layer.Linear(8, 4), neuron.IFNode(step_mode="m"), layer.Linear(4, 2), neuron.IFNode(step_mode="m"), ) x_seq = torch.rand([1024, 8]) with torch.no_grad(): y_seq = functional.chunk_multi_step_forward(16, x_seq, net) print(y_seq.shape) # torch.Size([1024, 2])
- spikingjelly.activation_based.functional.forward.seq_to_ann_forward(x_seq: Tensor, stateless_module: Module | list | tuple | Sequential | Callable)[源代码]#
-
中文
使用无状态层进行多步前向传播。输入
x_seq的时间和批量维度将被展平,得到[T*batch_size, ...]形状的张量;随后,输入到无状态层中;最后,将输出张量恢复到序列形式[T, batch_size, ...]。- 参数:
x_seq (torch.Tensor) --
shape=[T, batch_size, ...]的输入tensorstateless_module (Union[torch.nn.Module, list, tuple, torch.nn.Sequential, Callable]) -- 单个或多个无状态网络层
- 返回:
shape=[T, batch_size, ...]的输出tensor- 返回类型:
- 抛出:
Exception -- 任何底层无状态模块在前向传播时抛出的异常都会原样向上传播
English
Applied forward on stateless modules. Flatten the time and batch dimensions of
x_seqso thatshape=[T*batch_size, ...], feed the reshaped tensor to the stateless module(s), and reshape the output back to the sequence formshape=[T, batch_size, ...].- 参数:
x_seq (torch.Tensor) -- the input tensor with
shape=[T, batch_size, ...]stateless_module (Union[torch.nn.Module, list, tuple, torch.nn.Sequential, Callable]) -- one or many stateless modules
- 返回:
the output tensor with
shape=[T, batch_size, ...]- 返回类型:
- 抛出:
Exception -- Any exception raised by an underlying stateless module is propagated unchanged
- spikingjelly.activation_based.functional.forward.t_last_seq_to_ann_forward(x_seq: Tensor, stateless_module: Module | list | tuple | Sequential | Callable)[源代码]#
-
中文
使用无状态层进行多步前向传播。
备注
SpikingJelly中默认序列数据形状为
shape=[T, batch_size, ...]。 但此函数是用于另一种格式,即shape=[batch_size, ..., T]。 当torch.vmap可用时,此函数会直接调用torch.vmap(stateless_module, in_dims=-1, out_dims=-1)并行地对时间维执行前向传播。 因此此路径要求stateless_module是可直接调用的对象,例如nn.Module或nn.Sequential。普通的list或tuple仅在torch.vmap不可用、退化到t_last_multi_step_forward()时才可用。备注
不能用于BN层,因为BN层的running mean/var是输入依赖的。 对于BN层,只需要输入被当作是
shape = [N, C, ..]即可并行计算,需要用户手动实现。- 参数:
x_seq (torch.Tensor) --
shape=[batch_size, ..., T]的输入tensorstateless_module (Union[torch.nn.Module, list, tuple, torch.nn.Sequential, Callable]) -- 单个或多个无状态网络层
- 返回:
shape=[batch_size, ..., T]的输出tensor- 返回类型:
- 抛出:
English
Applied forward on stateless modules.
Note
The default shape of sequence data in SpikingJelly is
shape=[T, batch_size, ...]. However, this function is used for the other data format whereshape=[batch_size, ..., T]. Whentorch.vmapis available, this function callstorch.vmap(stateless_module, in_dims=-1, out_dims=-1)to apply the forward pass over the time dimension in parallel. Therefore, this path requiresstateless_moduleto be directly callable, e.g.nn.Moduleornn.Sequential. Plainlistandtupleinputs only work on the fallback path that usest_last_multi_step_forward()whentorch.vmapis unavailable.Note
This function can not be applied to wrap BN because its running mean/var depends on inputs. The BN can be computed in parallel as long as the input is regarded as
shape = [N, C, ..], which can be implemented by user manually.- 参数:
x_seq (torch.Tensor) -- the input tensor with
shape=[batch_size, ..., T]stateless_module (Union[torch.nn.Module, list, tuple, torch.nn.Sequential, Callable]) -- one or many stateless modules
- 返回:
the output tensor with
shape=[batch_size, ..., T]- 返回类型:
- 抛出: