spikingjelly.datasets.bullying10k 源代码

import multiprocessing
import os
from pathlib import Path
import shutil
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Tuple

import numpy as np
from torchvision.datasets.utils import extract_archive

from .. import configure
from .base import NeuromorphicDatasetFolder
from . import utils


__all__ = ["BULLYING10K_CATEGORY_LABEL", "Bullying10kClassification"]


BULLYING10K_CATEGORY_LABEL = {
    "handshake": 0,
    "slapping": 1,
    "punching": 2,
    "walking": 3,
    "fingerguess": 4,
    "strangling": 5,
    "greeting": 6,
    "pushing": 7,
    "hairgrabs": 8,
    "kicking": 9,
}


def _convert_npy_to_npz(src_path: Path, dst_dir: Path, label: int):
    original_data = np.load(src_path, allow_pickle=True)
    original_data = [y for x in original_data for y in x]
    # original_data: [(t, x, y, p, ...), ...]
    # npz data: {"t": t, "x": x, "y": y, "p": p}
    t = np.array([d[0] for d in original_data])
    x = np.array([d[1] for d in original_data])
    y = np.array([d[2] for d in original_data])
    p = np.array([d[3] for d in original_data])
    target_file_path = dst_dir / str(label) / f"{src_path.stem}.npz"
    utils.np_savez(target_file_path, t=t, x=x, y=y, p=p)
    print(f"[{target_file_path}] saved.")


