import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, Tuple
from .. import base
from .net_config import detach_net
__all__ = [
"fptt_online_training_init_w_ra",
"fptt_online_training",
"ottt_online_training",
]
[文档]
def fptt_online_training_init_w_ra(optimizer: torch.optim.Optimizer) -> list:
"""
**API Language:**
:ref:`中文 <fptt_online_training_init_w_ra-cn>` | :ref:`English <fptt_online_training_init_w_ra-en>`
----
.. _fptt_online_training_init_w_ra-cn:
* **中文**
初始化 :func:`fptt_online_training` 使用的 ``w_ra`` 列表。返回列表中的元素顺序与
``optimizer.param_groups`` 中参数的遍历顺序一致,列表元素是各参数当前的 ``w.data``。
:param optimizer: 网络使用的优化器
:type optimizer: torch.optim.Optimizer
:return: 与优化器参数顺序对齐的运行平均列表,列表元素为各参数当前的 ``w.data``
:rtype: list[torch.Tensor]
:raises Exception: 若优化器参数组中存在不可访问 ``.data`` 的对象,则底层异常会原样向上传播
----
.. _fptt_online_training_init_w_ra-en:
* **English**
Initialize the ``w_ra`` list used by :func:`fptt_online_training`. The
returned list follows the traversal order of parameters in
``optimizer.param_groups`` and stores the current ``w.data`` of each
parameter.
:param optimizer: the optimizer for the network
:type optimizer: torch.optim.Optimizer
:return: a list aligned with optimizer parameter order whose elements are
the current ``w.data`` tensors
:rtype: list[torch.Tensor]
:raises Exception: Any exception raised while accessing ``.data`` of optimizer parameters is propagated unchanged
"""
w_ra = []
for item in optimizer.param_groups:
for w in item["params"]:
w_ra.append(w.data)
return w_ra
[文档]
def fptt_online_training(
model: nn.Module,
optimizer: torch.optim.Optimizer,
x_seq: torch.Tensor,
target_seq: torch.Tensor,
f_loss_t: Callable,
alpha: float,
w_ra: list,
) -> None:
"""
**API Language:**
:ref:`中文 <fptt_online_training-cn>` | :ref:`English <fptt_online_training-en>`
----
.. _fptt_online_training-cn:
* **中文**
使用 FPTT 在线训练方法沿 ``x_seq.shape[0]`` 对应的时间维逐步训练网络。每个时间步都会执行一次
前向、损失计算、参数更新与 ``detach_net``,并对
:class:`spikingjelly.activation_based.base.MemoryModule` 的内部状态进行保存和恢复。
该函数要求 ``x_seq`` 与 ``target_seq`` 的时间维均位于第 0 维,且长度一致。
``w_ra`` 应由 :func:`fptt_online_training_init_w_ra` 初始化,并与 ``optimizer``
当前参数顺序保持一致。
:param model: 神经网络
:type model: nn.Module
:param optimizer: 网络使用的优化器
:type optimizer: torch.optim.Optimizer
:param x_seq: 输入序列
:type x_seq: torch.Tensor
:param target_seq: 目标序列
:type target_seq: torch.Tensor
:param f_loss_t: 单个时间步的损失函数,调用形式应为 ``f_loss_t(y_t, target_t) -> torch.Tensor``
:type f_loss_t: Callable
:param alpha: FPTT 使用的超参数
:type alpha: float
:param w_ra: 由 :func:`fptt_online_training_init_w_ra` 初始化的运行平均列表,
其中每个元素与一个优化器参数对应
:type w_ra: list[torch.Tensor]
:return: ``None``
:rtype: None
:raises IndexError: 若 ``target_seq`` 的时间长度小于 ``x_seq``,按时间步索引目标时会抛出异常
:raises Exception: 任何模型前向、损失计算、反向传播或优化器更新异常都会原样向上传播
----
.. _fptt_online_training-en:
* **English**
The FPTT online learning method proposed by `Training Recurrent Neural Networks via Forward Propagation Through Time <https://proceedings.mlr.press/v139/kag21a.html>`_ and used for SNN in `Accurate online training of dynamical spiking neural networks through Forward Propagation Through Time <https://arxiv.org/abs/2112.11231>`_ .
This function iterates over the time dimension ``x_seq.shape[0]`` and
performs forward, loss computation, parameter update, and ``detach_net`` at
every time step. It also stores and restores the internal states of
:class:`spikingjelly.activation_based.base.MemoryModule`.
The function expects both ``x_seq`` and ``target_seq`` to place the time
axis at dimension 0 and to share the same temporal length. ``w_ra`` should
be initialized by :func:`fptt_online_training_init_w_ra` and remain aligned
with the current parameter order of ``optimizer``.
:param model: the neural network
:type model: nn.Module
:param optimizer: the optimizer for the network
:type optimizer: torch.optim.Optimizer
:param x_seq: the input sequence
:type x_seq: torch.Tensor
:param target_seq: the target sequence
:type target_seq: torch.Tensor
:param f_loss_t: the loss function, which should have the formulation of
``def f_loss_t(y_t, target_t) -> torch.Tensor``
:type f_loss_t: Callable
:param alpha: the hyper-parameter
:type alpha: float
:param w_ra: the running-average list initialized by
:func:`fptt_online_training_init_w_ra`, where each element corresponds
to one optimizer parameter
:type w_ra: list[torch.Tensor]
:return: ``None``
:rtype: None
:raises IndexError: Raised when ``target_seq`` is shorter than ``x_seq`` along the time dimension
:raises Exception: Any exception raised during model forward, loss computation, backward pass, or optimizer update is propagated unchanged
----
* **代码示例 | Example**
.. code-block:: python
from spikingjelly.activation_based import neuron
net = nn.Sequential(
nn.Linear(8, 4), neuron.IFNode(), nn.Linear(4, 2), neuron.IFNode()
)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
T = 4
N = 2
w_ra = fptt_online_training_init_w_ra(optimizer)
for epoch in range(2):
x_seq = torch.rand([T, N, 8])
target_seq = torch.rand([T, N, 2])
fptt_online_training(
model=net,
optimizer=optimizer,
x_seq=x_seq,
target_seq=target_seq,
f_loss_t=F.mse_loss,
alpha=0.1,
w_ra=w_ra,
)
functional.reset_net(net)
"""
T = x_seq.shape[0]
grad__l_t_last__to__w_t = []
for item in optimizer.param_groups:
for w in item["params"]:
grad__l_t_last__to__w_t.append(0.0)
for t in range(T):
optimizer.zero_grad()
y_t = model(x_seq[t])
loss_t = f_loss_t(y_t, target_seq[t])
loss_reg = 0.0
i = 0
for item in optimizer.param_groups:
for w in item["params"]:
loss_reg = loss_reg + F.mse_loss(
w, w_ra[i] + grad__l_t_last__to__w_t[i] / (2.0 * alpha)
)
i += 1
loss_reg = loss_reg * (alpha / 2.0)
loss = loss_t + loss_reg
loss.backward()
# update params
optimizer.step()
detach_net(model)
# store hidden states
states = []
i = 0
for m in model.modules():
if isinstance(m, base.MemoryModule):
states.append(copy.deepcopy(m._memories))
i += 1
# update w_ra
optimizer.zero_grad()
if t < T - 1:
y_t = model(x_seq[t])
loss_t = f_loss_t(y_t, target_seq[t])
loss_t.backward()
with torch.no_grad():
i = 0
for item in optimizer.param_groups:
for w in item["params"]:
grad__l_t_last__to__w_t[i] = w.grad
w_ra[i] = (w_ra[i] + w) / 2.0 - w.grad / (2.0 * alpha)
i += 1
optimizer.zero_grad()
# recover hidden states
i = 0
for m in model.modules():
if isinstance(m, base.MemoryModule):
m._memories = states[i]
i += 1
[文档]
def ottt_online_training(
model: nn.Module,
optimizer: torch.optim.Optimizer,
x_seq: torch.Tensor,
target_seq: torch.Tensor,
f_loss_t: Callable,
online: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
**API Language:**
:ref:`中文 <ottt_online_training-cn>` | :ref:`English <ottt_online_training-en>`
----
.. _ottt_online_training-cn:
* **中文**
使用 OTTT 在线训练方法训练网络,也可用于文献中提到的 SLTT 训练。函数会先将
``x_seq`` 和 ``target_seq`` 从 ``[B, T, ...]`` 转置为 ``[T, B, ...]``,然后沿时间维逐步执行
前向与反向传播。若 ``online`` 为 ``True``,则每个时间步都会执行一次参数更新;否则先累积整段序列的梯度,
再在最后统一更新。
该函数要求 ``x_seq`` 与 ``target_seq`` 的前两维分别表示 batch 和 time,且
两者在这两维上的长度一致。
:param model: 神经网络
:type model: nn.Module
:param optimizer: 网络使用的优化器
:type optimizer: torch.optim.Optimizer
:param x_seq: 输入序列,形状为 ``[B, T, ...]``
:type x_seq: torch.Tensor
:param target_seq: 目标序列,形状为 ``[B, T, ...]``
:type target_seq: torch.Tensor
:param f_loss_t: 单个时间步的损失函数,调用形式应为 ``f_loss_t(y_t, target_t) -> torch.Tensor``
:type f_loss_t: Callable
:param online: 是否在每个时间步在线更新参数;若为 ``False``,则仅在整段序列结束后更新一次
:type online: bool
:return: ``(batch_loss, y_all)``,其中 ``batch_loss`` 是各时间步损失之和,
``y_all`` 是形状为 ``[B, T, ...]`` 的按时间堆叠且已 detach 的输出
:rtype: tuple[torch.Tensor, torch.Tensor]
:raises IndexError: 若 ``target_seq`` 与 ``x_seq`` 在时间维长度不一致,则按时间步索引时会抛出异常
:raises Exception: 任何模型前向、损失计算、反向传播或优化器更新异常都会原样向上传播
----
.. _ottt_online_training-en:
* **English**
The OTTT online training method is proposed by `Online Training Through Time for Spiking Neural Networks <https://openreview.net/forum?id=Siv3nHYHheI>`_.
This function can also be used for SLTT training method proposed by `Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks <https://openaccess.thecvf.com/content/ICCV2023/html/Meng_Towards_Memory-_and_Time-Efficient_Backpropagation_for_Training_Spiking_Neural_Networks_ICCV_2023_paper.html>`_ .
It first transposes ``x_seq`` and ``target_seq`` from ``[B, T, ...]`` to
``[T, B, ...]`` and then runs forward and backward passes step by step along
the time dimension. If ``online`` is ``True``, the optimizer updates
parameters at every time step; otherwise, gradients are accumulated through
the whole sequence and applied once at the end.
The function expects ``x_seq`` and ``target_seq`` to use batch and time as
the first two dimensions and to share the same sizes on those dimensions.
:param model: the neural network
:type model: nn.Module
:param optimizer: the optimizer for the network
:type optimizer: torch.optim.Optimizer
:param x_seq: the input sequence with ``shape=[B, T, ...]``
:type x_seq: torch.Tensor
:param target_seq: the target sequence with ``shape=[B, T, ...]``
:type target_seq: torch.Tensor
:param f_loss_t: the loss function, which should have the formulation of
``def f_loss_t(y_t, target_t) -> torch.Tensor``
:type f_loss_t: Callable
:param online: whether to update parameters online at each time step or to
accumulate gradients through time steps
:type online: bool
:return: ``(batch_loss, y_all)``, where ``batch_loss`` is the sum of per-step
losses and ``y_all`` is the detached stacked output with
``shape=[B, T, ...]``
:rtype: tuple[torch.Tensor, torch.Tensor]
:raises IndexError: Raised when ``target_seq`` and ``x_seq`` do not match on the time dimension
:raises Exception: Any exception raised during model forward, loss computation, backward pass, or optimizer update is propagated unchanged
----
* **代码示例 | Example**
.. code-block:: python
from spikingjelly.activation_based import neuron, layer, functional
net = layer.OTTTSequential(
nn.Linear(8, 4), neuron.OTTTLIFNode(), nn.Linear(4, 2), neuron.LIFNode()
)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
T = 4
N = 2
online = True
for epoch in range(2):
x_seq = torch.rand([N, T, 8])
target_seq = torch.rand([N, T, 2])
functional.ottt_online_training(
model=net,
optimizer=optimizer,
x_seq=x_seq,
target_seq=target_seq,
f_loss_t=F.mse_loss,
online=online,
)
functional.reset_net(net)
"""
# input x_seq/target_seq: [B, T, ...]
# transpose to [T, B, ...]
x_seq = x_seq.transpose(0, 1)
target_seq = target_seq.transpose(0, 1)
T = x_seq.shape[0]
batch_loss = 0.0
y_all = []
if not online:
optimizer.zero_grad()
for t in range(T):
if online:
optimizer.zero_grad()
y_t = model(x_seq[t])
loss = f_loss_t(y_t, target_seq[t].contiguous())
loss.backward()
# update params
if online:
optimizer.step()
batch_loss += loss.data
y_all.append(y_t.detach())
if not online:
optimizer.step()
# y_all: [B, T, ...]
y_all = torch.stack(y_all, dim=1)
return batch_loss, y_all