spikingjelly.datasets.n_mnist 源代码

from typing import Callable, Optional, Tuple, Union
import os
from pathlib import Path
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time

from torchvision.datasets.utils import extract_archive

from .. import configure
from . import utils
from .base import NeuromorphicDatasetFolder


__all__ = ["NMNIST"]


def _read_bin_save_to_np(bin_file: Union[str, Path], np_file: Union[str, Path]):
    events = utils.load_ATIS_bin(bin_file)
    utils.np_savez(np_file, t=events["t"], x=events["x"], y=events["y"], p=events["p"])
    print(f"Save [{bin_file}] to [{np_file}].")


[文档] class NMNIST(NeuromorphicDatasetFolder): def __init__( self, root: str, train: bool = True, data_type: Optional[str] = "event", frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: """ **API Language:** :ref:`中文 <NMNIST.__init__-cn>` | :ref:`English <NMNIST.__init__-en>` ---- .. _NMNIST.__init__-cn: * **中文** N-MNIST 数据集,由 `Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades <https://www.frontiersin.org/articles/10.3389/fnins.2015.00437/full>`_ 提出。 有关参数的更多详细信息,请参考 :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>` ---- .. _NMNIST.__init__-en: * **English** The N-MNIST dataset, which is proposed by `Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades <https://www.frontiersin.org/articles/10.3389/fnins.2015.00437/full>`_. Refer to :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>` :param root: 数据集的根路径 :type root: Union[str, Path] :param train: 是否使用训练集 :type train: Optional[bool] :param data_type: ``\"event\"`` 或 ``\"frame\"`` :type data_type: str :param frames_number: 积分帧的数量 :type frames_number: Optional[int] :param split_by: ``\"time\"`` 或 ``\"number\"`` :type split_by: Optional[str] :param duration: 每帧的时间时长 :type duration: Optional[int] :param custom_integrate_function: 用户自定义积分函数 :type custom_integrate_function: Optional[Callable] :param custom_integrated_frames_dir_name: 自定义积分帧目录名 :type custom_integrated_frames_dir_name: Optional[str] :param transform: 数据变换 :type transform: Optional[Callable] :param target_transform: 标签变换 :type target_transform: Optional[Callable] :param root: Root directory of the dataset :type root: Union[str, Path] :param train: Whether to use training set or test set :type train: Optional[bool] :param data_type: ``\"event\"`` or ``\"frame\"`` :type data_type: str :param frames_number: Number of frames to integrate :type frames_number: Optional[int] :param split_by: ``\"time\"`` or ``\"number\"`` :type split_by: Optional[str] :param duration: Time duration per frame :type duration: Optional[int] :param custom_integrate_function: User-defined integrate function :type custom_integrate_function: Optional[Callable] :param custom_integrated_frames_dir_name: Custom frames directory name :type custom_integrated_frames_dir_name: Optional[str] :param transform: Transform function :type transform: Optional[Callable] :param target_transform: Target transform function :type target_transform: Optional[Callable] :return: None :rtype: None """ if train is None: raise ValueError("`train` must be `True` or `False`") super().__init__( root, train, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform, target_transform, )
[文档] @classmethod def get_H_W(cls) -> Tuple: r""" **API Language:** :ref:`中文 <n_mnist.get_H_W-cn>` | :ref:`English <n_mnist.get_H_W-en>` ---- .. _n_mnist.get_H_W-cn: * **中文** :return: ``(34, 34)`` :rtype: Tuple[int, int] ---- .. _n_mnist.get_H_W-en: * **English** :return: ``(34, 34)`` :rtype: Tuple[int, int] """ return 34, 34
[文档] @classmethod def resource_url_md5(cls) -> list: r""" **API Language:** :ref:`中文 <n_mnist.resource_url_md5-cn>` | :ref:`English <n_mnist.resource_url_md5-en>` ---- .. _n_mnist.resource_url_md5-cn: * **中文** :return: N-MNIST 数据集的下载链接与 MD5 校验值列表 :rtype: list ---- .. _n_mnist.resource_url_md5-en: * **English** :return: List of download URLs and MD5 checksums for the N-MNIST dataset :rtype: list """ url = "https://www.garrickorchard.com/datasets/n-mnist" return [ ("Train.zip", url, "20959b8e626244a1b502305a9e6e2031"), ("Test.zip", url, "69ca8762b2fe404d9b9bad1103e97832"), ]
[文档] @classmethod def downloadable(cls) -> bool: r""" **API Language:** :ref:`中文 <n_mnist.downloadable-cn>` | :ref:`English <n_mnist.downloadable-en>` ---- .. _n_mnist.downloadable-cn: * **中文** 由于数据集版权限制,N-MNIST 不提供自动下载,用户需手动下载。 :return: ``False`` :rtype: bool ---- .. _n_mnist.downloadable-en: * **English** The N-MNIST dataset does not provide automatic download due to copyright restrictions. Users need to download it manually. :return: ``False`` :rtype: bool """ return False
[文档] @classmethod def extract_downloaded_files(cls, download_root: Path, extract_root: Path): r""" **API Language:** :ref:`中文 <n_mnist.extract_downloaded_files-cn>` | :ref:`English <n_mnist.extract_downloaded_files-en>` ---- .. _n_mnist.extract_downloaded_files-cn: * **中文** 从 ``download_root`` 中的所有 zip 文件提取到 ``extract_root``。 :param download_root: 下载文件所在目录 :type download_root: Path :param extract_root: 提取目标目录 :type extract_root: Path :return: None :rtype: None ---- .. _n_mnist.extract_downloaded_files-en: * **English** Extract all zip files from ``download_root`` into ``extract_root``. :param download_root: Directory containing the downloaded files :type download_root: Path :param extract_root: Directory to extract into :type extract_root: Path :return: None :rtype: None """ with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 2)) as tpe: futures = [] for zip_file in os.listdir(download_root): zip_file = download_root / zip_file print(f"Extract [{zip_file}] to [{extract_root}].") futures.append(tpe.submit(extract_archive, zip_file, extract_root)) for future in futures: future.result()
[文档] @classmethod def create_raw_from_extracted(cls, extract_root: Path, raw_root: Path): r""" **API Language:** :ref:`中文 <n_mnist.create_raw_from_extracted-cn>` | :ref:`English <n_mnist.create_raw_from_extracted-en>` ---- .. _n_mnist.create_raw_from_extracted-cn: * **中文** 将提取后的 ATIS 二进制文件按训练/测试集转换为 ``.npz`` 格式并保存。 每个类别目录下的 ``.bin`` 文件会被并行转换为 ``.npz`` 文件。 :param extract_root: 包含已提取文件的目录 :type extract_root: Path :param raw_root: 保存原始数据的目录 :type raw_root: Path :return: None :rtype: None ---- .. _n_mnist.create_raw_from_extracted-en: * **English** Convert extracted ATIS binary files to ``.npz`` format by train/test split. Each ``.bin`` file under the class directories is converted to ``.npz`` in parallel. :param extract_root: Directory containing the extracted files :type extract_root: Path :param raw_root: Directory to save the raw dataset :type raw_root: Path :return: None :rtype: None """ t_ckp = time.time() with ThreadPoolExecutor( max_workers=min( multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess, ) ) as tpe: futures = [] for train_test_dir in ["Train", "Test"]: source_dir = extract_root / train_test_dir target_dir = raw_root / train_test_dir.lower() target_dir.mkdir() print(f"Mkdir [{target_dir}].") for class_name in os.listdir(source_dir): bin_dir = source_dir / class_name np_dir = target_dir / class_name np_dir.mkdir() print(f"Mkdir [{np_dir}].") for bin_file in os.listdir(bin_dir): source_file = bin_dir / bin_file target_file = np_dir / (os.path.splitext(bin_file)[0] + ".npz") print(f"Start to convert [{source_file}] to [{target_file}].") futures.append( tpe.submit(_read_bin_save_to_np, source_file, target_file) ) for future in futures: future.result() print(f"Used time = [{round(time.time() - t_ckp, 2)}s].")