[文档] class Bullying10kClassification(NeuromorphicDatasetFolder): def __init__( self, root: str, train: bool = True, data_type: str = "event", frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ): """ **API Language:** :ref:`中文 <Bullying10kClassification.__init__-cn>` | :ref:`English <Bullying10kClassification.__init__-en>` ---- .. _Bullying10kClassification.__init__-cn: * **中文** Bullying10K 动作识别(分类)数据集,由 `Bullying10K: A Neuromorphic Dataset towards Privacy-Preserving Bullying Recognition <https://arxiv.org/abs/2306.11546>`_ 提出。 有关参数的更多详细信息,请参考 :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.NeuromorphicDatasetFolder>` ---- .. _Bullying10kClassification.__init__-en: * **English** The Bullying10K dataset for action recognition (classification), which is proposed by `Bullying10K: A Neuromorphic Dataset towards Privacy-Preserving Bullying Recognition <https://arxiv.org/abs/2306.11546>`_. Refer to :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.NeuromorphicDatasetFolder>` for more details about params information. :param root: 数据集的根路径 :type root: Union[str, Path] :param train: 是否使用训练集 :type train: Optional[bool] :param data_type: ``\"event\"`` 或 ``\"frame\"`` :type data_type: str :param frames_number: 积分帧的数量 :type frames_number: Optional[int] :param split_by: ``\"time\"`` 或 ``\"number\"`` :type split_by: Optional[str] :param duration: 每帧的时间时长 :type duration: Optional[int] :param custom_integrate_function: 用户自定义积分函数 :type custom_integrate_function: Optional[Callable] :param custom_integrated_frames_dir_name: 自定义积分帧目录名 :type custom_integrated_frames_dir_name: Optional[str] :param transform: 数据变换 :type transform: Optional[Callable] :param target_transform: 标签变换 :type target_transform: Optional[Callable] :param root: Root directory of the dataset :type root: Union[str, Path] :param train: Whether to use training set or test set :type train: Optional[bool] :param data_type: ``\"event\"`` or ``\"frame\"`` :type data_type: str :param frames_number: Number of frames to integrate :type frames_number: Optional[int] :param split_by: ``\"time\"`` or ``\"number\"`` :type split_by: Optional[str] :param duration: Time duration per frame :type duration: Optional[int] :param custom_integrate_function: User-defined integrate function :type custom_integrate_function: Optional[Callable] :param custom_integrated_frames_dir_name: Custom frames directory name :type custom_integrated_frames_dir_name: Optional[str] :param transform: Transform function :type transform: Optional[Callable] :param target_transform: Target transform function :type target_transform: Optional[Callable] :return: None :rtype: None """ if train is None: raise ValueError( "The argument `train` must be specified as a boolean value." ) super().__init__( root, train, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform, target_transform, )
[文档] @classmethod def get_H_W(cls) -> Tuple: r""" **API Language:** :ref:`中文 <Bullying10K.get_H_W-cn>` | :ref:`English <Bullying10K.get_H_W-en>` ---- .. _Bullying10K.get_H_W-cn: * **中文** :return: ``(260, 346)`` :rtype: Tuple[int, int] ---- .. _Bullying10K.get_H_W-en: * **English** :return: ``(260, 346)`` :rtype: Tuple[int, int] """ return 260, 346
[文档] @classmethod def resource_url_md5(cls) -> List[Tuple[str, str, str]]: return [ ( "handshake.zip", "https://figshare.com/ndownloader/files/41268834", "681d70f499e736a1e805305284ddc425", ), ( "slapping.zip", "https://figshare.com/ndownloader/files/41247021", "84b41d6805958f9f62f425223916ffc2", ), ( "punching.zip", "https://figshare.com/ndownloader/files/41263314", "40954f480ab210099d448b7b88fc4719", ), ( "walking.zip", "https://figshare.com/ndownloader/files/41247024", "56e4cac9c0814ce701c3b2292c15b6a9", ), ( "fingerguess.zip", "https://figshare.com/ndownloader/files/41253057", "f83114e5b4f0ea57cac86fb080c7e4d7", ), ( "strangling.zip", "https://figshare.com/ndownloader/files/41261904", "8185ecd6f3147e9b609d22f06270aa86", ), ( "greeting.zip", "https://figshare.com/ndownloader/files/41268792", "4a763fad728b04c8356db8544f1121fe", ), ( "pushing.zip", "https://figshare.com/ndownloader/files/41268951", "7986c74ade7149a98672120a89b13ba8", ), ( "hairgrabs.zip", "https://figshare.com/ndownloader/files/41277855", "a9cf690ed0a3305da4a4b8e110f64db1", ), ( "kicking.zip", "https://figshare.com/ndownloader/files/41278008", "6c3218f977de4ac29c84a10b17779c33", ), ]
[文档] @classmethod def downloadable(cls) -> bool: """ :return: ``True`` """ return True
[文档] @classmethod def extract_downloaded_files(cls, download_root: Path, extract_root: Path): with ThreadPoolExecutor( max_workers=min(multiprocessing.cpu_count(), 10) ) as tpe: futures = [] for file_name in os.listdir(download_root): if not file_name.endswith(".zip"): # move the json files to extract_root directly src_file = download_root / file_name dst_file = extract_root / file_name shutil.copy(src_file, dst_file) else: zip_file = download_root / file_name print(f"Extract [{zip_file}] to [{extract_root}].") futures.append(tpe.submit(extract_archive, zip_file, extract_root)) for future in futures: future.result()
[文档] @classmethod def create_raw_from_extracted(cls, extract_root: Path, raw_root: Path): train_dir = raw_root / "train" test_dir = raw_root / "test" train_dir.mkdir() test_dir.mkdir() print(f"Mkdir [{train_dir}] and [{test_dir}].") for label in range(10): (train_dir / str(label)).mkdir() (test_dir / str(label)).mkdir() print( f"Mkdir {os.listdir(train_dir)} in [{train_dir}] " f"and {os.listdir(test_dir)} in [{test_dir}]." ) all_files_labels = [] categories = list( filter( lambda x: (not x.endswith(".json")) and (not x.startswith(".")), os.listdir(extract_root), ) ) for c in categories: cpath = extract_root / c for dir_path, _, dir_file_names in os.walk(cpath): dir_path = Path(dir_path) for dfn in dir_file_names: all_files_labels.append( (dir_path / dfn, BULLYING10K_CATEGORY_LABEL[c]) ) num_files = len(all_files_labels) all_files_labels = np.array(all_files_labels) print(f"Found {num_files} files in total.") print( "Use the same way to split training / validation sets as the original work: " "https://github.com/Brain-Cog-Lab/Bullying10K/blob/main/Bullying10k.py" ) test_loc = np.zeros(num_files, dtype=bool) test_loc[range(0, num_files, 5)] = 1 train_files_labels = all_files_labels[~test_loc] test_files_labels = all_files_labels[test_loc] print( f"Training set: {len(train_files_labels)} files. " f"Test set: {len(test_files_labels)} files." ) t_ckp = time.time() with ThreadPoolExecutor( max_workers=min( multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess, ) ) as tpe: futures = [] print( f"Start the ThreadPoolExecutor with max workers = [{tpe._max_workers}]." ) for fpath, label in train_files_labels: futures.append(tpe.submit(_convert_npy_to_npz, fpath, train_dir, label)) for fpath, label in test_files_labels: futures.append(tpe.submit(_convert_npy_to_npz, fpath, test_dir, label)) for future in futures: future.result() print(f"Used time = [{round(time.time() - t_ckp, 2)}s].") print( f"All npy files have been converted into npz files " f"and into [{train_dir, test_dir}]." ) # remove the extracted files, since they're too large print(f"Remove the directory [{extract_root}].") shutil.rmtree(extract_root)