Forward Functions#

SpikingJelly 的 前向传播函数 实现了 SNN 的多步前向传播逻辑。


SpikingJelly's forward functions provide multi-step forward propagation logic for SNNs.

spikingjelly.activation_based.functional.forward.multi_step_forward(x_seq: Tensor, single_step_module: Module | list[Module] | tuple[Module] | Sequential | Callable)[源代码]#

API Language: 中文 | English


  • 中文

在单步模块 single_step_module 上使用多步前向传播。函数内部将执行一个for循环, 执行 T 次单步前向传播。若 single_step_module 为多个模块,则每个时间步都会按顺序依次执行这些模块。

参数:
  • x_seq (torch.Tensor) -- shape=[T, batch_size, ...] 的输入tensor

  • single_step_module (Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]) -- 一个或多个单步模块

返回:

shape=[T, batch_size, ...] 的输出tensor

返回类型:

torch.Tensor

抛出:

Exception -- 任何底层模块在某个时间步前向传播时抛出的异常都会原样向上传播


  • English

Applies multi-step forward on single_step_module. The function runs a for loop to execute single-step forward for T times. If single_step_module contains multiple modules, they are applied sequentially at each time-step.

参数:
  • x_seq (torch.Tensor) -- the input tensor with shape=[T, batch_size, ...]

  • single_step_module (Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]) -- one or many single-step modules

返回:

the output tensor with shape=[T, batch_size, ...]

返回类型:

torch.Tensor

抛出:

Exception -- Any exception raised by an underlying module at any time step is propagated unchanged

spikingjelly.activation_based.functional.forward.t_last_multi_step_forward(x_seq: Tensor, single_step_module: Module | list[Module] | tuple[Module] | Sequential | Callable)[源代码]#

API Language: 中文 | English


  • 中文

在单步模块 single_step_module 上使用多步前向传播。

此函数适用于时间维位于最后一维的序列张量,即 shape=[batch_size, ..., T]。 它会沿最后一维逐个时间步取出切片,并在每个时间步顺序执行单步模块。

参数:
  • x_seq (Tensor) -- shape=[batch_size, ..., T] 的输入tensor

  • single_step_module (Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]) -- 一个或多个单步模块

返回:

shape=[batch_size, ..., T] 的输出tensor

返回类型:

torch.Tensor

抛出:

Exception -- 任何底层模块在某个时间步前向传播时抛出的异常都会原样向上传播


  • English

Apply multi-step forward on single_step_module.

This helper is intended for sequence tensors whose time axis is the last dimension, i.e. shape=[batch_size, ..., T]. It slices along the last dimension and applies the single-step module(s) at each time step.

参数:
  • x_seq (torch.Tensor) -- the input tensor with shape=[batch_size, ..., T]

  • single_step_module (Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]) -- one or many single-step modules

返回:

the output tensor with shape=[batch_size, ..., T]

返回类型:

torch.Tensor

抛出:

Exception -- Any exception raised by an underlying module at any time step is propagated unchanged

spikingjelly.activation_based.functional.forward.chunk_multi_step_forward(split_size: int, x_seq: Tensor, multi_step_module: Module)[源代码]#

API Language: 中文 | English


  • 中文

shape = [T, *] 的输入 x_seq 拆分成多个 shape = [split_size, *] 的小tensor(若 T % split_size != 0,最后一个tensor的 shape[0] 会小于 split_size),然后逐个输入到 multi_step_module 中,再沿着 dim=0 将输出重新拼接,因此输出的首维长度仍为 T

chunk_multi_step_forward 可以在使用很大的 T 进行不带梯度的推理(例如ANN2SNN)时使用,能够减少内存消耗量。

参数:
返回:

输出

返回类型:

torch.Tensor

抛出:

Exception -- 任何 multi_step_module 在某个分块上的前向传播异常都会原样向上传播


  • English

Splits the input x_seq with shape = [T, *] to many tensor chunks with shape = [split_size, *] (if T % split_size != 0, shape[0] of the last tensor chunk will be smaller than split_size), and sends chunks to multi_step_module, then concatenates the outputs back along dim=0, so the output keeps the original leading length T.

chunk_multi_step_forward can be used for inference with a large T (e.g., ANN2SNN) to reduce the memory consumption.

