Online Learning Pipelines#

在线学习 的辅助函数。


Auxiliary functions for online learning .

spikingjelly.activation_based.functional.online_learning.fptt_online_training_init_w_ra(optimizer: Optimizer) list[源代码]#

API Language: 中文 | English


  • 中文

初始化 fptt_online_training() 使用的 w_ra 列表。返回列表中的元素顺序与 optimizer.param_groups 中参数的遍历顺序一致,列表元素是各参数当前的 w.data

参数:

optimizer (torch.optim.Optimizer) -- 网络使用的优化器

返回:

与优化器参数顺序对齐的运行平均列表,列表元素为各参数当前的 w.data

返回类型:

list[torch.Tensor]

抛出:

Exception -- 若优化器参数组中存在不可访问 .data 的对象,则底层异常会原样向上传播


  • English

Initialize the w_ra list used by fptt_online_training(). The returned list follows the traversal order of parameters in optimizer.param_groups and stores the current w.data of each parameter.

参数:

optimizer (torch.optim.Optimizer) -- the optimizer for the network

返回:

a list aligned with optimizer parameter order whose elements are the current w.data tensors

返回类型:

list[torch.Tensor]

抛出:

Exception -- Any exception raised while accessing .data of optimizer parameters is propagated unchanged

spikingjelly.activation_based.functional.online_learning.fptt_online_training(model: Module, optimizer: Optimizer, x_seq: Tensor, target_seq: Tensor, f_loss_t: Callable, alpha: float, w_ra: list) None[源代码]#

API Language: 中文 | English


  • 中文

使用 FPTT 在线训练方法沿 x_seq.shape[0] 对应的时间维逐步训练网络。每个时间步都会执行一次 前向、损失计算、参数更新与 detach_net,并对 spikingjelly.activation_based.base.MemoryModule 的内部状态进行保存和恢复。

该函数要求 x_seqtarget_seq 的时间维均位于第 0 维,且长度一致。 w_ra 应由 fptt_online_training_init_w_ra() 初始化,并与 optimizer 当前参数顺序保持一致。

参数:
返回:

None

返回类型:

None

抛出:
  • IndexError -- 若 target_seq 的时间长度小于 x_seq,按时间步索引目标时会抛出异常

  • Exception -- 任何模型前向、损失计算、反向传播或优化器更新异常都会原样向上传播


  • English

The FPTT online learning method proposed by Training Recurrent Neural Networks via Forward Propagation Through Time and used for SNN in Accurate online training of dynamical spiking neural networks through Forward Propagation Through Time . This function iterates over the time dimension x_seq.shape[0] and performs forward, loss computation, parameter update, and detach_net at every time step. It also stores and restores the internal states of spikingjelly.activation_based.base.MemoryModule.

The function expects both x_seq and target_seq to place the time axis at dimension 0 and to share the same temporal length. w_ra should be initialized by fptt_online_training_init_w_ra() and remain aligned with the current parameter order of optimizer.

参数:
  • model (nn.Module) -- the neural network

  • optimizer (torch.optim.Optimizer) -- the optimizer for the network

  • x_seq (torch.Tensor) -- the input sequence

  • target_seq (torch.Tensor) -- the target sequence

  • f_loss_t (Callable) -- the loss function, which should have the formulation of def f_loss_t(y_t, target_t) -> torch.Tensor

  • alpha (float) -- the hyper-parameter

  • w_ra (list[torch.Tensor]) -- the running-average list initialized by fptt_online_training_init_w_ra(), where each element corresponds to one optimizer parameter

返回:

None

返回类型:

None

抛出:
  • IndexError -- Raised when target_seq is shorter than x_seq along the time dimension

  • Exception -- Any exception raised during model forward, loss computation, backward pass, or optimizer update is propagated unchanged


  • 代码示例 | Example

from spikingjelly.activation_based import neuron

net = nn.Sequential(
    nn.Linear(8, 4), neuron.IFNode(), nn.Linear(4, 2), neuron.IFNode()
)

optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

