spikingjelly.datasets.utils module#

spikingjelly.datasets.utils.save_as_pic(x: Tensor | ndarray, save_pic_to: str = './', pic_first_name: str = 'pic') None[源代码]#

API Language: 中文 | English


  • 中文

将事件帧 x 保存为一组图片。函数会将 x[:, 0] 写入绿色通道、x[:, 1] 写入蓝色通道, 并按 {pic_first_name}_{t}.png 的格式逐帧保存。

参数:
  • x (Union[torch.Tensor, np.ndarray]) -- 形状为 [T, 2, H, W] 的帧序列

  • save_pic_to (str) -- 图片保存目录

  • pic_first_name (str) -- 图片文件名前缀,保存文件名形如 f"{pic_first_name}_{t}.png"

返回:

None

返回类型:

None


  • English

Save event frames in x as images. The function writes x[:, 0] to the green channel, x[:, 1] to the blue channel, and stores each frame as {pic_first_name}_{t}.png.

参数:
  • x (Union[torch.Tensor, np.ndarray]) -- frames with shape=[T, 2, H, W]

  • save_pic_to (str) -- where to store images

  • pic_first_name (str) -- prefix for image names (stored image names are: f"{pic_first_name}_{t}.png")

返回:

None

返回类型:

None


  • 代码示例 | Example

save_as_pic(frame, './demo', 'first_pic')
spikingjelly.datasets.utils.save_every_frame_of_an_entire_DVS_dataset(dataset: str, dataset_path: str, time_steps: int, save_pic_to: str = './', number_of_threads: int = 4)[源代码]#

API Language: 中文 | English


  • 中文

将指定 DVS 数据集的每个样本按固定帧数加载为帧数据, 并将所有帧逐张保存为图片。

参数:
  • dataset (str) -- 要保存的数据集名称。当前可用的选项有:DVS128Gesture、CIFAR10DVS 和 NCaltech101。

  • dataset_path (str) -- 与加载数据集相同的存储路径。

  • time_steps (int) -- 与加载数据集相同的 T。

  • save_pic_to (str) -- 每一帧图像的保存位置。

  • number_of_threads (int) -- 用于保存图像的线程数。

返回:

None

返回类型:

None

抛出:

ValueError -- 当必要参数为空, 或 dataset 不是 "DVS128Gesture", "CIFAR10DVS", "NCaltech101" 之一时抛出。


  • English

Load every sample from the specified DVS dataset as frame data with a fixed frame count, and save every frame as an image.

参数:
  • dataset (str) -- name of the dataset to be saved. The current available options are: DVS128Gesture, CIFAR10DVS and NCaltech101.

  • dataset_path (str) -- same storage path as loading dataset.

  • time_steps (int) -- same T as loading the dataset.

  • save_pic_to (str) -- where to store each frame's image.

  • number_of_threads (int) -- how many threads are used to save images.

返回:

None

返回类型:

None

抛出:

ValueError -- raised when required arguments are empty, or when dataset is not one of "DVS128Gesture", "CIFAR10DVS", or "NCaltech101".


  • 代码示例 | Example

save_every_frame_of_an_entire_DVS_dataset(dataset='DVS128Gesture', dataset_path="../../datasets/DVS128Gesture",
                                        time_steps=16, save_pic_to='./demo', number_of_threads=20)
save_every_frame_of_an_entire_DVS_dataset(dataset='CIFAR10DVS', dataset_path="../../datasets/cifar10dvs",
                                        time_steps=10, save_pic_to='./demo', number_of_threads=20)
save_every_frame_of_an_entire_DVS_dataset(dataset='NCaltech101', dataset_path="../../datasets/NCaltech101",
                                        time_steps=14, save_pic_to='./demo', number_of_threads=20)
spikingjelly.datasets.utils.play_frame(x: Tensor | ndarray, save_gif_to: str = None) None[源代码]#

API Language: 中文 | English


  • 中文

参数:
  • x (Union[torch.Tensor, np.ndarray]) -- 形状为 shape=[T, 2, H, W] 的帧

  • save_gif_to (str) -- 如果 None,此函数将播放帧。 如果不为 None,此函数将不播放帧,而是将帧保存到路径 save_gif_to 中的 gif 文件

返回:

None

返回类型:

None


  • English

参数:
  • x (Union[torch.Tensor, np.ndarray]) -- frames with shape=[T, 2, H, W]

  • save_gif_to (str) -- If None, this function will play the frames. If not None, this function will not play the frames but save the frames to a gif file in the path save_gif_to

