spikingjelly.datasets.asl_dvs 源代码

from .utils import (
    EventsFramesDatasetBase, 
    convert_events_dir_to_frames_dir,
    FunctionThread,
    normalize_frame,
)
import os
import tqdm
import numpy as np
from torchvision.datasets import utils
import multiprocessing
import shutil
import scipy.io
import torch
labels_dict = {
'a': 0,
'b': 1,
'c': 2,
'd': 3,
'e': 4,
'f': 5,
'g': 6,
'h': 7,
'i': 8,
'k': 9,
'l': 10,
'm': 11,
'n': 12,
'o': 13,
'p': 14,
'q': 15,
'r': 16,
's': 17,
't': 18,
'u': 19,
'v': 20,
'w': 21,
'x': 22,
'y': 23
}  # gesture_mapping.csv
# url md5
resource = ['https://www.dropbox.com/sh/ibq0jsicatn7l6r/AACNrNELV56rs1YInMWUs9CAa?dl=0',
            '8b46191acf6c1760ad3f2d2cb4380e24']


[文档]class ASLDVS(EventsFramesDatasetBase):
[文档] @staticmethod def get_wh(): return 240, 180
[文档] @staticmethod def download_and_extract(download_root: str, extract_root: str): file_name = os.path.join(download_root, 'ICCV2019_DVS_dataset.zip') if os.path.exists(file_name): print('ICCV2019_DVS_dataset.zip already exists, check md5') if utils.check_md5(file_name, resource[1]): print('md5 checked, extracting...') temp_extract_root = os.path.join(download_root, 'temp_extract') os.mkdir(temp_extract_root) utils.extract_archive(file_name, temp_extract_root) for zip_file in tqdm.tqdm(utils.list_files(temp_extract_root, '.zip')): utils.extract_archive(os.path.join(temp_extract_root, zip_file), extract_root) shutil.rmtree(temp_extract_root) return else: print(f'{file_name} corrupted.') print(f'Please download from {resource[0]} and save to {download_root} manually.') raise NotImplementedError
[文档] @staticmethod def read_bin(file_name: str): events = scipy.io.loadmat(file_name) return { 't': events['ts'].squeeze(), 'x': events['x'].squeeze(), 'y': events['y'].squeeze(), 'p': events['pol'].squeeze() }
[文档] @staticmethod def get_events_item(file_name): base_name = os.path.basename(file_name) return ASLDVS.read_bin(file_name), labels_dict[base_name[0]]
[文档] @staticmethod def get_frames_item(file_name): base_name = os.path.basename(file_name) return torch.from_numpy(np.load(file_name)['arr_0']).float(), labels_dict[base_name[0]]
[文档] @staticmethod def create_frames_dataset(events_data_dir: str, frames_data_dir: str, frames_num: int, split_by: str, normalization: str or None): width, height = ASLDVS.get_wh() thread_list = [] for source_dir in utils.list_dir(events_data_dir): abs_source_dir = os.path.join(events_data_dir, source_dir) abs_target_dir = os.path.join(frames_data_dir, source_dir) if not os.path.exists(abs_target_dir): os.mkdir(abs_target_dir) print(f'mkdir {abs_target_dir}') print(f'thread {thread_list.__len__()} convert events data in {abs_source_dir} to {abs_target_dir}') thread_list.append( FunctionThread(convert_events_dir_to_frames_dir, abs_source_dir, abs_target_dir, '.mat', ASLDVS.read_bin, height, width, frames_num, split_by, normalization, 1, True)) # 文件数量太多,体积太大,因此采用压缩格式 # 这个数据集的文件夹数量24,一次启动24个线程太多,因此一次只启动max_running_threads个线程 max_running_threads = max(multiprocessing.cpu_count(), 8) for i in range(0, thread_list.__len__(), max_running_threads): for j in range(i, min(i + max_running_threads, thread_list.__len__())): thread_list[j].start() print(f'thread {j} start') for j in range(i, min(i + max_running_threads, thread_list.__len__())): print('thread', j, 'join') thread_list[j].join() print('thread', j, 'finished')
def __init__(self, root: str, train: bool, split_ratio=0.9, use_frame=True, frames_num=10, split_by='number', normalization='max'): ''' :param root: 保存数据集的根目录 :type root: str :param train: 是否使用训练集 :type train: bool :param split_ratio: 分割比例。每一类中前split_ratio的数据会被用作训练集,剩下的数据为测试集 :type split_ratio: float :param use_frame: 是否将事件数据转换成帧数据 :type use_frame: bool :param frames_num: 转换后数据的帧数 :type frames_num: int :param split_by: 脉冲数据转换成帧数据的累计方式。``'time'`` 或 ``'number'`` :type split_by: str :param normalization: 归一化方法,为 ``None`` 表示不进行归一化; 为 ``'frequency'`` 则每一帧的数据除以每一帧的累加的原始数据数量; 为 ``'max'`` 则每一帧的数据除以每一帧中数据的最大值; 为 ``norm`` 则每一帧的数据减去每一帧中的均值,然后除以标准差 :type normalization: str or None ASL-DVS数据集,出自 `Graph-Based Object Classification for Neuromorphic Vision Sensing <https://arxiv.org/abs/1908.06648>`_, 包含24个英文字母(从A到Y,排除J)的美国手语,American Sign Language (ASL)。更多信息参见 https://github.com/PIX2NVS/NVS2Graph, 原始数据的下载地址为 https://www.dropbox.com/sh/ibq0jsicatn7l6r/AACNrNELV56rs1YInMWUs9CAa?dl=0。 关于转换成帧数据的细节,参见 :func:`~spikingjelly.datasets.utils.integrate_events_to_frames`。 ''' super().__init__() self.train = train events_root = os.path.join(root, 'events') if os.path.exists(events_root): # 如果root目录下存在events_root目录 print(f'events data root {events_root} already exists.') else: self.download_and_extract(root, events_root) self.file_name = [] # 保存数据文件的路径 self.use_frame = use_frame self.data_dir = None if use_frame: self.normalization = normalization if normalization == 'frequency': dir_suffix = normalization else: dir_suffix = None frames_root = os.path.join(root, f'frames_num_{frames_num}_split_by_{split_by}_normalization_{dir_suffix}') if os.path.exists(frames_root): # 如果root目录下存在frames_root目录,且frames_root下有10个子文件夹,则认为数据集文件存在 print(f'frames data root {frames_root} already exists.') else: os.mkdir(frames_root) self.create_frames_dataset(events_root, frames_root, frames_num, split_by, normalization) self.data_dir = frames_root else: self.data_dir = events_root if train: # + 1 是因为文件的下标从1开始 index = np.arange(0, int(split_ratio * 4200)) + 1 else: index = np.arange(int(split_ratio * 4200), 4200) + 1 for class_name in labels_dict.keys(): class_dir = os.path.join(self.data_dir, class_name) for i in index: self.file_name.append(os.path.join(class_dir, class_name + '_' + str(i).zfill(4))) def __len__(self): return self.file_name.__len__() def __getitem__(self, index): if self.use_frame: frames, labels = self.get_frames_item(self.file_name[index] + '.npz') if self.normalization is not None and self.normalization != 'frequency': frames = normalize_frame(frames, self.normalization) return frames, labels else: return self.get_events_item(self.file_name[index] + '.mat')