spikingjelly.datasets package

Submodules

spikingjelly.datasets.asl_dvs module

class spikingjelly.datasets.asl_dvs.ASLDVS(root: str, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]

基类:spikingjelly.datasets.NeuromorphicDatasetFolder

参数
  • root (str) – root path of the dataset

  • data_type (str) – event or frame

  • frames_number (int) – the integrated frame number

  • split_by (str) – time or number

  • duration (int) – the time duration of each frame

  • transform (callable) – a function/transform that takes in a sample and returns a transformed version. E.g, transforms.RandomCrop for images.

  • target_transform (callable) – a function/transform that takes in the target and transforms it.

If data_type == 'event'

the sample in this dataset is a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray.

If data_type == 'frame' and frames_number is not None

events will be integrated to frames with fixed frames number. split_by will define how to split events. See cal_fixed_frames_number_segment_index for more details.

If data_type == 'frame', frames_number is None, and duration is not None

events will be integrated to frames with fixed time duration.

static resource_url_md5() list[源代码]
返回

A list url that url[i] is a tuple, which contains the i-th file’s name, download link, and MD5

返回类型

list

static downloadable() bool[源代码]
返回

Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually

返回类型

bool

static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
参数
  • download_root (str) – Root directory path which saves downloaded dataset files

  • extract_root (str) – Root directory path which saves extracted files from downloaded files

返回

None

This function defines how to extract download files.

static load_origin_data(file_name: str) Dict[源代码]
参数

file_name (str) – path of the events file

返回

a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

返回类型

Dict

This function defines how to read the origin binary data.

static get_H_W() Tuple[源代码]
返回

A tuple (H, W), where H is the height of the data and W` is the weight of the data. For example, this function returns ``(128, 128) for the DVS128 Gesture dataset.

返回类型

tuple

static read_mat_save_to_np(mat_file: str, np_file: str)[源代码]
static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
参数
  • extract_root (str) – Root directory path which saves extracted files from downloaded files

  • events_np_root – Root directory path which saves events files in the npz format

返回

None

This function defines how to convert the origin binary data in extract_root to npz format and save converted files in events_np_root.

spikingjelly.datasets.cifar10_dvs module

spikingjelly.datasets.cifar10_dvs.read_bits(arr, mask=None, shift=None)[源代码]
spikingjelly.datasets.cifar10_dvs.skip_header(fp)[源代码]
spikingjelly.datasets.cifar10_dvs.load_raw_events(fp, bytes_skip=0, bytes_trim=0, filter_dvs=False, times_first=False)[源代码]
spikingjelly.datasets.cifar10_dvs.parse_raw_address(addr, x_mask=4190208, x_shift=12, y_mask=2143289344, y_shift=22, polarity_mask=2048, polarity_shift=11)[源代码]
spikingjelly.datasets.cifar10_dvs.load_events(fp, filter_dvs=False, **kwargs)[源代码]
class spikingjelly.datasets.cifar10_dvs.CIFAR10DVS(root: str, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]

基类:spikingjelly.datasets.NeuromorphicDatasetFolder

参数
  • root (str) – root path of the dataset

  • data_type (str) – event or frame

  • frames_number (int) – the integrated frame number

  • split_by (str) – time or number

  • duration (int) – the time duration of each frame

  • transform (callable) – a function/transform that takes in a sample and returns a transformed version. E.g, transforms.RandomCrop for images.

  • target_transform (callable) – a function/transform that takes in the target and transforms it.

If data_type == 'event'

the sample in this dataset is a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray.

If data_type == 'frame' and frames_number is not None

events will be integrated to frames with fixed frames number. split_by will define how to split events. See cal_fixed_frames_number_segment_index for more details.

If data_type == 'frame', frames_number is None, and duration is not None

events will be integrated to frames with fixed time duration.

static resource_url_md5() list[源代码]
返回

A list url that url[i] is a tuple, which contains the i-th file’s name, download link, and MD5

返回类型

list

static downloadable() bool[源代码]
返回

Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually

返回类型

bool

static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
参数
  • download_root (str) – Root directory path which saves downloaded dataset files

  • extract_root (str) – Root directory path which saves extracted files from downloaded files

返回

None

This function defines how to extract download files.

static load_origin_data(file_name: str) Dict[源代码]
参数

file_name (str) – path of the events file

返回

a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

返回类型

Dict

This function defines how to read the origin binary data.

static get_H_W() Tuple[源代码]
返回

A tuple (H, W), where H is the height of the data and W` is the weight of the data. For example, this function returns ``(128, 128) for the DVS128 Gesture dataset.

