spikingjelly.datasets.dvs128_gesture 源代码

from .utils import (
import os
import tqdm
import numpy as np
import struct
from torchvision.datasets import utils
import time
import multiprocessing
import torch
# https://www.research.ibm.com/dvsgesture/
# https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794

labels_dict = {
'hand_clapping': 1,  # 注意不是从0开始
'right_hand_wave': 2,
'left_hand_wave': 3,
'right_arm_clockwise': 4,
'right_arm_counter_clockwise': 5,
'left_arm_clockwise': 6,
'left_arm_counter_clockwise': 7,
'arm_roll': 8,
'air_drums': 9,
'air_guitar': 10,
'other_gestures': 11
}  # gesture_mapping.csv
# url md5
resource = ['https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794', '8a5c71fb11e24e5ca5b11866ca6c00a1']

[文档]class DVS128Gesture(EventsFramesDatasetBase):
[文档] @staticmethod def get_wh(): return 128, 128
[文档] @staticmethod def download_and_extract(download_root: str, extract_root: str): file_name = os.path.join(download_root, 'DvsGesture.tar.gz') if os.path.exists(file_name): print('DvsGesture.tar.gz already exists, check md5') if utils.check_md5(file_name, resource[1]): print('md5 checked, extracting...') utils.extract_archive(file_name, extract_root) return else: print(f'{file_name} corrupted.') print(f'Please download from {resource[0]} and save to {download_root} manually.') raise NotImplementedError
[文档] @staticmethod def read_bin(file_name: str): # https://gitlab.com/inivation/dv/dv-python/ with open(file_name, 'rb') as bin_f: # skip ascii header line = bin_f.readline() while line.startswith(b'#'): if line == b'#!END-HEADER\r\n': break else: line = bin_f.readline() txyp = { 't': [], 'x': [], 'y': [], 'p': [] } while True: header = bin_f.read(28) if not header or len(header) == 0: break # read header e_type = struct.unpack('H', header[0:2])[0] e_source = struct.unpack('H', header[2:4])[0] e_size = struct.unpack('I', header[4:8])[0] e_offset = struct.unpack('I', header[8:12])[0] e_tsoverflow = struct.unpack('I', header[12:16])[0] e_capacity = struct.unpack('I', header[16:20])[0] e_number = struct.unpack('I', header[20:24])[0] e_valid = struct.unpack('I', header[24:28])[0] data_length = e_capacity * e_size data = bin_f.read(data_length) counter = 0 if e_type == 1: while data[counter:counter + e_size]: aer_data = struct.unpack('I', data[counter:counter + 4])[0] timestamp = struct.unpack('I', data[counter + 4:counter + 8])[0] | e_tsoverflow << 31 x = (aer_data >> 17) & 0x00007FFF y = (aer_data >> 2) & 0x00007FFF pol = (aer_data >> 1) & 0x00000001 counter = counter + e_size txyp['x'].append(x) txyp['y'].append(y) txyp['t'].append(timestamp) txyp['p'].append(pol) else: # non-polarity event packet, not implemented pass txyp['x'] = np.asarray(txyp['x']) txyp['y'] = np.asarray(txyp['y']) txyp['t'] = np.asarray(txyp['t']) txyp['p'] = np.asarray(txyp['p']) return txyp
[文档] @staticmethod def convert_aedat_dir_to_npy_dir(aedat_data_dir: str, events_npy_train_root: str, events_npy_test_root: str): def cvt_files_fun(aedat_file_list, output_dir): for aedat_file in aedat_file_list: base_name = aedat_file[0: -6] events = DVS128Gesture.read_bin(os.path.join(aedat_data_dir, aedat_file)) # 读取csv文件,获取各段的label,保存对应的数据和label events_csv = np.loadtxt(os.path.join(aedat_data_dir, base_name + '_labels.csv'), dtype=np.uint32, delimiter=',', skiprows=1) index = 0 index_l = 0 index_r = 0 for i in range(events_csv.shape[0]): label = events_csv[i][0] t_start = events_csv[i][1] t_end = events_csv[i][2] while True: t = events['t'][index] if t < t_start: index += 1 else: index_l = index # 左闭 break while True: t = events['t'][index] if t < t_end: index += 1 else: index_r = index # 右开 break # [index_l, index_r) j = 0 while True: file_name = os.path.join(output_dir, f'{base_name}_{label}_{j}.npy') # 由于不同线程执行的base_name一定不相同,因此这里不会出现多线程之间的数据复用造成的错误 if os.path.exists(file_name): # 防止同一个aedat里存在多个相同label的数据段 j += 1 else: np.save(file=file_name, arr={ 't': events['t'][index_l:index_r], 'x': events['x'][index_l:index_r], 'y': events['y'][index_l:index_r], 'p': events['p'][index_l:index_r] }) break with open(os.path.join(aedat_data_dir, 'trials_to_train.txt')) as trials_to_train_txt, open( os.path.join(aedat_data_dir, 'trials_to_test.txt')) as trials_to_test_txt: train_list = [] for fname in trials_to_train_txt.readlines(): fname = fname.strip() if fname.__len__() > 0: train_list.append(fname) test_list = [] for fname in trials_to_test_txt.readlines(): fname = fname.strip() if fname.__len__() > 0: test_list.append(fname) # 将aedat_data_dir目录下的.aedat文件读取并转换成np保存的字典,保存在npy_data_dir目录 print('convert events data from aedat to numpy format.') # 速度很慢,并行化 npy_data_num = train_list.__len__() + test_list.__len__() thread_num = max(multiprocessing.cpu_count(), 2) block = train_list.__len__() // (thread_num - 1) # 训练集分成thread_num - 1个子任务 thread_list = [] for i in range(thread_num - 1): thread_list.append(FunctionThread(cvt_files_fun, train_list[i * block: (i + 1) * block], events_npy_train_root)) print(f'thread {i} start') thread_list[-1].start() # 测试集再单独作为一个线程 thread_list.append(FunctionThread(cvt_files_fun, test_list, events_npy_test_root)) print(f'thread {thread_num - 1} start') thread_list[-1].start() with tqdm.tqdm(total=npy_data_num) as pbar: while True: working_thread = [] finished_thread = [] for i in range(thread_list.__len__()): if thread_list[i].is_alive(): working_thread.append(i) else: finished_thread.append(i) pbar.update(utils.list_files(events_npy_train_root, '.npy').__len__() + utils.list_files(events_npy_test_root, '.npy').__len__()) print('working thread:', working_thread) print('finished thread:', finished_thread) if finished_thread.__len__() == thread_list.__len__(): return else: time.sleep(10)
[文档] @staticmethod def create_frames_dataset(events_data_dir: str, frames_data_dir: str, frames_num: int, split_by: str, normalization: str or None): width, height = DVS128Gesture.get_wh() def read_fun(file_name): return np.load(file_name, allow_pickle=True).item() convert_events_dir_to_frames_dir(events_data_dir, frames_data_dir, '.npy', read_fun, height, width, frames_num, split_by, normalization, thread_num=4)
[文档] @staticmethod def get_events_item(file_name): return np.load(file_name, allow_pickle=True).item(), int(os.path.basename(file_name).split('_')[-2]) - 1
[文档] @staticmethod def get_frames_item(file_name): return torch.from_numpy(np.load(file_name)).float(), int(os.path.basename(file_name).split('_')[-2]) - 1
def __init__(self, root: str, train: bool, use_frame=True, frames_num=10, split_by='number', normalization='max'): ''' :param root: 保存数据集的根目录。其中应该至少包含 `DvsGesture.tar.gz` 和 `gesture_mapping.csv` :type root: str :param train: 是否使用训练集 :type train: bool :param use_frame: 是否将事件数据转换成帧数据 :type use_frame: bool :param frames_num: 转换后数据的帧数 :type frames_num: int :param split_by: 脉冲数据转换成帧数据的累计方式。``'time'`` 或 ``'number'`` :type split_by: str :param normalization: 归一化方法,为 ``None`` 表示不进行归一化; 为 ``'frequency'`` 则每一帧的数据除以每一帧的累加的原始数据数量; 为 ``'max'`` 则每一帧的数据除以每一帧中数据的最大值; 为 ``norm`` 则每一帧的数据减去每一帧中的均值,然后除以标准差 :type normalization: str or None DVS128 Gesture数据集,出自 `A Low Power, Fully Event-Based Gesture Recognition System <https://openaccess.thecvf.com/content_cvpr_2017/papers/Amir_A_Low_Power_CVPR_2017_paper.pdf>`_, 数据来源于DVS相机拍摄的手势。原始数据的原始下载地址参见 https://www.research.ibm.com/dvsgesture/。 关于转换成帧数据的细节,参见 :func:`~spikingjelly.datasets.utils.integrate_events_to_frames`。 :param root: root directory of dataset, which should contain `DvsGesture.tar.gz` and `gesture_mapping.csv` :type root: str :param train: whether use the train dataset. If `False`, use the test dataset :type train: bool :param use_frame: whether use the frames data. If `False`, use the events data :type use_frame: bool :param frames_num: the number of frames :type frames_num: int :param split_by: how to split the events, can be ``'number', 'time'`` :type split_by: str :param normalization: how to normalize frames, can be ``None, 'frequency', 'max', 'norm', 'sum'`` :type normalization: str or None DVS128 Gesture dataset, which is provided by `A Low Power, Fully Event-Based Gesture Recognition System <https://openaccess.thecvf.com/content_cvpr_2017/papers/Amir_A_Low_Power_CVPR_2017_paper.pdf>`, contains the gesture recorded by a DVS128 camera. The origin dataset can be downloaded from https://www.research.ibm.com/dvsgesture/. For more details about converting events to frames, see :func:`~spikingjelly.datasets.utils.integrate_events_to_frames`。 ''' super().__init__() events_npy_root = os.path.join(root, 'events_npy') events_npy_train_root = os.path.join(events_npy_root, 'train') events_npy_test_root = os.path.join(events_npy_root, 'test') if os.path.exists(events_npy_train_root) and os.path.exists(events_npy_test_root): print(f'npy format events data root {events_npy_train_root}, {events_npy_test_root} already exists') else: extracted_root = os.path.join(root, 'extracted') if os.path.exists(extracted_root): print(f'extracted root {extracted_root} already exists.') else: self.download_and_extract(root, extracted_root) if not os.path.exists(events_npy_root): os.mkdir(events_npy_root) print(f'mkdir {events_npy_root}') os.mkdir(events_npy_train_root) print(f'mkdir {events_npy_train_root}') os.mkdir(events_npy_test_root) print(f'mkdir {events_npy_test_root}') print('read events data from *.aedat and save to *.npy...') self.convert_aedat_dir_to_npy_dir(os.path.join(extracted_root, 'DvsGesture'), events_npy_train_root, events_npy_test_root) self.file_name = [] # 保存数据文件的路径 self.use_frame = use_frame self.data_dir = None if use_frame: self.normalization = normalization if normalization == 'frequency': dir_suffix = normalization else: dir_suffix = None frames_root = os.path.join(root, f'frames_num_{frames_num}_split_by_{split_by}_normalization_{dir_suffix}') frames_train_root = os.path.join(frames_root, 'train') frames_test_root = os.path.join(frames_root, 'test') if os.path.exists(frames_root): # 如果root目录下存在frames_root目录,则认为数据集文件存在 print(f'frames data root {frames_root} already exists.') else: os.mkdir(frames_root) os.mkdir(frames_train_root) os.mkdir(frames_test_root) print(f'mkdir {frames_root}, {frames_train_root}, {frames_test_root}.') print('creating frames data..') self.create_frames_dataset(events_npy_train_root, frames_train_root, frames_num, split_by, normalization) self.create_frames_dataset(events_npy_test_root, frames_test_root, frames_num, split_by, normalization) if train: self.data_dir = frames_train_root else: self.data_dir = frames_test_root self.file_name = utils.list_files(self.data_dir, '.npy', True) else: if train: self.data_dir = events_npy_train_root else: self.data_dir = events_npy_test_root self.file_name = utils.list_files(self.data_dir, '.npy', True) def __len__(self): return self.file_name.__len__() def __getitem__(self, index): if self.use_frame: frames, labels = self.get_frames_item(self.file_name[index]) if self.normalization is not None and self.normalization != 'frequency': frames = normalize_frame(frames, self.normalization) return frames, labels else: return self.get_events_item(self.file_name[index])