返回:

None

返回类型:

None

spikingjelly.datasets.utils.load_aedat_v3(file_name: str | Path) dict[源代码]#

API Language: 中文 | English


  • 中文

此函数参考了 inivation/dv/dv-python 编写。 它可以用于 DVS128 Gesture。

参数:

file_name (Union[str, pathlib.Path]) -- aedat v3 文件的路径

返回:

一个字典,其键为 ['t', 'x', 'y', 'p'],值为 numpy.ndarray

返回类型:

dict


  • English

This function is written by referring to inivation/dv/dv-python . It can be used for DVS128 Gesture.

参数:

file_name (Union[str, pathlib.Path]) -- path of the aedat v3 file

返回:

a dict whose keys are ['t', 'x', 'y', 'p'] and values are numpy.ndarray

返回类型:

dict

spikingjelly.datasets.utils.load_ATIS_bin(file_name: str | Path) dict[源代码]#

API Language: 中文 | English


  • 中文

此函数参考了 jackd/events-tfds 编写。 每个 ATIS 二进制示例都是一个独立的二进制文件,包含一个事件列表。每个事件占用 40 位,如下所述: 位 39 - 32: X地址(以像素为单位) 位 31 - 24: Y地址(以像素为单位) 位 23: 极性(0 表示 OFF,1 表示 ON) 位 22 - 0: 时间戳(以微秒为单位)

参数:

file_name (Union[str, pathlib.Path]) -- ATIS 二进制文件的路径

返回:

一个字典,其键为 ['t', 'x', 'y', 'p'],值为 numpy.ndarray

返回类型:

dict


  • English

This function is written by referring to jackd/events-tfds . Each ATIS binary example is a separate binary file consisting of a list of events. Each event occupies 40 bits as described below: bit 39 - 32: Xaddress (in pixels) bit 31 - 24: Yaddress (in pixels) bit 23: Polarity (0 for OFF, 1 for ON) bit 22 - 0: Timestamp (in microseconds)

参数:

file_name (Union[str, pathlib.Path]) -- path of the ATIS binary file

返回:

a dict whose keys are ['t', 'x', 'y', 'p'] and values are numpy.ndarray

返回类型:

dict

spikingjelly.datasets.utils.load_npz_frames(file_name: str | Path) ndarray[源代码]#

API Language: 中文 | English


  • 中文

参数:

file_name (Union[str, pathlib.Path]) -- 保存帧的 npz 文件的路径

返回:

返回类型:

np.ndarray


  • English

参数:

file_name (Union[str, pathlib.Path]) -- path of the npz file that saves the frames

返回:

frames

返回类型:

np.ndarray

spikingjelly.datasets.utils.integrate_events_segment_to_frame(x: ndarray, y: ndarray, p: ndarray, H: int, W: int, j_l: int = 0, j_r: int = -1) ndarray[源代码]#

API Language: 中文 | English


  • 中文

将双通道帧记为 \(F\),像素 \((p, x, y)\) 的像素值是从索引在 \([j_{l}, j_{r})\) 内的事件数据积分得到的:

\[\begin{split}F(p, x, y) = \\sum_{i = j_{l}}^{j_{r} - 1} \\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})\end{split}\]

其中 \(\\lfloor \\cdot \\rfloor\) 是取整运算,\(\\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})\) 是指示函数,仅在 \((p, x, y) = (p_{i}, x_{i}, y_{i})\) 时等于 1。

参数:
  • x (numpy.ndarray) -- 事件的 x 坐标

  • y (numpy.ndarray) -- 事件的 y 坐标

  • p (numpy.ndarray) -- 事件的极性

  • H (int) -- 帧的高度

  • W (int) -- 帧的宽度

  • j_l (int) -- 积分区间的起始索引(包含)

  • j_r (int) -- 积分区间的右端索引(不包含)

返回:

单个双通道帧

返回类型:

np.ndarray


  • English

Denote a two channels frame as \(F\) and a pixel at \((p, x, y)\) as \(F(p, x, y)\), the pixel value is integrated from the events data whose indices are in \([j_{l}, j_{r})\):

\[\begin{split}F(p, x, y) = \\sum_{i = j_{l}}^{j_{r} - 1} \\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})\end{split}\]

where \(\\lfloor \\cdot \\rfloor\) is the floor operation, \(\\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})\) is an indicator function and it equals 1 only when \((p, x, y) = (p_{i}, x_{i}, y_{i})\).

