spikingjelly.activation_based.layer.container 源代码

import logging
from typing import Callable

import torch
import torch.nn as nn
from torch import Tensor

from .. import base, functional


__all__ = [
    "MultiStepContainer",
    "SeqToANNContainer",
    "TLastMultiStepContainer",
    "TLastSeqToANNContainer",
    "StepModeContainer",
    "ElementWiseRecurrentContainer",
    "LinearRecurrentContainer",
]


def _check_step_mode(block: nn.Sequential, caller: str = "MultiStepContainer"):
    for m in block:
        assert not hasattr(m, "step_mode") or m.step_mode == "s"
        if isinstance(m, base.StepModule):
            if "m" in m.supported_step_mode():
                logging.warning(
                    f"{m} supports for step_mode == 'm', "
                    f"which should not be contained by {caller}!"
                )


[文档] class MultiStepContainer(nn.Sequential, base.MultiStepModule): def __init__(self, *args): r""" **API Language:** :ref:`中文 <MultiStepContainer.__init__-cn>` | :ref:`English <MultiStepContainer.__init__-en>` ---- .. _MultiStepContainer.__init__-cn: * **中文** * **中文** :func:`spikingjelly.activation_based.functional.multi_step_forward` 的容器。构造方式与 `torch.nn.Sequential` 一致。 ---- .. _MultiStepContainer.__init__-en: * **English** * **English** Container of :func:`spikingjelly.activation_based.functional.multi_step_forward`. Its constructor signature is the same as `torch.nn.Sequential`. :return: None :rtype: None """ super().__init__(*args) _check_step_mode(self, "MultiStepContainer")
[文档] def forward(self, x_seq: Tensor): """ :param x_seq: with shape ``[T, batch_size, ...]`` :type x_seq: torch.Tensor :return: y_seq with shape ``[T, batch_size, ...]`` :rtype: torch.Tensor """ return functional.multi_step_forward(x_seq, super().forward)
[文档] class SeqToANNContainer(nn.Sequential, base.MultiStepModule): def __init__(self, *args): """ **API Language:** :ref:`中文 <SeqToANNContainer-cn>` | :ref:`English <SeqToANNContainer-en>` ---- .. _SeqToANNContainer-cn: * **中文** * **中文** :func:`spikingjelly.activation_based.functional.seq_to_ann_forward` 的容器。构造方式与 `torch.nn.Sequential` 一致。 ---- .. _SeqToANNContainer-en: * **English** * **English** Container of :func:`spikingjelly.activation_based.functional.seq_to_ann_forward`. Its constructor signature is the same as `torch.nn.Sequential`. :return: None :rtype: None """ super().__init__(*args) _check_step_mode(self, "SeqToANNContainer")
[文档] def forward(self, x_seq: Tensor): """ :param x_seq: with shape ``[T, batch_size, ...]`` :type x_seq: torch.Tensor :return: y_seq with shape ``[T, batch_size, ...]`` :rtype: torch.Tensor """ return functional.seq_to_ann_forward(x_seq, super().forward)
[文档] class TLastMultiStepContainer(nn.Sequential, base.MultiStepModule): def __init__(self, *args): """ See :func:`spikingjelly.activation_based.functional.forward.t_last_multi_step_forward` . :return: None :rtype: None """ super().__init__(*args) _check_step_mode(self, "TLastMultiStepContainer")
[文档] def forward(self, x_seq: Tensor): """ :param x_seq: shape ``[batch_size, ..., T]`` :type x_seq: Tensor :return: y_seq with shape ``[batch_size, ..., T]`` :rtype: Tensor """ return functional.t_last_seq_to_ann_forward(x_seq, super().forward)
[文档] class TLastSeqToANNContainer(nn.Sequential, base.MultiStepModule): def __init__(self, *args): """ See :func:`spikingjelly.activation_based.functional.forward.t_last_seq_to_ann_forward` . :return: None :rtype: None """ super().__init__(*args) _check_step_mode(self, "TLastSeqToANNContainer")
[文档] def forward(self, x_seq: Tensor): """ :param x_seq: with shape ``[batch_size, ..., T]`` :type x_seq: Tensor :return: y_seq with shape ``[batch_size, ..., T]`` :rtype: Tensor """ return functional.t_last_seq_to_ann_forward(x_seq, super().forward)
[文档] class StepModeContainer(nn.Sequential, base.StepModule): def __init__(self, stateful: bool, step_mode: str = "s", *args): """ Call single-step forward, multi-step forward or seq-to-ANN forward according to ``stateful`` and ``step_mode``. :param stateful: 是否是有状态的容器 :type stateful: bool :param step_mode: 步进模式,``\"s\"`` 或 ``\"m\"`` :type step_mode: str :param args: 与 ``torch.nn.Sequential`` 相同的构造参数 :param stateful: Whether the container is stateful :type stateful: bool :param step_mode: Step mode, ``\"s\"`` or ``\"m\"`` :type step_mode: str :param args: Same constructor arguments as ``torch.nn.Sequential`` :type args: tuple :return: None :rtype: None """ super().__init__(*args) self.stateful = stateful _check_step_mode(self, "StepModeContainer") self.step_mode = step_mode
[文档] def forward(self, x: torch.Tensor): if self.step_mode == "s": return super().forward(x) elif self.step_mode == "m": if self.stateful: return functional.multi_step_forward(x, super().forward) else: return functional.seq_to_ann_forward(x, super().forward)
[文档] class ElementWiseRecurrentContainer(base.MemoryModule): def __init__( self, sub_module: nn.Module, element_wise_function: Callable, step_mode="s" ): r""" **API Language:** :ref:`中文 <ElementWiseRecurrentContainer.__init__-cn>` | :ref:`English <ElementWiseRecurrentContainer.__init__-en>` ---- .. _ElementWiseRecurrentContainer.__init__-cn: * **中文** 使用逐元素运算的自连接包装器。记 ``sub_module`` 的输入输出为 :math:`i[t]` 和 :math:`y[t]` (注意 :math:`y[t]` 也是整个模块的输出), 整个模块的输入为 :math:`x[t]`,则 .. math:: i[t] = f(x[t], y[t-1]) 其中 :math:`f` 是用户自定义的逐元素函数。我们默认 :math:`y[-1] = 0`。 .. Note:: ``sub_module`` 输入和输出的尺寸需要相同。 :param sub_module: 被包含的模块 :type sub_module: torch.nn.Module :param element_wise_function: 用户自定义的逐元素函数,应该形如 ``z=f(x, y)`` :type element_wise_function: Callable :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str ---- .. _ElementWiseRecurrentContainer.__init__-en: * **English** A container that use a element-wise recurrent connection. Denote the inputs and outputs of ``sub_module`` as :math:`i[t]` and :math:`y[t]` (Note that :math:`y[t]` is also the outputs of this module), and the inputs of this module as :math:`x[t]`, then .. math:: i[t] = f(x[t], y[t-1]) where :math:`f` is the user-defined element-wise function. We set :math:`y[-1] = 0`. .. admonition:: Note :class: note The shape of inputs and outputs of ``sub_module`` must be the same. :param sub_module: the contained module :type sub_module: torch.nn.Module :param element_wise_function: the user-defined element-wise function, which should have the format ``z=f(x, y)`` :type element_wise_function: Callable :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str ---- * **代码示例 | Example** .. code-block:: python T = 8 net = ElementWiseRecurrentContainer( neuron.IFNode(v_reset=None), element_wise_function=lambda x, y: x + y ) print(net) x = torch.zeros([T]) x[0] = 1.5 for t in range(T): print(t, f"x[t]={x[t]}, s[t]={net(x[t])}") functional.reset_net(net) :return: None :rtype: None """ super().__init__() self.step_mode = step_mode assert not hasattr(sub_module, "step_mode") or sub_module.step_mode == "s" self.sub_module = sub_module self.element_wise_function = element_wise_function self.register_memory("y", None)
[文档] def single_step_forward(self, x: Tensor): if self.y is None: self.y = torch.zeros_like(x.data) self.y = self.sub_module(self.element_wise_function(self.y, x)) return self.y
def extra_repr(self) -> str: return f"element-wise function={self.element_wise_function}, step_mode={self.step_mode}"
[文档] class LinearRecurrentContainer(base.MemoryModule): def __init__( self, sub_module: nn.Module, in_features: int, out_features: int, bias: bool = True, step_mode="s", ) -> None: r""" **API Language:** :ref:`中文 <LinearRecurrentContainer.__init__-cn>` | :ref:`English <LinearRecurrentContainer.__init__-en>` ---- .. _LinearRecurrentContainer.__init__-cn: * **中文** 使用线性层的自连接包装器。记 ``sub_module`` 的输入和输出为 :math:`i[t]` 和 :math:`y[t]` (注意 :math:`y[t]` 也是整个模块的输出), 整个模块的输入记作 :math:`x[t]` ,则 .. math:: i[t] = \begin{pmatrix} x[t] \\ y[t-1]\end{pmatrix} W^{T} + b 其中 :math:`W, b` 是线性层的权重和偏置项。默认 :math:`y[-1] = 0`。 :math:`x[t]` 应该 ``shape = [N, *, in_features]``,:math:`y[t]` 则应该 ``shape = [N, *, out_features]``。 .. Note:: 自连接是由 ``torch.nn.Linear(in_features + out_features, in_features, bias)`` 实现的。 :param sub_module: 被包含的模块 :type sub_module: torch.nn.Module :param in_features: 输入的特征数量 :type in_features: int :param out_features: 输出的特征数量 :type out_features: int :param bias: 若为 ``False``,则线性自连接不会带有可学习的偏执项 :type bias: bool :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str ---- .. _LinearRecurrentContainer.__init__-en: * **English** A container that use a linear recurrent connection. Denote the inputs and outputs of ``sub_module`` as :math:`i[t]` and :math:`y[t]` (Note that :math:`y[t]` is also the outputs of this module), and the inputs of this module as :math:`x[t]`, then .. math:: i[t] = \begin{pmatrix} x[t] \\ y[t-1]\end{pmatrix} W^{T} + b where :math:`W, b` are the weight and bias of the linear connection. We set :math:`y[-1] = 0`. :math:`x[t]` should have the shape ``[N, *, in_features]``, and :math:`y[t]` has the shape ``[N, *, out_features]``. .. admonition:: Note :class: note The recurrent connection is implement by ``torch.nn.Linear(in_features + out_features, in_features, bias)``. :param sub_module: the contained module :type sub_module: torch.nn.Module :param in_features: size of each input sample :type in_features: int :param out_features: size of each output sample :type out_features: int :param bias: If set to ``False``, the linear recurrent layer will not learn an additive bias :type bias: bool :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str ---- * **代码示例 | Example** .. code-block:: python in_features = 4 out_features = 2 T = 8 N = 2 net = LinearRecurrentContainer( nn.Sequential( nn.Linear(in_features, out_features), neuron.LIFNode(), ), in_features, out_features, ) print(net) x = torch.rand([T, N, in_features]) for t in range(T): print(t, net(x[t])) functional.reset_net(net) :return: None :rtype: None """ super().__init__() self.step_mode = step_mode assert not hasattr(sub_module, "step_mode") or sub_module.step_mode == "s" self.sub_module_out_features = out_features self.rc = nn.Linear(in_features + out_features, in_features, bias) self.sub_module = sub_module self.register_memory("y", None)
[文档] def single_step_forward(self, x: Tensor): if self.y is None: if x.ndim == 2: self.y = torch.zeros([x.shape[0], self.sub_module_out_features]).to(x) else: out_shape = [x.shape[0]] out_shape.extend(x.shape[1:-1]) out_shape.append(self.sub_module_out_features) self.y = torch.zeros(out_shape).to(x) x = torch.cat((x, self.y), dim=-1) self.y = self.sub_module(self.rc(x)) return self.y
def extra_repr(self) -> str: return f", step_mode={self.step_mode}"