spikingjelly.datasets.transform module#
- spikingjelly.datasets.transform.random_temporal_delete(x_seq: Tensor | ndarray, T_remain: int, batch_first: bool)[源代码]#
-
中文
在 Deep Residual Learning in Spiking Neural Networks 中使用的随机时间删除数据增强。
- 参数:
x_seq (Union[torch.Tensor, np.ndarray]) -- 一个序列, 其 shape = [T, N, *], 其中 T 是序列长度, N 是批次大小
T_remain (int) -- 剩余的长度
batch_first (bool) -- 如果 True, x_seq 将被视为 shape = [N, T, *]
- 返回:
长度为 T_remain 的序列, 通过随机移除 T - T_remain 个切片获得
- 返回类型:
Union[torch.Tensor, np.ndarray]
- 抛出:
ValueError -- 当
T_remain为负数, 或大于当前时间维长度时由numpy.random.choice抛出。
English
The random temporal delete data augmentation used in Deep Residual Learning in Spiking Neural Networks.
- 参数:
x_seq (Union[torch.Tensor, np.ndarray]) -- a sequence with shape = [T, N, *], where T is the sequence length and N is the batch size
T_remain (int) -- the remained length
batch_first (bool) -- if True, x_seq will be regarded as shape = [N, T, *]
- 返回:
the sequence with length T_remain, which is obtained by randomly removing T - T_remain slices
- 返回类型:
Union[torch.Tensor, np.ndarray]
- 抛出:
ValueError -- raised by
numpy.random.choicewhenT_remainis negative or larger than the current time dimension length.
代码示例 | Example
import torch from spikingjelly.datasets import random_temporal_delete T = 8 T_remain = 5 N = 4 x_seq = torch.arange(0, N * T).view([N, T]) print("x_seq=\n", x_seq) print( "random_temporal_delete(x_seq)=\n", random_temporal_delete(x_seq, T_remain, batch_first=True), )
Outputs:
x_seq= tensor([[ 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31]]) random_temporal_delete(x_seq)= tensor([[ 0, 1, 4, 6, 7], [ 8, 9, 12, 14, 15], [16, 17, 20, 22, 23], [24, 25, 28, 30, 31]])
- class spikingjelly.datasets.transform.RandomTemporalDelete(T_remain: int, batch_first: bool)[源代码]#
基类:
Module
中文
random_temporal_delete()的torch.nn.Module封装。前向传播时会使用构造时给定的T_remain和batch_first调用random_temporal_delete()。
English
A
torch.nn.Modulewrapper aroundrandom_temporal_delete(). Duringforward, it callsrandom_temporal_delete()with theT_remainandbatch_firstvalues provided at construction time.
中文
中文
在 Deep Residual Learning in Spiking Neural Networks 中使用的随机时间删除数据增强。 详见
random_temporal_delete()。
English
English
The random temporal delete data augmentation used in Deep Residual Learning in Spiking Neural Networks. Refer to
random_temporal_delete()for more details.- 参数:
- 返回:
None
- 返回类型:
None
- forward(x_seq: Tensor | ndarray) Tensor | ndarray[源代码]#
-
中文
使用当前模块保存的
T_remain和batch_first配置, 对输入序列执行random_temporal_delete()。- 参数:
x_seq (Union[torch.Tensor, np.ndarray]) -- 输入序列。其时间维布局由
batch_first决定。- 返回:
随机删除时间切片后的序列。
- 返回类型:
Union[torch.Tensor, np.ndarray]
- 抛出:
ValueError -- 当
self.T_remain非法时, 由random_temporal_delete()内部的numpy.random.choice抛出
English
Apply
random_temporal_delete()to the input sequence with theT_remainandbatch_firstconfiguration stored in this module.- 参数:
x_seq (Union[torch.Tensor, np.ndarray]) -- input sequence. The time-dimension layout is determined by
batch_first.- 返回:
sequence after random temporal deletion.
- 返回类型:
Union[torch.Tensor, np.ndarray]
- 抛出:
ValueError -- raised by
numpy.random.choiceinsiderandom_temporal_delete()whenself.T_remainis invalid