spikingjelly.datasets.shd 源代码

from typing import Callable, Dict, Optional, Tuple

import h5py
import numpy as np
from torch.utils.data import Dataset
from torchvision.datasets import utils
from torchvision.datasets.utils import extract_archive
import os
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time
import math
import bisect
from .. import configure
from ..datasets import np_savez

[文档]def cal_fixed_frames_number_segment_index_shd(events_t: np.ndarray, split_by: str, frames_num: int) -> tuple: j_l = np.zeros(shape=[frames_num], dtype=int) j_r = np.zeros(shape=[frames_num], dtype=int) N = events_t.size if split_by == 'number': di = N // frames_num for i in range(frames_num): j_l[i] = i * di j_r[i] = j_l[i] + di j_r[-1] = N elif split_by == 'time': dt = (events_t[-1] - events_t[0]) / frames_num idx = np.arange(N) for i in range(frames_num): t_l = dt * i + events_t[0] t_r = t_l + dt mask = np.logical_and(events_t >= t_l, events_t < t_r) idx_masked = idx[mask] j_l[i] = idx_masked[0] j_r[i] = idx_masked[-1] + 1 j_r[-1] = N else: raise NotImplementedError return j_l, j_r
[文档]def integrate_events_segment_to_frame_shd(x: np.ndarray, W: int, j_l: int = 0, j_r: int = -1) -> np.ndarray: frame = np.zeros(shape=[W]) x = x[j_l: j_r].astype(int) # avoid overflow position = x events_number_per_pos = np.bincount(position) frame[np.arange(events_number_per_pos.size)] += events_number_per_pos return frame
[文档]def integrate_events_by_fixed_frames_number_shd(events: Dict, split_by: str, frames_num: int, W: int) -> np.ndarray: t, x = (events[key] for key in ('t', 'x')) j_l, j_r = cal_fixed_frames_number_segment_index_shd(t, split_by, frames_num) frames = np.zeros([frames_num, W]) for i in range(frames_num): frames[i] = integrate_events_segment_to_frame_shd(x, W, j_l[i], j_r[i]) return frames
[文档]def integrate_events_file_to_frames_file_by_fixed_frames_number_shd(h5_file: h5py.File, i: int, output_dir: str, split_by: str, frames_num: int, W: int, print_save: bool = False) -> None: events = {'t': h5_file['spikes']['times'][i], 'x': h5_file['spikes']['units'][i]} label = h5_file['labels'][i] fname = os.path.join(output_dir, str(label), str(i)) np_savez(fname, frames=integrate_events_by_fixed_frames_number_shd(events, split_by, frames_num, W)) if print_save: print(f'Frames [{fname}] saved.')
[文档]def integrate_events_by_fixed_duration_shd(events: Dict, duration: int, W: int) -> np.ndarray: x = events['x'] t = 1000*events['t'] t = t - t[0] N = t.size frames_num = int(math.ceil(t[-1] / duration)) frames = np.zeros([frames_num, W]) frame_index = t // duration left = 0 for i in range(frames_num - 1): right = np.searchsorted(frame_index, i + 1, side='left') frames[i] = integrate_events_segment_to_frame_shd(x, W, left, right) left = right frames[-1] = integrate_events_segment_to_frame_shd(x, W, left, N) return frames
[文档]def integrate_events_file_to_frames_file_by_fixed_duration_shd(h5_file: h5py.File, i: int, output_dir: str, duration: int, W: int, print_save: bool = False) -> None: events = {'t': h5_file['spikes']['times'][i], 'x': h5_file['spikes']['units'][i]} label = h5_file['labels'][i] fname = os.path.join(output_dir, str(label), str(i)) frames = integrate_events_by_fixed_duration_shd(events, duration, W) np_savez(fname, frames=frames) if print_save: print(f'Frames [{fname}] saved.') return frames.shape[0]
[文档]def custom_integrate_function_example(h5_file: h5py.File, i: int, output_dir: str, W: int): events = {'t': h5_file['spikes']['times'][i], 'x': h5_file['spikes']['units'][i]} label = h5_file['labels'][i] frames = np.zeros([2, W]) index_split = np.random.randint(low=0, high=events['t'].__len__()) frames[0] = integrate_events_segment_to_frame_shd(events['x'], W, 0, index_split) frames[1] = integrate_events_segment_to_frame_shd(events['x'], W, index_split, events['t'].__len__()) fname = os.path.join(output_dir, str(label), str(i)) np_savez(fname, frames=frames)
[文档]class SpikingHeidelbergDigits(Dataset): def __init__( self, root: str, train: bool = None, data_type: str = 'event', frames_number: int = None, split_by: str = None, duration: int = None, custom_integrate_function: Callable = None, custom_integrated_frames_dir_name: str = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: """ The Spiking Heidelberg Digits (SHD) dataset, which is proposed by `The Heidelberg Spiking Data Sets for the Systematic Evaluation of Spiking Neural Networks <https://doi.org/10.1109/TNNLS.2020.3044364>`_. Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information. .. admonition:: Note :class: note Events in this dataset are in the format of ``(x, t)`` rather than ``(x, y, t, p)``. Thus, this dataset is not inherited from :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` directly. But their procedures are similar. :class:`spikingjelly.datasets.shd.custom_integrate_function_example` is an example of ``custom_integrate_function``, which is similar to the cunstom function for DVS Gesture in the ``Neuromorphic Datasets Processing`` tutorial. """ super().__init__() self.root = root self.train = train self.data_type = data_type self.frames_number = frames_number self.split_by = split_by self.duration = duration self.custom_integrate_function = custom_integrate_function self.custom_integrated_frames_dir_name = custom_integrated_frames_dir_name self.transform = transform self.target_transform = target_transform download_root = os.path.join(root, 'download') extract_root = os.path.join(root, 'extract') if not os.path.exists(extract_root): if os.path.exists(download_root): print(f'The [{download_root}] directory for saving downloaded files already exists, check files...') # check files resource_list = self.resource_url_md5() for i in range(resource_list.__len__()): file_name, url, md5 = resource_list[i] fpath = os.path.join(download_root, file_name) if not utils.check_integrity(fpath=fpath, md5=md5): print(f'The file [{fpath}] does not exist or is corrupted.') if os.path.exists(fpath): # If file is corrupted, we will remove it. os.remove(fpath) print(f'Remove [{fpath}]') if self.downloadable(): # If file does not exist, we will download it. print(f'Download [{file_name}] from [{url}] to [{download_root}]') utils.download_url(url=url, root=download_root, filename=file_name, md5=md5) else: raise NotImplementedError( f'This dataset can not be downloaded by SpikingJelly, please download [{file_name}] from [{url}] manually and put files at {download_root}.') else: os.mkdir(download_root) print(f'Mkdir [{download_root}] to save downloaded files.') resource_list = self.resource_url_md5() if self.downloadable(): # download and extract file for i in range(resource_list.__len__()): file_name, url, md5 = resource_list[i] print(f'Download [{file_name}] from [{url}] to [{download_root}]') utils.download_url(url=url, root=download_root, filename=file_name, md5=md5) else: raise NotImplementedError(f'This dataset can not be downloaded by SpikingJelly, ' f'please download files manually and put files at [{download_root}]. ' f'The resources file_name, url, and md5 are: \n{resource_list}') os.mkdir(extract_root) print(f'Mkdir [{extract_root}].') self.extract_downloaded_files(download_root, extract_root) else: print(f'The directory [{extract_root}] for saving extracted files already exists.\n' f'SpikingJelly will not check the data integrity of extracted files.\n' f'If extracted files are not integrated, please delete [{extract_root}] manually, ' f'then SpikingJelly will re-extract files from [{download_root}].') # shutil.rmtree(extract_root) # print(f'Delete [{extract_root}].') if self.data_type == 'event': if self.train: self.h5_file = h5py.File(os.path.join(extract_root, 'shd_train.h5')) else: self.h5_file = h5py.File(os.path.join(extract_root, 'shd_test.h5')) self.length = self.h5_file['labels'].__len__() return elif self.data_type == 'frame': if frames_number is not None: assert frames_number > 0 and isinstance(frames_number, int) assert split_by == 'time' or split_by == 'number' frames_np_root = os.path.join(root, f'frames_number_{frames_number}_split_by_{split_by}') if os.path.exists(frames_np_root): print(f'The directory [{frames_np_root}] already exists.') else: os.mkdir(frames_np_root) print(f'Mkdir [{frames_np_root}].') frames_np_train_root = os.path.join(frames_np_root, 'train') os.mkdir(frames_np_train_root) print(f'Mkdir [{frames_np_train_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_train_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].') frames_np_test_root = os.path.join(frames_np_root, 'test') os.mkdir(frames_np_test_root) print(f'Mkdir [{frames_np_test_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_test_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].') # use multi-thread to accelerate t_ckp = time.time() with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe: sub_threads = [] print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].') h5_file = h5py.File(os.path.join(extract_root, 'shd_train.h5')) for i in range(h5_file['labels'].__len__()): print(f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].') sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number_shd, h5_file, i, frames_np_train_root, self.split_by, frames_number, self.get_W(), True)) h5_file = h5py.File(os.path.join(extract_root, 'shd_test.h5')) for i in range(h5_file['labels'].__len__()): print(f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].') sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number_shd, h5_file, i, frames_np_test_root, self.split_by, frames_number, self.get_W(), True)) for sub_thread in sub_threads: if sub_thread.exception(): print(sub_thread.exception()) exit(-1) print(f'Used time = [{round(time.time() - t_ckp, 2)}s].') self.frames_np_root = frames_np_root elif duration is not None: assert duration > 0 and isinstance(duration, int) frames_np_root = os.path.join(root, f'duration_{duration}') if os.path.exists(frames_np_root): print(f'The directory [{frames_np_root}] already exists.') else: os.mkdir(frames_np_root) print(f'Mkdir [{frames_np_root}].') frames_np_train_root = os.path.join(frames_np_root, 'train') os.mkdir(frames_np_train_root) print(f'Mkdir [{frames_np_train_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_train_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].') frames_np_test_root = os.path.join(frames_np_root, 'test') os.mkdir(frames_np_test_root) print(f'Mkdir [{frames_np_test_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_test_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].') # use multi-thread to accelerate t_ckp = time.time() with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe: print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].') sub_threads = [] h5_file = h5py.File(os.path.join(extract_root, 'shd_train.h5')) for i in range(h5_file['labels'].__len__()): print( f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].') sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration_shd, h5_file, i, frames_np_train_root, self.duration, self.get_W(), True)) h5_file = h5py.File(os.path.join(extract_root, 'shd_test.h5')) for i in range(h5_file['labels'].__len__()): print( f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].') sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration_shd, h5_file, i, frames_np_test_root, self.duration, self.get_W(), True)) for sub_thread in sub_threads: if sub_thread.exception(): print(sub_thread.exception()) exit(-1) print(f'Used time = [{round(time.time() - t_ckp, 2)}s].') self.frames_np_root = frames_np_root elif custom_integrate_function is not None: if custom_integrated_frames_dir_name is None: custom_integrated_frames_dir_name = custom_integrate_function.__name__ frames_np_root = os.path.join(root, custom_integrated_frames_dir_name) if os.path.exists(frames_np_root): print(f'The directory [{frames_np_root}] already exists.') else: os.mkdir(frames_np_root) print(f'Mkdir [{frames_np_root}].') frames_np_train_root = os.path.join(frames_np_root, 'train') os.mkdir(frames_np_train_root) print(f'Mkdir [{frames_np_train_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_train_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].') frames_np_test_root = os.path.join(frames_np_root, 'test') os.mkdir(frames_np_test_root) print(f'Mkdir [{frames_np_test_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_test_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].') # use multi-thread to accelerate t_ckp = time.time() with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe: print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].') sub_threads = [] h5_file = h5py.File(os.path.join(extract_root, 'shd_train.h5')) for i in range(h5_file['labels'].__len__()): print( f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].') sub_threads.append(tpe.submit(custom_integrate_function, h5_file, i, frames_np_train_root, self.get_W())) h5_file = h5py.File(os.path.join(extract_root, 'shd_test.h5')) for i in range(h5_file['labels'].__len__()): print( f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].') sub_threads.append(tpe.submit(custom_integrate_function, h5_file, i, frames_np_test_root, self.get_W())) for sub_thread in sub_threads: if sub_thread.exception(): print(sub_thread.exception()) exit(-1) print(f'Used time = [{round(time.time() - t_ckp, 2)}s].') self.frames_np_root = frames_np_root else: raise ValueError('At least one of "frames_number", "duration" and "custom_integrate_function" should not be None.') self.frames_path = [] self.frames_label = [] if self.train: sub_dir = 'train' else: sub_dir = 'test' for i in range(self.classes_number()): for fname in os.listdir(os.path.join(self.frames_np_root, sub_dir, str(i))): self.frames_path.append( os.path.join(self.frames_np_root, sub_dir, str(i), fname) ) self.frames_label.append(i) self.length = self.frames_label.__len__() else: raise NotImplementedError(self.data_type)
[文档] def classes_number(self): return 20
def __len__(self): return self.length def __getitem__(self, i: int): if self.data_type == 'event': events = {'t': self.h5_file['spikes']['times'][i], 'x': self.h5_file['spikes']['units'][i]} label = self.h5_file['labels'][i] if self.transform is not None: events = self.transform(events) if self.target_transform is not None: label = self.target_transform(label) return events, label elif self.data_type == 'frame': frames = np.load(self.frames_path[i], allow_pickle=True)['frames'].astype(np.float32) label = self.frames_label[i] if self.transform is not None: frames = self.transform(frames) if self.target_transform is not None: label = self.target_transform(label) return frames, label
[文档] @staticmethod def resource_url_md5() -> list: ''' :return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5 :rtype: list ''' return [ ('shd_train.h5.zip', 'https://zenkelab.org/datasets/shd_train.h5.zip', 'f3252aeb598ac776c1b526422d90eecb'), ('shd_test.h5.zip', 'https://zenkelab.org/datasets/shd_test.h5.zip', '1503a5064faa34311c398fb0a1ed0a6f'), ]
[文档] @staticmethod def downloadable() -> bool: ''' :return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually :rtype: bool ''' return True
[文档] @staticmethod def extract_downloaded_files(download_root: str, extract_root: str): ''' :param download_root: Root directory path which saves downloaded dataset files :type download_root: str :param extract_root: Root directory path which saves extracted files from downloaded files :type extract_root: str :return: None This function defines how to extract download files. ''' with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 2)) as tpe: sub_threads = [] for zip_file in os.listdir(download_root): zip_file = os.path.join(download_root, zip_file) print(f'Extract [{zip_file}] to [{extract_root}].') sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root)) for sub_thread in sub_threads: if sub_thread.exception(): print(sub_thread.exception()) exit(-1)
[文档] @staticmethod def get_W(): return 700
[文档]class SpikingSpeechCommands(Dataset): def __init__( self, root: str, split: str = 'train', data_type: str = 'event', frames_number: int = None, split_by: str = None, duration: int = None, custom_integrate_function: Callable = None, custom_integrated_frames_dir_name: str = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: """ The Spiking Speech Commands (SSC) dataset, which is proposed by `The Heidelberg Spiking Data Sets for the Systematic Evaluation of Spiking Neural Networks <https://doi.org/10.1109/TNNLS.2020.3044364>`_. Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information. .. admonition:: Note :class: note Events in this dataset are in the format of ``(x, t)`` rather than ``(x, y, t, p)``. Thus, this dataset is not inherited from :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` directly. But their procedures are similar. :class:`spikingjelly.datasets.shd.custom_integrate_function_example` is an example of ``custom_integrate_function``, which is similar to the cunstom function for DVS Gesture in the ``Neuromorphic Datasets Processing`` tutorial. """ super().__init__() self.root = root self.split = split self.data_type = data_type self.frames_number = frames_number self.split_by = split_by self.duration = duration self.custom_integrate_function = custom_integrate_function self.custom_integrated_frames_dir_name = custom_integrated_frames_dir_name self.transform = transform self.target_transform = target_transform download_root = os.path.join(root, 'download') extract_root = os.path.join(root, 'extract') if not os.path.exists(extract_root): if os.path.exists(download_root): print(f'The [{download_root}] directory for saving downloaded files already exists, check files...') # check files resource_list = self.resource_url_md5() for i in range(resource_list.__len__()): file_name, url, md5 = resource_list[i] fpath = os.path.join(download_root, file_name) if not utils.check_integrity(fpath=fpath, md5=md5): print(f'The file [{fpath}] does not exist or is corrupted.') if os.path.exists(fpath): # If file is corrupted, we will remove it. os.remove(fpath) print(f'Remove [{fpath}]') if self.downloadable(): # If file does not exist, we will download it. print(f'Download [{file_name}] from [{url}] to [{download_root}]') utils.download_url(url=url, root=download_root, filename=file_name, md5=md5) else: raise NotImplementedError( f'This dataset can not be downloaded by SpikingJelly, please download [{file_name}] from [{url}] manually and put files at {download_root}.') else: os.mkdir(download_root) print(f'Mkdir [{download_root}] to save downloaded files.') resource_list = self.resource_url_md5() if self.downloadable(): # download and extract file for i in range(resource_list.__len__()): file_name, url, md5 = resource_list[i] print(f'Download [{file_name}] from [{url}] to [{download_root}]') utils.download_url(url=url, root=download_root, filename=file_name, md5=md5) else: raise NotImplementedError(f'This dataset can not be downloaded by SpikingJelly, ' f'please download files manually and put files at [{download_root}]. ' f'The resources file_name, url, and md5 are: \n{resource_list}') os.mkdir(extract_root) print(f'Mkdir [{extract_root}].') self.extract_downloaded_files(download_root, extract_root) else: print(f'The directory [{extract_root}] for saving extracted files already exists.\n' f'SpikingJelly will not check the data integrity of extracted files.\n' f'If extracted files are not integrated, please delete [{extract_root}] manually, ' f'then SpikingJelly will re-extract files from [{download_root}].') # shutil.rmtree(extract_root) # print(f'Delete [{extract_root}].') if self.data_type == 'event': if self.split == 'train': self.h5_file = h5py.File(os.path.join(extract_root, 'ssc_train.h5')) elif self.split == 'valid': self.h5_file = h5py.File(os.path.join(extract_root, 'ssc_valid.h5')) else: self.h5_file = h5py.File(os.path.join(extract_root, 'ssc_test.h5')) self.length = self.h5_file['labels'].__len__() return elif self.data_type == 'frame': if frames_number is not None: assert frames_number > 0 and isinstance(frames_number, int) assert split_by == 'time' or split_by == 'number' frames_np_root = os.path.join(root, f'frames_number_{frames_number}_split_by_{split_by}') if os.path.exists(frames_np_root): print(f'The directory [{frames_np_root}] already exists.') else: os.mkdir(frames_np_root) print(f'Mkdir [{frames_np_root}].') frames_np_train_root = os.path.join(frames_np_root, 'train') os.mkdir(frames_np_train_root) print(f'Mkdir [{frames_np_train_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_train_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].') frames_np_valid_root = os.path.join(frames_np_root, 'valid') os.mkdir(frames_np_valid_root) print(f'Mkdir [{frames_np_valid_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_valid_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_valid_root, str(i))}].') frames_np_test_root = os.path.join(frames_np_root, 'test') os.mkdir(frames_np_test_root) print(f'Mkdir [{frames_np_test_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_test_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].') # use multi-thread to accelerate t_ckp = time.time() with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe: sub_threads = [] print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].') h5_file = h5py.File(os.path.join(extract_root, 'ssc_train.h5')) for i in range(h5_file['labels'].__len__()): print(f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].') sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number_shd, h5_file, i, frames_np_train_root, self.split_by, frames_number, self.get_W(), True)) h5_file = h5py.File(os.path.join(extract_root, 'ssc_valid.h5')) for i in range(h5_file['labels'].__len__()): print(f'Start to integrate [{i}]-th valid sample to frames and save to [{frames_np_valid_root}].') sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number_shd, h5_file, i, frames_np_test_root, self.split_by, frames_number, self.get_W(), True)) h5_file = h5py.File(os.path.join(extract_root, 'ssc_test.h5')) for i in range(h5_file['labels'].__len__()): print(f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].') sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number_shd, h5_file, i, frames_np_test_root, self.split_by, frames_number, self.get_W(), True)) for sub_thread in sub_threads: if sub_thread.exception(): print(sub_thread.exception()) exit(-1) print(f'Used time = [{round(time.time() - t_ckp, 2)}s].') self.frames_np_root = frames_np_root elif duration is not None: assert duration > 0 and isinstance(duration, int) frames_np_root = os.path.join(root, f'duration_{duration}') if os.path.exists(frames_np_root): print(f'The directory [{frames_np_root}] already exists.') else: os.mkdir(frames_np_root) print(f'Mkdir [{frames_np_root}].') frames_np_train_root = os.path.join(frames_np_root, 'train') os.mkdir(frames_np_train_root) print(f'Mkdir [{frames_np_train_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_train_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].') frames_np_valid_root = os.path.join(frames_np_root, 'valid') os.mkdir(frames_np_valid_root) print(f'Mkdir [{frames_np_valid_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_valid_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_valid_root, str(i))}].') frames_np_test_root = os.path.join(frames_np_root, 'test') os.mkdir(frames_np_test_root) print(f'Mkdir [{frames_np_test_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_test_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].') # use multi-thread to accelerate t_ckp = time.time() with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe: sub_threads = [] print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].') h5_file = h5py.File(os.path.join(extract_root, 'ssc_train.h5')) for i in range(h5_file['labels'].__len__()): print( f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].') sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration_shd, h5_file, i, frames_np_train_root, self.duration, self.get_W(), True)) h5_file = h5py.File(os.path.join(extract_root, 'ssc_valid.h5')) for i in range(h5_file['labels'].__len__()): print( f'Start to integrate [{i}]-th valid sample to frames and save to [{frames_np_valid_root}].') sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration_shd, h5_file, i, frames_np_valid_root, self.duration, self.get_W(), True)) h5_file = h5py.File(os.path.join(extract_root, 'ssc_test.h5')) for i in range(h5_file['labels'].__len__()): print( f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].') sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration_shd, h5_file, i, frames_np_test_root, self.duration, self.get_W(), True)) for sub_thread in sub_threads: if sub_thread.exception(): print(sub_thread.exception()) exit(-1) print(f'Used time = [{round(time.time() - t_ckp, 2)}s].') self.frames_np_root = frames_np_root elif custom_integrate_function is not None: if custom_integrated_frames_dir_name is None: custom_integrated_frames_dir_name = custom_integrate_function.__name__ frames_np_root = os.path.join(root, custom_integrated_frames_dir_name) if os.path.exists(frames_np_root): print(f'The directory [{frames_np_root}] already exists.') else: os.mkdir(frames_np_root) print(f'Mkdir [{frames_np_root}].') frames_np_train_root = os.path.join(frames_np_root, 'train') os.mkdir(frames_np_train_root) print(f'Mkdir [{frames_np_train_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_train_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].') frames_np_valid_root = os.path.join(frames_np_root, 'valid') os.mkdir(frames_np_valid_root) print(f'Mkdir [{frames_np_valid_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_valid_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_valid_root, str(i))}].') frames_np_test_root = os.path.join(frames_np_root, 'test') os.mkdir(frames_np_test_root) print(f'Mkdir [{frames_np_test_root}].') for i in range(self.classes_number()): os.mkdir(os.path.join(frames_np_test_root, str(i))) print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].') # use multi-thread to accelerate t_ckp = time.time() with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe: print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].') sub_threads = [] h5_file = h5py.File(os.path.join(extract_root, 'ssc_train.h5')) for i in range(h5_file['labels'].__len__()): print( f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].') sub_threads.append(tpe.submit(custom_integrate_function, h5_file, i, frames_np_train_root, self.get_W())) h5_file = h5py.File(os.path.join(extract_root, 'ssc_valid.h5')) for i in range(h5_file['labels'].__len__()): print( f'Start to integrate [{i}]-th valid sample to frames and save to [{frames_np_valid_root}].') sub_threads.append(tpe.submit(custom_integrate_function, h5_file, i, frames_np_valid_root, self.get_W())) h5_file = h5py.File(os.path.join(extract_root, 'ssc_test.h5')) for i in range(h5_file['labels'].__len__()): print( f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].') sub_threads.append(tpe.submit(custom_integrate_function, h5_file, i, frames_np_test_root, self.get_W())) for sub_thread in sub_threads: if sub_thread.exception(): print(sub_thread.exception()) exit(-1) print(f'Used time = [{round(time.time() - t_ckp, 2)}s].') self.frames_np_root = frames_np_root else: raise ValueError('At least one of "frames_number", "duration" and "custom_integrate_function" should not be None.') self.frames_path = [] self.frames_label = [] if self.split == 'train': sub_dir = 'train' elif self.split == 'valid': sub_dir = 'valid' else: sub_dir = 'test' for i in range(self.classes_number()): for fname in os.listdir(os.path.join(self.frames_np_root, sub_dir, str(i))): self.frames_path.append( os.path.join(self.frames_np_root, sub_dir, str(i), fname) ) self.frames_label.append(i) self.length = self.frames_label.__len__() else: raise NotImplementedError(self.data_type)
[文档] def classes_number(self): return 35
def __len__(self): return self.length def __getitem__(self, i: int): if self.data_type == 'event': events = {'t': self.h5_file['spikes']['times'][i], 'x': self.h5_file['spikes']['units'][i]} label = self.h5_file['labels'][i] if self.transform is not None: events = self.transform(events) if self.target_transform is not None: label = self.target_transform(label) return events, label elif self.data_type == 'frame': frames = np.load(self.frames_path[i], allow_pickle=True)['frames'].astype(np.float32) label = self.frames_label[i] if self.transform is not None: frames = self.transform(frames) if self.target_transform is not None: label = self.target_transform(label) return frames, label
[文档] @staticmethod def resource_url_md5() -> list: ''' :return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5 :rtype: list ''' return [ ('ssc_train.h5.zip', 'https://zenkelab.org/datasets/ssc_train.h5.zip', 'd102be95e7144fcc0553d1f45ba94170'), ('ssc_valid.h5.zip', 'https://zenkelab.org/datasets/ssc_valid.h5.zip', 'b4eee3516a4a90dd0c71a6ac23a8ae43'), ('ssc_test.h5.zip', 'https://zenkelab.org/datasets/ssc_test.h5.zip', 'a35ff1e9cffdd02a20eb850c17c37748'), ]
[文档] @staticmethod def downloadable() -> bool: ''' :return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually :rtype: bool ''' return True
[文档] @staticmethod def extract_downloaded_files(download_root: str, extract_root: str): ''' :param download_root: Root directory path which saves downloaded dataset files :type download_root: str :param extract_root: Root directory path which saves extracted files from downloaded files :type extract_root: str :return: None This function defines how to extract download files. ''' with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 2)) as tpe: sub_threads = [] for zip_file in os.listdir(download_root): zip_file = os.path.join(download_root, zip_file) print(f'Extract [{zip_file}] to [{extract_root}].') sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root)) for sub_thread in sub_threads: if sub_thread.exception(): print(sub_thread.exception()) exit(-1)
[文档] @staticmethod def get_W(): return 700