参数:
  • x (numpy.ndarray) -- x-coordinate of events

  • y (numpy.ndarray) -- y-coordinate of events

  • p (numpy.ndarray) -- polarity of events

  • H (int) -- height of the frame

  • W (int) -- width of the frame

  • j_l (int) -- the start index of the integral interval, which is included

  • j_r (int) -- the right index of the integral interval, which is not included

返回:

a single two-channel frame

返回类型:

np.ndarray

spikingjelly.datasets.utils.cal_fixed_frames_number_segment_index(events_t: ndarray, split_by: str, frames_num: int) tuple[源代码]#

API Language: 中文 | English


  • 中文

frames_num 记为 \(M\),如果 split_by'time',则

\[\begin{split}\\Delta T & = [\\frac{t_{N-1} - t_{0}}{M}] \\\\ j_{l} & = \\mathop{\\arg\\min}\\limits_{k} \\{t_{k} | t_{k} \\geq t_{0} + \\Delta T \\cdot j\\} \\\\ j_{r} & = \\begin{cases} \\mathop{\\arg\\max}\\limits_{k} \\{t_{k} | t_{k} < t_{0} + \\Delta T \\cdot (j + 1)\\} + 1, & j < M - 1 \\cr N, & j = M - 1 \\end{cases}\end{split}\]

如果 split_by'number',则

\[\begin{split}j_{l} & = [\\frac{N}{M}] \\cdot j \\\\ j_{r} & = \\begin{cases} [\\frac{N}{M}] \\cdot (j + 1), & j < M - 1 \\cr N, & j = M - 1 \\end{cases}\end{split}\]
参数:
  • events_t (numpy.ndarray) -- 事件的 t

  • split_by (str) -- 'time' 或 'number'

  • frames_num (int) -- 帧的数量

返回:

一个元组 (j_l, j_r)

返回类型:

tuple


  • English

Denote frames_num as \(M\), if split_by is 'time', then

\[\begin{split}\\Delta T & = [\\frac{t_{N-1} - t_{0}}{M}] \\\\ j_{l} & = \\mathop{\\arg\\min}\\limits_{k} \\{t_{k} | t_{k} \\geq t_{0} + \\Delta T \\cdot j\\} \\\\ j_{r} & = \\begin{cases} \\mathop{\\arg\\max}\\limits_{k} \\{t_{k} | t_{k} < t_{0} + \\Delta T \\cdot (j + 1)\\} + 1, & j < M - 1 \\cr N, & j = M - 1 \\end{cases}\end{split}\]

If split_by is 'number', then

\[\begin{split}j_{l} & = [\\frac{N}{M}] \\cdot j \\\\ j_{r} & = \\begin{cases} [\\frac{N}{M}] \\cdot (j + 1), & j < M - 1 \\cr N, & j = M - 1 \\end{cases}\end{split}\]
参数:
  • events_t (numpy.ndarray) -- events' t

  • split_by (str) -- 'time' or 'number'

  • frames_num (int) -- the number of frames

返回:

a tuple (j_l, j_r)

返回类型:

tuple

spikingjelly.datasets.utils.integrate_events_by_fixed_frames_number(events: dict, split_by: str, frames_num: int, H: int, W: int) ndarray[源代码]#

API Language: 中文 | English


  • 中文

按固定帧数将事件积分到帧中。 详见 cal_fixed_frames_number_segment_index()integrate_events_segment_to_frame()

参数:
  • events (dict) -- 一个字典,其键为 ['t', 'x', 'y', 'p'],值为 numpy.ndarray

  • split_by (str) -- 'time' 或 'number'

  • frames_num (int) -- 帧的数量

  • H (int) -- 帧的高度

  • W (int) -- 帧的宽度

返回:

返回类型:

np.ndarray


  • English

Integrate events to frames by fixed frames number. See cal_fixed_frames_number_segment_index() and integrate_events_segment_to_frame() for more details.

参数:
  • events (dict) -- a dict whose keys are ['t', 'x', 'y', 'p'] and values are numpy.ndarray

  • split_by (str) -- 'time' or 'number'

  • frames_num (int) -- the number of frames

  • H (int) -- the height of frame

  • W (int) -- the width of frame

返回:

frames

返回类型:

np.ndarray

spikingjelly.datasets.utils.integrate_events_file_to_frames_file_by_fixed_frames_number(loader: Callable, events_np_file: str, output_dir: str, split_by: str, frames_num: int, H: int, W: int, print_save: bool = False) None[源代码]#