返回类型

tuple

static read_aedat_save_to_np(bin_file: str, np_file: str)[源代码]
static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
参数
  • extract_root (str) – Root directory path which saves extracted files from downloaded files

  • events_np_root – Root directory path which saves events files in the npz format

返回

None

This function defines how to convert the origin binary data in extract_root to npz format and save converted files in events_np_root.

spikingjelly.datasets.dvs128_gesture module

class spikingjelly.datasets.dvs128_gesture.DVS128Gesture(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]

基类:spikingjelly.datasets.NeuromorphicDatasetFolder

参数
  • root (str) – root path of the dataset

  • train (bool) – whether use the train set

  • data_type (str) – event or frame

  • frames_number (int) – the integrated frame number

  • split_by (str) – time or number

  • duration (int) – the time duration of each frame

  • transform (callable) – a function/transform that takes in a sample and returns a transformed version. E.g, transforms.RandomCrop for images.

  • target_transform (callable) – a function/transform that takes in the target and transforms it.

If data_type == 'event'

the sample in this dataset is a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray.

If data_type == 'frame' and frames_number is not None

events will be integrated to frames with fixed frames number. split_by will define how to split events. See cal_fixed_frames_number_segment_index for more details.

If data_type == 'frame', frames_number is None, and duration is not None

events will be integrated to frames with fixed time duration.

static resource_url_md5() list[源代码]
返回

A list url that url[i] is a tuple, which contains the i-th file’s name, download link, and MD5

返回类型

list

static downloadable() bool[源代码]
返回

Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually

返回类型

bool

static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
参数
  • download_root (str) – Root directory path which saves downloaded dataset files

  • extract_root (str) – Root directory path which saves extracted files from downloaded files

返回

None

This function defines how to extract download files.

static load_origin_data(file_name: str) Dict[源代码]
参数

file_name (str) – path of the events file

返回

a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

返回类型

Dict

This function defines how to read the origin binary data.

static split_aedat_files_to_np(fname: str, aedat_file: str, csv_file: str, output_dir: str)[源代码]
static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
参数
  • extract_root (str) – Root directory path which saves extracted files from downloaded files

  • events_np_root – Root directory path which saves events files in the npz format

返回

None

This function defines how to convert the origin binary data in extract_root to npz format and save converted files in events_np_root.

static get_H_W() Tuple[源代码]
返回

A tuple (H, W), where H is the height of the data and W` is the weight of the data. For example, this function returns ``(128, 128) for the DVS128 Gesture dataset.

返回类型

tuple

spikingjelly.datasets.n_caltech101 module

class spikingjelly.datasets.n_caltech101.NCaltech101(root: str, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]

基类:spikingjelly.datasets.NeuromorphicDatasetFolder

参数
  • root (str) – root path of the dataset

  • data_type (str) – event or frame

  • frames_number (int) – the integrated frame number

  • split_by (str) – time or number

  • duration (int) – the time duration of each frame

  • transform (callable) – a function/transform that takes in a sample and returns a transformed version. E.g, transforms.RandomCrop for images.

  • target_transform (callable) – a function/transform that takes in the target and transforms it.

If data_type == 'event'

the sample in this dataset is a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray.

If data_type == 'frame' and frames_number is not None

events will be integrated to frames with fixed frames number. split_by will define how to split events. See cal_fixed_frames_number_segment_index for more details.

If data_type == 'frame', frames_number is None, and duration is not None

events will be integrated to frames with fixed time duration.

static resource_url_md5() list[源代码]
返回

A list url that url[i] is a tuple, which contains the i-th file’s name, download link, and MD5

返回类型

list

static downloadable() bool[源代码]
返回

Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually

返回类型

bool

static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
参数
  • download_root (str) – Root directory path which saves downloaded dataset files

  • extract_root (str) – Root directory path which saves extracted files from downloaded files

返回

None

This function defines how to extract download files.

static load_origin_data(file_name: str) Dict[源代码]
参数

file_name (str) – path of the events file

返回

a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

返回类型

Dict

This function defines how to read the origin binary data.

static get_H_W() Tuple[源代码]
返回

A tuple (H, W), where H is the height of the data and W` is the weight of the data. For example, this function returns ``(128, 128) for the DVS128 Gesture dataset.