参数:
  • split_size (int) -- the split size

  • x_seq (Tensor) -- the input tensor

  • multi_step_module (nn.Module) -- a network in multi-step mode

返回:

the output tensor

返回类型:

Tensor

抛出:

Exception -- Any exception raised by multi_step_module on a chunk is propagated unchanged


  • 代码示例 | Example

import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, layer, functional

net = nn.Sequential(
    layer.Linear(8, 4),
    neuron.IFNode(step_mode="m"),
    layer.Linear(4, 2),
    neuron.IFNode(step_mode="m"),
)

x_seq = torch.rand([1024, 8])
with torch.no_grad():
    y_seq = functional.chunk_multi_step_forward(16, x_seq, net)
    print(y_seq.shape)
    # torch.Size([1024, 2])
spikingjelly.activation_based.functional.forward.seq_to_ann_forward(x_seq: Tensor, stateless_module: Module | list | tuple | Sequential | Callable)[源代码]#

API Language: 中文 | English


  • 中文

使用无状态层进行多步前向传播。输入 x_seq 的时间和批量维度将被展平,得到 [T*batch_size, ...] 形状的张量;随后,输入到无状态层中;最后,将输出张量恢复到序列形式 [T, batch_size, ...]

参数:
返回:

shape=[T, batch_size, ...] 的输出tensor

返回类型:

torch.Tensor

抛出:

Exception -- 任何底层无状态模块在前向传播时抛出的异常都会原样向上传播


  • English

Applied forward on stateless modules. Flatten the time and batch dimensions of x_seq so that shape=[T*batch_size, ...], feed the reshaped tensor to the stateless module(s), and reshape the output back to the sequence form shape=[T, batch_size, ...].

参数:
返回:

the output tensor with shape=[T, batch_size, ...]

返回类型:

torch.Tensor

抛出:

Exception -- Any exception raised by an underlying stateless module is propagated unchanged

spikingjelly.activation_based.functional.forward.t_last_seq_to_ann_forward(x_seq: Tensor, stateless_module: Module | list | tuple | Sequential | Callable)[源代码]#

API Language: 中文 | English


  • 中文

使用无状态层进行多步前向传播。

备注

SpikingJelly中默认序列数据形状为 shape=[T, batch_size, ...]。 但此函数是用于另一种格式,即 shape=[batch_size, ..., T]。 当 torch.vmap 可用时,此函数会直接调用 torch.vmap(stateless_module, in_dims=-1, out_dims=-1) 并行地对时间维执行前向传播。 因此此路径要求 stateless_module 是可直接调用的对象,例如 nn.Modulenn.Sequential。普通的 listtuple 仅在 torch.vmap 不可用、退化到 t_last_multi_step_forward() 时才可用。

备注

不能用于BN层,因为BN层的running mean/var是输入依赖的。 对于BN层,只需要输入被当作是 shape = [N, C, ..] 即可并行计算,需要用户手动实现。

参数:
返回:

shape=[batch_size, ..., T] 的输出tensor

返回类型:

torch.Tensor

抛出:
  • TypeError -- 当 torch.vmap 可用但 stateless_module 不是可直接调用对象时,torch.vmap 路径可能抛出类型错误

  • Exception -- 任何底层无状态模块在 vmap 或 fallback 路径中抛出的异常都会原样向上传播


  • English

Applied forward on stateless modules.

Note

The default shape of sequence data in SpikingJelly is shape=[T, batch_size, ...]. However, this function is used for the other data format where shape=[batch_size, ..., T]. When torch.vmap is available, this function calls torch.vmap(stateless_module, in_dims=-1, out_dims=-1) to apply the forward pass over the time dimension in parallel. Therefore, this path requires stateless_module to be directly callable, e.g. nn.Module or nn.Sequential. Plain list and tuple inputs only work on the fallback path that uses t_last_multi_step_forward() when torch.vmap is unavailable.

Note

This function can not be applied to wrap BN because its running mean/var depends on inputs. The BN can be computed in parallel as long as the input is regarded as shape = [N, C, ..], which can be implemented by user manually.

参数:
返回:

the output tensor with shape=[batch_size, ..., T]

返回类型:

torch.Tensor

抛出:
  • TypeError -- When torch.vmap is available, the vmap path may raise a type error if stateless_module is not directly callable

  • Exception -- Any exception raised by an underlying stateless module on either the vmap or fallback path is propagated unchanged