API Language: 中文 | English


  • 中文

将单个事件文件按固定帧数积分成帧,并将结果保存到 output_dir 下与 events_np_file 同名的 .npz 文件中。保存文件包含键 frames

参数:
  • loader (Callable) -- 从 events_np_file 加载事件字典的函数

  • events_np_file (str) -- 事件文件路径

  • output_dir (str) -- 帧文件输出目录

  • split_by (str) -- 'time''number'

  • frames_num (int) -- 帧数量

  • H (int) -- 帧高度

  • W (int) -- 帧宽度

  • print_save (bool) -- 若为 True,则打印保存路径

返回:

None

返回类型:

None


  • English

Integrate an event file to frames by a fixed frame count and save it. The saved archive contains the frames key. See cal_fixed_frames_number_segment_index() and integrate_events_segment_to_frame() for more details.

参数:
  • loader (Callable) -- a function that can load events from events_np_file

  • events_np_file (str) -- path of the events np file

  • output_dir (str) -- output directory for saving the frames

  • split_by (str) -- 'time' or 'number'

  • frames_num (int) -- the number of frames

  • H (int) -- the height of frame

  • W (int) -- the width of frame

  • print_save (bool) -- If True, this function will print saved files' paths.

返回:

None

返回类型:

None

spikingjelly.datasets.utils.integrate_events_by_fixed_duration(events: dict, duration: int, H: int, W: int) ndarray[源代码]#

API Language: 中文 | English


  • 中文

按每帧固定时间时长将事件积分到帧中。

参数:
  • events (dict) -- 一个字典,其键为 ['t', 'x', 'y', 'p'],值为 numpy.ndarray

  • duration (int) -- 每帧的时间时长

  • H (int) -- 帧的高度

  • W (int) -- 帧的宽度

返回:

返回类型:

np.ndarray


  • English

Integrate events to frames by fixed time duration of each frame.

参数:
  • events (dict) -- a dict whose keys are ['t', 'x', 'y', 'p'] and values are numpy.ndarray

  • duration (int) -- the time duration of each frame

  • H (int) -- the height of frame

  • W (int) -- the width of frame

返回:

frames

返回类型:

np.ndarray

spikingjelly.datasets.utils.integrate_events_file_to_frames_file_by_fixed_duration(loader: Callable, events_np_file: str, output_dir: str, duration: int, H: int, W: int, print_save: bool = False) int[源代码]#

API Language: 中文 | English


  • 中文

按每帧固定时间时长将事件积分到帧中并保存。

参数:
  • loader (Callable) -- 一个可以从 events_np_file 加载事件的函数

  • events_np_file (str) -- 事件的 np 文件的路径

  • output_dir (str) -- 保存帧的输出目录

  • duration (int) -- 每帧的时间时长

  • H (int) -- 帧的高度

  • W (int) -- 帧的宽度

  • print_save (bool) -- 如果 True,此函数将打印保存的文件的路径。

返回:

帧的数量

返回类型:

int


  • English

Integrate events to frames by fixed time duration of each frame.

参数:
  • loader (Callable) -- a function that can load events from events_np_file

  • events_np_file (str) -- path of the events np file

  • output_dir (str) -- output directory for saving the frames

  • duration (int) -- the time duration of each frame

  • H (int) -- the height of frame

  • W (int) -- the weight of frame

  • print_save (bool) -- If True, this function will print saved files' paths.

返回:

number of frames saved

返回类型:

int

spikingjelly.datasets.utils.save_frames_to_npz_and_print(fname: str, frames: ndarray)[源代码]#

API Language: 中文 | English


  • 中文

参数:
  • fname (str) -- 目标 npz 文件的路径

  • frames (np.ndarray) -- 帧对象

返回:

None


  • English

参数:
  • fname (str) -- path of the target npz file

  • frames (np.ndarray) -- frames object

返回:

None

spikingjelly.datasets.utils.create_same_directory_structure(source_dir: str | Path, target_dir: str | Path) None[源代码]#

API Language: 中文 | English


  • 中文

target_dir 中创建与 source_dir 相同的目录结构。

参数:
返回:

None


  • English

Create the same directory structure in target_dir with that of source_dir.

参数:
  • source_dir (Union[str, pathlib.Path]) -- Path of the directory that be copied from

  • target_dir (Union[str, pathlib.Path]) -- Path of the directory that be copied to

返回:

None

spikingjelly.datasets.utils.split_to_train_test_set(train_ratio: float, origin_dataset: Dataset, num_classes: int, random_split: bool = False)[源代码]#