返回类型

tuple

static read_bin_save_to_np(bin_file: str, np_file: str)[源代码]
static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
参数
  • extract_root (str) – Root directory path which saves extracted files from downloaded files

  • events_np_root – Root directory path which saves events files in the npz format

返回

None

This function defines how to convert the origin binary data in extract_root to npz format and save converted files in events_np_root.

spikingjelly.datasets.n_mnist module

class spikingjelly.datasets.n_mnist.NMNIST(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]

基类:spikingjelly.datasets.NeuromorphicDatasetFolder

参数
  • root (str) – root path of the dataset

  • train (bool) – whether use the train set

  • data_type (str) – event or frame

  • frames_number (int) – the integrated frame number

  • split_by (str) – time or number

  • duration (int) – the time duration of each frame

  • transform (callable) – a function/transform that takes in a sample and returns a transformed version. E.g, transforms.RandomCrop for images.

  • target_transform (callable) – a function/transform that takes in the target and transforms it.

If data_type == 'event'

the sample in this dataset is a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray.

If data_type == 'frame' and frames_number is not None

events will be integrated to frames with fixed frames number. split_by will define how to split events. See cal_fixed_frames_number_segment_index for more details.

If data_type == 'frame', frames_number is None, and duration is not None

events will be integrated to frames with fixed time duration.

static resource_url_md5() list[源代码]
返回

A list url that url[i] is a tuple, which contains the i-th file’s name, download link, and MD5

返回类型

list

static downloadable() bool[源代码]
返回

Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually

返回类型

bool

static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
参数
  • download_root (str) – Root directory path which saves downloaded dataset files

  • extract_root (str) – Root directory path which saves extracted files from downloaded files

返回

None

This function defines how to extract download files.

static load_origin_data(file_name: str) Dict[源代码]
参数

file_name (str) – path of the events file

返回

a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

返回类型

Dict

This function defines how to read the origin binary data.

static get_H_W() Tuple[源代码]
返回

