spikingjelly.datasets.transform 源代码

from typing import Union

import numpy as np
import torch


__all__ = ["random_temporal_delete", "RandomTemporalDelete"]


[文档] def random_temporal_delete( x_seq: Union[torch.Tensor, np.ndarray], T_remain: int, batch_first: bool ): r""" **API Language:** :ref:`中文 <random_temporal_delete-cn>` | :ref:`English <random_temporal_delete-en>` ---- .. _random_temporal_delete-cn: * **中文** 在 `Deep Residual Learning in Spiking Neural Networks <https://arxiv.org/abs/2102.04159>`_ 中使用的随机时间删除数据增强。 :param x_seq: 一个序列, 其 `shape = [T, N, *]`, 其中 `T` 是序列长度, `N` 是批次大小 :type x_seq: Union[torch.Tensor, np.ndarray] :param T_remain: 剩余的长度 :type T_remain: int :param batch_first: 如果 `True`, `x_seq` 将被视为 `shape = [N, T, *]` :type batch_first: bool :return: 长度为 `T_remain` 的序列, 通过随机移除 `T - T_remain` 个切片获得 :rtype: Union[torch.Tensor, np.ndarray] :raises ValueError: 当 ``T_remain`` 为负数, 或大于当前时间维长度时由 ``numpy.random.choice`` 抛出。 ---- .. _random_temporal_delete-en: * **English** The random temporal delete data augmentation used in `Deep Residual Learning in Spiking Neural Networks <https://arxiv.org/abs/2102.04159>`_. :param x_seq: a sequence with `shape = [T, N, *]`, where `T` is the sequence length and `N` is the batch size :type x_seq: Union[torch.Tensor, np.ndarray] :param T_remain: the remained length :type T_remain: int :param batch_first: if `True`, `x_seq` will be regarded as `shape = [N, T, *]` :type batch_first: bool :return: the sequence with length `T_remain`, which is obtained by randomly removing `T - T_remain` slices :rtype: Union[torch.Tensor, np.ndarray] :raises ValueError: raised by ``numpy.random.choice`` when ``T_remain`` is negative or larger than the current time dimension length. ---- * **代码示例 | Example** .. code-block:: python import torch from spikingjelly.datasets import random_temporal_delete T = 8 T_remain = 5 N = 4 x_seq = torch.arange(0, N * T).view([N, T]) print("x_seq=\n", x_seq) print( "random_temporal_delete(x_seq)=\n", random_temporal_delete(x_seq, T_remain, batch_first=True), ) Outputs: .. code-block:: shell x_seq= tensor([[ 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31]]) random_temporal_delete(x_seq)= tensor([[ 0, 1, 4, 6, 7], [ 8, 9, 12, 14, 15], [16, 17, 20, 22, 23], [24, 25, 28, 30, 31]]) """ if batch_first: sec_list = np.random.choice(x_seq.shape[1], T_remain, replace=False) else: sec_list = np.random.choice(x_seq.shape[0], T_remain, replace=False) sec_list.sort() if batch_first: return x_seq[:, sec_list] else: return x_seq[sec_list]
[文档] class RandomTemporalDelete(torch.nn.Module): r""" **API Language:** :ref:`中文 <RandomTemporalDelete-cn>` | :ref:`English <RandomTemporalDelete-en>` ---- .. _RandomTemporalDelete-cn: * **中文** :func:`random_temporal_delete` 的 ``torch.nn.Module`` 封装。前向传播时会使用构造时给定的 ``T_remain`` 和 ``batch_first`` 调用 :func:`random_temporal_delete`。 ---- .. _RandomTemporalDelete-en: * **English** A ``torch.nn.Module`` wrapper around :func:`random_temporal_delete`. During ``forward``, it calls :func:`random_temporal_delete` with the ``T_remain`` and ``batch_first`` values provided at construction time. """ def __init__(self, T_remain: int, batch_first: bool): r""" **API Language:** :ref:`中文 <RandomTemporalDelete.__init__-cn>` | :ref:`English <RandomTemporalDelete.__init__-en>` ---- .. _RandomTemporalDelete.__init__-cn: * **中文** * **中文** 在 `Deep Residual Learning in Spiking Neural Networks <https://arxiv.org/abs/2102.04159>`_ 中使用的随机时间删除数据增强。 详见 :func:`random_temporal_delete`。 :param T_remain: 剩余的长度 :type T_remain: int :param batch_first: 如果 `True`, `x_seq` 将被视为 `shape = [N, T, *]` :type batch_first: bool ---- .. _RandomTemporalDelete.__init__-en: * **English** * **English** The random temporal delete data augmentation used in `Deep Residual Learning in Spiking Neural Networks <https://arxiv.org/abs/2102.04159>`_. Refer to :func:`random_temporal_delete` for more details. :param T_remain: the remained length :type T_remain: int :param batch_first: if `True`, `x_seq` will be regarded as `shape = [N, T, *]` :type batch_first: bool :return: None :rtype: None """ super().__init__() self.T_remain = T_remain self.batch_first = batch_first
[文档] def forward( self, x_seq: Union[torch.Tensor, np.ndarray] ) -> Union[torch.Tensor, np.ndarray]: r""" **API Language:** :ref:`中文 <RandomTemporalDelete.forward-cn>` | :ref:`English <RandomTemporalDelete.forward-en>` ---- .. _RandomTemporalDelete.forward-cn: * **中文** 使用当前模块保存的 ``T_remain`` 和 ``batch_first`` 配置, 对输入序列执行 :func:`random_temporal_delete`。 :param x_seq: 输入序列。其时间维布局由 ``batch_first`` 决定。 :type x_seq: Union[torch.Tensor, np.ndarray] :return: 随机删除时间切片后的序列。 :rtype: Union[torch.Tensor, np.ndarray] :raises ValueError: 当 ``self.T_remain`` 非法时, 由 :func:`random_temporal_delete` 内部的 ``numpy.random.choice`` 抛出 ---- .. _RandomTemporalDelete.forward-en: * **English** Apply :func:`random_temporal_delete` to the input sequence with the ``T_remain`` and ``batch_first`` configuration stored in this module. :param x_seq: input sequence. The time-dimension layout is determined by ``batch_first``. :type x_seq: Union[torch.Tensor, np.ndarray] :return: sequence after random temporal deletion. :rtype: Union[torch.Tensor, np.ndarray] :raises ValueError: raised by ``numpy.random.choice`` inside :func:`random_temporal_delete` when ``self.T_remain`` is invalid """ return random_temporal_delete(x_seq, self.T_remain, self.batch_first)