spikingjelly.datasets package
Submodules
spikingjelly.datasets.asl_dvs module
- class spikingjelly.datasets.asl_dvs.ASLDVS(root: str, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The ASL-DVS dataset, which is proposed by Graph-based Object Classification for Neuromorphic Vision Sensing.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolderfor more details about params information.- static resource_url_md5() list[源代码]
- 返回
A list
urlthaturl[i]is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool[源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static load_origin_data(file_name: str) Dict[源代码]
- 参数
file_name (str) – path of the events file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']and values arenumpy.ndarray- 返回类型
Dict
This function defines how to read the origin binary data.
- static get_H_W() Tuple[源代码]
- 返回
A tuple
(H, W), whereHis the height of the data andW` is the weight of the data. For example, this function returns ``(128, 128)for the DVS128 Gesture dataset.- 返回类型
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npzformat
- 返回
None
This function defines how to convert the origin binary data in
extract_roottonpzformat and save converted files inevents_np_root.
spikingjelly.datasets.cifar10_dvs module
- spikingjelly.datasets.cifar10_dvs.load_raw_events(fp, bytes_skip=0, bytes_trim=0, filter_dvs=False, times_first=False)[源代码]
- spikingjelly.datasets.cifar10_dvs.parse_raw_address(addr, x_mask=4190208, x_shift=12, y_mask=2143289344, y_shift=22, polarity_mask=2048, polarity_shift=11)[源代码]
- class spikingjelly.datasets.cifar10_dvs.CIFAR10DVS(root: str, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The CIFAR10-DVS dataset, which is proposed by `CIFAR10-DVS: An Event-Stream Dataset for Object Classification
<https://internal-journal.frontiersin.org/articles/10.3389/fnins.2017.00309/full>`_.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolderfor more details about params information.- static resource_url_md5() list[源代码]
- 返回
A list
urlthaturl[i]is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool[源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static load_origin_data(file_name: str) Dict[源代码]
- 参数
file_name (str) – path of the events file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']and values arenumpy.ndarray- 返回类型
Dict
This function defines how to read the origin binary data.
- static get_H_W() Tuple[源代码]
- 返回
A tuple
(H, W), whereHis the height of the data andW` is the weight of the data. For example, this function returns ``(128, 128)for the DVS128 Gesture dataset.- 返回类型
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npzformat
- 返回
None
This function defines how to convert the origin binary data in
extract_roottonpzformat and save converted files inevents_np_root.
spikingjelly.datasets.dvs128_gesture module
- class spikingjelly.datasets.dvs128_gesture.DVS128Gesture(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The DVS128 Gesture dataset, which is proposed by A Low Power, Fully Event-Based Gesture Recognition System.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolderfor more details about params information.- static resource_url_md5() list[源代码]
- 返回
A list
urlthaturl[i]is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool[源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static load_origin_data(file_name: str) Dict[源代码]
- 参数
file_name (str) – path of the events file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']and values arenumpy.ndarray- 返回类型
Dict
This function defines how to read the origin binary data.
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npzformat
- 返回
None
This function defines how to convert the origin binary data in
extract_roottonpzformat and save converted files inevents_np_root.
spikingjelly.datasets.es_imagenet module
- class spikingjelly.datasets.es_imagenet.ESImageNet(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The ES-ImageNet dataset, which is proposed by ES-ImageNet: A Million Event-Stream Classification Dataset for Spiking Neural Networks.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolderfor more details about params information.- static resource_url_md5() list[源代码]
- 返回
A list
urlthaturl[i]is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool[源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npzformat
- 返回
None
This function defines how to convert the origin binary data in
extract_roottonpzformat and save converted files inevents_np_root.
spikingjelly.datasets.n_caltech101 module
- class spikingjelly.datasets.n_caltech101.NCaltech101(root: str, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The N-Caltech101 dataset, which is proposed by Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolderfor more details about params information.- static resource_url_md5() list[源代码]
- 返回
A list
urlthaturl[i]is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool[源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static load_origin_data(file_name: str) Dict[源代码]
- 参数
file_name (str) – path of the events file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']and values arenumpy.ndarray- 返回类型
Dict
This function defines how to read the origin binary data.
- static get_H_W() Tuple[源代码]
- 返回
A tuple
(H, W), whereHis the height of the data andW` is the weight of the data. For example, this function returns ``(128, 128)for the DVS128 Gesture dataset.- 返回类型
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npzformat
- 返回
None
This function defines how to convert the origin binary data in
extract_roottonpzformat and save converted files inevents_np_root.
spikingjelly.datasets.n_mnist module
- class spikingjelly.datasets.n_mnist.NMNIST(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The N-MNIST dataset, which is proposed by Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolderfor more details about params information.- static resource_url_md5() list[源代码]
- 返回
A list
urlthaturl[i]is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool[源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static load_origin_data(file_name: str) Dict[源代码]
- 参数
file_name (str) – path of the events file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']and values arenumpy.ndarray- 返回类型
Dict
This function defines how to read the origin binary data.
- static get_H_W() Tuple[源代码]
- 返回
A tuple
(H, W), whereHis the height of the data andW` is the weight of the data. For example, this function returns ``(128, 128)for the DVS128 Gesture dataset.- 返回类型
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npzformat
- 返回
None
This function defines how to convert the origin binary data in
extract_roottonpzformat and save converted files inevents_np_root.
spikingjelly.datasets.speechcommands module
- spikingjelly.datasets.speechcommands.load_speechcommands_item(relpath: str, path: str) Tuple[Tensor, int, str, str, int][源代码]
- class spikingjelly.datasets.speechcommands.SPEECHCOMMANDS(label_dict: Dict, root: str, silence_cnt: Optional[int] = 0, silence_size: Optional[int] = 16000, transform: Optional[Callable] = None, url: Optional[str] = 'speech_commands_v0.02', split: Optional[str] = 'train', folder_in_archive: Optional[str] = 'SpeechCommands', download: Optional[bool] = False)[源代码]
基类:
Dataset- 参数
label_dict (Dict) – 标签与类别的对应字典
root (str) – 数据集的根目录
silence_cnt (int, optional) – Silence数据的数量
silence_size (int, optional) – Silence数据的尺寸
transform (Callable, optional) – A function/transform that takes in a raw audio
url (str, optional) – 数据集版本,默认为v0.02
split (str, optional) – 数据集划分,可以是
"train", "test", "val",默认为"train"folder_in_archive (str, optional) – 解压后的目录名称,默认为
"SpeechCommands"download (bool, optional) – 是否下载数据,默认为False
SpeechCommands语音数据集,出自 Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition,根据给出的测试集与验证集列表进行了划分,包含v0.01与v0.02两个版本。
数据集包含三大类单词的音频:
指令单词,共10个,”Yes”, “No”, “Up”, “Down”, “Left”, “Right”, “On”, “Off”, “Stop”, “Go”. 对于v0.02,还额外增加了5个:”Forward”, “Backward”, “Follow”, “Learn”, “Visual”.
0~9的数字,共10个:”One”, “Two”, “Three”, “Four”, “Five”, “Six”, “Seven”, “Eight”, “Nine”.
辅助词,可以视为干扰词,共10个:”Bed”, “Bird”, “Cat”, “Dog”, “Happy”, “House”, “Marvin”, “Sheila”, “Tree”, “Wow”.
v0.01版本包含共计30类,64,727个音频片段,v0.02版本包含共计35类,105,829个音频片段。更详细的介绍参见前述论文,以及数据集的README。
代码实现基于torchaudio并扩充了功能,同时也参考了 原论文的实现。
Module contents
- spikingjelly.datasets.play_frame(x: Tensor, save_gif_to: Optional[str] = None) None[源代码]
- 参数
x (torch.Tensor or np.ndarray) – frames with
shape=[T, 2, H, W]save_gif_to (str) – If
None, this function will play the frames. IfTrue, this function will not play the frames but save frames to a gif file in the directorysave_gif_to
- 返回
None
- spikingjelly.datasets.load_matlab_mat(file_name: str) Dict[源代码]
- 参数
file_name (str) – path of the matlab’s mat file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']and values arenumpy.ndarray- 返回类型
Dict
- spikingjelly.datasets.load_aedat_v3(file_name: str) Dict[源代码]
- 参数
file_name (str) – path of the aedat v3 file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']and values arenumpy.ndarray- 返回类型
Dict
This function is written by referring to https://gitlab.com/inivation/dv/dv-python . It can be used for DVS128 Gesture.
- spikingjelly.datasets.load_ATIS_bin(file_name: str) Dict[源代码]
- 参数
file_name (str) – path of the aedat v3 file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']and values arenumpy.ndarray- 返回类型
Dict
This function is written by referring to https://github.com/jackd/events-tfds . Each ATIS binary example is a separate binary file consisting of a list of events. Each event occupies 40 bits as described below: bit 39 - 32: Xaddress (in pixels) bit 31 - 24: Yaddress (in pixels) bit 23: Polarity (0 for OFF, 1 for ON) bit 22 - 0: Timestamp (in microseconds)
- spikingjelly.datasets.load_npz_frames(file_name: str) ndarray[源代码]
- 参数
file_name (str) – path of the npz file that saves the frames
- 返回
frames
- 返回类型
np.ndarray
- spikingjelly.datasets.integrate_events_segment_to_frame(x: ndarray, y: ndarray, p: ndarray, H: int, W: int, j_l: int = 0, j_r: int = -1) ndarray[源代码]
- param x
x-coordinate of events
- type x
numpy.ndarray
- param y
y-coordinate of events
- type y
numpy.ndarray
- param p
polarity of events
- type p
numpy.ndarray
- param H
height of the frame
- type H
int
- param W
weight of the frame
- type W
int
- param j_l
the start index of the integral interval, which is included
- type j_l
int
- param j_r
the right index of the integral interval, which is not included
- type j_r
- return
frames
- rtype
np.ndarray
Denote a two channels frame as \(F\) and a pixel at \((p, x, y)\) as \(F(p, x, y)\), the pixel value is integrated from the events data whose indices are in \([j_{l}, j_{r})\): .. math:
F(p, x, y) = \sum_{i = j_{l}}^{j_{r} - 1} \mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})
where :math:`lfloor cdot
floor` is the floor operation, \(\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})\) is an indicator function and it equals 1 only when \((p, x, y) = (p_{i}, x_{i}, y_{i})\).
- spikingjelly.datasets.cal_fixed_frames_number_segment_index(events_t: ndarray, split_by: str, frames_num: int) tuple[源代码]
- 参数
events_t (numpy.ndarray) – events’ t
split_by (str) – ‘time’ or ‘number’
frames_num (int) – the number of frames
- 返回
a tuple
(j_l, j_r)- 返回类型
Denote
frames_numas \(M\), ifsplit_byis'time', then .. math:\Delta T & = [\frac{t_{N-1} - t_{0}}{M}] \\ j_{l} & = \mathop{\arg\min}\limits_{k} \{t_{k} | t_{k} \geq t_{0} + \Delta T \cdot j\} \\ j_{r} & = \begin{cases} \mathop{\arg\max}\limits_{k} \{t_{k} | t_{k} < t_{0} + \Delta T \cdot (j + 1)\} + 1, & j < M - 1 \cr N, & j = M - 1 \end{cases}
If
split_byis'number', then .. math:j_{l} & = [\frac{N}{M}] \cdot j \\ j_{r} & = \begin{cases} [\frac{N}{M}] \cdot (j + 1), & j < M - 1 \cr N, & j = M - 1 \end{cases}
- spikingjelly.datasets.integrate_events_by_fixed_frames_number(events: Dict, split_by: str, frames_num: int, H: int, W: int) ndarray[源代码]
- 参数
- 返回
frames
- 返回类型
np.ndarray
Integrate events to frames by fixed frames number. See
cal_fixed_frames_number_segment_indexandintegrate_events_segment_to_framefor more details.
- spikingjelly.datasets.integrate_events_file_to_frames_file_by_fixed_frames_number(loader: Callable, events_np_file: str, output_dir: str, split_by: str, frames_num: int, H: int, W: int, print_save: bool = False) None[源代码]
- 参数
loader (Callable) – a function that can load events from events_np_file
events_np_file (str) – path of the events np file
output_dir (str) – output directory for saving the frames
split_by (str) – ‘time’ or ‘number’
frames_num (int) – the number of frames
H (int) – the height of frame
W (int) – the weight of frame
print_save (bool) – If
True, this function will print saved files’ paths.
- 返回
None
Integrate a events file to frames by fixed frames number and save it. See
cal_fixed_frames_number_segment_indexandintegrate_events_segment_to_framefor more details.
- spikingjelly.datasets.integrate_events_by_fixed_duration(events: Dict, duration: int, H: int, W: int) ndarray[源代码]
- 参数
- 返回
frames
- 返回类型
np.ndarray
Integrate events to frames by fixed time duration of each frame.
- spikingjelly.datasets.integrate_events_file_to_frames_file_by_fixed_duration(loader: Callable, events_np_file: str, output_dir: str, duration: int, H: int, W: int, print_save: bool = False) None[源代码]
- 参数
loader (Callable) – a function that can load events from events_np_file
events_np_file (str) – path of the events np file
output_dir (str) – output directory for saving the frames
duration (int) – the time duration of each frame
H (int) – the height of frame
W (int) – the weight of frame
print_save (bool) – If
True, this function will print saved files’ paths.
- 返回
None
Integrate events to frames by fixed time duration of each frame.
- spikingjelly.datasets.create_same_directory_structure(source_dir: str, target_dir: str) None[源代码]
- 参数
- 返回
None
Create the same directory structure in
target_dirwith that ofsource_dir.
- spikingjelly.datasets.split_to_train_test_set(train_ratio: float, origin_dataset: Dataset, num_classes: int, random_split: bool = False)[源代码]
- 参数
train_ratio (float) – split the ratio of the origin dataset as the train set
origin_dataset (torch.utils.data.Dataset) – the origin dataset
num_classes (int) – total classes number, e.g.,
10for the MNIST datasetrandom_split (int) – If
False, the front ratio of samples in each classes will be included in train set, while the reset will be included in test set. IfTrue, this function will split samples in each classes randomly. The randomness is controlled bynumpy.random.seed
- 返回
a tuple
(train_set, test_set)- 返回类型
- spikingjelly.datasets.pad_sequence_collate(batch: list)[源代码]
- 参数
batch (list) – a list of samples that contains
(x, y), wherexis a list containing sequences with different length andyis the label- 返回
batched samples
(x_p, y, x_len), where ``x_pis paddedxwith the same length, y` is the label, andx_lenis the length of thex- 返回类型
This function can be use as the
collate_fnforDataLoaderto process the dataset with variable length, e.g., aNeuromorphicDatasetFolderwith fixed duration to integrate events to frames. Here is an example: .. code-block:: python class VariableLengthDataset(torch.utils.data.Dataset):- def __init__(self, n=1000):
super().__init__() self.n = n
- def __getitem__(self, i):
return torch.rand([i + 1, 2]), self.n - i - 1
- def __len__(self):
return self.n
- loader = torch.utils.data.DataLoader(VariableLengthDataset(n=32), batch_size=2, collate_fn=pad_sequence_collate,
shuffle=True)
- for i, (x_p, label, x_len) in enumerate(loader):
print(f’x_p.shape={x_p.shape}, label={label}, x_len={x_len}’) if i == 2:
break
And the outputs are: .. code-block:: bash
x_p.shape=torch.Size([2, 18, 2]), label=tensor([14, 30]), x_len=tensor([18, 2]) x_p.shape=torch.Size([2, 29, 2]), label=tensor([3, 6]), x_len=tensor([29, 26]) x_p.shape=torch.Size([2, 23, 2]), label=tensor([ 9, 23]), x_len=tensor([23, 9])
- spikingjelly.datasets.padded_sequence_mask(sequence_len: Tensor, T=None)[源代码]
- 参数
sequence_len (torch.Tensor) – a tensor
shape = [N]that contains sequences lengths of each batch elementT (int) – The maximum length of sequences. If
None, the maximum element insequence_lenwill be seen asT
- 返回
a bool mask with shape = [T, N], where the padded position is
False- 返回类型
Here is an example: .. code-block:: python
x1 = torch.rand([2, 6]) x2 = torch.rand([3, 6]) x3 = torch.rand([4, 6]) x = torch.nn.utils.rnn.pad_sequence([x1, x2, x3]) # [T, N, *] print(‘x.shape=’, x.shape) x_len = torch.as_tensor([x1.shape[0], x2.shape[0], x3.shape[0]]) mask = padded_sequence_mask(x_len) print(‘mask.shape=’, mask.shape) print(‘mask=n’, mask)
And the outputs are: .. code-block:: bash
x.shape= torch.Size([4, 3, 6]) mask.shape= torch.Size([4, 3]) mask=
- tensor([[ True, True, True],
[ True, True, True], [False, True, True], [False, False, True]])
- class spikingjelly.datasets.NeuromorphicDatasetFolder(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
基类:
DatasetFolder- 参数
root (str) – root path of the dataset
train (bool) – whether use the train set. Set
TrueorFalsefor those datasets provide train/test division, e.g., DVS128 Gesture dataset. If the dataset does not provide train/test division, e.g., CIFAR10-DVS, please setNoneand usesplit_to_train_test_setfunction to get train/test setdata_type (str) – event or frame
frames_number (int) – the integrated frame number
split_by (str) – time or number
duration (int) – the time duration of each frame
custom_integrate_function (Callable) – a user-defined function that inputs are
events, H, W.eventsis a dict whose keys are['t', 'x', 'y', 'p']and values arenumpy.ndarrayHis the height of the data andWis the weight of the data. For example, H=128 and W=128 for the DVS128 Gesture dataset. The user should define how to integrate events to frames, and return frames.custom_integrated_frames_dir_name (str or None) – The name of directory for saving frames integrating by
custom_integrate_function. Ifcustom_integrated_frames_dir_nameisNone, it will be set tocustom_integrate_function.__name__transform (callable) – a function/transform that takes in a sample and returns a transformed version. E.g,
transforms.RandomCropfor images.target_transform (callable) – a function/transform that takes in the target and transforms it.
The base class for neuromorphic dataset. Users can define a new dataset by inheriting this class and implementing all abstract methods. Users can refer to
spikingjelly.datasets.dvs128_gesture.DVS128Gesture. Ifdata_type == 'event'the sample in this dataset is a dict whose keys are
['t', 'x', 'y', 'p']and values arenumpy.ndarray.- If
data_type == 'frame'andframes_numberis notNone events will be integrated to frames with fixed frames number.
split_bywill define how to split events. Seecal_fixed_frames_number_segment_indexfor more details.- If
data_type == 'frame',frames_numberisNone, anddurationis notNone events will be integrated to frames with fixed time duration.
- If
data_type == 'frame',frames_numberisNone,durationisNone, andcustom_integrate_functionis notNone: events will be integrated by the user-defined function and saved to the
custom_integrated_frames_dir_namedirectory inrootdirectory. Here is an example from SpikingJelly’s tutorials: .. code-block:: pythonfrom spikingjelly.datasets.dvs128_gesture import DVS128Gesture from typing import Dict import numpy as np import spikingjelly.datasets as sjds def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
index_split = np.random.randint(low=0, high=events[‘t’].__len__()) frames = np.zeros([2, 2, H, W]) t, x, y, p = (events[key] for key in (‘t’, ‘x’, ‘y’, ‘p’)) frames[0] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, 0, index_split) frames[1] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, index_split, events[‘t’].__len__()) return frames
root_dir = ‘D:/datasets/DVS128Gesture’ train_set = DVS128Gesture(root_dir, train=True, data_type=’frame’, custom_integrate_function=integrate_events_to_2_frames_randomly) from spikingjelly.datasets import play_frame frame, label = train_set[500] play_frame(frame)
- abstract static resource_url_md5() list[源代码]
- 返回
A list
urlthaturl[i]is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- abstract static downloadable() bool[源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- abstract static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- abstract static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npzformat
- 返回
None
This function defines how to convert the origin binary data in
extract_roottonpzformat and save converted files inevents_np_root.
- spikingjelly.datasets.random_temporal_delete(x_seq: Tensor, T_remain: int, batch_first)[源代码]
- 参数
x_seq (torch.Tensor or np.ndarray) – a sequence with shape = [T, N, *], where T is the sequence length and N is the batch size
T_remain (int) – the remained length
batch_first (bool) – if True, x_seq will be regarded as shape = [N, T, *]
- 返回
the sequence with length T_remain, which is obtained by randomly removing T - T_remain slices
- 返回类型
torch.Tensor or np.ndarray
The random temporal delete data augmentation used in Deep Residual Learning in Spiking Neural Networks. Codes example: .. code-block:: python
import torch from spikingjelly.datasets import random_temporal_delete T = 8 T_remain = 5 N = 4 x_seq = torch.arange(0, N*T).view([N, T]) print(‘x_seq=n’, x_seq) print(‘random_temporal_delete(x_seq)=n’, random_temporal_delete(x_seq, T_remain, batch_first=True))
Outputs: .. code-block:: shell
- x_seq=
- tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31]])
- random_temporal_delete(x_seq)=
- tensor([[ 0, 1, 4, 6, 7],
[ 8, 9, 12, 14, 15], [16, 17, 20, 22, 23], [24, 25, 28, 30, 31]])
- class spikingjelly.datasets.RandomTemporalDelete(T_remain: int, batch_first: bool)[源代码]
基类:
Module- 参数
T_remain (int) – the remained length
batch_first – if True, x_seq will be regarded as shape = [N, T, *]
The random temporal delete data augmentation used in Deep Residual Learning in Spiking Neural Networks. Refer to
random_temporal_deletefor more details.
- spikingjelly.datasets.create_sub_dataset(source_dir: str, target_dir: str, ratio: float, use_soft_link=True, randomly=False)[源代码]
- 参数
source_dir (str) – the directory path of the origin dataset
target_dir (str) – the directory path of the sub dataset
ratio (float) – the ratio of samples sub dataset will copy from the origin dataset
use_soft_link (bool) – if
True, the sub dataset will use soft link to copy; else, the sub dataset will copy filesrandomly (bool) – if
True, the files copy from the origin dataset will be picked up randomly. The randomness is controlled bynumpy.random.seed
Create a sub dataset with copy
ratioof samples from the origin dataset.