API Language: 中文 | English


  • 中文

参数:
  • train_ratio (float) -- 将原始数据集按此比例划分为训练集

  • origin_dataset (torch.utils.data.Dataset) -- 原始数据集

  • num_classes (int) -- 总类别数,例如 MNIST 数据集为 10

  • random_split (bool) -- 如果 False,每个类的前半部分样本将包含在训练集中,其余部分包含在测试集中。 如果 True,此函数将随机划分每个类别的样本。 随机性由 numpy.random.seed 控制

返回:

一个元组 (train_set, test_set), 二者均为基于 origin_dataset 构造的 torch.utils.data.Subset

返回类型:

tuple[torch.utils.data.Subset, torch.utils.data.Subset]


  • English

参数:
  • train_ratio (float) -- split the ratio of the origin dataset as the train set

  • origin_dataset (torch.utils.data.Dataset) -- the origin dataset

  • num_classes (int) -- total classes number, e.g., 10 for the MNIST dataset

  • random_split (bool) -- If False, the front ratio of samples in each classes will be included in train set, while the reset will be included in test set. If True, this function will split samples in each classes randomly. The randomness is controlled by numpy.random.seed

返回:

a tuple (train_set, test_set), where both elements are torch.utils.data.Subset instances built from origin_dataset

返回类型:

tuple[torch.utils.data.Subset, torch.utils.data.Subset]

spikingjelly.datasets.utils.fast_split_to_train_test_set(train_ratio: float, origin_dataset: Dataset, num_classes: int, random_split: bool = False, batch_size: int = 16)[源代码]#

API Language: 中文 | English


  • 中文

参数:
  • train_ratio (float) -- 将原始数据集按此比例划分为训练集

  • origin_dataset (torch.utils.data.Dataset) -- 原始数据集

  • num_classes (int) -- 总类别数,例如 MNIST 数据集为 10

  • random_split (bool) -- 如果 False,每个类的前半部分样本将包含在训练集中,其余部分包含在测试集中。 如果 True,此函数将随机划分每个类别的样本。随机性由 numpy.random.seed 控制

  • batch_size (int) -- 每个批次处理的样本数量

返回:

一个元组 (train_set, test_set), 二者均为基于 origin_dataset 构造的 torch.utils.data.Subset

返回类型:

tuple[torch.utils.data.Subset, torch.utils.data.Subset]


  • English

参数:
  • train_ratio (float) -- split the ratio of the origin dataset as the train set

  • origin_dataset (torch.utils.data.Dataset) -- the origin dataset

  • num_classes (int) -- total classes number, e.g., 10 for the MNIST dataset

  • random_split (bool) -- If False, the front ratio of samples in each classes will be included in train set, while the reset will be included in test set. If True, this function will split samples in each classes randomly. The randomness is controlled by numpy.random.seed

  • batch_size (int) -- the number of samples to process in each batch

返回:

a tuple (train_set, test_set), where both elements are torch.utils.data.Subset instances built from origin_dataset

返回类型:

tuple[torch.utils.data.Subset, torch.utils.data.Subset]

spikingjelly.datasets.utils.pad_sequence_collate(batch: list)[源代码]#

API Language: 中文 | English


  • 中文

  • 中文

可作为 DataLoadercollate_fn 处理变长序列样本,例如按固定时长积分为帧的 NeuromorphicDatasetFolder。函数会将每个 x 转成 torch.Tensor,再用 torch.nn.utils.rnn.pad_sequence(..., batch_first=True) 补齐。

参数:

batch (list) -- 样本列表,每个样本形如 (x, y),其中 x 是长度可变的序列,y 是标签

返回:

(x_p, y, x_len), 其中 x_p 是按相同长度补齐后的批数据, y 是标签张量, x_len 是各样本原始长度张量

返回类型:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]


  • English

  • English

This function can be used as the collate_fn for DataLoader to process the dataset with variable length, e.g., a NeuromorphicDatasetFolder with fixed duration to integrate events to frames.

参数:

batch (list) -- a list of samples (x, y), where x is a variable-length sequence and y is the label

返回:

batched samples (x_p, y, x_len), where x_p is padded x to the same length, y is the label, and x_len is the length of x

返回类型:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]


  • 代码示例 | Example

class VariableLengthDataset(torch.utils.data.Dataset):
    def __init__(self, n=1000):
        super().__init__()
        self.n = n

    def __getitem__(self, i):
        return torch.rand([i + 1, 2]), self.n - i - 1

    def __len__(self):
        return self.n


