Online Learning Pipelines#
在线学习 的辅助函数。
Auxiliary functions for online learning .
- spikingjelly.activation_based.functional.online_learning.fptt_online_training_init_w_ra(optimizer: Optimizer) list[源代码]#
-
中文
初始化
fptt_online_training()使用的w_ra列表。返回列表中的元素顺序与optimizer.param_groups中参数的遍历顺序一致,列表元素是各参数当前的w.data。- 参数:
optimizer (torch.optim.Optimizer) -- 网络使用的优化器
- 返回:
与优化器参数顺序对齐的运行平均列表,列表元素为各参数当前的
w.data- 返回类型:
- 抛出:
Exception -- 若优化器参数组中存在不可访问
.data的对象,则底层异常会原样向上传播
English
Initialize the
w_ralist used byfptt_online_training(). The returned list follows the traversal order of parameters inoptimizer.param_groupsand stores the currentw.dataof each parameter.- 参数:
optimizer (torch.optim.Optimizer) -- the optimizer for the network
- 返回:
a list aligned with optimizer parameter order whose elements are the current
w.datatensors- 返回类型:
- 抛出:
Exception -- Any exception raised while accessing
.dataof optimizer parameters is propagated unchanged
- spikingjelly.activation_based.functional.online_learning.fptt_online_training(model: Module, optimizer: Optimizer, x_seq: Tensor, target_seq: Tensor, f_loss_t: Callable, alpha: float, w_ra: list) None[源代码]#
-
中文
使用 FPTT 在线训练方法沿
x_seq.shape[0]对应的时间维逐步训练网络。每个时间步都会执行一次 前向、损失计算、参数更新与detach_net,并对spikingjelly.activation_based.base.MemoryModule的内部状态进行保存和恢复。该函数要求
x_seq与target_seq的时间维均位于第 0 维,且长度一致。w_ra应由fptt_online_training_init_w_ra()初始化,并与optimizer当前参数顺序保持一致。- 参数:
model (nn.Module) -- 神经网络
optimizer (torch.optim.Optimizer) -- 网络使用的优化器
x_seq (torch.Tensor) -- 输入序列
target_seq (torch.Tensor) -- 目标序列
f_loss_t (Callable) -- 单个时间步的损失函数,调用形式应为
f_loss_t(y_t, target_t) -> torch.Tensoralpha (float) -- FPTT 使用的超参数
w_ra (list[torch.Tensor]) -- 由
fptt_online_training_init_w_ra()初始化的运行平均列表, 其中每个元素与一个优化器参数对应
- 返回:
None- 返回类型:
None
- 抛出:
IndexError -- 若
target_seq的时间长度小于x_seq,按时间步索引目标时会抛出异常Exception -- 任何模型前向、损失计算、反向传播或优化器更新异常都会原样向上传播
English
The FPTT online learning method proposed by Training Recurrent Neural Networks via Forward Propagation Through Time and used for SNN in Accurate online training of dynamical spiking neural networks through Forward Propagation Through Time . This function iterates over the time dimension
x_seq.shape[0]and performs forward, loss computation, parameter update, anddetach_netat every time step. It also stores and restores the internal states ofspikingjelly.activation_based.base.MemoryModule.The function expects both
x_seqandtarget_seqto place the time axis at dimension 0 and to share the same temporal length.w_rashould be initialized byfptt_online_training_init_w_ra()and remain aligned with the current parameter order ofoptimizer.- 参数:
model (nn.Module) -- the neural network
optimizer (torch.optim.Optimizer) -- the optimizer for the network
x_seq (torch.Tensor) -- the input sequence
target_seq (torch.Tensor) -- the target sequence
f_loss_t (Callable) -- the loss function, which should have the formulation of
def f_loss_t(y_t, target_t) -> torch.Tensoralpha (float) -- the hyper-parameter
w_ra (list[torch.Tensor]) -- the running-average list initialized by
fptt_online_training_init_w_ra(), where each element corresponds to one optimizer parameter
- 返回:
None- 返回类型:
None
- 抛出:
IndexError -- Raised when
target_seqis shorter thanx_seqalong the time dimensionException -- Any exception raised during model forward, loss computation, backward pass, or optimizer update is propagated unchanged
代码示例 | Example
from spikingjelly.activation_based import neuron net = nn.Sequential( nn.Linear(8, 4), neuron.IFNode(), nn.Linear(4, 2), neuron.IFNode() ) optimizer = torch.optim.SGD(net.parameters(), lr=0.1) T = 4 N = 2 w_ra = fptt_online_training_init_w_ra(optimizer) for epoch in range(2): x_seq = torch.rand([T, N, 8]) target_seq = torch.rand([T, N, 2]) fptt_online_training( model=net, optimizer=optimizer, x_seq=x_seq, target_seq=target_seq, f_loss_t=F.mse_loss, alpha=0.1, w_ra=w_ra, ) functional.reset_net(net)
- spikingjelly.activation_based.functional.online_learning.ottt_online_training(model: Module, optimizer: Optimizer, x_seq: Tensor, target_seq: Tensor, f_loss_t: Callable, online: bool) Tuple[Tensor, Tensor][源代码]#
-
中文
使用 OTTT 在线训练方法训练网络,也可用于文献中提到的 SLTT 训练。函数会先将
x_seq和target_seq从[B, T, ...]转置为[T, B, ...],然后沿时间维逐步执行 前向与反向传播。若online为True,则每个时间步都会执行一次参数更新;否则先累积整段序列的梯度, 再在最后统一更新。该函数要求
x_seq与target_seq的前两维分别表示 batch 和 time,且 两者在这两维上的长度一致。- 参数:
model (nn.Module) -- 神经网络
optimizer (torch.optim.Optimizer) -- 网络使用的优化器
x_seq (torch.Tensor) -- 输入序列,形状为
[B, T, ...]target_seq (torch.Tensor) -- 目标序列,形状为
[B, T, ...]f_loss_t (Callable) -- 单个时间步的损失函数,调用形式应为
f_loss_t(y_t, target_t) -> torch.Tensoronline (bool) -- 是否在每个时间步在线更新参数;若为
False,则仅在整段序列结束后更新一次
- 返回:
(batch_loss, y_all),其中batch_loss是各时间步损失之和,y_all是形状为[B, T, ...]的按时间堆叠且已 detach 的输出- 返回类型:
- 抛出:
IndexError -- 若
target_seq与x_seq在时间维长度不一致,则按时间步索引时会抛出异常Exception -- 任何模型前向、损失计算、反向传播或优化器更新异常都会原样向上传播
English
The OTTT online training method is proposed by Online Training Through Time for Spiking Neural Networks. This function can also be used for SLTT training method proposed by Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks . It first transposes
x_seqandtarget_seqfrom[B, T, ...]to[T, B, ...]and then runs forward and backward passes step by step along the time dimension. IfonlineisTrue, the optimizer updates parameters at every time step; otherwise, gradients are accumulated through the whole sequence and applied once at the end.The function expects
x_seqandtarget_seqto use batch and time as the first two dimensions and to share the same sizes on those dimensions.- 参数:
model (nn.Module) -- the neural network
optimizer (torch.optim.Optimizer) -- the optimizer for the network
x_seq (torch.Tensor) -- the input sequence with
shape=[B, T, ...]target_seq (torch.Tensor) -- the target sequence with
shape=[B, T, ...]f_loss_t (Callable) -- the loss function, which should have the formulation of
def f_loss_t(y_t, target_t) -> torch.Tensoronline (bool) -- whether to update parameters online at each time step or to accumulate gradients through time steps
- 返回:
(batch_loss, y_all), wherebatch_lossis the sum of per-step losses andy_allis the detached stacked output withshape=[B, T, ...]- 返回类型:
- 抛出:
IndexError -- Raised when
target_seqandx_seqdo not match on the time dimensionException -- Any exception raised during model forward, loss computation, backward pass, or optimizer update is propagated unchanged
代码示例 | Example
from spikingjelly.activation_based import neuron, layer, functional net = layer.OTTTSequential( nn.Linear(8, 4), neuron.OTTTLIFNode(), nn.Linear(4, 2), neuron.LIFNode() ) optimizer = torch.optim.SGD(net.parameters(), lr=0.1) T = 4 N = 2 online = True for epoch in range(2): x_seq = torch.rand([N, T, 8]) target_seq = torch.rand([N, T, 2]) functional.ottt_online_training( model=net, optimizer=optimizer, x_seq=x_seq, target_seq=target_seq, f_loss_t=F.mse_loss, online=online, ) functional.reset_net(net)