from typing import Callable, Union
import torch
import torch.nn as nn
from torch import Tensor
__all__ = [
"multi_step_forward",
"t_last_multi_step_forward",
"chunk_multi_step_forward",
"seq_to_ann_forward",
"t_last_seq_to_ann_forward",
]
[文档]
def multi_step_forward(
x_seq: Tensor,
single_step_module: Union[
nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable
],
):
"""
**API Language:**
:ref:`中文 <multi_step_forward-cn>` | :ref:`English <multi_step_forward-en>`
----
.. _multi_step_forward-cn:
* **中文**
在单步模块 ``single_step_module`` 上使用多步前向传播。函数内部将执行一个for循环,
执行 ``T`` 次单步前向传播。若 ``single_step_module`` 为多个模块,则每个时间步都会按顺序依次执行这些模块。
:param x_seq: ``shape=[T, batch_size, ...]`` 的输入tensor
:type x_seq: torch.Tensor
:param single_step_module: 一个或多个单步模块
:type single_step_module: Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]
:return: ``shape=[T, batch_size, ...]`` 的输出tensor
:rtype: torch.Tensor
:raises Exception: 任何底层模块在某个时间步前向传播时抛出的异常都会原样向上传播
----
.. _multi_step_forward-en:
* **English**
Applies multi-step forward on ``single_step_module``. The function runs a
for loop to execute single-step forward for ``T`` times. If
``single_step_module`` contains multiple modules, they are applied
sequentially at each time-step.
:param x_seq: the input tensor with ``shape=[T, batch_size, ...]``
:type x_seq: torch.Tensor
:param single_step_module: one or many single-step modules
:type single_step_module: Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]
:return: the output tensor with ``shape=[T, batch_size, ...]``
:rtype: torch.Tensor
:raises Exception: Any exception raised by an underlying module at any time step is propagated unchanged
"""
y_seq = []
if isinstance(single_step_module, (list, tuple, nn.Sequential)):
for t in range(x_seq.shape[0]):
x_seq_t = x_seq[t]
for m in single_step_module:
x_seq_t = m(x_seq_t)
y_seq.append(x_seq_t)
else:
for t in range(x_seq.shape[0]):
y_seq.append(single_step_module(x_seq[t]))
return torch.stack(y_seq)
[文档]
def t_last_multi_step_forward(
x_seq: Tensor,
single_step_module: Union[
nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable
],
):
"""
**API Language:**
:ref:`中文 <t_last_multi_step_forward-cn>` | :ref:`English <t_last_multi_step_forward-en>`
----
.. _t_last_multi_step_forward-cn:
* **中文**
在单步模块 ``single_step_module`` 上使用多步前向传播。
此函数适用于时间维位于最后一维的序列张量,即 ``shape=[batch_size, ..., T]``。
它会沿最后一维逐个时间步取出切片,并在每个时间步顺序执行单步模块。
:param x_seq: ``shape=[batch_size, ..., T]`` 的输入tensor
:type x_seq: Tensor
:param single_step_module: 一个或多个单步模块
:type single_step_module: Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]
:return: ``shape=[batch_size, ..., T]`` 的输出tensor
:rtype: torch.Tensor
:raises Exception: 任何底层模块在某个时间步前向传播时抛出的异常都会原样向上传播
----
.. _t_last_multi_step_forward-en:
* **English**
Apply multi-step forward on ``single_step_module``.
This helper is intended for sequence tensors whose time axis is the last
dimension, i.e. ``shape=[batch_size, ..., T]``. It slices along the last
dimension and applies the single-step module(s) at each time step.
:param x_seq: the input tensor with ``shape=[batch_size, ..., T]``
:type x_seq: torch.Tensor
:param single_step_module: one or many single-step modules
:type single_step_module: Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]
:return: the output tensor with ``shape=[batch_size, ..., T]``
:rtype: torch.Tensor
:raises Exception: Any exception raised by an underlying module at any time step is propagated unchanged
"""
y_seq = []
if isinstance(single_step_module, (list, tuple, nn.Sequential)):
for t in range(x_seq.shape[-1]):
x_seq_t = x_seq[..., t]
for m in single_step_module:
x_seq_t = m(x_seq_t)
y_seq.append(x_seq_t)
else:
for t in range(x_seq.shape[-1]):
y_seq.append(single_step_module(x_seq[..., t]))
return torch.stack(y_seq, dim=-1)
[文档]
def chunk_multi_step_forward(
split_size: int, x_seq: Tensor, multi_step_module: nn.Module
):
"""
**API Language:**
:ref:`中文 <chunk_multi_step_forward-cn>` | :ref:`English <chunk_multi_step_forward-en>`
----
.. _chunk_multi_step_forward-cn:
* **中文**
将 ``shape = [T, *]`` 的输入 ``x_seq`` 拆分成多个 ``shape = [split_size, *]`` 的小tensor(若 ``T % split_size != 0``,最后一个tensor的 ``shape[0]`` 会小于 ``split_size``),然后逐个输入到 ``multi_step_module`` 中,再沿着 ``dim=0`` 将输出重新拼接,因此输出的首维长度仍为 ``T``。
``chunk_multi_step_forward`` 可以在使用很大的 ``T`` 进行不带梯度的推理(例如ANN2SNN)时使用,能够减少内存消耗量。
:param split_size: 分割的尺寸
:type split_size: int
:param x_seq: 输入
:type x_seq: torch.Tensor
:param multi_step_module: 一个使用多步传播模式的网络
:type multi_step_module: torch.nn.Module
:return: 输出
:rtype: torch.Tensor
:raises Exception: 任何 ``multi_step_module`` 在某个分块上的前向传播异常都会原样向上传播
----
.. _chunk_multi_step_forward-en:
* **English**
Splits the input ``x_seq`` with ``shape = [T, *]`` to many tensor chunks with ``shape = [split_size, *]`` (if ``T % split_size != 0``,
``shape[0]`` of the last tensor chunk will be smaller than ``split_size``), and sends chunks to ``multi_step_module``,
then concatenates the outputs back along ``dim=0``, so the output keeps the original leading length ``T``.
``chunk_multi_step_forward`` can be used for inference with a large ``T`` (e.g., ANN2SNN) to reduce the memory consumption.
:param split_size: the split size
:type split_size: int
:param x_seq: the input tensor
:type x_seq: Tensor
:param multi_step_module: a network in multi-step mode
:type multi_step_module: nn.Module
:return: the output tensor
:rtype: Tensor
:raises Exception: Any exception raised by ``multi_step_module`` on a chunk is propagated unchanged
----
* **代码示例 | Example**
.. code-block:: python
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, layer, functional
net = nn.Sequential(
layer.Linear(8, 4),
neuron.IFNode(step_mode="m"),
layer.Linear(4, 2),
neuron.IFNode(step_mode="m"),
)
x_seq = torch.rand([1024, 8])
with torch.no_grad():
y_seq = functional.chunk_multi_step_forward(16, x_seq, net)
print(y_seq.shape)
# torch.Size([1024, 2])
"""
y_seq = []
for x in torch.split(x_seq, split_size):
y_seq.append(multi_step_module(x))
return torch.cat(y_seq, 0)
[文档]
def seq_to_ann_forward(
x_seq: Tensor,
stateless_module: Union[nn.Module, list, tuple, nn.Sequential, Callable],
):
"""
**API Language:**
:ref:`中文 <seq_to_ann_forward-cn>` | :ref:`English <seq_to_ann_forward-en>`
----
.. _seq_to_ann_forward-cn:
* **中文**
使用无状态层进行多步前向传播。输入 ``x_seq`` 的时间和批量维度将被展平,得到 ``[T*batch_size, ...]``
形状的张量;随后,输入到无状态层中;最后,将输出张量恢复到序列形式 ``[T, batch_size, ...]`` 。
:param x_seq: ``shape=[T, batch_size, ...]`` 的输入tensor
:type x_seq: torch.Tensor
:param stateless_module: 单个或多个无状态网络层
:type stateless_module: Union[torch.nn.Module, list, tuple, torch.nn.Sequential, Callable]
:return: ``shape=[T, batch_size, ...]`` 的输出tensor
:rtype: torch.Tensor
:raises Exception: 任何底层无状态模块在前向传播时抛出的异常都会原样向上传播
----
.. _seq_to_ann_forward-en:
* **English**
Applied forward on stateless modules. Flatten the time and batch dimensions
of ``x_seq`` so that ``shape=[T*batch_size, ...]``, feed the reshaped tensor
to the stateless module(s), and reshape the output back to the sequence form
``shape=[T, batch_size, ...]``.
:param x_seq: the input tensor with ``shape=[T, batch_size, ...]``
:type x_seq: torch.Tensor
:param stateless_module: one or many stateless modules
:type stateless_module: Union[torch.nn.Module, list, tuple, torch.nn.Sequential, Callable]
:return: the output tensor with ``shape=[T, batch_size, ...]``
:rtype: torch.Tensor
:raises Exception: Any exception raised by an underlying stateless module is propagated unchanged
"""
y_shape = [x_seq.shape[0], x_seq.shape[1]]
y = x_seq.flatten(0, 1)
if isinstance(stateless_module, (list, tuple, nn.Sequential)):
for m in stateless_module:
y = m(y)
else:
y = stateless_module(y)
y_shape.extend(y.shape[1:])
return y.view(y_shape)
[文档]
def t_last_seq_to_ann_forward(
x_seq: Tensor,
stateless_module: Union[nn.Module, list, tuple, nn.Sequential, Callable],
):
"""
**API Language:**
:ref:`中文 <t_last_seq_to_ann_forward-cn>` | :ref:`English <t_last_seq_to_ann_forward-en>`
----
.. _t_last_seq_to_ann_forward-cn:
* **中文**
使用无状态层进行多步前向传播。
.. note::
SpikingJelly中默认序列数据形状为 ``shape=[T, batch_size, ...]``。
但此函数是用于另一种格式,即 ``shape=[batch_size, ..., T]``。
当 ``torch.vmap`` 可用时,此函数会直接调用
``torch.vmap(stateless_module, in_dims=-1, out_dims=-1)`` 并行地对时间维执行前向传播。
因此此路径要求 ``stateless_module`` 是可直接调用的对象,例如 ``nn.Module`` 或
``nn.Sequential``。普通的 ``list`` 或 ``tuple`` 仅在 ``torch.vmap`` 不可用、退化到
:func:`t_last_multi_step_forward` 时才可用。
.. note::
不能用于BN层,因为BN层的running mean/var是输入依赖的。
对于BN层,只需要输入被当作是 ``shape = [N, C, ..]`` 即可并行计算,需要用户手动实现。
:param x_seq: ``shape=[batch_size, ..., T]`` 的输入tensor
:type x_seq: torch.Tensor
:param stateless_module: 单个或多个无状态网络层
:type stateless_module: Union[torch.nn.Module, list, tuple, torch.nn.Sequential, Callable]
:return: ``shape=[batch_size, ..., T]`` 的输出tensor
:rtype: torch.Tensor
:raises TypeError: 当 ``torch.vmap`` 可用但 ``stateless_module`` 不是可直接调用对象时,``torch.vmap`` 路径可能抛出类型错误
:raises Exception: 任何底层无状态模块在 ``vmap`` 或 fallback 路径中抛出的异常都会原样向上传播
----
.. _t_last_seq_to_ann_forward-en:
* **English**
Applied forward on stateless modules.
.. admonition:: Note
:class: note
The default shape of sequence data in SpikingJelly is
``shape=[T, batch_size, ...]``. However, this function is used for the
other data format where ``shape=[batch_size, ..., T]``. When
``torch.vmap`` is available, this function calls
``torch.vmap(stateless_module, in_dims=-1, out_dims=-1)`` to apply the
forward pass over the time dimension in parallel. Therefore, this path
requires ``stateless_module`` to be directly callable, e.g.
``nn.Module`` or ``nn.Sequential``. Plain ``list`` and ``tuple`` inputs
only work on the fallback path that uses
:func:`t_last_multi_step_forward` when ``torch.vmap`` is unavailable.
.. admonition:: Note
:class: note
This function can not be applied to wrap BN because its running mean/var
depends on inputs. The BN can be computed in parallel as long as the
input is regarded as ``shape = [N, C, ..]``, which can be implemented
by user manually.
:param x_seq: the input tensor with ``shape=[batch_size, ..., T]``
:type x_seq: torch.Tensor
:param stateless_module: one or many stateless modules
:type stateless_module: Union[torch.nn.Module, list, tuple, torch.nn.Sequential, Callable]
:return: the output tensor with ``shape=[batch_size, ..., T]``
:rtype: torch.Tensor
:raises TypeError: When ``torch.vmap`` is available, the vmap path may raise a type error if ``stateless_module`` is not directly callable
:raises Exception: Any exception raised by an underlying stateless module on either the vmap or fallback path is propagated unchanged
"""
if hasattr(torch, "vmap"):
vmap_f = torch.vmap(stateless_module, in_dims=-1, out_dims=-1)
return vmap_f(x_seq)
else:
return t_last_multi_step_forward(x_seq, stateless_module)