T = 4
N = 2
w_ra = fptt_online_training_init_w_ra(optimizer)
for epoch in range(2):
    x_seq = torch.rand([T, N, 8])
    target_seq = torch.rand([T, N, 2])

    fptt_online_training(
        model=net,
        optimizer=optimizer,
        x_seq=x_seq,
        target_seq=target_seq,
        f_loss_t=F.mse_loss,
        alpha=0.1,
        w_ra=w_ra,
    )
    functional.reset_net(net)
spikingjelly.activation_based.functional.online_learning.ottt_online_training(model: Module, optimizer: Optimizer, x_seq: Tensor, target_seq: Tensor, f_loss_t: Callable, online: bool) Tuple[Tensor, Tensor][源代码]#

API Language: 中文 | English


  • 中文

使用 OTTT 在线训练方法训练网络,也可用于文献中提到的 SLTT 训练。函数会先将 x_seqtarget_seq[B, T, ...] 转置为 [T, B, ...],然后沿时间维逐步执行 前向与反向传播。若 onlineTrue,则每个时间步都会执行一次参数更新;否则先累积整段序列的梯度, 再在最后统一更新。

该函数要求 x_seqtarget_seq 的前两维分别表示 batch 和 time,且 两者在这两维上的长度一致。

参数:
  • model (nn.Module) -- 神经网络

  • optimizer (torch.optim.Optimizer) -- 网络使用的优化器

  • x_seq (torch.Tensor) -- 输入序列,形状为 [B, T, ...]

  • target_seq (torch.Tensor) -- 目标序列,形状为 [B, T, ...]

  • f_loss_t (Callable) -- 单个时间步的损失函数,调用形式应为 f_loss_t(y_t, target_t) -> torch.Tensor

  • online (bool) -- 是否在每个时间步在线更新参数;若为 False,则仅在整段序列结束后更新一次

返回:

(batch_loss, y_all),其中 batch_loss 是各时间步损失之和, y_all 是形状为 [B, T, ...] 的按时间堆叠且已 detach 的输出

返回类型:

tuple[torch.Tensor, torch.Tensor]

抛出:
  • IndexError -- 若 target_seqx_seq 在时间维长度不一致,则按时间步索引时会抛出异常

  • Exception -- 任何模型前向、损失计算、反向传播或优化器更新异常都会原样向上传播


  • English

The OTTT online training method is proposed by Online Training Through Time for Spiking Neural Networks. This function can also be used for SLTT training method proposed by Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks . It first transposes x_seq and target_seq from [B, T, ...] to [T, B, ...] and then runs forward and backward passes step by step along the time dimension. If online is True, the optimizer updates parameters at every time step; otherwise, gradients are accumulated through the whole sequence and applied once at the end.

The function expects x_seq and target_seq to use batch and time as the first two dimensions and to share the same sizes on those dimensions.

参数:
  • model (nn.Module) -- the neural network

  • optimizer (torch.optim.Optimizer) -- the optimizer for the network

  • x_seq (torch.Tensor) -- the input sequence with shape=[B, T, ...]

  • target_seq (torch.Tensor) -- the target sequence with shape=[B, T, ...]

  • f_loss_t (Callable) -- the loss function, which should have the formulation of def f_loss_t(y_t, target_t) -> torch.Tensor

  • online (bool) -- whether to update parameters online at each time step or to accumulate gradients through time steps

返回:

(batch_loss, y_all), where batch_loss is the sum of per-step losses and y_all is the detached stacked output with shape=[B, T, ...]

返回类型:

tuple[torch.Tensor, torch.Tensor]

抛出:
  • IndexError -- Raised when target_seq and x_seq do not match on the time dimension

  • Exception -- Any exception raised during model forward, loss computation, backward pass, or optimizer update is propagated unchanged


  • 代码示例 | Example

from spikingjelly.activation_based import neuron, layer, functional

net = layer.OTTTSequential(
    nn.Linear(8, 4), neuron.OTTTLIFNode(), nn.Linear(4, 2), neuron.LIFNode()
)

optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

T = 4
N = 2
online = True
for epoch in range(2):
    x_seq = torch.rand([N, T, 8])
    target_seq = torch.rand([N, T, 2])

    functional.ottt_online_training(
        model=net,
        optimizer=optimizer,
        x_seq=x_seq,
        target_seq=target_seq,
        f_loss_t=F.mse_loss,
        online=online,
    )
    functional.reset_net(net)