spikingjelly.activation_based.layer.attention 源代码

import torch
import torch.nn as nn
from einops import rearrange

from .. import base, neuron
from .container import SeqToANNContainer


__all__ = [
    "TemporalWiseAttention",
    "MultiDimensionalAttention",
    "SpikingSelfAttention",
    "QKAttention",
    "TokenQKAttention",
    "ChannelQKAttention",
]


[文档] class TemporalWiseAttention(nn.Module, base.MultiStepModule): def __init__(self, T: int, reduction: int = 16, dimension: int = 4): """ **API Language:** :ref:`中文 <TemporalWiseAttention.__init__-cn>` | :ref:`English <TemporalWiseAttention.__init__-en>` ---- .. _TemporalWiseAttention.__init__-cn: * **中文** `Temporal-Wise Attention Spiking Neural Networks for Event Streams Classification <https://openaccess.thecvf.com/content/ICCV2021/html/Yao_Temporal-Wise_Attention_Spiking_Neural_Networks_for_Event_Streams_Classification_ICCV_2021_paper.html>`_ 中提出 的TemporalWiseAttention层。TemporalWiseAttention层必须放在二维卷积层之后脉冲神经元之前,例如: ``Conv2d -> TemporalWiseAttention -> LIF`` 输入的尺寸是 ``[T, N, C, H, W]`` 或者 ``[T, N, L]`` ,经过TemporalWiseAttention层,输出为 ``[T, N, C, H, W]`` 或者 ``[T, N, L]`` 。 ``reduction`` 是压缩比,相当于论文中的 :math:`r`。 :param T: 输入数据的时间步长 :type T: int :param reduction: 压缩比 :type reduction: int :param dimension: 输入数据的维度。当输入数据为[T, N, C, H, W]时, dimension = 4;输入数据维度为[T, N, L]时,dimension = 2。 :type dimension: int ---- .. _TemporalWiseAttention.__init__-en: * **English** The TemporalWiseAttention layer is proposed in `Temporal-Wise Attention Spiking Neural Networks for Event Streams Classification <https://openaccess.thecvf.com/content/ICCV2021/html/Yao_Temporal-Wise_Attention_Spiking_Neural_Networks_for_Event_Streams_Classification_ICCV_2021_paper.html>`_. It should be placed after the convolution layer and before the spiking neurons, e.g., ``Conv2d -> TemporalWiseAttention -> LIF`` The dimension of the input is ``[T, N, C, H, W]`` or ``[T, N, L]`` , after the TemporalWiseAttention layer, the output dimension is ``[T, N, C, H, W]`` or ``[T, N, L]`` . ``reduction`` is the reduction ratio,which is :math:`r` in the paper. :param T: timewindows of input :type T: int :param reduction: reduction ratio :type reduction: int :param dimension: Dimensions of input. If the input dimension is [T, N, C, H, W], dimension = 4; when the input dimension is [T, N, L], dimension = 2. :type dimension: int :return: None :rtype: None """ super().__init__() self.step_mode = "m" assert dimension == 4 or dimension == 2, "dimension must be 4 or 2" self.dimension = dimension # Sequence if self.dimension == 2: self.avg_pool = nn.AdaptiveAvgPool1d(1) self.max_pool = nn.AdaptiveMaxPool1d(1) elif self.dimension == 4: self.avg_pool = nn.AdaptiveAvgPool3d(1) self.max_pool = nn.AdaptiveMaxPool3d(1) assert T >= reduction, "reduction cannot be greater than T" # Excitation self.sharedMLP = nn.Sequential( nn.Linear(T, T // reduction, bias=False), nn.ReLU(), nn.Linear(T // reduction, T, bias=False), ) self.sigmoid = nn.Sigmoid()
[文档] def forward(self, x_seq: torch.Tensor): assert x_seq.dim() == 3 or x_seq.dim() == 5, ValueError( f"expected 3D or 5D input with shape [T, N, M] or [T, N, C, H, W], but got input with shape {x_seq.shape}" ) x_seq = x_seq.transpose(0, 1) avgout = self.sharedMLP( self.avg_pool(x_seq).view([x_seq.shape[0], x_seq.shape[1]]) ) maxout = self.sharedMLP( self.max_pool(x_seq).view([x_seq.shape[0], x_seq.shape[1]]) ) scores = self.sigmoid(avgout + maxout) if self.dimension == 2: y_seq = x_seq * scores[:, :, None] elif self.dimension == 4: y_seq = x_seq * scores[:, :, None, None, None] y_seq = y_seq.transpose(0, 1) return y_seq
[文档] class MultiDimensionalAttention(nn.Module, base.MultiStepModule): def __init__( self, T: int, C: int, reduction_t: int = 16, reduction_c: int = 16, kernel_size=3, ): """ **API Language:** :ref:`中文 <MultiStepMultiDimensionalAttention.__init__-cn>` | :ref:`English <MultiStepMultiDimensionalAttention.__init__-en>` ---- .. _MultiStepMultiDimensionalAttention.__init__-cn: * **中文** `Attention Spiking Neural Networks <https://ieeexplore.ieee.org/document/10032591>`_ 中提出 的MA-SNN模型以及MultiStepMultiDimensionalAttention层。 您可以从以下链接中找到MA-SNN的示例项目: - https://github.com/MA-SNN/MA-SNN - https://github.com/ridgerchu/SNN_Attention_VGG 输入的尺寸是 ``[T, N, C, H, W]`` ,经过MultiStepMultiDimensionalAttention层,输出为 ``[T, N, C, H, W]`` 。 :param T: 输入数据的时间步长 :type T: int :param C: 输入数据的通道数 :type C: int :param reduction_t: 时间压缩比 :type reduction_t: int :param reduction_c: 通道压缩比 :type reduction_c: int :param kernel_size: 空间注意力机制的卷积核大小 :type kernel_size: int ---- .. _MultiStepMultiDimensionalAttention.__init__-en: * **English** The MA-SNN model and MultiStepMultiDimensionalAttention layer are proposed in `Attention Spiking Neural Networks <https://ieeexplore.ieee.org/document/10032591>`_. You can find the example projects of MA-SNN in the following links: - https://github.com/MA-SNN/MA-SNN - https://github.com/ridgerchu/SNN_Attention_VGG The dimension of the input is ``[T, N, C, H, W]`` , after the MultiStepMultiDimensionalAttention layer, the output dimension is ``[T, N, C, H, W]`` . :param T: timewindows of input :type T: int :param C: channel number of input :type C: int :param reduction_t: temporal reduction ratio :type reduction_t: int :param reduction_c: channel reduction ratio :type reduction_c: int :param kernel_size: convolution kernel size of SpatialAttention :type kernel_size: int :return: None :rtype: None """ super().__init__() assert T >= reduction_t, "reduction_t cannot be greater than T" assert C >= reduction_c, "reduction_c cannot be greater than C" # Attention class TimeAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(TimeAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool3d(1) self.max_pool = nn.AdaptiveMaxPool3d(1) self.sharedMLP = nn.Sequential( nn.Conv3d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(), nn.Conv3d(in_planes // ratio, in_planes, 1, bias=False), ) self.sigmoid = nn.Sigmoid() def forward(self, x): avgout = self.sharedMLP(self.avg_pool(x)) maxout = self.sharedMLP(self.max_pool(x)) return self.sigmoid(avgout + maxout) class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool3d(1) self.max_pool = nn.AdaptiveMaxPool3d(1) self.sharedMLP = nn.Sequential( nn.Conv3d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(), nn.Conv3d(in_planes // ratio, in_planes, 1, bias=False), ) self.sigmoid = nn.Sigmoid() def forward(self, x): x = rearrange(x, "b f c h w -> b c f h w") avgout = self.sharedMLP(self.avg_pool(x)) maxout = self.sharedMLP(self.max_pool(x)) out = self.sigmoid(avgout + maxout) out = rearrange(out, "b c f h w -> b f c h w") return out class SpatialAttention(nn.Module): def __init__(self, kernel_size=3): super(SpatialAttention, self).__init__() assert kernel_size in (3, 7), "kernel size must be 3 or 7" padding = 3 if kernel_size == 7 else 1 self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): x = rearrange(x, "b f c h w -> b (f c) h w") avgout = torch.mean(x, dim=1, keepdim=True) maxout, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avgout, maxout], dim=1) x = self.conv(x) x = x.unsqueeze(1) return self.sigmoid(x) self.ta = TimeAttention(T, reduction_t) self.ca = ChannelAttention(C, reduction_c) self.sa = SpatialAttention(kernel_size) self.sigmoid = nn.Sigmoid() self.relu = nn.ReLU()
[文档] def forward(self, x: torch.Tensor): assert x.dim() == 5, ValueError( f"expected 5D input with shape [T, N, C, H, W], but got input with shape {x.shape}" ) x = x.transpose(0, 1) out = self.ta(x) * x out = self.ca(out) * out out = self.sa(out) * out out = self.relu(out) out = out.transpose(0, 1) return out
[文档] class SpikingSelfAttention(nn.Module, base.MultiStepModule): def __init__(self, dim, num_heads=8, backend: str = "torch"): """ **API Language:** :ref:`中文 <SpikingSelfAttention.__init__-cn>` | :ref:`English <SpikingSelfAttention.__init__-en>` ---- .. _SpikingSelfAttention.__init__-cn: * **中文** `Spikformer: When Spiking Neural Network Meets Transformer <https://openreview.net/forum?id=frE4fUwz_h>`_ 中提出的 Spiking Self Attention 层。本模块在 `Spikformer源代码 <https://github.com/ZK-Zhou/spikformer/blob/main/imagenet/model.py>`_ 的基础上做了改进,显著提高了运行效率。关于 Spikformer 和本模块实现方式的更多信息, 参见教程 :doc:`../tutorials/cn/spikformer` 。 本模块的输入是尺寸为 ``[T, N, C, L]`` 的脉冲张量,其中 ``T`` 是时间步数, ``N`` 是 batch size ,``C`` 是 channel 数量,``L`` 是 token 数量 (对于视觉任务, ``L=H*W`` )。 输出是尺寸为 ``[T, N, C, L]`` 的脉冲张量。 :param dim: channel 数量 :type dim: int :param num_heads: 多头自注意力的头数量,默认为 ``8`` :type num_heads: int :param backend: 本模块内部神经元使用的后端,默认为 ``torch`` :type backend: str ---- .. _SpikingSelfAttention.__init__-en: * **English** Spiking Self-Attention layer proposed in `Spikformer: When Spiking Neural Network Meets Transformer <https://openreview.net/forum?id=frE4fUwz_h>`_. This module is implemented based on `Spikformer source code <https://github.com/ZK-Zhou/spikformer/blob/main/imagenet/model.py>`_ with several improvements that significantly enhance efficiency. For more details about Spikformer and the implementation of this module, please refer to the tutorial :doc:`../tutorials/en/spikformer`. The input to this module is a spike tensor of shape ``[T, N, C, L]``, where ``T`` denotes the number of time steps, ``N`` is the batch size, ``C`` is the number of channels, and ``L`` is the number of tokens (for vision tasks, ``L = H * W``). The output is a spiking tensor with the same shape ``[T, N, C, L]``. :param dim: number of channels :type dim: int :param num_heads: number of heads in multi-head self-attention. Default: ``8`` :type num_heads: int :param backend: backend used by the internal neurons of this module. Default: ``torch`` :type backend: str :return: None :rtype: None """ super().__init__() if dim % num_heads != 0: raise ValueError(f"dim {dim} should be divided by num_heads {num_heads}.") self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = 0.125 self._backend = backend self.qkv_conv_bn = SeqToANNContainer( nn.Conv1d(dim, dim * 3, kernel_size=1, stride=1, bias=False), nn.BatchNorm1d(dim * 3), ) self.qkv_lif = neuron.LIFNode( tau=2.0, detach_reset=True, step_mode="m", backend=backend ) self.attn_lif = neuron.LIFNode( tau=2.0, v_threshold=0.5, detach_reset=True, step_mode="m", backend=backend ) self.proj_conv_bn = SeqToANNContainer( nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False), nn.BatchNorm1d(dim), ) self.proj_lif = neuron.LIFNode( tau=2.0, detach_reset=True, step_mode="m", backend=backend ) @property def backend(self): """ 一旦设置,本模块中所有神经元的后端都会被同样地设置。 Once set, the backend of all the neurons in this module will also be changed. """ return self._backend @backend.setter def backend(self, value: str): self._backend = value self.qkv_lif.backend = value self.attn_lif.backend = value self.proj_lif.backend = value @staticmethod def _ssa_kernel_torch(qkv, scale): # TODO: add triton implementation # qkv.shape = [T, N, 3, NUM_HEADS, Cph, L] # qt, kt, vt.shape = [T, N, NUM_HEADS, Cph, L] qt, kt, vt = qkv.flatten(2, 3).chunk(3, dim=2) x_seq = vt @ kt.transpose(-2, -1) x_seq = (x_seq @ qt) * scale return x_seq # [T, N, NUM_HEADS, Cph, L]
[文档] def forward(self, x_seq: torch.Tensor): """ :param x_seq: ``shape=[T, N, C, L]`` :type x_seq: torch.Tensor :return: ``shape=[T, N, C, L]`` :rtype: torch.Tensor """ if x_seq.ndim != 4: raise ValueError( f"expected 4D input with shape [T, N, C, L], " f"but got input with shape {x_seq.shape}" ) T, N, C, L = x_seq.shape qkv = self.qkv_conv_bn(x_seq) qkv = self.qkv_lif(qkv) # [T, N, 3*C, L] local_dim = qkv.shape[2] // 3 if local_dim % self.num_heads != 0: raise ValueError( f"local qkv dim {local_dim} is not divisible by num_heads={self.num_heads}." ) qkv = qkv.reshape(T, N, 3, self.num_heads, local_dim // self.num_heads, L) x_seq = self._ssa_kernel_torch(qkv, self.scale) x_seq = self.attn_lif(x_seq).reshape(T, N, local_dim, L) x_seq = self.proj_conv_bn(x_seq) x_seq = self.proj_lif(x_seq) # [T, N, C, L] return x_seq
def extra_repr(self): return f"dim={self.dim}, num_heads={self.num_heads}, backend={self.backend}"
[文档] class QKAttention(nn.Module, base.MultiStepModule): def __init__( self, dim: int, num_heads: int = 8, qka_type: str = "token", backend: str = "torch", ): """ **API Language:** :ref:`中文 <QKAttention.__init__-cn>` | :ref:`English <QKAttention.__init__-en>` ---- .. _QKAttention.__init__-cn: * **中文** `QKFormer: Hierarchical Spiking Transformer using Q-K Attention <https://openreview.net/forum?id=AVd7DpiooC>`_ 中提出的 Q-K Attention 层。本模块在 `QKFormer源代码 <https://github.com/zhouchenlin2096/QKFormer/blob/master/imagenet/qkformer.py>`_ 的基础上做了改进,显著提高了运行效率;改进思路与 Spikformer 类似,见教程 :doc:`../tutorials/cn/spikformer` 。 本模块的输入是尺寸为 ``[T, N, C, L]`` 的脉冲张量,其中 ``T`` 是时间步数, ``N`` 是 batch size ,``C`` 是 channel 数量,``L`` 是 token 数量 (对于视觉任务, ``L=H*W`` )。 输出是尺寸为 ``[T, N, C, L]`` 的脉冲张量。 :param dim: channel 数量 :type dim: int :param num_heads: 多头自注意力的头数量,默认为 ``8`` :type num_heads: int :param qka_type: QKAttention的类型,可选值为 ``token`` 和 ``channel``。默认为 ``token``,生成逐token的掩码 :type qka_type: str :param backend: 本模块内部神经元使用的后端,默认为 ``torch`` :type backend: str ---- .. _QKAttention.__init__-en: * **English** Q-K Attention layer proposed in `QKFormer: Hierarchical Spiking Transformer using Q-K Attention <https://openreview.net/forum?id=AVd7DpiooC>`_. This module is implemented based on the `QKFormer source code <https://github.com/zhouchenlin2096/QKFormer/blob/master/imagenet/qkformer.py>`_, with several improvements that significantly enhance efficiency. The improvement strategy is similar to that used in Spikformer; see the tutorial :doc:`../tutorials/en/spikformer` for details. The input to this module is a spike tensor of shape ``[T, N, C, L]``, where ``T`` denotes the number of time steps, ``N`` is the batch size, ``C`` is the number of channels, and ``L`` is the number of tokens (for vision tasks, ``L = H * W``). The output is a spiking tensor with the same shape ``[T, N, C, L]``. :param dim: number of channels. :type dim: int :param num_heads: number of heads in multi-head self-attention. Default: ``8``. :type num_heads: int :param qka_type: type of QKAttention. Available options are ``token`` and ``channel``. The default is ``token``, which generates a token-wise mask. :type qka_type: str :param backend: backend used by the internal neurons of this module. Default: ``torch``. :type backend: str :return: None :rtype: None """ super().__init__() if dim % num_heads != 0: raise ValueError(f"dim {dim} should be divided by num_heads {num_heads}.") if qka_type not in ["token", "channel"]: raise ValueError( f"qka_type should be either 'token' or 'channel', but got {qka_type}." ) self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self._qka_type = qka_type self._backend = backend self.qk_conv_bn = SeqToANNContainer( nn.Conv1d(dim, dim * 2, kernel_size=1, stride=1, bias=False), nn.BatchNorm1d(dim * 2), ) self.qk_lif = neuron.LIFNode( tau=2.0, detach_reset=True, step_mode="m", backend=backend ) self.sum_dim = 3 if qka_type == "token" else 4 self.attn_lif = neuron.LIFNode( tau=2.0, v_threshold=0.5, detach_reset=True, step_mode="m", backend=backend ) self.proj_conv_bn = SeqToANNContainer( nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False), nn.BatchNorm1d(dim), ) self.proj_lif = neuron.LIFNode( tau=2.0, detach_reset=True, step_mode="m", backend=backend ) @property def backend(self): """ 一旦设置,本模块中所有神经元的后端都会被同样地设置。 Once set, the backend of all the neurons in this module will also be changed. """ return self._backend @backend.setter def backend(self, value: str): self._backend = value self.qk_lif.backend = value self.attn_lif.backend = value self.proj_lif.backend = value @property def qka_type(self): """ 只读。构造时设置,随后不可修改。 Read-only. Set when constructing, and cannot be modified afterwards. """ return self._qka_type def _qka_forward_torch(self, qk): # qk.shape = [T, N, 2, NUM_HEADS, Cph, L] # q, k = [T, N, NUM_HEADS, Cph, L] q, k = qk.flatten(2, 3).chunk(2, dim=2) q = torch.sum(q, dim=self.sum_dim, keepdim=True) # [T, N, NUM_HEADS, 1, L] if qka_type == "token" # [T, N, NUM_HEADS, Cph, 1] if qka_type == "channel" attn = self.attn_lif(q) x_seq = attn * k return x_seq # [T, N, NUM_HEADS, Cph, L]
[文档] def forward(self, x_seq): """ :param x_seq: ``shape=[T, N, C, L]`` :type x_seq: torch.Tensor :return: ``shape=[T, N, C, L]`` :rtype: torch.Tensor """ if x_seq.ndim != 4: raise ValueError( f"expected 4D input with shape [T, N, C, L], " f"but got input with shape {x_seq.shape}" ) T, N, C, L = x_seq.shape qk = self.qk_conv_bn(x_seq) qk = self.qk_lif(qk) # [T, N, 2*C, L] qk = qk.reshape(T, N, 2, self.num_heads, C // self.num_heads, L) x_seq = self._qka_forward_torch(qk) x_seq = x_seq.flatten(2, 3) # [T, N, C, L] x_seq = self.proj_conv_bn(x_seq) x_seq = self.proj_lif(x_seq) return x_seq
def extra_repr(self): return ( f"dim={self.dim}, num_heads={self.num_heads}, " f"qka_type={self.qka_type}, backend={self.backend}" )
[文档] class TokenQKAttention(QKAttention): def __init__(self, dim: int, num_heads: int = 8, backend: str = "torch"): """ ``QKAttention(..., qka_type="token")`` . See :class:`QKAttention` . :param dim: 输入维度 :type dim: int :param num_heads: 注意力头数 :type num_heads: int :param backend: 后端 :type backend: str :param dim: Input dimension :type dim: int :param num_heads: Number of attention heads :type num_heads: int :param backend: Backend :type backend: str :return: None :rtype: None """ super().__init__(dim, num_heads, qka_type="token", backend=backend)
[文档] class ChannelQKAttention(QKAttention): def __init__(self, dim: int, num_heads: int = 8, backend: str = "torch"): """ ``QKAttention(..., qka_type="channel")`` . See :class:`QKAttention` . :param dim: 输入维度 :type dim: int :param num_heads: 注意力头数 :type num_heads: int :param backend: 后端 :type backend: str :param dim: Input dimension :type dim: int :param num_heads: Number of attention heads :type num_heads: int :param backend: Backend :type backend: str :return: None :rtype: None """ super().__init__(dim, num_heads, qka_type="channel", backend=backend)