loader = torch.utils.data.DataLoader(
    VariableLengthDataset(n=32),
    batch_size=2,
    collate_fn=pad_sequence_collate,
    shuffle=True,
)

for i, (x_p, label, x_len) in enumerate(loader):
    print(f"x_p.shape={x_p.shape}, label={label}, x_len={x_len}")
    if i == 2:
        break

Outputs:

x_p.shape=torch.Size([2, 18, 2]), label=tensor([14, 30]), x_len=tensor([18,  2])
x_p.shape=torch.Size([2, 29, 2]), label=tensor([3, 6]), x_len=tensor([29, 26])
x_p.shape=torch.Size([2, 23, 2]), label=tensor([ 9, 23]), x_len=tensor([23,  9])
返回:

None

返回类型:

None

spikingjelly.datasets.utils.padded_sequence_mask(sequence_len: Tensor, T: int | None = None)[源代码]#

API Language: 中文 | English


  • 中文

  • 中文

根据每个样本的有效序列长度生成形状为 [T, N] 的布尔掩码。若 TNone, 则使用 sequence_len 中的最大值。若 sequence_len 位于 CUDA 且可用 cupy, 则调用自定义 CuPy kernel;否则使用 PyTorch 广播计算。

参数:
  • sequence_len (torch.Tensor) -- 形状为 [N] 的张量,包含每个 batch 元素的序列长度

  • T (Optional[int]) -- 序列最大长度。若为 None,则取 sequence_len 中的最大值

返回:

形状为 [T, N] 的布尔掩码,填充位置为 False

返回类型:

torch.Tensor


  • English

  • English

Generate a bool mask with shape [T, N] from the valid sequence length of each sample. If T is None, the maximum element in sequence_len will be used. When sequence_len is on CUDA and cupy is available, this function uses the custom CuPy kernel; otherwise it falls back to PyTorch broadcasting.

参数:
  • sequence_len (torch.Tensor) -- a tensor shape = [N] that contains sequence lengths of each batch element

  • T (Optional[int]) -- the maximum length of sequences. If None, the maximum element in sequence_len will be used as T

返回:

a bool mask with shape = [T, N], where the padded position is False

返回类型:

torch.Tensor


  • 代码示例 | Example

x1 = torch.rand([2, 6])
x2 = torch.rand([3, 6])
x3 = torch.rand([4, 6])
x = torch.nn.utils.rnn.pad_sequence([x1, x2, x3])  # [T, N, *]
print("x.shape=", x.shape)
x_len = torch.as_tensor([x1.shape[0], x2.shape[0], x3.shape[0]])
mask = padded_sequence_mask(x_len)
print("mask.shape=", mask.shape)
print("mask=\n", mask)

Outputs:

x.shape= torch.Size([4, 3, 6])
mask.shape= torch.Size([4, 3])
mask=
 tensor([[ True,  True,  True],
        [ True,  True,  True],
        [False,  True,  True],
        [False, False,  True]])
返回:

None

返回类型:

None

spikingjelly.datasets.utils.create_sub_dataset(source_dir: str, target_dir: str, ratio: float, use_soft_link=True, randomly=False)[源代码]#

API Language: 中文 | English


  • 中文

从原始数据集中按类别子目录结构复制一个子数据集。每个叶子目录会按 ratio 选择样本; 若 use_soft_linkTrue 则创建软链接,否则复制文件。若 randomlyTrue, 则使用 numpy.random.shuffle 随机打乱待选文件顺序。

参数:
  • source_dir (str) -- 原始数据集目录

  • target_dir (str) -- 子数据集目录

  • ratio (float) -- 子数据集从原始数据集中复制的样本比例

  • use_soft_link (bool) -- 若为 True,则使用软链接;否则直接复制文件

  • randomly (bool) -- 若为 True,则随机选择复制的文件。随机性由 numpy.random.seed 控制

返回:

None

返回类型:

None


  • English

Create a sub dataset with copy ratio of samples from the origin dataset.

参数:
  • source_dir (str) -- the directory path of the origin dataset

  • target_dir (str) -- the directory path of the sub dataset

  • ratio (float) -- the ratio of samples the sub dataset will copy from the origin dataset

  • use_soft_link (bool) -- if True, the sub dataset will use soft links; otherwise, it will copy files

  • randomly (bool) -- if True, the files copied from the origin dataset will be picked up randomly. The randomness is controlled by numpy.random.seed

返回:

None

返回类型:

None