A tuple (H, W), where H is the height of the data and W` is the weight of the data. For example, this function returns ``(128, 128) for the DVS128 Gesture dataset.

返回类型

tuple

static read_bin_save_to_np(bin_file: str, np_file: str)[源代码]
static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
参数
  • extract_root (str) – Root directory path which saves extracted files from downloaded files

  • events_np_root – Root directory path which saves events files in the npz format

返回

None

This function defines how to convert the origin binary data in extract_root to npz format and save converted files in events_np_root.

spikingjelly.datasets.speechcommands module

spikingjelly.datasets.speechcommands.load_speechcommands_item(relpath: str, path: str) Tuple[torch.Tensor, int, str, str, int][源代码]
class spikingjelly.datasets.speechcommands.SPEECHCOMMANDS(label_dict: Dict, root: str, silence_cnt: Optional[int] = 0, silence_size: Optional[int] = 16000, transform: Optional[Callable] = None, url: Optional[str] = 'speech_commands_v0.02', split: Optional[str] = 'train', folder_in_archive: Optional[str] = 'SpeechCommands', download: Optional[bool] = False)[源代码]

基类:torch.utils.data.dataset.Dataset

参数
  • label_dict (Dict) – 标签与类别的对应字典

  • root (str) – 数据集的根目录

  • silence_cnt (int, optional) – Silence数据的数量

  • silence_size (int, optional) – Silence数据的尺寸

  • transform (Callable, optional) – A function/transform that takes in a raw audio

  • url (str, optional) – 数据集版本,默认为v0.02

  • split (str, optional) – 数据集划分,可以是 "train", "test", "val",默认为 "train"

  • folder_in_archive (str, optional) – 解压后的目录名称,默认为 "SpeechCommands"

  • download (bool, optional) – 是否下载数据,默认为False

SpeechCommands语音数据集,出自 Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition,根据给出的测试集与验证集列表进行了划分,包含v0.01与v0.02两个版本。

数据集包含三大类单词的音频:

  1. 指令单词,共10个,”Yes”, “No”, “Up”, “Down”, “Left”, “Right”, “On”, “Off”, “Stop”, “Go”. 对于v0.02,还额外增加了5个:”Forward”, “Backward”, “Follow”, “Learn”, “Visual”.

  2. 0~9的数字,共10个:”One”, “Two”, “Three”, “Four”, “Five”, “Six”, “Seven”, “Eight”, “Nine”.

  3. 辅助词,可以视为干扰词,共10个:”Bed”, “Bird”, “Cat”, “Dog”, “Happy”, “House”, “Marvin”, “Sheila”, “Tree”, “Wow”.

v0.01版本包含共计30类,64,727个音频片段,v0.02版本包含共计35类,105,829个音频片段。更详细的介绍参见前述论文,以及数据集的README。

代码实现基于torchaudio并扩充了功能,同时也参考了 原论文的实现

Module contents

spikingjelly.datasets.play_frame(x: torch.Tensor, save_gif_to: Optional[str] = None) None[源代码]
参数
  • x (torch.Tensor or np.ndarray) – frames with shape=[T, 2, H, W]

  • save_gif_to (str) – If None, this function will play the frames. If True, this function will not play the frames but save frames to a gif file in the directory save_gif_to

返回

None

spikingjelly.datasets.load_matlab_mat(file_name: str) Dict[源代码]
参数

file_name (str) – path of the matlab’s mat file

返回

a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

返回类型

Dict

spikingjelly.datasets.load_aedat_v3(file_name: str) Dict[源代码]
参数

file_name (str) – path of the aedat v3 file

返回

a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

返回类型

Dict

This function is written by referring to https://gitlab.com/inivation/dv/dv-python . It can be used for DVS128 Gesture.

spikingjelly.datasets.load_ATIS_bin(file_name: str) Dict[源代码]
参数

file_name (str) – path of the aedat v3 file

返回

a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

返回类型

Dict

This function is written by referring to https://github.com/jackd/events-tfds .

Each ATIS binary example is a separate binary file consisting of a list of events. Each event occupies 40 bits as described below: bit 39 - 32: Xaddress (in pixels) bit 31 - 24: Yaddress (in pixels) bit 23: Polarity (0 for OFF, 1 for ON) bit 22 - 0: Timestamp (in microseconds)

spikingjelly.datasets.load_npz_frames(file_name: str) numpy.ndarray[源代码]
参数

file_name (str) – path of the npz file that saves the frames

返回

frames

返回类型

np.ndarray

spikingjelly.datasets.integrate_events_segment_to_frame(events: Dict, H: int, W: int, j_l: int = 0, j_r: int = - 1) numpy.ndarray[源代码]
param events

a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

type events

Dict

param H

height of the frame

type H

int

param W

weight of the frame

type W

int

param j_l

the start index of the integral interval, which is included

type j_l

int

param j_r

the right index of the integral interval, which is not included

type j_r

return

frames

rtype

np.ndarray

Denote a two channels frame as \(F\) and a pixel at \((p, x, y)\) as \(F(p, x, y)\), the pixel value is integrated from the events data whose indices are in \([j_{l}, j_{r})\):

\[F(p, x, y) &= \sum_{i = j_{l}}^{j_{r} - 1} \mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})\]

where \(\lfloor \cdot floor\) is the floor operation, \(\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})\) is an indicator function and it equals 1 only when \((p, x, y) = (p_{i}, x_{i}, y_{i})\).

spikingjelly.datasets.cal_fixed_frames_number_segment_index(events_t: numpy.ndarray, split_by: str, frames_num: int) tuple[源代码]
参数
  • events_t (numpy.ndarray) – events’ t

  • split_by (str) – ‘time’ or ‘number’

  • frames_num (int) – the number of frames

返回

a tuple (j_l, j_r)

返回类型

tuple

Denote frames_num as \(M\), if split_by is 'time', then

\[\begin{split}\Delta T & = [\frac{t_{N-1} - t_{0}}{M}] \\ j_{l} & = \mathop{\arg\min}\limits_{k} \{t_{k} | t_{k} \geq t_{0} + \Delta T \cdot j\} \\ j_{r} & = \begin{cases} \mathop{\arg\max}\limits_{k} \{t_{k} | t_{k} < t_{0} + \Delta T \cdot (j + 1)\} + 1, & j < M - 1 \cr N, & j = M - 1 \end{cases}\end{split}\]

If split_by is 'number', then

\[\begin{split}j_{l} & = [\frac{N}{M}] \cdot j \\ j_{r} & = \begin{cases} [\frac{N}{M}] \cdot (j + 1), & j < M - 1 \cr N, & j = M - 1 \end{cases}\end{split}\]
spikingjelly.datasets.integrate_events_by_fixed_frames_number(events: Dict, split_by: str, frames_num: int, H: int, W: int) numpy.ndarray[源代码]
参数
  • events (Dict) – a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

  • split_by (str) – ‘time’ or ‘number’

  • frames_num (int) – the number of frames

  • H (int) – the height of frame

  • W (int) – the weight of frame

返回

frames

返回类型

np.ndarray

Integrate events to frames by fixed frames number. See cal_fixed_frames_number_segment_index and integrate_events_segment_to_frame for more details.

spikingjelly.datasets.integrate_events_file_to_frames_file_by_fixed_frames_number(events_np_file: str, output_dir: str, split_by: str, frames_num: int, H: int, W: int, print_save: bool = False) None[源代码]
参数
  • events_np_file (str) – path of the events np file

  • output_dir (str) – output directory for saving the frames

  • split_by (str) – ‘time’ or ‘number’

  • frames_num (int) – the number of frames

  • H (int) – the height of frame

  • W (int) – the weight of frame

  • print_save (bool) – If True, this function will print saved files’ paths.

返回

None

Integrate a events file to frames by fixed frames number and save it. See cal_fixed_frames_number_segment_index and integrate_events_segment_to_frame for more details.

spikingjelly.datasets.integrate_events_by_fixed_duration(events: Dict, duration: int, H: int, W: int) numpy.ndarray[源代码]
参数
  • events (Dict) – a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

  • duration (int) – the time duration of each frame

  • H (int) – the height of frame

  • W (int) – the weight of frame

返回

frames

返回类型

np.ndarray

Integrate events to frames by fixed time duration of each frame.

spikingjelly.datasets.integrate_events_file_to_frames_file_by_fixed_duration(events_np_file: str, output_dir: str, duration: int, H: int, W: int, print_save: bool = False) None[源代码]
参数
  • events_np_file (str) – path of the events np file

  • output_dir (str) – output directory for saving the frames

  • duration (int) – the time duration of each frame

  • H (int) – the height of frame

  • W (int) – the weight of frame

  • print_save (bool) – If True, this function will print saved files’ paths.

返回

None

Integrate events to frames by fixed time duration of each frame.

spikingjelly.datasets.create_same_directory_structure(source_dir: str, target_dir: str) None[源代码]
参数
  • source_dir (str) – Path of the directory that be copied from

  • target_dir (str) – Path of the directory that be copied to

返回

None

Create the same directory structure in target_dir with that of source_dir.

spikingjelly.datasets.split_to_train_test_set(train_ratio: float, origin_dataset: torch.utils.data.dataset.Dataset, num_classes: int, random_split: bool = False)[源代码]
参数
  • train_ratio (float) – split the ratio of the origin dataset as the train set

  • origin_dataset (torch.utils.data.Dataset) – the origin dataset

  • num_classes (int) – total classes number, e.g., 10 for the MNIST dataset

  • random_split (int) – If False, the front ratio of samples in each classes will be included in train set, while the reset will be included in test set. If True, this function will split samples in each classes randomly. The randomness is controlled by numpy.randon.seed

返回

a tuple (train_set, test_set)

返回类型

tuple

spikingjelly.datasets.pad_sequence_collate(batch: list)[源代码]
参数

batch (list) – a list of samples that contains (x, y), where x.shape=[T, *] and y is the label

返回

batched samples, where x is padded with the same length

返回类型

tuple

This function can be use as the collate_fn for DataLoader to process the dataset with variable length, e.g., a NeuromorphicDatasetFolder with fixed duration to integrate events to frames.

Here is an example:

class RandomLengthDataset(torch.utils.data.Dataset):
    def __init__(self, n=1000):
        super().__init__()
        self.n = n

    def __getitem__(self, i):
        return torch.rand([random.randint(1, 10), 28, 28]), random.randint(0, 10)

    def __len__(self):
        return self.n

loader = torch.utils.data.DataLoader(RandomLengthDataset(n=32), batch_size=16, collate_fn=pad_sequence_collate)

for x, y, z in loader:
    print(x.shape, y.shape, z)

And the outputs are:

torch.Size([10, 16, 28, 28]) torch.Size([16]) tensor([ 1,  9,  3,  4,  1,  2,  9,  7,  2,  1,  5,  7,  4, 10,  9,  5])
torch.Size([10, 16, 28, 28]) torch.Size([16]) tensor([ 1,  8,  7, 10,  3, 10,  6,  7,  5,  9, 10,  5,  9,  6,  7,  6])
spikingjelly.datasets.padded_sequence_mask(sequence_len: torch.Tensor, T=None)[源代码]
param sequence_len

a tensor shape = [N] that contains sequences lengths of each batch element

type sequence_len

torch.Tensor

param T

The maximum length of sequences. If None, the maximum element in sequence_len will be seen as T

type T

int

return

a bool mask with shape = [T, N], where the padded position is False

rtype

torch.Tensor

Here is an example:

x1 = torch.rand([2, 6])
x2 = torch.rand([3, 6])
x3 = torch.rand([4, 6])
x = torch.nn.utils.rnn.pad_sequence([x1, x2, x3])  # [T, N, *]
print('x.shape=', x.shape)
x_len = torch.as_tensor([x1.shape[0], x2.shape[0], x3.shape[0]])
mask = padded_sequence_mask(x_len)
print('mask.shape=', mask.shape)
print('mask=

‘, mask)

And the outputs are:

x.shape= torch.Size([4, 3, 6])
mask.shape= torch.Size([4, 3])
mask=
 tensor([[ True,  True,  True],
        [ True,  True,  True],
        [False,  True,  True],
        [False, False,  True]])
class spikingjelly.datasets.NeuromorphicDatasetFolder(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]

基类:torchvision.datasets.folder.DatasetFolder

参数
  • root (str) – root path of the dataset

  • train (bool) – whether use the train set. Set True or False for those datasets provide train/test division, e.g., DVS128 Gesture dataset. If the dataset does not provide train/test division, e.g., CIFAR10-DVS, please set None and use split_to_train_test_set function to get train/test set

  • data_type (str) – event or frame

  • frames_number (int) – the integrated frame number

  • split_by (str) – time or number

  • duration (int) – the time duration of each frame

  • transform (callable) – a function/transform that takes in a sample and returns a transformed version. E.g, transforms.RandomCrop for images.

  • target_transform (callable) – a function/transform that takes in the target and transforms it.

The base class for neuromorphic dataset. Users can define a new dataset by inheriting this class and implementing all abstract methods. Users can refer to spikingjelly.datasets.dvs128_gesture.DVS128Gesture.

If data_type == 'event'

the sample in this dataset is a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray.

If data_type == 'frame' and frames_number is not None

events will be integrated to frames with fixed frames number. split_by will define how to split events. See cal_fixed_frames_number_segment_index for more details.

If data_type == 'frame', frames_number is None, and duration is not None

events will be integrated to frames with fixed time duration.

abstract static load_origin_data(file_name: str) Dict[源代码]
参数

file_name (str) – path of the events file

返回

a dict whose keys are [‘t’, ‘x’, ‘y’, ‘p’] and values are numpy.ndarray

返回类型

Dict

This function defines how to read the origin binary data.

abstract static resource_url_md5() list[源代码]
返回

A list url that url[i] is a tuple, which contains the i-th file’s name, download link, and MD5

返回类型

list

abstract static downloadable() bool[源代码]
返回

Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually

返回类型

bool

abstract static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
参数
  • download_root (str) – Root directory path which saves downloaded dataset files

  • extract_root (str) – Root directory path which saves extracted files from downloaded files

返回

None

This function defines how to extract download files.

abstract static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
参数
  • extract_root (str) – Root directory path which saves extracted files from downloaded files

  • events_np_root – Root directory path which saves events files in the npz format

返回

None

This function defines how to convert the origin binary data in extract_root to npz format and save converted files in events_np_root.

abstract static get_H_W() Tuple[源代码]
返回

A tuple (H, W), where H is the height of the data and W` is the weight of the data. For example, this function returns ``(128, 128) for the DVS128 Gesture dataset.

返回类型

tuple