spikingjelly.datasets.transform module#

spikingjelly.datasets.transform.random_temporal_delete(x_seq: Tensor | ndarray, T_remain: int, batch_first: bool)[源代码]#

API Language: 中文 | English


  • 中文

Deep Residual Learning in Spiking Neural Networks 中使用的随机时间删除数据增强。

参数:
  • x_seq (Union[torch.Tensor, np.ndarray]) -- 一个序列, 其 shape = [T, N, *], 其中 T 是序列长度, N 是批次大小

  • T_remain (int) -- 剩余的长度

  • batch_first (bool) -- 如果 True, x_seq 将被视为 shape = [N, T, *]

返回:

长度为 T_remain 的序列, 通过随机移除 T - T_remain 个切片获得

返回类型:

Union[torch.Tensor, np.ndarray]

抛出:

ValueError -- 当 T_remain 为负数, 或大于当前时间维长度时由 numpy.random.choice 抛出。


  • English

The random temporal delete data augmentation used in Deep Residual Learning in Spiking Neural Networks.

参数:
  • x_seq (Union[torch.Tensor, np.ndarray]) -- a sequence with shape = [T, N, *], where T is the sequence length and N is the batch size

  • T_remain (int) -- the remained length

  • batch_first (bool) -- if True, x_seq will be regarded as shape = [N, T, *]

返回:

the sequence with length T_remain, which is obtained by randomly removing T - T_remain slices

返回类型:

Union[torch.Tensor, np.ndarray]

抛出:

ValueError -- raised by numpy.random.choice when T_remain is negative or larger than the current time dimension length.


  • 代码示例 | Example

import torch
from spikingjelly.datasets import random_temporal_delete

T = 8
T_remain = 5
N = 4
x_seq = torch.arange(0, N * T).view([N, T])
print("x_seq=\n", x_seq)
print(
    "random_temporal_delete(x_seq)=\n",
    random_temporal_delete(x_seq, T_remain, batch_first=True),
)

Outputs:

x_seq=
 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29, 30, 31]])
random_temporal_delete(x_seq)=
 tensor([[ 0,  1,  4,  6,  7],
        [ 8,  9, 12, 14, 15],
        [16, 17, 20, 22, 23],
        [24, 25, 28, 30, 31]])
class spikingjelly.datasets.transform.RandomTemporalDelete(T_remain: int, batch_first: bool)[源代码]#

基类:Module

API Language: 中文 | English


  • 中文

random_temporal_delete()torch.nn.Module 封装。前向传播时会使用构造时给定的 T_remainbatch_first 调用 random_temporal_delete()


  • English

A torch.nn.Module wrapper around random_temporal_delete(). During forward, it calls random_temporal_delete() with the T_remain and batch_first values provided at construction time.

API Language: 中文 | English


  • 中文

  • 中文

Deep Residual Learning in Spiking Neural Networks 中使用的随机时间删除数据增强。 详见 random_temporal_delete()

参数:
  • T_remain (int) -- 剩余的长度

  • batch_first (bool) -- 如果 True, x_seq 将被视为 shape = [N, T, *]


  • English

  • English

The random temporal delete data augmentation used in Deep Residual Learning in Spiking Neural Networks. Refer to random_temporal_delete() for more details.

参数:
  • T_remain (int) -- the remained length

  • batch_first (bool) -- if True, x_seq will be regarded as shape = [N, T, *]

返回:

None

返回类型:

None

forward(x_seq: Tensor | ndarray) Tensor | ndarray[源代码]#

API Language: 中文 | English


  • 中文

使用当前模块保存的 T_remainbatch_first 配置, 对输入序列执行 random_temporal_delete()

参数:

x_seq (Union[torch.Tensor, np.ndarray]) -- 输入序列。其时间维布局由 batch_first 决定。

返回:

随机删除时间切片后的序列。

返回类型:

Union[torch.Tensor, np.ndarray]

抛出:

ValueError -- 当 self.T_remain 非法时, 由 random_temporal_delete() 内部的 numpy.random.choice 抛出


  • English

Apply random_temporal_delete() to the input sequence with the T_remain and batch_first configuration stored in this module.

参数:

x_seq (Union[torch.Tensor, np.ndarray]) -- input sequence. The time-dimension layout is determined by batch_first.

返回:

sequence after random temporal deletion.

返回类型:

Union[torch.Tensor, np.ndarray]

抛出:

ValueError -- raised by numpy.random.choice inside random_temporal_delete() when self.T_remain is invalid